cite-agent 1.3.9__py3-none-any.whl → 1.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cite_agent/__init__.py +13 -13
- cite_agent/__version__.py +1 -1
- cite_agent/action_first_mode.py +150 -0
- cite_agent/adaptive_providers.py +413 -0
- cite_agent/archive_api_client.py +186 -0
- cite_agent/auth.py +0 -1
- cite_agent/auto_expander.py +70 -0
- cite_agent/cache.py +379 -0
- cite_agent/circuit_breaker.py +370 -0
- cite_agent/citation_network.py +377 -0
- cite_agent/cli.py +8 -16
- cite_agent/cli_conversational.py +113 -3
- cite_agent/confidence_calibration.py +381 -0
- cite_agent/deduplication.py +325 -0
- cite_agent/enhanced_ai_agent.py +689 -371
- cite_agent/error_handler.py +228 -0
- cite_agent/execution_safety.py +329 -0
- cite_agent/full_paper_reader.py +239 -0
- cite_agent/observability.py +398 -0
- cite_agent/offline_mode.py +348 -0
- cite_agent/paper_comparator.py +368 -0
- cite_agent/paper_summarizer.py +420 -0
- cite_agent/pdf_extractor.py +350 -0
- cite_agent/proactive_boundaries.py +266 -0
- cite_agent/quality_gate.py +442 -0
- cite_agent/request_queue.py +390 -0
- cite_agent/response_enhancer.py +257 -0
- cite_agent/response_formatter.py +458 -0
- cite_agent/response_pipeline.py +295 -0
- cite_agent/response_style_enhancer.py +259 -0
- cite_agent/self_healing.py +418 -0
- cite_agent/similarity_finder.py +524 -0
- cite_agent/streaming_ui.py +13 -9
- cite_agent/thinking_blocks.py +308 -0
- cite_agent/tool_orchestrator.py +416 -0
- cite_agent/trend_analyzer.py +540 -0
- cite_agent/unpaywall_client.py +226 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/METADATA +15 -1
- cite_agent-1.4.3.dist-info/RECORD +62 -0
- cite_agent-1.3.9.dist-info/RECORD +0 -32
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/WHEEL +0 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/entry_points.txt +0 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/licenses/LICENSE +0 -0
- {cite_agent-1.3.9.dist-info → cite_agent-1.4.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Graceful Error Handling - Never Expose Technical Details to Users
|
|
3
|
+
Converts technical errors to friendly, actionable messages
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Optional, Dict, Any
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GracefulErrorHandler:
|
|
13
|
+
"""
|
|
14
|
+
Converts all technical errors to user-friendly messages
|
|
15
|
+
|
|
16
|
+
PRINCIPLE: Users should never see:
|
|
17
|
+
- Stack traces
|
|
18
|
+
- API error codes
|
|
19
|
+
- Certificate errors
|
|
20
|
+
- Connection details
|
|
21
|
+
- Internal variable names
|
|
22
|
+
- Technical jargon
|
|
23
|
+
|
|
24
|
+
PRINCIPLE: Users should always see:
|
|
25
|
+
- What went wrong in simple terms
|
|
26
|
+
- What they can do about it
|
|
27
|
+
- Alternative options if available
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
# User-friendly error messages mapped from technical errors
|
|
31
|
+
ERROR_MESSAGES = {
|
|
32
|
+
# Network / Connection
|
|
33
|
+
'ConnectionError': "I'm having trouble connecting right now. Please try again in a moment.",
|
|
34
|
+
'Timeout': "That's taking longer than expected. Let me try a simpler approach.",
|
|
35
|
+
'TimeoutError': "That's taking longer than expected. Let me try a simpler approach.",
|
|
36
|
+
'ConnectTimeout': "I couldn't connect. Please check your network and try again.",
|
|
37
|
+
|
|
38
|
+
# API Errors
|
|
39
|
+
'HTTPError': "I encountered an issue accessing that service. Let me try again.",
|
|
40
|
+
'APIError': "Something went wrong on my end. Let me try another way.",
|
|
41
|
+
'RateLimitError': "I've hit my usage limit. Please try again in a few minutes.",
|
|
42
|
+
'QuotaExceeded': "I've reached my daily limit. Please try again tomorrow.",
|
|
43
|
+
|
|
44
|
+
# Authentication / Authorization
|
|
45
|
+
'AuthenticationError': "I'm having trouble with authentication. Please check your setup.",
|
|
46
|
+
'PermissionError': "I don't have permission to access that. Please check the permissions.",
|
|
47
|
+
'UnauthorizedError': "I need authorization to access that resource.",
|
|
48
|
+
|
|
49
|
+
# Data / Parsing
|
|
50
|
+
'JSONDecodeError': "I received data in an unexpected format. Let me try again.",
|
|
51
|
+
'ParseError': "I couldn't understand the response. Let me try another approach.",
|
|
52
|
+
'ValueError': "I received unexpected data. Let me try again.",
|
|
53
|
+
'KeyError': "I couldn't find the expected information. Let me try differently.",
|
|
54
|
+
|
|
55
|
+
# TLS / SSL / Certificate
|
|
56
|
+
'SSLError': "I'm having trouble with the secure connection. Please try again.",
|
|
57
|
+
'CertificateError': "I'm having trouble with the secure connection. Please try again.",
|
|
58
|
+
'TLS_error': "I'm having trouble with the secure connection. Please try again.",
|
|
59
|
+
|
|
60
|
+
# File System
|
|
61
|
+
'FileNotFoundError': "I couldn't find that file. Please check the path.",
|
|
62
|
+
'IsADirectoryError': "That's a directory, not a file. Please specify a file path.",
|
|
63
|
+
'NotADirectoryError': "That's a file, not a directory. Please specify a directory path.",
|
|
64
|
+
'PermissionError_file': "I don't have permission to access that file.",
|
|
65
|
+
|
|
66
|
+
# General
|
|
67
|
+
'Exception': "Something unexpected happened. Let me try again.",
|
|
68
|
+
'RuntimeError': "I encountered an unexpected issue. Let me try another approach.",
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def handle_error(
|
|
73
|
+
cls,
|
|
74
|
+
error: Exception,
|
|
75
|
+
context: str = "",
|
|
76
|
+
fallback_action: Optional[str] = None
|
|
77
|
+
) -> str:
|
|
78
|
+
"""
|
|
79
|
+
Convert any error to a user-friendly message
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
error: The exception that occurred
|
|
83
|
+
context: What the agent was trying to do (e.g., "search papers", "read file")
|
|
84
|
+
fallback_action: What the user can do instead (e.g., "try a different search")
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
User-friendly error message (never technical details)
|
|
88
|
+
"""
|
|
89
|
+
# Log technical details for debugging (but don't show to user!)
|
|
90
|
+
logger.error(f"Error in {context}: {type(error).__name__}: {str(error)}", exc_info=True)
|
|
91
|
+
|
|
92
|
+
# Get error type name
|
|
93
|
+
error_type = type(error).__name__
|
|
94
|
+
|
|
95
|
+
# Look up user-friendly message
|
|
96
|
+
user_message = cls.ERROR_MESSAGES.get(error_type)
|
|
97
|
+
|
|
98
|
+
# Check for specific error patterns in the message
|
|
99
|
+
error_str = str(error).lower()
|
|
100
|
+
|
|
101
|
+
if not user_message:
|
|
102
|
+
# Pattern matching for specific errors
|
|
103
|
+
if 'certificate' in error_str or 'tls' in error_str or 'ssl' in error_str:
|
|
104
|
+
user_message = cls.ERROR_MESSAGES['CertificateError']
|
|
105
|
+
elif 'timeout' in error_str:
|
|
106
|
+
user_message = cls.ERROR_MESSAGES['Timeout']
|
|
107
|
+
elif 'connection' in error_str or 'connect' in error_str:
|
|
108
|
+
user_message = cls.ERROR_MESSAGES['ConnectionError']
|
|
109
|
+
elif 'rate limit' in error_str or 'quota' in error_str:
|
|
110
|
+
user_message = cls.ERROR_MESSAGES['RateLimitError']
|
|
111
|
+
elif 'auth' in error_str or 'unauthorized' in error_str:
|
|
112
|
+
user_message = cls.ERROR_MESSAGES['AuthenticationError']
|
|
113
|
+
elif 'not found' in error_str:
|
|
114
|
+
user_message = cls.ERROR_MESSAGES['FileNotFoundError']
|
|
115
|
+
else:
|
|
116
|
+
# Generic fallback
|
|
117
|
+
user_message = cls.ERROR_MESSAGES['Exception']
|
|
118
|
+
|
|
119
|
+
# Add context if provided
|
|
120
|
+
if context:
|
|
121
|
+
full_message = f"While trying to {context}, {user_message.lower()}"
|
|
122
|
+
else:
|
|
123
|
+
full_message = user_message
|
|
124
|
+
|
|
125
|
+
# Add fallback action if provided
|
|
126
|
+
if fallback_action:
|
|
127
|
+
full_message += f" {fallback_action}"
|
|
128
|
+
|
|
129
|
+
return full_message
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
def wrap_response_with_error_handling(cls, response: str) -> str:
|
|
133
|
+
"""
|
|
134
|
+
Scan response for any leaked technical errors and clean them
|
|
135
|
+
|
|
136
|
+
This is a safety net in case errors slip through
|
|
137
|
+
"""
|
|
138
|
+
# Technical terms that should NEVER appear in user responses
|
|
139
|
+
forbidden_patterns = [
|
|
140
|
+
('TLS_error', 'secure connection issue'),
|
|
141
|
+
('CERTIFICATE_VERIFY_FAILED', 'secure connection issue'),
|
|
142
|
+
('upstream connect error', 'connection issue'),
|
|
143
|
+
('stack trace', ''),
|
|
144
|
+
('Traceback (most recent call last)', ''),
|
|
145
|
+
('Exception:', ''),
|
|
146
|
+
('ERROR:', ''),
|
|
147
|
+
('⚠️ I couldn\'t finish the reasoning step', 'I encountered an issue'),
|
|
148
|
+
('language model call failed', 'I had trouble processing that'),
|
|
149
|
+
('API call failed', 'I had trouble accessing that service'),
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
cleaned_response = response
|
|
153
|
+
had_technical_errors = False
|
|
154
|
+
|
|
155
|
+
for technical_term, friendly_replacement in forbidden_patterns:
|
|
156
|
+
if technical_term.lower() in cleaned_response.lower():
|
|
157
|
+
# If we find technical errors, replace with friendly version
|
|
158
|
+
logger.warning(f"Found leaked technical error in response: {technical_term}")
|
|
159
|
+
had_technical_errors = True
|
|
160
|
+
|
|
161
|
+
if friendly_replacement:
|
|
162
|
+
# Replace with friendly term
|
|
163
|
+
import re
|
|
164
|
+
cleaned_response = re.sub(
|
|
165
|
+
re.escape(technical_term),
|
|
166
|
+
friendly_replacement,
|
|
167
|
+
cleaned_response,
|
|
168
|
+
flags=re.IGNORECASE
|
|
169
|
+
)
|
|
170
|
+
else:
|
|
171
|
+
# Remove the line entirely
|
|
172
|
+
lines = cleaned_response.split('\n')
|
|
173
|
+
cleaned_response = '\n'.join([
|
|
174
|
+
line for line in lines
|
|
175
|
+
if technical_term.lower() not in line.lower()
|
|
176
|
+
])
|
|
177
|
+
|
|
178
|
+
# If the response became empty or too short AFTER cleaning errors, provide generic friendly message
|
|
179
|
+
# Don't flag legitimately short responses (greetings, acknowledgments, etc.)
|
|
180
|
+
if had_technical_errors and len(cleaned_response.strip()) < 20:
|
|
181
|
+
cleaned_response = "I encountered an issue while processing that. Could you try rephrasing your question?"
|
|
182
|
+
|
|
183
|
+
return cleaned_response
|
|
184
|
+
|
|
185
|
+
@classmethod
|
|
186
|
+
def create_fallback_response(cls, original_query: str, error: Exception) -> str:
|
|
187
|
+
"""
|
|
188
|
+
Create a complete fallback response when main processing fails
|
|
189
|
+
|
|
190
|
+
Returns a helpful response instead of exposing the error
|
|
191
|
+
"""
|
|
192
|
+
# Get friendly error message
|
|
193
|
+
error_msg = cls.handle_error(error, "process your request")
|
|
194
|
+
|
|
195
|
+
# Try to be helpful based on query type
|
|
196
|
+
query_lower = original_query.lower()
|
|
197
|
+
|
|
198
|
+
suggestions = []
|
|
199
|
+
|
|
200
|
+
if any(word in query_lower for word in ['search', 'find', 'papers', 'research']):
|
|
201
|
+
suggestions.append("• Try a more specific search term")
|
|
202
|
+
suggestions.append("• Check if the topic exists in our database")
|
|
203
|
+
|
|
204
|
+
if any(word in query_lower for word in ['revenue', 'stock', 'financial', 'company']):
|
|
205
|
+
suggestions.append("• Try searching for a specific company by name")
|
|
206
|
+
suggestions.append("• Check if the company is publicly traded")
|
|
207
|
+
|
|
208
|
+
if any(word in query_lower for word in ['file', 'directory', 'folder', 'read']):
|
|
209
|
+
suggestions.append("• Check if the file path is correct")
|
|
210
|
+
suggestions.append("• Try using an absolute path")
|
|
211
|
+
|
|
212
|
+
response_parts = [error_msg]
|
|
213
|
+
|
|
214
|
+
if suggestions:
|
|
215
|
+
response_parts.append("\nYou could try:")
|
|
216
|
+
response_parts.extend(suggestions)
|
|
217
|
+
|
|
218
|
+
return '\n'.join(response_parts)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# Convenient function for quick error handling
|
|
222
|
+
def handle_error_gracefully(
|
|
223
|
+
error: Exception,
|
|
224
|
+
context: str = "",
|
|
225
|
+
fallback_action: Optional[str] = None
|
|
226
|
+
) -> str:
|
|
227
|
+
"""Shortcut function for graceful error handling"""
|
|
228
|
+
return GracefulErrorHandler.handle_error(error, context, fallback_action)
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Command Execution Safety & Enforcement
|
|
3
|
+
Ensures backend executes EXACTLY what agent planned - validates pre/post execution
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import hashlib
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Dict, List, Optional, Any
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CommandClassification(Enum):
|
|
18
|
+
"""Safety classification for commands"""
|
|
19
|
+
SAFE = "safe" # File reads, non-destructive queries
|
|
20
|
+
WRITE = "write" # File writes, modifications
|
|
21
|
+
DANGEROUS = "dangerous" # Dangerous: rm -rf, format disk
|
|
22
|
+
BLOCKED = "blocked" # Never execute
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CommandAuditLevel(Enum):
|
|
26
|
+
"""How strictly to audit commands"""
|
|
27
|
+
PERMISSIVE = "permissive" # Allow, log only
|
|
28
|
+
STRICT = "strict" # Require pre-approval
|
|
29
|
+
ENFORCED = "enforced" # Prevent mismatches
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class CommandPlan:
|
|
34
|
+
"""What agent intends to do"""
|
|
35
|
+
command: str
|
|
36
|
+
classification: CommandClassification
|
|
37
|
+
reason: str # Why this command
|
|
38
|
+
expected_output_pattern: Optional[str] = None
|
|
39
|
+
max_execution_time_s: float = 60.0
|
|
40
|
+
|
|
41
|
+
def get_hash(self) -> str:
|
|
42
|
+
"""Get hash of command for comparison"""
|
|
43
|
+
return hashlib.sha256(self.command.encode()).hexdigest()[:16]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class CommandExecution:
|
|
48
|
+
"""Record of command execution"""
|
|
49
|
+
command: str
|
|
50
|
+
planned_hash: str
|
|
51
|
+
executed_hash: str
|
|
52
|
+
classification: CommandClassification
|
|
53
|
+
status: str # success, failure, timeout, mismatch
|
|
54
|
+
exit_code: int = 0
|
|
55
|
+
output: str = ""
|
|
56
|
+
error: str = ""
|
|
57
|
+
execution_time_s: float = 0.0
|
|
58
|
+
timestamp: datetime = field(default_factory=datetime.now)
|
|
59
|
+
user_id: Optional[str] = None
|
|
60
|
+
|
|
61
|
+
def was_modified(self) -> bool:
|
|
62
|
+
"""Check if executed command differs from planned"""
|
|
63
|
+
return self.planned_hash != self.executed_hash
|
|
64
|
+
|
|
65
|
+
def to_audit_log(self) -> str:
|
|
66
|
+
"""Format as audit log entry"""
|
|
67
|
+
modified = "⚠️ MODIFIED" if self.was_modified() else "✓ AS-PLANNED"
|
|
68
|
+
return (
|
|
69
|
+
f"{self.timestamp.isoformat()} | {self.classification.value.upper()} | "
|
|
70
|
+
f"{modified} | exit={self.exit_code} | {self.command[:60]}..."
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class CommandExecutionValidator:
|
|
75
|
+
"""
|
|
76
|
+
Pre and post-execution validation to ensure safety
|
|
77
|
+
|
|
78
|
+
Prevents:
|
|
79
|
+
- Agent plans "cat file.txt", backend executes "rm -rf /"
|
|
80
|
+
- Command injection attacks
|
|
81
|
+
- Unexpected output/errors
|
|
82
|
+
- Timeout vulnerabilities
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(self, audit_level: CommandAuditLevel = CommandAuditLevel.STRICT):
|
|
86
|
+
self.audit_level = audit_level
|
|
87
|
+
self.audit_log: List[CommandExecution] = []
|
|
88
|
+
self.dangerous_patterns = [
|
|
89
|
+
"rm -rf",
|
|
90
|
+
"mkfs",
|
|
91
|
+
"format",
|
|
92
|
+
": () { :", # Bash fork bomb
|
|
93
|
+
"dd if=/dev/zero of=/",
|
|
94
|
+
"chmod -R 777 /",
|
|
95
|
+
"reboot",
|
|
96
|
+
"shutdown -h",
|
|
97
|
+
]
|
|
98
|
+
self.blocked_commands = [
|
|
99
|
+
"sudo rm -rf /",
|
|
100
|
+
"rm -rf /etc",
|
|
101
|
+
"rm -rf /boot",
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
def validate_plan(self, plan: CommandPlan) -> tuple[bool, Optional[str]]:
|
|
105
|
+
"""
|
|
106
|
+
Validate a command plan before execution
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
(is_valid, error_reason)
|
|
110
|
+
"""
|
|
111
|
+
# Check for explicitly blocked commands
|
|
112
|
+
for blocked in self.blocked_commands:
|
|
113
|
+
if blocked in plan.command:
|
|
114
|
+
return False, f"Command explicitly blocked: {blocked}"
|
|
115
|
+
|
|
116
|
+
# Check classification
|
|
117
|
+
if plan.classification == CommandClassification.BLOCKED:
|
|
118
|
+
return False, f"Command classified as BLOCKED"
|
|
119
|
+
|
|
120
|
+
# For DANGEROUS commands, require strict approval
|
|
121
|
+
if plan.classification == CommandClassification.DANGEROUS:
|
|
122
|
+
if self.audit_level == CommandAuditLevel.PERMISSIVE:
|
|
123
|
+
logger.warning(f"⚠️ DANGEROUS command allowed in permissive mode: {plan.command}")
|
|
124
|
+
else:
|
|
125
|
+
return False, "Dangerous commands require special approval"
|
|
126
|
+
|
|
127
|
+
# Check for dangerous patterns
|
|
128
|
+
for pattern in self.dangerous_patterns:
|
|
129
|
+
if pattern in plan.command:
|
|
130
|
+
classification = CommandClassification.DANGEROUS
|
|
131
|
+
if plan.classification == CommandClassification.SAFE:
|
|
132
|
+
return False, f"Command contains dangerous pattern: {pattern}"
|
|
133
|
+
|
|
134
|
+
return True, None
|
|
135
|
+
|
|
136
|
+
def validate_execution(
|
|
137
|
+
self,
|
|
138
|
+
plan: CommandPlan,
|
|
139
|
+
executed_command: str,
|
|
140
|
+
exit_code: int,
|
|
141
|
+
output: str,
|
|
142
|
+
error: str,
|
|
143
|
+
execution_time_s: float,
|
|
144
|
+
user_id: Optional[str] = None
|
|
145
|
+
) -> tuple[bool, Optional[str]]:
|
|
146
|
+
"""
|
|
147
|
+
Validate execution matches plan
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
(is_valid, error_reason)
|
|
151
|
+
"""
|
|
152
|
+
# Record execution
|
|
153
|
+
execution = CommandExecution(
|
|
154
|
+
command=executed_command,
|
|
155
|
+
planned_hash=plan.get_hash(),
|
|
156
|
+
executed_hash=hashlib.sha256(executed_command.encode()).hexdigest()[:16],
|
|
157
|
+
classification=plan.classification,
|
|
158
|
+
status="unknown",
|
|
159
|
+
exit_code=exit_code,
|
|
160
|
+
output=output,
|
|
161
|
+
error=error,
|
|
162
|
+
execution_time_s=execution_time_s,
|
|
163
|
+
user_id=user_id
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Check for modification
|
|
167
|
+
if execution.was_modified():
|
|
168
|
+
if self.audit_level == CommandAuditLevel.ENFORCED:
|
|
169
|
+
execution.status = "mismatch"
|
|
170
|
+
self.audit_log.append(execution)
|
|
171
|
+
return False, f"Command was modified during execution!\nPlanned: {plan.command}\nExecuted: {executed_command}"
|
|
172
|
+
else:
|
|
173
|
+
logger.warning(
|
|
174
|
+
f"⚠️ Command modification detected:\n"
|
|
175
|
+
f" Planned: {plan.command}\n"
|
|
176
|
+
f" Executed: {executed_command}"
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Check timeout
|
|
180
|
+
if execution_time_s > plan.max_execution_time_s:
|
|
181
|
+
execution.status = "timeout"
|
|
182
|
+
self.audit_log.append(execution)
|
|
183
|
+
return False, f"Command exceeded max execution time: {execution_time_s:.1f}s > {plan.max_execution_time_s}s"
|
|
184
|
+
|
|
185
|
+
# Check for unexpected errors
|
|
186
|
+
if error and plan.classification == CommandClassification.SAFE:
|
|
187
|
+
# SAFE commands shouldn't produce errors
|
|
188
|
+
if exit_code != 0:
|
|
189
|
+
execution.status = "error"
|
|
190
|
+
self.audit_log.append(execution)
|
|
191
|
+
return False, f"Safe command produced error (exit {exit_code}): {error}"
|
|
192
|
+
|
|
193
|
+
# Verify output matches expected pattern if specified
|
|
194
|
+
if plan.expected_output_pattern:
|
|
195
|
+
import re
|
|
196
|
+
if not re.search(plan.expected_output_pattern, output):
|
|
197
|
+
logger.warning(
|
|
198
|
+
f"⚠️ Output doesn't match expected pattern:\n"
|
|
199
|
+
f" Pattern: {plan.expected_output_pattern}\n"
|
|
200
|
+
f" Got: {output[:200]}"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Mark as success
|
|
204
|
+
execution.status = "success" if exit_code == 0 else "failure"
|
|
205
|
+
self.audit_log.append(execution)
|
|
206
|
+
|
|
207
|
+
return True, None
|
|
208
|
+
|
|
209
|
+
def get_audit_log(self, limit: int = 50) -> List[str]:
|
|
210
|
+
"""Get recent audit log entries"""
|
|
211
|
+
return [entry.to_audit_log() for entry in self.audit_log[-limit:]]
|
|
212
|
+
|
|
213
|
+
def get_command_stats(self) -> Dict[str, Any]:
|
|
214
|
+
"""Get statistics about commands executed"""
|
|
215
|
+
if not self.audit_log:
|
|
216
|
+
return {"total": 0}
|
|
217
|
+
|
|
218
|
+
safe_count = sum(1 for e in self.audit_log if e.classification == CommandClassification.SAFE)
|
|
219
|
+
write_count = sum(1 for e in self.audit_log if e.classification == CommandClassification.WRITE)
|
|
220
|
+
dangerous_count = sum(1 for e in self.audit_log if e.classification == CommandClassification.DANGEROUS)
|
|
221
|
+
failed_count = sum(1 for e in self.audit_log if e.status == "failure")
|
|
222
|
+
modified_count = sum(1 for e in self.audit_log if e.was_modified())
|
|
223
|
+
|
|
224
|
+
return {
|
|
225
|
+
"total_executed": len(self.audit_log),
|
|
226
|
+
"safe_commands": safe_count,
|
|
227
|
+
"write_commands": write_count,
|
|
228
|
+
"dangerous_commands": dangerous_count,
|
|
229
|
+
"failed_commands": failed_count,
|
|
230
|
+
"modified_commands": modified_count,
|
|
231
|
+
"modification_rate": modified_count / len(self.audit_log) if self.audit_log else 0,
|
|
232
|
+
"failure_rate": failed_count / len(self.audit_log) if self.audit_log else 0,
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
def get_status_message(self) -> str:
|
|
236
|
+
"""Human-readable status"""
|
|
237
|
+
stats = self.get_command_stats()
|
|
238
|
+
|
|
239
|
+
if stats.get("total_executed", 0) == 0:
|
|
240
|
+
return "📋 **Command Execution Safety**: No commands executed yet"
|
|
241
|
+
|
|
242
|
+
lines = [
|
|
243
|
+
"📋 **Command Execution Safety**",
|
|
244
|
+
f"• Audit level: {self.audit_level.value.upper()}",
|
|
245
|
+
f"• Total executed: {stats['total_executed']}",
|
|
246
|
+
f"• Safe commands: {stats['safe_commands']} | Write: {stats['write_commands']} | Dangerous: {stats['dangerous_commands']}",
|
|
247
|
+
f"• Failures: {stats['failed_commands']} ({stats['failure_rate']:.1%})",
|
|
248
|
+
f"• Modifications detected: {stats['modified_commands']} ({stats['modification_rate']:.1%})",
|
|
249
|
+
]
|
|
250
|
+
|
|
251
|
+
if stats['modification_rate'] > 0:
|
|
252
|
+
lines.append("\n⚠️ **WARNING**: Commands being modified during execution!")
|
|
253
|
+
|
|
254
|
+
return "\n".join(lines)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class CommandSandbox:
|
|
258
|
+
"""
|
|
259
|
+
Optional sandboxing for dangerous commands
|
|
260
|
+
Runs in isolated environment with limited permissions
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
def __init__(self, enable_sandbox: bool = False):
|
|
264
|
+
self.enabled = enable_sandbox
|
|
265
|
+
|
|
266
|
+
def prepare_sandbox(self, command: str, dangerous: bool = False) -> str:
|
|
267
|
+
"""
|
|
268
|
+
Wrap command in sandbox if needed
|
|
269
|
+
|
|
270
|
+
Returns modified command that runs safely
|
|
271
|
+
"""
|
|
272
|
+
if not self.enabled or not dangerous:
|
|
273
|
+
return command
|
|
274
|
+
|
|
275
|
+
# Use firejail if available
|
|
276
|
+
sandboxed = f"firejail --quiet --timeout=60 {command}"
|
|
277
|
+
return sandboxed
|
|
278
|
+
|
|
279
|
+
def cleanup_sandbox(self):
|
|
280
|
+
"""Clean up sandbox resources"""
|
|
281
|
+
if self.enabled:
|
|
282
|
+
logger.info("🧹 Cleaning up sandbox")
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
# Global validator
|
|
286
|
+
command_validator = CommandExecutionValidator()
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
if __name__ == "__main__":
|
|
290
|
+
# Test the validator
|
|
291
|
+
validator = CommandExecutionValidator(CommandAuditLevel.STRICT)
|
|
292
|
+
|
|
293
|
+
# Test 1: Safe command
|
|
294
|
+
plan = CommandPlan(
|
|
295
|
+
command="cat /etc/hostname",
|
|
296
|
+
classification=CommandClassification.SAFE,
|
|
297
|
+
reason="Read system hostname",
|
|
298
|
+
expected_output_pattern=r"\w+"
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
valid, error = validator.validate_plan(plan)
|
|
302
|
+
print(f"Plan validation: {valid} - {error}")
|
|
303
|
+
|
|
304
|
+
# Simulate execution
|
|
305
|
+
valid, error = validator.validate_execution(
|
|
306
|
+
plan,
|
|
307
|
+
executed_command="cat /etc/hostname",
|
|
308
|
+
exit_code=0,
|
|
309
|
+
output="localhost\n",
|
|
310
|
+
error="",
|
|
311
|
+
execution_time_s=0.05
|
|
312
|
+
)
|
|
313
|
+
print(f"Execution validation: {valid} - {error}")
|
|
314
|
+
|
|
315
|
+
# Test 2: Dangerous command
|
|
316
|
+
plan2 = CommandPlan(
|
|
317
|
+
command="rm -rf /tmp/test",
|
|
318
|
+
classification=CommandClassification.DANGEROUS,
|
|
319
|
+
reason="Clean test directory"
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
valid, error = validator.validate_plan(plan2)
|
|
323
|
+
print(f"\nDangerous plan validation: {valid} - {error}")
|
|
324
|
+
|
|
325
|
+
# Show audit log
|
|
326
|
+
print("\n" + validator.get_status_message())
|
|
327
|
+
print("\n📋 **Audit Log**")
|
|
328
|
+
for entry in validator.get_audit_log(5):
|
|
329
|
+
print(f" {entry}")
|