cite-agent 1.3.9__py3-none-any.whl → 1.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cite_agent/__init__.py +13 -13
  2. cite_agent/__version__.py +1 -1
  3. cite_agent/action_first_mode.py +150 -0
  4. cite_agent/adaptive_providers.py +413 -0
  5. cite_agent/archive_api_client.py +186 -0
  6. cite_agent/auth.py +0 -1
  7. cite_agent/auto_expander.py +70 -0
  8. cite_agent/cache.py +379 -0
  9. cite_agent/circuit_breaker.py +370 -0
  10. cite_agent/citation_network.py +377 -0
  11. cite_agent/cli.py +8 -16
  12. cite_agent/cli_conversational.py +113 -3
  13. cite_agent/confidence_calibration.py +381 -0
  14. cite_agent/deduplication.py +325 -0
  15. cite_agent/enhanced_ai_agent.py +689 -371
  16. cite_agent/error_handler.py +228 -0
  17. cite_agent/execution_safety.py +329 -0
  18. cite_agent/full_paper_reader.py +239 -0
  19. cite_agent/observability.py +398 -0
  20. cite_agent/offline_mode.py +348 -0
  21. cite_agent/paper_comparator.py +368 -0
  22. cite_agent/paper_summarizer.py +420 -0
  23. cite_agent/pdf_extractor.py +350 -0
  24. cite_agent/proactive_boundaries.py +266 -0
  25. cite_agent/quality_gate.py +442 -0
  26. cite_agent/request_queue.py +390 -0
  27. cite_agent/response_enhancer.py +257 -0
  28. cite_agent/response_formatter.py +458 -0
  29. cite_agent/response_pipeline.py +295 -0
  30. cite_agent/response_style_enhancer.py +259 -0
  31. cite_agent/self_healing.py +418 -0
  32. cite_agent/similarity_finder.py +524 -0
  33. cite_agent/streaming_ui.py +13 -9
  34. cite_agent/thinking_blocks.py +308 -0
  35. cite_agent/tool_orchestrator.py +416 -0
  36. cite_agent/trend_analyzer.py +540 -0
  37. cite_agent/unpaywall_client.py +226 -0
  38. {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/METADATA +15 -1
  39. cite_agent-1.4.3.dist-info/RECORD +62 -0
  40. cite_agent-1.3.9.dist-info/RECORD +0 -32
  41. {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/WHEEL +0 -0
  42. {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/entry_points.txt +0 -0
  43. {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/licenses/LICENSE +0 -0
  44. {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,418 @@
1
+ """
2
+ Self-Healing Agent Mechanisms
3
+ Auto-recovery from common failures, learns what works
4
+ """
5
+
6
+ import asyncio
7
+ import time
8
+ from dataclasses import dataclass, field
9
+ from enum import Enum
10
+ from typing import Dict, List, Optional, Callable, Any
11
+ from datetime import datetime, timedelta
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class FailureType(Enum):
18
+ """Types of failures that can be detected"""
19
+ PROVIDER_SLOW = "provider_slow" # Provider responding slowly
20
+ PROVIDER_DOWN = "provider_down" # Provider completely down
21
+ RATE_LIMIT = "rate_limit" # Rate limit hit
22
+ TIMEOUT = "timeout" # Request timeout
23
+ DEGRADED_QUALITY = "degraded_quality" # Responses getting worse
24
+ MEMORY_LEAK = "memory_leak" # Memory usage climbing
25
+ CIRCUIT_OPEN = "circuit_open" # Circuit breaker open
26
+
27
+
28
+ class RecoveryAction(Enum):
29
+ """Recovery actions the agent can take"""
30
+ SWITCH_PROVIDER = "switch_provider" # Try different provider
31
+ DEGRADE_MODE = "degrade_mode" # Reduce features
32
+ RETRY_EXPONENTIAL = "retry_exponential" # Retry with backoff
33
+ CLEAR_CACHE = "clear_cache" # Clear caches
34
+ RESTART_SESSION = "restart_session" # Restart connection
35
+ FALLBACK_LOCAL = "fallback_local" # Use local data only
36
+ ALERT_USER = "alert_user" # Tell user about issue
37
+
38
+
39
+ @dataclass
40
+ class FailureEvent:
41
+ """A detected failure"""
42
+ failure_type: FailureType
43
+ severity: float # 0.0 to 1.0
44
+ context: Dict[str, Any] = field(default_factory=dict)
45
+ timestamp: datetime = field(default_factory=datetime.now)
46
+
47
+ def __lt__(self, other):
48
+ """For sorting by severity"""
49
+ return self.severity < other.severity
50
+
51
+
52
+ @dataclass
53
+ class RecoveryHistory:
54
+ """History of recovery actions taken"""
55
+ failure_type: FailureType
56
+ action_taken: RecoveryAction
57
+ success: bool
58
+ timestamp: datetime = field(default_factory=datetime.now)
59
+ details: str = ""
60
+
61
+
62
+ class SelfHealingAgent:
63
+ """
64
+ Automatically detects failures and recovers gracefully
65
+
66
+ Features:
67
+ - Detects: slow providers, rate limits, degradation, circuit breaks
68
+ - Responds: switches providers, degrades gracefully, clears caches
69
+ - Learns: what recovery works for what failure
70
+ - Remembers: past failures and what fixed them
71
+ """
72
+
73
+ def __init__(self):
74
+ # Failure detection thresholds
75
+ self.slow_threshold_ms = 5000 # >5s = slow
76
+ self.degradation_threshold = 0.2 # >20% worse = degraded
77
+ self.memory_threshold_mb = 500 # >500MB = leak
78
+
79
+ # Failure tracking
80
+ self.recent_failures: Dict[FailureType, List[FailureEvent]] = {
81
+ ft: [] for ft in FailureType
82
+ }
83
+
84
+ # Recovery history (learn what works)
85
+ self.recovery_history: List[RecoveryHistory] = []
86
+
87
+ # Degradation state
88
+ self.is_degraded = False
89
+ self.degradation_reason: Optional[str] = None
90
+ self.degradation_started: Optional[datetime] = None
91
+
92
+ # Recovery callbacks
93
+ self.recovery_callbacks: Dict[RecoveryAction, Callable] = {}
94
+
95
+ def detect_slow_provider(
96
+ self,
97
+ provider: str,
98
+ latency_ms: float,
99
+ recent_latencies: List[float]
100
+ ) -> bool:
101
+ """Detect if provider is getting slow"""
102
+ if latency_ms > self.slow_threshold_ms:
103
+ logger.warning(f"🐢 Provider '{provider}' is slow: {latency_ms}ms")
104
+
105
+ failure = FailureEvent(
106
+ failure_type=FailureType.PROVIDER_SLOW,
107
+ severity=min(1.0, latency_ms / 10000), # Normalize
108
+ context={"provider": provider, "latency_ms": latency_ms}
109
+ )
110
+ self.recent_failures[FailureType.PROVIDER_SLOW].append(failure)
111
+ return True
112
+
113
+ return False
114
+
115
+ def detect_rate_limiting(
116
+ self,
117
+ provider: str,
118
+ error_message: str
119
+ ) -> bool:
120
+ """Detect rate limit errors"""
121
+ rate_limit_indicators = [
122
+ "rate limit",
123
+ "too many requests",
124
+ "429",
125
+ "quota exceeded",
126
+ "requests per minute"
127
+ ]
128
+
129
+ if any(indicator in error_message.lower() for indicator in rate_limit_indicators):
130
+ logger.warning(f"⚠️ Rate limit hit on provider '{provider}'")
131
+
132
+ failure = FailureEvent(
133
+ failure_type=FailureType.RATE_LIMIT,
134
+ severity=0.7,
135
+ context={"provider": provider, "error": error_message}
136
+ )
137
+ self.recent_failures[FailureType.RATE_LIMIT].append(failure)
138
+ return True
139
+
140
+ return False
141
+
142
+ def detect_degradation(
143
+ self,
144
+ metric_name: str,
145
+ current_value: float,
146
+ historical_baseline: float
147
+ ) -> bool:
148
+ """Detect service degradation"""
149
+ if historical_baseline == 0:
150
+ return False
151
+
152
+ degradation_rate = abs(current_value - historical_baseline) / historical_baseline
153
+
154
+ if degradation_rate > self.degradation_threshold:
155
+ logger.warning(
156
+ f"📉 Degradation detected in '{metric_name}': "
157
+ f"{historical_baseline} → {current_value} ({degradation_rate:.1%})"
158
+ )
159
+
160
+ failure = FailureEvent(
161
+ failure_type=FailureType.DEGRADED_QUALITY,
162
+ severity=min(1.0, degradation_rate),
163
+ context={
164
+ "metric": metric_name,
165
+ "baseline": historical_baseline,
166
+ "current": current_value
167
+ }
168
+ )
169
+ self.recent_failures[FailureType.DEGRADED_QUALITY].append(failure)
170
+ return True
171
+
172
+ return False
173
+
174
+ def detect_memory_leak(self, current_memory_mb: float) -> bool:
175
+ """Detect memory leaks"""
176
+ if current_memory_mb > self.memory_threshold_mb:
177
+ logger.warning(f"💾 High memory usage: {current_memory_mb}MB")
178
+
179
+ failure = FailureEvent(
180
+ failure_type=FailureType.MEMORY_LEAK,
181
+ severity=min(1.0, current_memory_mb / 1000),
182
+ context={"memory_mb": current_memory_mb}
183
+ )
184
+ self.recent_failures[FailureType.MEMORY_LEAK].append(failure)
185
+ return True
186
+
187
+ return False
188
+
189
+ async def perform_recovery(
190
+ self,
191
+ failure: FailureEvent,
192
+ available_providers: List[str],
193
+ current_provider: str
194
+ ) -> tuple[bool, Optional[str]]:
195
+ """
196
+ Perform recovery for a detected failure
197
+
198
+ Returns:
199
+ (success, recovery_action_taken)
200
+ """
201
+ # Check recovery history for this failure type
202
+ previous_solutions = self._get_previous_solutions(failure.failure_type)
203
+
204
+ if failure.failure_type == FailureType.PROVIDER_SLOW:
205
+ # Try switching to faster provider
206
+ better_provider = next(
207
+ (p for p in available_providers if p != current_provider),
208
+ None
209
+ )
210
+ if better_provider:
211
+ logger.info(f"🔄 Switching provider: {current_provider} → {better_provider}")
212
+ success = await self._execute_recovery(
213
+ RecoveryAction.SWITCH_PROVIDER,
214
+ {"new_provider": better_provider}
215
+ )
216
+ self._record_recovery(failure.failure_type, RecoveryAction.SWITCH_PROVIDER, success)
217
+ return success, better_provider
218
+
219
+ elif failure.failure_type == FailureType.RATE_LIMIT:
220
+ # Wait and retry
221
+ logger.info("⏳ Rate limited - exponential backoff")
222
+ await asyncio.sleep(5)
223
+ success = await self._execute_recovery(
224
+ RecoveryAction.RETRY_EXPONENTIAL,
225
+ {"wait_time": 5}
226
+ )
227
+ self._record_recovery(failure.failure_type, RecoveryAction.RETRY_EXPONENTIAL, success)
228
+ return success, None
229
+
230
+ elif failure.failure_type == FailureType.DEGRADED_QUALITY:
231
+ # Clear cache and retry
232
+ logger.info("🧹 Clearing cache to recover quality")
233
+ success = await self._execute_recovery(
234
+ RecoveryAction.CLEAR_CACHE,
235
+ {}
236
+ )
237
+ self._record_recovery(failure.failure_type, RecoveryAction.CLEAR_CACHE, success)
238
+ return success, None
239
+
240
+ elif failure.failure_type == FailureType.MEMORY_LEAK:
241
+ # Enter degradation mode
242
+ logger.warning("📉 Entering degraded mode to manage memory")
243
+ self._enter_degraded_mode("High memory usage")
244
+ success = await self._execute_recovery(
245
+ RecoveryAction.DEGRADE_MODE,
246
+ {"reason": "memory"}
247
+ )
248
+ return success, None
249
+
250
+ elif failure.failure_type == FailureType.CIRCUIT_OPEN:
251
+ # Wait for circuit to recover
252
+ logger.info("🔌 Circuit open - using fallback mode")
253
+ success = await self._execute_recovery(
254
+ RecoveryAction.FALLBACK_LOCAL,
255
+ {}
256
+ )
257
+ return success, None
258
+
259
+ # No recovery action found
260
+ return False, None
261
+
262
+ def _enter_degraded_mode(self, reason: str):
263
+ """Enter degraded mode with reduced features"""
264
+ self.is_degraded = True
265
+ self.degradation_reason = reason
266
+ self.degradation_started = datetime.now()
267
+ logger.warning(f"⚠️ DEGRADED MODE: {reason}")
268
+
269
+ def _exit_degraded_mode(self):
270
+ """Exit degraded mode"""
271
+ if self.degradation_started:
272
+ duration = (datetime.now() - self.degradation_started).total_seconds()
273
+ logger.info(f"🟢 Exiting degraded mode (lasted {duration:.0f}s)")
274
+
275
+ self.is_degraded = False
276
+ self.degradation_reason = None
277
+ self.degradation_started = None
278
+
279
+ async def _execute_recovery(
280
+ self,
281
+ action: RecoveryAction,
282
+ params: Dict[str, Any]
283
+ ) -> bool:
284
+ """Execute a recovery action"""
285
+ if action in self.recovery_callbacks:
286
+ try:
287
+ result = await self.recovery_callbacks[action](**params)
288
+ return result
289
+ except Exception as e:
290
+ logger.error(f"❌ Recovery action failed: {e}")
291
+ return False
292
+
293
+ # Default implementations
294
+ if action == RecoveryAction.SWITCH_PROVIDER:
295
+ return True # Caller handles
296
+ elif action == RecoveryAction.RETRY_EXPONENTIAL:
297
+ return True # Already waited
298
+ elif action == RecoveryAction.CLEAR_CACHE:
299
+ return True # No-op if no cache
300
+ elif action == RecoveryAction.DEGRADE_MODE:
301
+ return True # Mode entered
302
+ elif action == RecoveryAction.FALLBACK_LOCAL:
303
+ return True # Caller handles
304
+
305
+ return False
306
+
307
+ def _get_previous_solutions(self, failure_type: FailureType) -> List[RecoveryAction]:
308
+ """Get previous recovery actions that worked for this failure type"""
309
+ successes = [
310
+ entry.action_taken
311
+ for entry in self.recovery_history
312
+ if entry.failure_type == failure_type and entry.success
313
+ ]
314
+
315
+ # Return most recent successes first
316
+ from collections import Counter
317
+ counts = Counter(successes)
318
+ return [action for action, _ in counts.most_common()]
319
+
320
+ def _record_recovery(
321
+ self,
322
+ failure_type: FailureType,
323
+ action: RecoveryAction,
324
+ success: bool
325
+ ):
326
+ """Record recovery action for learning"""
327
+ entry = RecoveryHistory(
328
+ failure_type=failure_type,
329
+ action_taken=action,
330
+ success=success,
331
+ details=""
332
+ )
333
+ self.recovery_history.append(entry)
334
+
335
+ if success:
336
+ logger.info(f"✅ Recovery succeeded: {action.value}")
337
+ else:
338
+ logger.warning(f"❌ Recovery failed: {action.value}")
339
+
340
+ def _cleanup_old_failures(self):
341
+ """Remove old failure records (keep last hour)"""
342
+ cutoff = datetime.now() - timedelta(hours=1)
343
+
344
+ for failure_type in self.recent_failures:
345
+ self.recent_failures[failure_type] = [
346
+ f for f in self.recent_failures[failure_type]
347
+ if f.timestamp > cutoff
348
+ ]
349
+
350
+ def get_failure_summary(self) -> Dict[str, int]:
351
+ """Get summary of recent failures"""
352
+ self._cleanup_old_failures()
353
+
354
+ return {
355
+ failure_type.value: len(failures)
356
+ for failure_type, failures in self.recent_failures.items()
357
+ if failures
358
+ }
359
+
360
+ def get_recovery_effectiveness(self) -> Dict[str, float]:
361
+ """Get success rate of different recovery actions"""
362
+ if not self.recovery_history:
363
+ return {}
364
+
365
+ action_results: Dict[RecoveryAction, tuple[int, int]] = {} # (successes, total)
366
+
367
+ for entry in self.recovery_history:
368
+ if entry.action_taken not in action_results:
369
+ action_results[entry.action_taken] = (0, 0)
370
+
371
+ successes, total = action_results[entry.action_taken]
372
+ total += 1
373
+ if entry.success:
374
+ successes += 1
375
+ action_results[entry.action_taken] = (successes, total)
376
+
377
+ return {
378
+ action.value: successes / total
379
+ for action, (successes, total) in action_results.items()
380
+ }
381
+
382
+ def get_status_message(self) -> str:
383
+ """Human-readable status"""
384
+ lines = ["🏥 **Self-Healing Status**"]
385
+
386
+ if self.is_degraded:
387
+ lines.append(f"⚠️ DEGRADED MODE: {self.degradation_reason}")
388
+ else:
389
+ lines.append("🟢 Normal operation")
390
+
391
+ failures = self.get_failure_summary()
392
+ if failures:
393
+ lines.append("\n📊 **Recent Failures**")
394
+ for failure_type, count in sorted(failures.items(), key=lambda x: x[1], reverse=True):
395
+ lines.append(f" • {failure_type}: {count}")
396
+
397
+ effectiveness = self.get_recovery_effectiveness()
398
+ if effectiveness:
399
+ lines.append("\n✅ **Recovery Effectiveness**")
400
+ for action, rate in sorted(effectiveness.items(), key=lambda x: x[1], reverse=True):
401
+ lines.append(f" • {action}: {rate:.1%} success rate")
402
+
403
+ return "\n".join(lines)
404
+
405
+
406
+ # Global instance
407
+ self_healing_agent = SelfHealingAgent()
408
+
409
+
410
+ if __name__ == "__main__":
411
+ agent = SelfHealingAgent()
412
+
413
+ # Simulate some failures
414
+ agent.detect_slow_provider("cerebras", 6000, [100, 200, 300])
415
+ agent.detect_rate_limiting("groq", "429 Too Many Requests")
416
+ agent.detect_degradation("accuracy", 0.75, 0.95)
417
+
418
+ print(agent.get_status_message())