cite-agent 1.3.9__py3-none-any.whl → 1.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cite_agent/__init__.py +13 -13
- cite_agent/__version__.py +1 -1
- cite_agent/action_first_mode.py +150 -0
- cite_agent/adaptive_providers.py +413 -0
- cite_agent/archive_api_client.py +186 -0
- cite_agent/auth.py +0 -1
- cite_agent/auto_expander.py +70 -0
- cite_agent/cache.py +379 -0
- cite_agent/circuit_breaker.py +370 -0
- cite_agent/citation_network.py +377 -0
- cite_agent/cli.py +8 -16
- cite_agent/cli_conversational.py +113 -3
- cite_agent/confidence_calibration.py +381 -0
- cite_agent/deduplication.py +325 -0
- cite_agent/enhanced_ai_agent.py +689 -371
- cite_agent/error_handler.py +228 -0
- cite_agent/execution_safety.py +329 -0
- cite_agent/full_paper_reader.py +239 -0
- cite_agent/observability.py +398 -0
- cite_agent/offline_mode.py +348 -0
- cite_agent/paper_comparator.py +368 -0
- cite_agent/paper_summarizer.py +420 -0
- cite_agent/pdf_extractor.py +350 -0
- cite_agent/proactive_boundaries.py +266 -0
- cite_agent/quality_gate.py +442 -0
- cite_agent/request_queue.py +390 -0
- cite_agent/response_enhancer.py +257 -0
- cite_agent/response_formatter.py +458 -0
- cite_agent/response_pipeline.py +295 -0
- cite_agent/response_style_enhancer.py +259 -0
- cite_agent/self_healing.py +418 -0
- cite_agent/similarity_finder.py +524 -0
- cite_agent/streaming_ui.py +13 -9
- cite_agent/thinking_blocks.py +308 -0
- cite_agent/tool_orchestrator.py +416 -0
- cite_agent/trend_analyzer.py +540 -0
- cite_agent/unpaywall_client.py +226 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/METADATA +15 -1
- cite_agent-1.4.3.dist-info/RECORD +62 -0
- cite_agent-1.3.9.dist-info/RECORD +0 -32
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/WHEEL +0 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/entry_points.txt +0 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/licenses/LICENSE +0 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Intelligent Request Queue with Backpressure
|
|
3
|
+
Prioritizes requests, prevents thundering herd, gracefully degrades under load
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
import time
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Dict, List, Optional, Callable
|
|
11
|
+
from datetime import datetime, timedelta
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RequestPriority(Enum):
|
|
18
|
+
"""Request priority levels"""
|
|
19
|
+
URGENT = 0 # User initiated, blocking
|
|
20
|
+
NORMAL = 1 # Standard requests
|
|
21
|
+
BATCH = 2 # Background/analysis
|
|
22
|
+
MAINTENANCE = 3 # Cleanup, archival
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class QueuedRequest:
|
|
27
|
+
"""A request waiting in queue"""
|
|
28
|
+
request_id: str
|
|
29
|
+
user_id: str
|
|
30
|
+
priority: RequestPriority
|
|
31
|
+
submitted_at: datetime
|
|
32
|
+
max_wait_time: float # seconds
|
|
33
|
+
callback: Callable
|
|
34
|
+
args: tuple = field(default_factory=tuple)
|
|
35
|
+
kwargs: dict = field(default_factory=dict)
|
|
36
|
+
|
|
37
|
+
def is_expired(self) -> bool:
|
|
38
|
+
"""Check if request exceeded max wait time"""
|
|
39
|
+
elapsed = (datetime.now() - self.submitted_at).total_seconds()
|
|
40
|
+
return elapsed > self.max_wait_time
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class CircuitStatus(Enum):
|
|
44
|
+
"""Circuit breaker status"""
|
|
45
|
+
CLOSED = "closed" # Normal operation
|
|
46
|
+
OPEN = "open" # Failing, reject fast
|
|
47
|
+
HALF_OPEN = "half_open" # Testing recovery
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class RequestQueueMetrics:
|
|
52
|
+
"""Metrics about queue health"""
|
|
53
|
+
queue_depth: int
|
|
54
|
+
total_queued: int
|
|
55
|
+
total_processed: int
|
|
56
|
+
total_expired: int
|
|
57
|
+
avg_wait_time: float
|
|
58
|
+
p95_wait_time: float
|
|
59
|
+
circuit_status: CircuitStatus
|
|
60
|
+
active_requests: int
|
|
61
|
+
max_concurrent: int
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class IntelligentRequestQueue:
|
|
65
|
+
"""
|
|
66
|
+
Priority queue with backpressure, circuit breaker integration, and metrics
|
|
67
|
+
|
|
68
|
+
Features:
|
|
69
|
+
- Priority levels (urgent > normal > batch > maintenance)
|
|
70
|
+
- Per-user concurrency limits
|
|
71
|
+
- Queue depth monitoring
|
|
72
|
+
- Automatic circuit breaker integration
|
|
73
|
+
- Request expiration (don't serve stale requests)
|
|
74
|
+
- User notifications about wait time
|
|
75
|
+
- Graceful degradation under load
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
max_concurrent_global: int = 50,
|
|
81
|
+
max_concurrent_per_user: int = 3,
|
|
82
|
+
queue_size_limit: int = 1000,
|
|
83
|
+
warning_threshold: float = 0.7, # warn when queue at 70%
|
|
84
|
+
rejection_threshold: float = 0.95 # reject when queue at 95%
|
|
85
|
+
):
|
|
86
|
+
self.max_concurrent_global = max_concurrent_global
|
|
87
|
+
self.max_concurrent_per_user = max_concurrent_per_user
|
|
88
|
+
self.queue_size_limit = queue_size_limit
|
|
89
|
+
self.warning_threshold = warning_threshold
|
|
90
|
+
self.rejection_threshold = rejection_threshold
|
|
91
|
+
|
|
92
|
+
# Queues by priority
|
|
93
|
+
self.queues: Dict[RequestPriority, asyncio.Queue] = {
|
|
94
|
+
priority: asyncio.Queue(maxsize=queue_size_limit)
|
|
95
|
+
for priority in RequestPriority
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# Active requests tracking
|
|
99
|
+
self.active_requests: Dict[str, datetime] = {} # request_id -> start_time
|
|
100
|
+
self.user_active: Dict[str, int] = {} # user_id -> count
|
|
101
|
+
|
|
102
|
+
# Metrics
|
|
103
|
+
self.total_processed = 0
|
|
104
|
+
self.total_queued = 0
|
|
105
|
+
self.total_expired = 0
|
|
106
|
+
self.wait_times: List[float] = [] # for p95 calculation
|
|
107
|
+
|
|
108
|
+
# Circuit breaker state
|
|
109
|
+
self.circuit_status = CircuitStatus.CLOSED
|
|
110
|
+
self.circuit_open_at: Optional[datetime] = None
|
|
111
|
+
self.circuit_recovery_timeout = 30 # seconds
|
|
112
|
+
|
|
113
|
+
# Background worker
|
|
114
|
+
self.worker_task: Optional[asyncio.Task] = None
|
|
115
|
+
self.is_running = False
|
|
116
|
+
|
|
117
|
+
async def start(self):
|
|
118
|
+
"""Start the queue worker"""
|
|
119
|
+
if self.is_running:
|
|
120
|
+
return
|
|
121
|
+
|
|
122
|
+
self.is_running = True
|
|
123
|
+
self.worker_task = asyncio.create_task(self._process_queue())
|
|
124
|
+
logger.info("🚀 Request queue started")
|
|
125
|
+
|
|
126
|
+
async def stop(self):
|
|
127
|
+
"""Stop the queue worker"""
|
|
128
|
+
self.is_running = False
|
|
129
|
+
if self.worker_task:
|
|
130
|
+
await self.worker_task
|
|
131
|
+
logger.info("⛔ Request queue stopped")
|
|
132
|
+
|
|
133
|
+
async def submit(
|
|
134
|
+
self,
|
|
135
|
+
user_id: str,
|
|
136
|
+
callback: Callable,
|
|
137
|
+
priority: RequestPriority = RequestPriority.NORMAL,
|
|
138
|
+
max_wait_time: float = 30.0,
|
|
139
|
+
request_id: Optional[str] = None,
|
|
140
|
+
*args,
|
|
141
|
+
**kwargs
|
|
142
|
+
) -> tuple[bool, Optional[str]]:
|
|
143
|
+
"""
|
|
144
|
+
Submit a request to the queue
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
(success, error_message)
|
|
148
|
+
"""
|
|
149
|
+
if request_id is None:
|
|
150
|
+
request_id = f"{user_id}_{time.time()}"
|
|
151
|
+
|
|
152
|
+
# Check queue capacity
|
|
153
|
+
queue_usage = self._get_queue_usage()
|
|
154
|
+
|
|
155
|
+
if queue_usage > self.rejection_threshold:
|
|
156
|
+
return False, f"System overloaded (queue at {queue_usage*100:.0f}%). Please try again in 30 seconds."
|
|
157
|
+
|
|
158
|
+
if queue_usage > self.warning_threshold:
|
|
159
|
+
warning = f"⚠️ System busy. Your request may take up to {max_wait_time:.0f}s."
|
|
160
|
+
else:
|
|
161
|
+
warning = None
|
|
162
|
+
|
|
163
|
+
# Check circuit breaker
|
|
164
|
+
if self.circuit_status == CircuitStatus.OPEN:
|
|
165
|
+
if self._should_attempt_recovery():
|
|
166
|
+
self.circuit_status = CircuitStatus.HALF_OPEN
|
|
167
|
+
logger.info("🔄 Circuit breaker: attempting recovery")
|
|
168
|
+
else:
|
|
169
|
+
return False, "System is temporarily unavailable. Retrying in 30s..."
|
|
170
|
+
|
|
171
|
+
# Create queued request
|
|
172
|
+
request = QueuedRequest(
|
|
173
|
+
request_id=request_id,
|
|
174
|
+
user_id=user_id,
|
|
175
|
+
priority=priority,
|
|
176
|
+
submitted_at=datetime.now(),
|
|
177
|
+
max_wait_time=max_wait_time,
|
|
178
|
+
callback=callback,
|
|
179
|
+
args=args,
|
|
180
|
+
kwargs=kwargs
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Add to appropriate priority queue
|
|
184
|
+
try:
|
|
185
|
+
self.queues[priority].put_nowait(request)
|
|
186
|
+
self.total_queued += 1
|
|
187
|
+
|
|
188
|
+
message = f"✓ Queued (position #{self._get_queue_depth()})"
|
|
189
|
+
if warning:
|
|
190
|
+
message += f"\n{warning}"
|
|
191
|
+
|
|
192
|
+
return True, message
|
|
193
|
+
|
|
194
|
+
except asyncio.QueueFull:
|
|
195
|
+
return False, "Queue is full. Please try again soon."
|
|
196
|
+
|
|
197
|
+
async def _process_queue(self):
|
|
198
|
+
"""Main worker: continuously process queued requests"""
|
|
199
|
+
while self.is_running:
|
|
200
|
+
try:
|
|
201
|
+
# Check if we can process more requests
|
|
202
|
+
if len(self.active_requests) >= self.max_concurrent_global:
|
|
203
|
+
await asyncio.sleep(0.1)
|
|
204
|
+
continue
|
|
205
|
+
|
|
206
|
+
# Get next request from highest priority queue
|
|
207
|
+
request = await self._get_next_request()
|
|
208
|
+
if request is None:
|
|
209
|
+
await asyncio.sleep(0.1)
|
|
210
|
+
continue
|
|
211
|
+
|
|
212
|
+
# Check expiration
|
|
213
|
+
if request.is_expired():
|
|
214
|
+
self.total_expired += 1
|
|
215
|
+
logger.warning(f"⏰ Request {request.request_id} expired (waited too long)")
|
|
216
|
+
continue
|
|
217
|
+
|
|
218
|
+
# Check user concurrency limit
|
|
219
|
+
user_count = self.user_active.get(request.user_id, 0)
|
|
220
|
+
if user_count >= self.max_concurrent_per_user:
|
|
221
|
+
# Re-queue this request with same priority
|
|
222
|
+
await self.queues[request.priority].put(request)
|
|
223
|
+
await asyncio.sleep(0.5)
|
|
224
|
+
continue
|
|
225
|
+
|
|
226
|
+
# Execute the request
|
|
227
|
+
await self._execute_request(request)
|
|
228
|
+
|
|
229
|
+
except Exception as e:
|
|
230
|
+
logger.error(f"❌ Queue worker error: {e}", exc_info=True)
|
|
231
|
+
await asyncio.sleep(1)
|
|
232
|
+
|
|
233
|
+
async def _get_next_request(self) -> Optional[QueuedRequest]:
|
|
234
|
+
"""Get highest priority non-empty request"""
|
|
235
|
+
# Try each priority level in order
|
|
236
|
+
for priority in RequestPriority:
|
|
237
|
+
try:
|
|
238
|
+
return self.queues[priority].get_nowait()
|
|
239
|
+
except asyncio.QueueEmpty:
|
|
240
|
+
continue
|
|
241
|
+
|
|
242
|
+
return None
|
|
243
|
+
|
|
244
|
+
async def _execute_request(self, request: QueuedRequest):
|
|
245
|
+
"""Execute a request and track metrics"""
|
|
246
|
+
request_id = request.request_id
|
|
247
|
+
start_time = datetime.now()
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
# Track active request
|
|
251
|
+
self.active_requests[request_id] = start_time
|
|
252
|
+
self.user_active[request.user_id] = self.user_active.get(request.user_id, 0) + 1
|
|
253
|
+
|
|
254
|
+
wait_time = (start_time - request.submitted_at).total_seconds()
|
|
255
|
+
self.wait_times.append(wait_time)
|
|
256
|
+
|
|
257
|
+
logger.debug(f"▶️ Executing {request_id} (waited {wait_time:.1f}s)")
|
|
258
|
+
|
|
259
|
+
# Call the callback
|
|
260
|
+
result = await request.callback(*request.args, **request.kwargs)
|
|
261
|
+
|
|
262
|
+
# Record success
|
|
263
|
+
self.total_processed += 1
|
|
264
|
+
|
|
265
|
+
# Update circuit breaker
|
|
266
|
+
if self.circuit_status == CircuitStatus.HALF_OPEN:
|
|
267
|
+
self.circuit_status = CircuitStatus.CLOSED
|
|
268
|
+
logger.info("🟢 Circuit breaker: recovered")
|
|
269
|
+
|
|
270
|
+
return result
|
|
271
|
+
|
|
272
|
+
except Exception as e:
|
|
273
|
+
logger.error(f"❌ Request {request_id} failed: {e}")
|
|
274
|
+
|
|
275
|
+
# Update circuit breaker on failure
|
|
276
|
+
self._on_request_failure()
|
|
277
|
+
|
|
278
|
+
raise
|
|
279
|
+
|
|
280
|
+
finally:
|
|
281
|
+
# Clean up tracking
|
|
282
|
+
self.active_requests.pop(request_id, None)
|
|
283
|
+
self.user_active[request.user_id] -= 1
|
|
284
|
+
|
|
285
|
+
elapsed = (datetime.now() - start_time).total_seconds()
|
|
286
|
+
logger.debug(f"✓ Request {request_id} completed in {elapsed:.2f}s")
|
|
287
|
+
|
|
288
|
+
def _on_request_failure(self):
|
|
289
|
+
"""Called when a request fails - updates circuit breaker"""
|
|
290
|
+
# Track failure rate
|
|
291
|
+
if len(self.active_requests) > 0:
|
|
292
|
+
failure_rate = self.total_queued / max(1, self.total_processed + 1)
|
|
293
|
+
|
|
294
|
+
# Open circuit if failure rate high
|
|
295
|
+
if failure_rate > 0.3: # >30% failure rate
|
|
296
|
+
if self.circuit_status != CircuitStatus.OPEN:
|
|
297
|
+
logger.error(f"🔴 Circuit breaker: OPEN (failure rate {failure_rate:.1%})")
|
|
298
|
+
self.circuit_status = CircuitStatus.OPEN
|
|
299
|
+
self.circuit_open_at = datetime.now()
|
|
300
|
+
|
|
301
|
+
def _should_attempt_recovery(self) -> bool:
|
|
302
|
+
"""Check if circuit breaker should attempt recovery"""
|
|
303
|
+
if not self.circuit_open_at:
|
|
304
|
+
return True
|
|
305
|
+
|
|
306
|
+
elapsed = (datetime.now() - self.circuit_open_at).total_seconds()
|
|
307
|
+
return elapsed > self.circuit_recovery_timeout
|
|
308
|
+
|
|
309
|
+
def _get_queue_usage(self) -> float:
|
|
310
|
+
"""Get current queue usage as percentage (0.0 to 1.0)"""
|
|
311
|
+
total_queued = sum(q.qsize() for q in self.queues.values())
|
|
312
|
+
return min(1.0, total_queued / self.queue_size_limit)
|
|
313
|
+
|
|
314
|
+
def _get_queue_depth(self) -> int:
|
|
315
|
+
"""Get total requests in queue"""
|
|
316
|
+
return sum(q.qsize() for q in self.queues.values())
|
|
317
|
+
|
|
318
|
+
def get_metrics(self) -> RequestQueueMetrics:
|
|
319
|
+
"""Get current queue metrics"""
|
|
320
|
+
queue_depth = self._get_queue_depth()
|
|
321
|
+
wait_times_sorted = sorted(self.wait_times[-100:]) # Last 100 requests
|
|
322
|
+
|
|
323
|
+
p95_wait = wait_times_sorted[int(len(wait_times_sorted) * 0.95)] if wait_times_sorted else 0
|
|
324
|
+
avg_wait = sum(self.wait_times) / len(self.wait_times) if self.wait_times else 0
|
|
325
|
+
|
|
326
|
+
return RequestQueueMetrics(
|
|
327
|
+
queue_depth=queue_depth,
|
|
328
|
+
total_queued=self.total_queued,
|
|
329
|
+
total_processed=self.total_processed,
|
|
330
|
+
total_expired=self.total_expired,
|
|
331
|
+
avg_wait_time=avg_wait,
|
|
332
|
+
p95_wait_time=p95_wait,
|
|
333
|
+
circuit_status=self.circuit_status,
|
|
334
|
+
active_requests=len(self.active_requests),
|
|
335
|
+
max_concurrent=self.max_concurrent_global
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
def get_status_message(self) -> str:
|
|
339
|
+
"""Human-readable queue status"""
|
|
340
|
+
metrics = self.get_metrics()
|
|
341
|
+
usage = self._get_queue_usage()
|
|
342
|
+
|
|
343
|
+
lines = [
|
|
344
|
+
"📊 **Request Queue Status**",
|
|
345
|
+
f"• Queue depth: {metrics.queue_depth}/{self.queue_size_limit} ({usage*100:.0f}%)",
|
|
346
|
+
f"• Active requests: {metrics.active_requests}/{metrics.max_concurrent}",
|
|
347
|
+
f"• Processed: {metrics.total_processed} | Queued: {metrics.total_queued} | Expired: {metrics.total_expired}",
|
|
348
|
+
f"• Avg wait: {metrics.avg_wait_time:.1f}s | P95 wait: {metrics.p95_wait_time:.1f}s",
|
|
349
|
+
f"• Circuit breaker: {metrics.circuit_status.value.upper()}",
|
|
350
|
+
]
|
|
351
|
+
|
|
352
|
+
return "\n".join(lines)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
# Example usage
|
|
356
|
+
async def example():
|
|
357
|
+
"""Example of using the queue"""
|
|
358
|
+
queue = IntelligentRequestQueue(
|
|
359
|
+
max_concurrent_global=10,
|
|
360
|
+
max_concurrent_per_user=2
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
await queue.start()
|
|
364
|
+
|
|
365
|
+
# Simulate a callback
|
|
366
|
+
async def process_query(query: str) -> str:
|
|
367
|
+
await asyncio.sleep(1) # Simulate work
|
|
368
|
+
return f"Result for: {query}"
|
|
369
|
+
|
|
370
|
+
# Submit requests
|
|
371
|
+
for i in range(5):
|
|
372
|
+
success, msg = await queue.submit(
|
|
373
|
+
user_id="user1",
|
|
374
|
+
callback=process_query,
|
|
375
|
+
priority=RequestPriority.NORMAL,
|
|
376
|
+
args=(f"query_{i}",)
|
|
377
|
+
)
|
|
378
|
+
print(f"Request {i}: {msg}")
|
|
379
|
+
|
|
380
|
+
# Wait a bit
|
|
381
|
+
await asyncio.sleep(10)
|
|
382
|
+
|
|
383
|
+
# Check status
|
|
384
|
+
print(queue.get_status_message())
|
|
385
|
+
|
|
386
|
+
await queue.stop()
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
if __name__ == "__main__":
|
|
390
|
+
asyncio.run(example())
|
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Response Enhancer - Polish responses to 0.80+ quality
|
|
3
|
+
Takes good responses and makes them great
|
|
4
|
+
|
|
5
|
+
Target: Every response should score 0.80+ on quality metrics
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import re
|
|
9
|
+
from typing import Dict, Any, List
|
|
10
|
+
import logging
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ResponseEnhancer:
|
|
16
|
+
"""
|
|
17
|
+
Enhances responses to maximize quality scores
|
|
18
|
+
|
|
19
|
+
Focus areas:
|
|
20
|
+
1. Completeness - Address all key terms from query
|
|
21
|
+
2. Structure - Add bullets, headers, emphasis
|
|
22
|
+
3. Clarity - Make more direct and specific
|
|
23
|
+
4. Scannability - Break up walls of text
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def enhance(cls, response: str, query: str, context: Dict[str, Any]) -> str:
|
|
28
|
+
"""
|
|
29
|
+
Enhance a response to maximize quality
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
response: Original response
|
|
33
|
+
query: User's query
|
|
34
|
+
context: Context including tools, data, etc.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Enhanced response
|
|
38
|
+
"""
|
|
39
|
+
if not response or len(response) < 10:
|
|
40
|
+
return response
|
|
41
|
+
|
|
42
|
+
enhanced = response
|
|
43
|
+
|
|
44
|
+
# Enhancement 1: Add structure if missing
|
|
45
|
+
enhanced = cls._add_structure(enhanced, query)
|
|
46
|
+
|
|
47
|
+
# Enhancement 2: Make more complete by addressing key terms
|
|
48
|
+
enhanced = cls._improve_completeness(enhanced, query, context)
|
|
49
|
+
|
|
50
|
+
# Enhancement 3: Improve clarity
|
|
51
|
+
enhanced = cls._improve_clarity(enhanced)
|
|
52
|
+
|
|
53
|
+
# Enhancement 4: Make more scannable
|
|
54
|
+
enhanced = cls._improve_scannability(enhanced)
|
|
55
|
+
|
|
56
|
+
# Enhancement 5: Add specificity
|
|
57
|
+
enhanced = cls._add_specificity(enhanced, context)
|
|
58
|
+
|
|
59
|
+
return enhanced
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def _add_structure(cls, response: str, query: str) -> str:
|
|
63
|
+
"""Add structure if response is unstructured"""
|
|
64
|
+
# Check if response lacks structure
|
|
65
|
+
has_bullets = '•' in response or '- ' in response
|
|
66
|
+
has_emphasis = '**' in response
|
|
67
|
+
has_paragraphs = '\n\n' in response
|
|
68
|
+
|
|
69
|
+
if has_bullets and has_emphasis:
|
|
70
|
+
return response # Already well-structured
|
|
71
|
+
|
|
72
|
+
lines = response.split('\n')
|
|
73
|
+
|
|
74
|
+
# If it's a short response (< 100 words), structure is less important
|
|
75
|
+
if len(response.split()) < 100:
|
|
76
|
+
return response
|
|
77
|
+
|
|
78
|
+
# If it's listing things but not using bullets, add them
|
|
79
|
+
if len(lines) > 1 and not has_bullets:
|
|
80
|
+
# Check if lines look like a list
|
|
81
|
+
list_indicators = ['1.', '2.', 'first', 'second', 'also', 'additionally']
|
|
82
|
+
looks_like_list = sum(1 for line in lines if any(ind in line.lower() for ind in list_indicators))
|
|
83
|
+
|
|
84
|
+
if looks_like_list >= 2:
|
|
85
|
+
# Convert to bulleted list
|
|
86
|
+
enhanced_lines = []
|
|
87
|
+
for line in lines:
|
|
88
|
+
stripped = line.strip()
|
|
89
|
+
if stripped and not stripped.endswith(':'):
|
|
90
|
+
if not stripped.startswith('•'):
|
|
91
|
+
enhanced_lines.append(f"• {stripped}")
|
|
92
|
+
else:
|
|
93
|
+
enhanced_lines.append(stripped)
|
|
94
|
+
else:
|
|
95
|
+
enhanced_lines.append(stripped)
|
|
96
|
+
|
|
97
|
+
return '\n'.join(enhanced_lines)
|
|
98
|
+
|
|
99
|
+
return response
|
|
100
|
+
|
|
101
|
+
@classmethod
|
|
102
|
+
def _improve_completeness(cls, response: str, query: str, context: Dict[str, Any]) -> str:
|
|
103
|
+
"""Make response more complete by addressing key query terms"""
|
|
104
|
+
# Extract key terms from query
|
|
105
|
+
stop_words = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'what', 'how', 'why', 'when', 'where', 'who', 'which', 'do', 'does', 'did', 'can', 'could', 'would', 'should', 'me', 'my', 'you', 'your', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'from'}
|
|
106
|
+
|
|
107
|
+
query_terms = [
|
|
108
|
+
word.lower().strip('?.,!:;')
|
|
109
|
+
for word in query.split()
|
|
110
|
+
if len(word) > 3 and word.lower() not in stop_words
|
|
111
|
+
]
|
|
112
|
+
|
|
113
|
+
if not query_terms:
|
|
114
|
+
return response
|
|
115
|
+
|
|
116
|
+
response_lower = response.lower()
|
|
117
|
+
|
|
118
|
+
# Find terms that aren't addressed
|
|
119
|
+
missing_terms = [term for term in query_terms if term not in response_lower]
|
|
120
|
+
|
|
121
|
+
# If we're missing major terms, try to add context
|
|
122
|
+
if len(missing_terms) > len(query_terms) * 0.5: # Missing > 50% of key terms
|
|
123
|
+
# Check if we have context that addresses these terms
|
|
124
|
+
if context.get('api_results') or context.get('tools_used'):
|
|
125
|
+
# Add note about what was checked
|
|
126
|
+
tools_used = context.get('tools_used', [])
|
|
127
|
+
|
|
128
|
+
if 'shell_execution' in tools_used:
|
|
129
|
+
# File/directory query
|
|
130
|
+
if any(term in query.lower() for term in ['file', 'directory', 'folder']):
|
|
131
|
+
if 'file' in missing_terms or 'directory' in missing_terms:
|
|
132
|
+
# Make it clear we checked files/directories
|
|
133
|
+
response = response.replace(
|
|
134
|
+
"We're in",
|
|
135
|
+
"I checked the current directory. We're in"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return response
|
|
139
|
+
|
|
140
|
+
@classmethod
|
|
141
|
+
def _improve_clarity(cls, response: str) -> str:
|
|
142
|
+
"""Make response more clear and direct"""
|
|
143
|
+
# Remove excessive hedging
|
|
144
|
+
hedge_phrases = {
|
|
145
|
+
'i think maybe': 'probably',
|
|
146
|
+
'i believe that possibly': 'likely',
|
|
147
|
+
'it seems like perhaps': 'it appears',
|
|
148
|
+
'i might suggest': 'i suggest',
|
|
149
|
+
'it could potentially be': 'it may be',
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
enhanced = response
|
|
153
|
+
for wordy, concise in hedge_phrases.items():
|
|
154
|
+
enhanced = enhanced.replace(wordy, concise)
|
|
155
|
+
|
|
156
|
+
# Remove filler phrases at start
|
|
157
|
+
filler_starters = [
|
|
158
|
+
'Well, ',
|
|
159
|
+
'So, ',
|
|
160
|
+
'Basically, ',
|
|
161
|
+
'Actually, ',
|
|
162
|
+
'You know, ',
|
|
163
|
+
]
|
|
164
|
+
|
|
165
|
+
for filler in filler_starters:
|
|
166
|
+
if enhanced.startswith(filler):
|
|
167
|
+
enhanced = enhanced[len(filler):]
|
|
168
|
+
# Capitalize first letter
|
|
169
|
+
if enhanced:
|
|
170
|
+
enhanced = enhanced[0].upper() + enhanced[1:]
|
|
171
|
+
|
|
172
|
+
return enhanced
|
|
173
|
+
|
|
174
|
+
@classmethod
|
|
175
|
+
def _improve_scannability(cls, response: str) -> str:
|
|
176
|
+
"""Make response more scannable"""
|
|
177
|
+
# Break up very long paragraphs
|
|
178
|
+
if '\n\n' not in response and len(response) > 300:
|
|
179
|
+
# Split into sentences
|
|
180
|
+
sentences = re.split(r'(?<=[.!?])\s+', response)
|
|
181
|
+
|
|
182
|
+
if len(sentences) >= 4:
|
|
183
|
+
# Group into paragraphs of 2-3 sentences
|
|
184
|
+
paragraphs = []
|
|
185
|
+
current = []
|
|
186
|
+
|
|
187
|
+
for sent in sentences:
|
|
188
|
+
current.append(sent)
|
|
189
|
+
if len(current) >= 2:
|
|
190
|
+
paragraphs.append(' '.join(current))
|
|
191
|
+
current = []
|
|
192
|
+
|
|
193
|
+
if current:
|
|
194
|
+
paragraphs.append(' '.join(current))
|
|
195
|
+
|
|
196
|
+
if len(paragraphs) > 1:
|
|
197
|
+
return '\n\n'.join(paragraphs)
|
|
198
|
+
|
|
199
|
+
# Check line length - break up super long lines
|
|
200
|
+
lines = response.split('\n')
|
|
201
|
+
enhanced_lines = []
|
|
202
|
+
|
|
203
|
+
for line in lines:
|
|
204
|
+
if len(line) > 200 and ',' in line:
|
|
205
|
+
# Split on commas for readability
|
|
206
|
+
parts = line.split(', ')
|
|
207
|
+
if len(parts) >= 3:
|
|
208
|
+
# Make it a bulleted list
|
|
209
|
+
enhanced_lines.append(parts[0] + ':')
|
|
210
|
+
for part in parts[1:]:
|
|
211
|
+
enhanced_lines.append(f" • {part.strip()}")
|
|
212
|
+
else:
|
|
213
|
+
enhanced_lines.append(line)
|
|
214
|
+
else:
|
|
215
|
+
enhanced_lines.append(line)
|
|
216
|
+
|
|
217
|
+
return '\n'.join(enhanced_lines)
|
|
218
|
+
|
|
219
|
+
@classmethod
|
|
220
|
+
def _add_specificity(cls, response: str, context: Dict[str, Any]) -> str:
|
|
221
|
+
"""Add specific details from context if response is vague"""
|
|
222
|
+
# Check if response is vague
|
|
223
|
+
vague_phrases = [
|
|
224
|
+
'some files',
|
|
225
|
+
'a few',
|
|
226
|
+
'several',
|
|
227
|
+
'multiple',
|
|
228
|
+
'various',
|
|
229
|
+
]
|
|
230
|
+
|
|
231
|
+
response_lower = response.lower()
|
|
232
|
+
is_vague = any(phrase in response_lower for phrase in vague_phrases)
|
|
233
|
+
|
|
234
|
+
if not is_vague:
|
|
235
|
+
return response
|
|
236
|
+
|
|
237
|
+
# Try to add specifics from context
|
|
238
|
+
api_results = context.get('api_results', {})
|
|
239
|
+
|
|
240
|
+
# If we have file data, be specific about count
|
|
241
|
+
if 'files' in response_lower and isinstance(api_results, dict):
|
|
242
|
+
# Look for file lists in results
|
|
243
|
+
for key, value in api_results.items():
|
|
244
|
+
if isinstance(value, (list, tuple)) and len(value) > 0:
|
|
245
|
+
# Found a list - add count
|
|
246
|
+
count = len(value)
|
|
247
|
+
response = response.replace('some files', f'{count} files')
|
|
248
|
+
response = response.replace('a few files', f'{count} files')
|
|
249
|
+
response = response.replace('several files', f'{count} files')
|
|
250
|
+
break
|
|
251
|
+
|
|
252
|
+
return response
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def enhance_response(response: str, query: str, context: Dict[str, Any] = None) -> str:
|
|
256
|
+
"""Convenience function to enhance a response"""
|
|
257
|
+
return ResponseEnhancer.enhance(response, query, context or {})
|