kite-agent 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. kite/__init__.py +46 -0
  2. kite/ab_testing.py +384 -0
  3. kite/agent.py +556 -0
  4. kite/agents/__init__.py +3 -0
  5. kite/agents/plan_execute.py +191 -0
  6. kite/agents/react_agent.py +509 -0
  7. kite/agents/reflective_agent.py +90 -0
  8. kite/agents/rewoo.py +119 -0
  9. kite/agents/tot.py +151 -0
  10. kite/conversation.py +125 -0
  11. kite/core.py +974 -0
  12. kite/data_loaders.py +111 -0
  13. kite/embedding_providers.py +372 -0
  14. kite/llm_providers.py +1278 -0
  15. kite/memory/__init__.py +6 -0
  16. kite/memory/advanced_rag.py +333 -0
  17. kite/memory/graph_rag.py +719 -0
  18. kite/memory/session_memory.py +423 -0
  19. kite/memory/vector_memory.py +579 -0
  20. kite/monitoring.py +611 -0
  21. kite/observers.py +107 -0
  22. kite/optimization/__init__.py +9 -0
  23. kite/optimization/resource_router.py +80 -0
  24. kite/persistence.py +42 -0
  25. kite/pipeline/__init__.py +5 -0
  26. kite/pipeline/deterministic_pipeline.py +323 -0
  27. kite/pipeline/reactive_pipeline.py +171 -0
  28. kite/pipeline_manager.py +15 -0
  29. kite/routing/__init__.py +6 -0
  30. kite/routing/aggregator_router.py +325 -0
  31. kite/routing/llm_router.py +149 -0
  32. kite/routing/semantic_router.py +228 -0
  33. kite/safety/__init__.py +6 -0
  34. kite/safety/circuit_breaker.py +360 -0
  35. kite/safety/guardrails.py +82 -0
  36. kite/safety/idempotency_manager.py +304 -0
  37. kite/safety/kill_switch.py +75 -0
  38. kite/tool.py +183 -0
  39. kite/tool_registry.py +87 -0
  40. kite/tools/__init__.py +21 -0
  41. kite/tools/code_execution.py +53 -0
  42. kite/tools/contrib/__init__.py +19 -0
  43. kite/tools/contrib/calculator.py +26 -0
  44. kite/tools/contrib/datetime_utils.py +20 -0
  45. kite/tools/contrib/linkedin.py +428 -0
  46. kite/tools/contrib/web_search.py +30 -0
  47. kite/tools/mcp/__init__.py +31 -0
  48. kite/tools/mcp/database_mcp.py +267 -0
  49. kite/tools/mcp/gdrive_mcp_server.py +503 -0
  50. kite/tools/mcp/gmail_mcp_server.py +601 -0
  51. kite/tools/mcp/postgres_mcp_server.py +490 -0
  52. kite/tools/mcp/slack_mcp_server.py +538 -0
  53. kite/tools/mcp/stripe_mcp_server.py +219 -0
  54. kite/tools/search.py +90 -0
  55. kite/tools/system_tools.py +54 -0
  56. kite/tools_manager.py +27 -0
  57. kite_agent-0.1.0.dist-info/METADATA +621 -0
  58. kite_agent-0.1.0.dist-info/RECORD +61 -0
  59. kite_agent-0.1.0.dist-info/WHEEL +5 -0
  60. kite_agent-0.1.0.dist-info/licenses/LICENSE +21 -0
  61. kite_agent-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,228 @@
1
+ """
2
+ Semantic Router Implementation
3
+ Based on Chapter 6: The Semantic Router
4
+
5
+ A router that classifies user intent and routes to specialists.
6
+ Uses embeddings for fast, cached classification.
7
+ """
8
+
9
+ import os
10
+ import hashlib
11
+ from typing import Dict, List, Optional, Callable
12
+ from dataclasses import dataclass
13
+ from functools import lru_cache
14
+ import numpy as np
15
+ from dotenv import load_dotenv
16
+
17
+ load_dotenv()
18
+
19
+
20
+ # ============================================================================
21
+ # ROUTING CONFIGURATION
22
+ # ============================================================================
23
+
24
+ @dataclass
25
+ class Route:
26
+ """Define a route with examples and handler."""
27
+ name: str
28
+ description: str
29
+ examples: List[str]
30
+ handler: Callable
31
+ embedding: Optional[np.ndarray] = None
32
+
33
+
34
+ # ============================================================================
35
+ # SEMANTIC ROUTER
36
+ # ============================================================================
37
+
38
+ class SemanticRouter:
39
+ """
40
+ Routes user queries to specialist agents based on semantic similarity.
41
+
42
+ Features:
43
+ - Embedding-based classification (fast, cheap)
44
+ - LRU caching for repeated queries
45
+ - Confidence thresholds
46
+ - Ambiguity handling
47
+
48
+ Example:
49
+ router = SemanticRouter(embedding_provider=my_embedder)
50
+ router.add_route("technical", tech_agent.handle, [
51
+ "my internet is down",
52
+ "can't connect to wifi",
53
+ "server not responding"
54
+ ])
55
+
56
+ response = router.route("my connection keeps dropping")
57
+ """
58
+
59
+ def __init__(self,
60
+ confidence_threshold: float = 0.75,
61
+ embedding_provider = None):
62
+ self.routes: List[Route] = []
63
+ self.confidence_threshold = confidence_threshold
64
+ self.embedding_provider = embedding_provider
65
+
66
+ # Precompute route embeddings if routes exist (none initially)
67
+ self._compute_route_embeddings()
68
+
69
+ @lru_cache(maxsize=10000)
70
+ def _get_embedding(self, text: str) -> np.ndarray:
71
+ """Get embedding for text with LRU caching."""
72
+ if self.embedding_provider:
73
+ return np.array(self.embedding_provider.embed(text))
74
+
75
+ raise RuntimeError("No embedding provider configured for SemanticRouter")
76
+
77
+ def _compute_route_embeddings(self):
78
+ """Precompute embeddings for all route examples."""
79
+ if not self.routes:
80
+ return
81
+
82
+ print("[CHART] Computing route embeddings...")
83
+ for route in self.routes:
84
+ if route.embedding is None:
85
+ # Combine all examples into one text for the route
86
+ combined_text = " | ".join(route.examples)
87
+ route.embedding = self._get_embedding(combined_text)
88
+
89
+ def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
90
+ """Calculate cosine similarity between two vectors."""
91
+ return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
92
+
93
+ def classify_intent(self, query: str) -> tuple[Route, float]:
94
+ """
95
+ Classify query intent using semantic similarity.
96
+
97
+ Args:
98
+ query: User query
99
+
100
+ Returns:
101
+ (best_route, confidence_score)
102
+ """
103
+ if not self.routes:
104
+ raise RuntimeError("No routes configured in SemanticRouter")
105
+
106
+ # Get query embedding (cached if seen before)
107
+ query_embedding = self._get_embedding(query)
108
+
109
+ # Calculate similarity to each route
110
+ scores = []
111
+ for route in self.routes:
112
+ similarity = self._cosine_similarity(query_embedding, route.embedding)
113
+ scores.append((route, similarity))
114
+
115
+ # Sort by similarity
116
+ scores.sort(key=lambda x: x[1], reverse=True)
117
+
118
+ best_route, best_score = scores[0]
119
+
120
+ print(f"\n Intent Classification:")
121
+ print(f" Query: {query}")
122
+ print(f" Best Match: {best_route.name} (confidence: {best_score:.2%})")
123
+
124
+ return best_route, best_score
125
+
126
+ async def route(self, query: str, context: Optional[Dict] = None) -> Dict:
127
+ """
128
+ Route query to appropriate specialist agent.
129
+
130
+ Args:
131
+ query: User query
132
+ context: Optional context to pass to handler
133
+
134
+ Returns:
135
+ Response dictionary with routing info
136
+ """
137
+ # Classify intent
138
+ route, confidence = self.classify_intent(query)
139
+
140
+ # Handle low confidence (ambiguity)
141
+ if confidence < self.confidence_threshold:
142
+ print(f"[WARN] Low confidence ({confidence:.2%} < {self.confidence_threshold:.2%})")
143
+
144
+ return {
145
+ "route": "none",
146
+ "confidence": confidence,
147
+ "response": "I'm not sure how to help with that. Could you be more specific?",
148
+ "needs_clarification": True,
149
+ "suggested_routes": [r.name for r, _ in self._get_top_routes(query, n=3)]
150
+ }
151
+
152
+ # Route to specialist
153
+ print(f"[OK] Routing to {route.name}")
154
+ if context:
155
+ response = route.handler(query, context)
156
+ else:
157
+ try:
158
+ response = route.handler(query)
159
+ except TypeError:
160
+ # Fallback if handler expects context but none provided
161
+ response = route.handler(query, {})
162
+
163
+ # Support both sync and async handlers
164
+ import asyncio
165
+ if asyncio.iscoroutine(response):
166
+ response = await response
167
+
168
+ # If response is a dict (from Agent.run), extract the 'response' text if available
169
+ if isinstance(response, dict) and 'response' in response:
170
+ response_text = response['response']
171
+ else:
172
+ response_text = str(response)
173
+
174
+ return {
175
+ "route": route.name,
176
+ "confidence": confidence,
177
+ "response": response_text,
178
+ "needs_clarification": False
179
+ }
180
+
181
+ def _get_top_routes(self, query: str, n: int = 3) -> List[tuple[Route, float]]:
182
+ """Get top N routes by similarity."""
183
+ query_embedding = self._get_embedding(query)
184
+
185
+ scores = [
186
+ (route, self._cosine_similarity(query_embedding, route.embedding))
187
+ for route in self.routes
188
+ ]
189
+
190
+ scores.sort(key=lambda x: x[1], reverse=True)
191
+ return scores[:n]
192
+
193
+ def add_route(self, name: str, examples: List[str] | str, description: str = "", handler: Callable = None):
194
+ """Add a new route dynamically."""
195
+ if isinstance(examples, str):
196
+ examples = [examples]
197
+
198
+ route = Route(
199
+ name=name,
200
+ description=description,
201
+ examples=examples,
202
+ handler=handler or (lambda x: f"Default response for {name}")
203
+ )
204
+
205
+ # Compute embedding if provider exists
206
+ if self.embedding_provider:
207
+ combined_text = " | ".join(examples)
208
+ route.embedding = self._get_embedding(combined_text)
209
+
210
+ self.routes.append(route)
211
+ print(f"[OK] Added route: {name}")
212
+
213
+ def get_stats(self) -> Dict:
214
+ """Get router statistics."""
215
+ cache_info = self._get_embedding.cache_info()
216
+
217
+ return {
218
+ "total_routes": len(self.routes),
219
+ "confidence_threshold": self.confidence_threshold,
220
+ "cache_size": cache_info.currsize,
221
+ "cache_hits": cache_info.hits,
222
+ "cache_misses": cache_info.misses,
223
+ "cache_hit_rate": (
224
+ cache_info.hits / (cache_info.hits + cache_info.misses)
225
+ if (cache_info.hits + cache_info.misses) > 0
226
+ else 0
227
+ )
228
+ }
@@ -0,0 +1,6 @@
1
+ """Safety patterns module."""
2
+ from .circuit_breaker import CircuitBreaker, CircuitBreakerConfig, CircuitState
3
+ from .idempotency_manager import IdempotencyManager, IdempotencyConfig
4
+ from .kill_switch import KillSwitch
5
+
6
+ __all__ = ['CircuitBreaker', 'CircuitBreakerConfig', 'CircuitState', 'IdempotencyManager', 'IdempotencyConfig', 'KillSwitch']
@@ -0,0 +1,360 @@
1
+ """
2
+ Circuit Breaker Pattern for AI Agent Operations
3
+
4
+ Prevents cascading failures when an agent repeatedly attempts failed operations.
5
+ This is critical for write-access AI systems where retries can cause real damage.
6
+
7
+ Author: [Your Name]
8
+ License: MIT
9
+ """
10
+
11
+ from dataclasses import dataclass, field
12
+ from datetime import datetime, timedelta
13
+ from enum import Enum
14
+ from typing import Dict, Optional, Callable, Any
15
+ import logging
16
+ from functools import wraps
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class CircuitState(Enum):
22
+ """States of the circuit breaker."""
23
+ CLOSED = "closed" # Normal operation
24
+ OPEN = "open" # Blocking requests
25
+ HALF_OPEN = "half_open" # Testing if service recovered
26
+
27
+
28
+ @dataclass
29
+ class CircuitBreakerConfig:
30
+ """Configuration for circuit breaker behavior."""
31
+ failure_threshold: int = 3 # Failures before opening
32
+ success_threshold: int = 2 # Successes to close from half-open
33
+ timeout_seconds: int = 60 # Time before attempting recovery
34
+ half_open_max_calls: int = 1 # Max concurrent calls in half-open
35
+
36
+
37
+ @dataclass
38
+ class CircuitStats:
39
+ """Statistics for monitoring circuit breaker health."""
40
+ total_calls: int = 0
41
+ successful_calls: int = 0
42
+ failed_calls: int = 0
43
+ rejected_calls: int = 0
44
+ last_failure_time: Optional[datetime] = None
45
+ last_success_time: Optional[datetime] = None
46
+ state_changes: int = 0
47
+
48
+
49
+ class CircuitBreaker:
50
+ """
51
+ Circuit breaker for AI agent operations.
52
+
53
+ Use this to wrap any operation that:
54
+ - Makes external API calls
55
+ - Modifies state
56
+ - Costs money
57
+ - Could fail repeatedly
58
+
59
+ Example:
60
+ circuit_breaker = CircuitBreaker(
61
+ name="stripe_refunds",
62
+ config=CircuitBreakerConfig(failure_threshold=3)
63
+ )
64
+
65
+ @circuit_breaker.protected
66
+ def process_refund(order_id: str, amount: float):
67
+ return stripe.Refund.create(...)
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ name: str,
73
+ config: CircuitBreakerConfig = None,
74
+ on_open: Optional[Callable] = None,
75
+ on_close: Optional[Callable] = None
76
+ ):
77
+ self.name = name
78
+ self.config = config or CircuitBreakerConfig()
79
+ self.state = CircuitState.CLOSED
80
+ self.stats = CircuitStats()
81
+
82
+ # Callbacks for state changes
83
+ self.on_open = on_open
84
+ self.on_close = on_close
85
+
86
+ # Failure tracking
87
+ self.consecutive_failures = 0
88
+ self.consecutive_successes = 0
89
+ self.open_until: Optional[datetime] = None
90
+ self.half_open_calls = 0
91
+
92
+ logger.info(f"Circuit breaker '{name}' initialized: {config}")
93
+
94
+ def _change_state(self, new_state: CircuitState, reason: str):
95
+ """Change circuit state and trigger callbacks."""
96
+ if new_state == self.state:
97
+ return
98
+
99
+ old_state = self.state
100
+ self.state = new_state
101
+ self.stats.state_changes += 1
102
+
103
+ logger.warning(
104
+ f"Circuit '{self.name}': {old_state.value} {new_state.value} "
105
+ f"(Reason: {reason})"
106
+ )
107
+
108
+ if new_state == CircuitState.OPEN and self.on_open:
109
+ self.on_open(self.name, self.stats)
110
+ elif new_state == CircuitState.CLOSED and self.on_close:
111
+ self.on_close(self.name, self.stats)
112
+
113
+ def _should_attempt_reset(self) -> bool:
114
+ """Check if enough time has passed to attempt recovery."""
115
+ if self.state != CircuitState.OPEN:
116
+ return False
117
+
118
+ if self.open_until and datetime.now() >= self.open_until:
119
+ return True
120
+
121
+ return False
122
+
123
+ def call(self, func: Callable, *args, **kwargs) -> Any:
124
+ """
125
+ Execute a function with circuit breaker protection.
126
+
127
+ Args:
128
+ func: Function to call
129
+ *args, **kwargs: Arguments to pass to function
130
+
131
+ Returns:
132
+ Result from function
133
+
134
+ Raises:
135
+ CircuitBreakerError: If circuit is open
136
+ Original exception: If function fails
137
+ """
138
+ self.stats.total_calls += 1
139
+
140
+ # Check if circuit should transition from OPEN to HALF_OPEN
141
+ if self._should_attempt_reset():
142
+ self._change_state(CircuitState.HALF_OPEN, "Timeout expired")
143
+ self.half_open_calls = 0
144
+
145
+ # Block calls if circuit is OPEN
146
+ if self.state == CircuitState.OPEN:
147
+ self.stats.rejected_calls += 1
148
+ raise CircuitBreakerError(
149
+ f"Circuit '{self.name}' is OPEN. "
150
+ f"Reset at {self.open_until}"
151
+ )
152
+
153
+ # Limit concurrent calls in HALF_OPEN state
154
+ if self.state == CircuitState.HALF_OPEN:
155
+ if self.half_open_calls >= self.config.half_open_max_calls:
156
+ self.stats.rejected_calls += 1
157
+ raise CircuitBreakerError(
158
+ f"Circuit '{self.name}' is HALF_OPEN and at capacity"
159
+ )
160
+ self.half_open_calls += 1
161
+
162
+ # Attempt the call
163
+ try:
164
+ result = func(*args, **kwargs)
165
+ self._on_success()
166
+ return result
167
+
168
+ except Exception as e:
169
+ self._on_failure()
170
+ raise
171
+
172
+ finally:
173
+ if self.state == CircuitState.HALF_OPEN:
174
+ self.half_open_calls -= 1
175
+
176
+ def _on_success(self):
177
+ """Handle successful function call."""
178
+ self.stats.successful_calls += 1
179
+ self.stats.last_success_time = datetime.now()
180
+ self.consecutive_failures = 0
181
+
182
+ if self.state == CircuitState.HALF_OPEN:
183
+ self.consecutive_successes += 1
184
+
185
+ if self.consecutive_successes >= self.config.success_threshold:
186
+ self._change_state(
187
+ CircuitState.CLOSED,
188
+ f"{self.consecutive_successes} consecutive successes"
189
+ )
190
+ self.consecutive_successes = 0
191
+
192
+ logger.debug(f"Circuit '{self.name}': Call succeeded")
193
+
194
+ def _on_failure(self):
195
+ """Handle failed function call."""
196
+ self.stats.failed_calls += 1
197
+ self.stats.last_failure_time = datetime.now()
198
+ self.consecutive_successes = 0
199
+ self.consecutive_failures += 1
200
+
201
+ logger.warning(
202
+ f"Circuit '{self.name}': Call failed "
203
+ f"({self.consecutive_failures}/{self.config.failure_threshold})"
204
+ )
205
+
206
+ # Open circuit if threshold exceeded
207
+ if self.consecutive_failures >= self.config.failure_threshold:
208
+ self.open_until = (
209
+ datetime.now() +
210
+ timedelta(seconds=self.config.timeout_seconds)
211
+ )
212
+ self._change_state(
213
+ CircuitState.OPEN,
214
+ f"{self.consecutive_failures} consecutive failures"
215
+ )
216
+
217
+ def protected(self, func: Callable) -> Callable:
218
+ """
219
+ Decorator to protect a function with circuit breaker.
220
+
221
+ Example:
222
+ @circuit_breaker.protected
223
+ def risky_operation():
224
+ return external_api.call()
225
+ """
226
+ @wraps(func)
227
+ def wrapper(*args, **kwargs):
228
+ return self.call(func, *args, **kwargs)
229
+ return wrapper
230
+
231
+ def reset(self):
232
+ """Manually reset circuit to closed state."""
233
+ self._change_state(CircuitState.CLOSED, "Manual reset")
234
+ self.consecutive_failures = 0
235
+ self.consecutive_successes = 0
236
+ self.open_until = None
237
+
238
+ def get_stats(self) -> Dict[str, Any]:
239
+ """Get current statistics."""
240
+ return {
241
+ "name": self.name,
242
+ "state": self.state.value,
243
+ "total_calls": self.stats.total_calls,
244
+ "successful_calls": self.stats.successful_calls,
245
+ "failed_calls": self.stats.failed_calls,
246
+ "rejected_calls": self.stats.rejected_calls,
247
+ "success_rate": (
248
+ self.stats.successful_calls / self.stats.total_calls
249
+ if self.stats.total_calls > 0 else 0
250
+ ),
251
+ "consecutive_failures": self.consecutive_failures,
252
+ "open_until": self.open_until.isoformat() if self.open_until else None,
253
+ "last_failure": (
254
+ self.stats.last_failure_time.isoformat()
255
+ if self.stats.last_failure_time else None
256
+ )
257
+ }
258
+
259
+
260
+ class CircuitBreakerError(Exception):
261
+ """Raised when circuit breaker blocks an operation."""
262
+ pass
263
+
264
+
265
+ class CircuitBreakerRegistry:
266
+ """
267
+ Global registry for managing multiple circuit breakers.
268
+
269
+ Use this when you have multiple operations that need separate circuits.
270
+ """
271
+
272
+ def __init__(self):
273
+ self._breakers: Dict[str, CircuitBreaker] = {}
274
+
275
+ def get_or_create(
276
+ self,
277
+ name: str,
278
+ config: Optional[CircuitBreakerConfig] = None
279
+ ) -> CircuitBreaker:
280
+ """Get existing circuit breaker or create new one."""
281
+ if name not in self._breakers:
282
+ self._breakers[name] = CircuitBreaker(name, config)
283
+ return self._breakers[name]
284
+
285
+ def get_all_stats(self) -> Dict[str, Dict[str, Any]]:
286
+ """Get statistics for all circuit breakers."""
287
+ return {
288
+ name: breaker.get_stats()
289
+ for name, breaker in self._breakers.items()
290
+ }
291
+
292
+ def reset_all(self):
293
+ """Reset all circuit breakers."""
294
+ for breaker in self._breakers.values():
295
+ breaker.reset()
296
+
297
+
298
+ # Global registry instance
299
+ registry = CircuitBreakerRegistry()
300
+
301
+
302
+ # Convenience function
303
+ def circuit_breaker(
304
+ name: str,
305
+ config: Optional[CircuitBreakerConfig] = None
306
+ ) -> CircuitBreaker:
307
+ """Get or create a circuit breaker from global registry."""
308
+ return registry.get_or_create(name, config)
309
+
310
+
311
+ if __name__ == "__main__":
312
+ # Example usage
313
+ import time
314
+
315
+ # Configure logging
316
+ logging.basicConfig(
317
+ level=logging.INFO,
318
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
319
+ )
320
+
321
+ # Create circuit breaker
322
+ breaker = CircuitBreaker(
323
+ name="example_api",
324
+ config=CircuitBreakerConfig(
325
+ failure_threshold=3,
326
+ timeout_seconds=5
327
+ )
328
+ )
329
+
330
+ # Simulate API calls
331
+ call_count = 0
332
+
333
+ def flaky_api_call():
334
+ """Simulates an API that fails sometimes."""
335
+ global call_count
336
+ call_count += 1
337
+
338
+ # Fail first 5 calls, then succeed
339
+ if call_count <= 5:
340
+ raise Exception("API Error")
341
+ return {"status": "success", "data": "..."}
342
+
343
+ # Test circuit breaker
344
+ print("\n=== Testing Circuit Breaker ===\n")
345
+
346
+ for i in range(10):
347
+ print(f"\nAttempt {i+1}:")
348
+ try:
349
+ result = breaker.call(flaky_api_call)
350
+ print(f"[OK] Success: {result}")
351
+ except CircuitBreakerError as e:
352
+ print(f" Circuit Open: {e}")
353
+ except Exception as e:
354
+ print(f" API Error: {e}")
355
+
356
+ time.sleep(1)
357
+
358
+ # Show final stats
359
+ print("\n=== Final Statistics ===")
360
+ print(breaker.get_stats())
@@ -0,0 +1,82 @@
1
+ """
2
+ Guardrails / Safety Patterns (Chapter 18)
3
+ Provides mechanisms to ensure agent inputs and outputs adhere to safety and structure guidelines.
4
+ """
5
+
6
+ from typing import Type, Optional, Any, Dict
7
+ from pydantic import BaseModel, ValidationError, Field
8
+ import re
9
+ import logging
10
+
11
+ logger = logging.getLogger("Guardrails")
12
+
13
+ class OutputGuardrail:
14
+ """
15
+ Enforces structured output from an LLM using Pydantic models.
16
+ Chapter 18: Output Filtering / Post-processing.
17
+ """
18
+ def __init__(self, model: Type[BaseModel], fix_on_failure: bool = True):
19
+ self.model = model
20
+ self.fix_on_failure = fix_on_failure
21
+
22
+ def validate(self, output: str) -> Optional[BaseModel]:
23
+ """
24
+ Validates and parses the LLM output. Uses regex to find the first JSON object.
25
+ """
26
+ try:
27
+ # Attempt to find a JSON-like block { ... }
28
+ # distinct from code blocks, just looking for the brace structure
29
+ json_match = re.search(r"\{.*\}", output, re.DOTALL)
30
+
31
+ if json_match:
32
+ clean_output = json_match.group(0)
33
+ else:
34
+ # Fallback to original cleanup if regex fails
35
+ clean_output = output.strip()
36
+ if clean_output.startswith("```json"):
37
+ clean_output = clean_output[7:]
38
+ if clean_output.startswith("```"):
39
+ clean_output = clean_output[3:]
40
+ if clean_output.endswith("```"):
41
+ clean_output = clean_output[:-3]
42
+
43
+ clean_output = clean_output.strip()
44
+
45
+ # Parse
46
+ parsed = self.model.model_validate_json(clean_output)
47
+ logger.info(f"Guardrail passed: {type(parsed).__name__}")
48
+ return parsed
49
+
50
+ except ValidationError as e:
51
+ logger.warning(f"Guardrail validation failed: {e}")
52
+ if self.fix_on_failure:
53
+ return None # Signal need for retry/fix
54
+ raise e
55
+ except Exception as e:
56
+ logger.error(f"Guardrail parsing error: {e}")
57
+ return None
58
+
59
+ class StandardEvaluation(BaseModel):
60
+ """Standard schema for Agent critique/review."""
61
+ score: int = Field(description="Score from 1-10")
62
+ feedback: str = Field(description="Specific feedback for improvement")
63
+ approved: bool = Field(description="Whether the output is acceptable")
64
+
65
+ class InputGuardrail:
66
+ """
67
+ Filters unsafe or irrelevant user inputs before they reach the agent.
68
+ Chapter 18: Input Validation / Sanitization.
69
+ """
70
+ def __init__(self, forbidden_terms: list = None):
71
+ self.forbidden_terms = forbidden_terms or ["ignore all instructions", "system prompt"]
72
+
73
+ def check(self, user_input: str) -> bool:
74
+ """
75
+ Returns True if safe, False if unsafe.
76
+ """
77
+ content = user_input.lower()
78
+ for term in self.forbidden_terms:
79
+ if term in content:
80
+ logger.warning(f"Guardrail blocked input containing: '{term}'")
81
+ return False
82
+ return True