agentic-threat-hunting-framework 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agentic_threat_hunting_framework-0.5.0.dist-info → agentic_threat_hunting_framework-0.5.1.dist-info}/METADATA +1 -1
- {agentic_threat_hunting_framework-0.5.0.dist-info → agentic_threat_hunting_framework-0.5.1.dist-info}/RECORD +7 -14
- athf/agents/base.py +5 -14
- athf/core/clickhouse_connection.py +0 -396
- athf/core/metrics_tracker.py +0 -518
- athf/core/query_executor.py +0 -169
- athf/core/query_parser.py +0 -203
- athf/core/query_suggester.py +0 -235
- athf/core/query_validator.py +0 -240
- athf/core/session_manager.py +0 -764
- {agentic_threat_hunting_framework-0.5.0.dist-info → agentic_threat_hunting_framework-0.5.1.dist-info}/WHEEL +0 -0
- {agentic_threat_hunting_framework-0.5.0.dist-info → agentic_threat_hunting_framework-0.5.1.dist-info}/entry_points.txt +0 -0
- {agentic_threat_hunting_framework-0.5.0.dist-info → agentic_threat_hunting_framework-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {agentic_threat_hunting_framework-0.5.0.dist-info → agentic_threat_hunting_framework-0.5.1.dist-info}/top_level.txt +0 -0
athf/core/query_parser.py
DELETED
|
@@ -1,203 +0,0 @@
|
|
|
1
|
-
"""Query library parser for YAML query files."""
|
|
2
|
-
|
|
3
|
-
import re
|
|
4
|
-
from pathlib import Path
|
|
5
|
-
from typing import Any, Dict, List, Optional, Tuple
|
|
6
|
-
|
|
7
|
-
import yaml
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class QueryParser:
|
|
11
|
-
"""Parser for query library YAML files."""
|
|
12
|
-
|
|
13
|
-
def __init__(self, queries_dir: Path):
|
|
14
|
-
"""Initialize query parser with queries directory.
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
queries_dir: Path to queries/ directory containing YAML files
|
|
18
|
-
"""
|
|
19
|
-
self.queries_dir = Path(queries_dir)
|
|
20
|
-
self._queries_cache: Dict[str, Dict[str, Any]] = {}
|
|
21
|
-
self._loaded = False
|
|
22
|
-
|
|
23
|
-
def load_all_queries(self) -> List[Dict[str, Any]]:
|
|
24
|
-
"""Load all queries from YAML files in queries/ directory.
|
|
25
|
-
|
|
26
|
-
Returns:
|
|
27
|
-
List of query dictionaries
|
|
28
|
-
"""
|
|
29
|
-
if self._loaded and self._queries_cache:
|
|
30
|
-
return list(self._queries_cache.values())
|
|
31
|
-
|
|
32
|
-
all_queries = []
|
|
33
|
-
yaml_files = sorted(self.queries_dir.glob("*.yaml"))
|
|
34
|
-
|
|
35
|
-
for yaml_file in yaml_files:
|
|
36
|
-
try:
|
|
37
|
-
with open(yaml_file, "r", encoding="utf-8") as f:
|
|
38
|
-
data = yaml.safe_load(f)
|
|
39
|
-
|
|
40
|
-
if not data or not isinstance(data, dict):
|
|
41
|
-
continue
|
|
42
|
-
|
|
43
|
-
category = data.get("category")
|
|
44
|
-
queries = data.get("queries", [])
|
|
45
|
-
|
|
46
|
-
for query in queries:
|
|
47
|
-
# Add category from file-level metadata
|
|
48
|
-
query["category"] = category
|
|
49
|
-
query["source_file"] = yaml_file.name
|
|
50
|
-
|
|
51
|
-
# Cache by query_id
|
|
52
|
-
query_id = query.get("query_id")
|
|
53
|
-
if query_id:
|
|
54
|
-
self._queries_cache[query_id] = query
|
|
55
|
-
|
|
56
|
-
all_queries.append(query)
|
|
57
|
-
|
|
58
|
-
except Exception as e:
|
|
59
|
-
# Log error but continue loading other files
|
|
60
|
-
print(f"Warning: Failed to parse {yaml_file}: {e}")
|
|
61
|
-
continue
|
|
62
|
-
|
|
63
|
-
self._loaded = True
|
|
64
|
-
return all_queries
|
|
65
|
-
|
|
66
|
-
def get_query_by_id(self, query_id: str) -> Optional[Dict[str, Any]]:
|
|
67
|
-
"""Retrieve query metadata by Q-XXXX ID.
|
|
68
|
-
|
|
69
|
-
Args:
|
|
70
|
-
query_id: Query ID (e.g., Q-USER-001)
|
|
71
|
-
|
|
72
|
-
Returns:
|
|
73
|
-
Query dictionary or None if not found
|
|
74
|
-
"""
|
|
75
|
-
if not self._loaded:
|
|
76
|
-
self.load_all_queries()
|
|
77
|
-
|
|
78
|
-
return self._queries_cache.get(query_id)
|
|
79
|
-
|
|
80
|
-
def search_queries(self, keyword: str) -> List[Dict[str, Any]]:
|
|
81
|
-
"""Search queries by keyword in name, description, tags.
|
|
82
|
-
|
|
83
|
-
Args:
|
|
84
|
-
keyword: Search keyword
|
|
85
|
-
|
|
86
|
-
Returns:
|
|
87
|
-
List of matching query dictionaries
|
|
88
|
-
"""
|
|
89
|
-
all_queries = self.load_all_queries()
|
|
90
|
-
keyword_lower = keyword.lower()
|
|
91
|
-
results = []
|
|
92
|
-
|
|
93
|
-
for query in all_queries:
|
|
94
|
-
# Search in name
|
|
95
|
-
if keyword_lower in query.get("name", "").lower():
|
|
96
|
-
results.append(query)
|
|
97
|
-
continue
|
|
98
|
-
|
|
99
|
-
# Search in description
|
|
100
|
-
if keyword_lower in query.get("description", "").lower():
|
|
101
|
-
results.append(query)
|
|
102
|
-
continue
|
|
103
|
-
|
|
104
|
-
# Search in tags
|
|
105
|
-
tags = query.get("tags", [])
|
|
106
|
-
if any(keyword_lower in tag.lower() for tag in tags):
|
|
107
|
-
results.append(query)
|
|
108
|
-
continue
|
|
109
|
-
|
|
110
|
-
return results
|
|
111
|
-
|
|
112
|
-
def filter_queries(self, category: Optional[str] = None, tags: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
|
113
|
-
"""Filter queries by category or tags.
|
|
114
|
-
|
|
115
|
-
Args:
|
|
116
|
-
category: Category to filter by (e.g., "user-activity")
|
|
117
|
-
tags: List of tags to filter by (any match)
|
|
118
|
-
|
|
119
|
-
Returns:
|
|
120
|
-
List of matching query dictionaries
|
|
121
|
-
"""
|
|
122
|
-
all_queries = self.load_all_queries()
|
|
123
|
-
results = all_queries
|
|
124
|
-
|
|
125
|
-
if category:
|
|
126
|
-
results = [q for q in results if q.get("category") == category]
|
|
127
|
-
|
|
128
|
-
if tags:
|
|
129
|
-
tag_set = {t.lower() for t in tags}
|
|
130
|
-
results = [q for q in results if any(qt.lower() in tag_set for qt in q.get("tags", []))]
|
|
131
|
-
|
|
132
|
-
return results
|
|
133
|
-
|
|
134
|
-
def validate_query(self, query: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
|
135
|
-
"""Validate query structure and required fields.
|
|
136
|
-
|
|
137
|
-
Args:
|
|
138
|
-
query: Query dictionary
|
|
139
|
-
|
|
140
|
-
Returns:
|
|
141
|
-
Tuple of (is_valid, list_of_errors)
|
|
142
|
-
"""
|
|
143
|
-
errors = []
|
|
144
|
-
|
|
145
|
-
# Required fields
|
|
146
|
-
required_fields = ["query_id", "name", "description", "query"]
|
|
147
|
-
for field in required_fields:
|
|
148
|
-
if not query.get(field):
|
|
149
|
-
errors.append(f"Missing required field: {field}")
|
|
150
|
-
|
|
151
|
-
# Validate query_id format: Q-[A-Z]+-[0-9]+
|
|
152
|
-
query_id = query.get("query_id", "")
|
|
153
|
-
if not re.match(r"^Q-[A-Z]+-\d+$", query_id):
|
|
154
|
-
errors.append(f"Invalid query_id format: {query_id} (expected Q-[CATEGORY]-[NUMBER])")
|
|
155
|
-
|
|
156
|
-
# Validate query includes time bounds
|
|
157
|
-
query_sql = query.get("query", "")
|
|
158
|
-
has_time_bound = "INTERVAL" in query_sql.upper() or "timestamp >=" in query_sql or "timestamp BETWEEN" in query_sql
|
|
159
|
-
if not has_time_bound:
|
|
160
|
-
errors.append("Query missing time constraint (INTERVAL or timestamp >=)")
|
|
161
|
-
|
|
162
|
-
# Validate query includes LIMIT clause
|
|
163
|
-
if "LIMIT" not in query_sql.upper():
|
|
164
|
-
errors.append("Query missing LIMIT clause")
|
|
165
|
-
|
|
166
|
-
# Validate placeholders match {{}} in query
|
|
167
|
-
placeholders_defined = set(query.get("placeholders", {}).keys())
|
|
168
|
-
placeholders_in_query = set(re.findall(r"\{\{(\w+)\}\}", query_sql))
|
|
169
|
-
|
|
170
|
-
# Check for placeholders used in query but not defined
|
|
171
|
-
undefined = placeholders_in_query - placeholders_defined
|
|
172
|
-
if undefined:
|
|
173
|
-
errors.append(f"Placeholders used but not defined: {', '.join(undefined)}")
|
|
174
|
-
|
|
175
|
-
# Check for placeholders defined but not used (warning, not error)
|
|
176
|
-
unused = placeholders_defined - placeholders_in_query
|
|
177
|
-
if unused:
|
|
178
|
-
# This is just a warning, not an error
|
|
179
|
-
pass
|
|
180
|
-
|
|
181
|
-
return (len(errors) == 0, errors)
|
|
182
|
-
|
|
183
|
-
def get_categories(self) -> List[str]:
|
|
184
|
-
"""Get list of all unique categories in query library.
|
|
185
|
-
|
|
186
|
-
Returns:
|
|
187
|
-
List of category names
|
|
188
|
-
"""
|
|
189
|
-
all_queries = self.load_all_queries()
|
|
190
|
-
categories: set[str] = {cat for q in all_queries if (cat := q.get("category")) is not None}
|
|
191
|
-
return sorted(categories)
|
|
192
|
-
|
|
193
|
-
def get_all_tags(self) -> List[str]:
|
|
194
|
-
"""Get list of all unique tags in query library.
|
|
195
|
-
|
|
196
|
-
Returns:
|
|
197
|
-
List of tag names
|
|
198
|
-
"""
|
|
199
|
-
all_queries = self.load_all_queries()
|
|
200
|
-
tags = set()
|
|
201
|
-
for query in all_queries:
|
|
202
|
-
tags.update(query.get("tags", []))
|
|
203
|
-
return sorted(tags)
|
athf/core/query_suggester.py
DELETED
|
@@ -1,235 +0,0 @@
|
|
|
1
|
-
"""Query suggestion engine for LLM-driven alert triage."""
|
|
2
|
-
|
|
3
|
-
import re
|
|
4
|
-
from pathlib import Path
|
|
5
|
-
from typing import Any, Dict, List, Tuple
|
|
6
|
-
|
|
7
|
-
from athf.core.query_parser import QueryParser
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class QuerySuggester:
|
|
11
|
-
"""Suggests relevant queries based on alert text analysis."""
|
|
12
|
-
|
|
13
|
-
def __init__(self, queries_dir: Path):
|
|
14
|
-
"""Initialize query suggester.
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
queries_dir: Path to queries/ directory
|
|
18
|
-
"""
|
|
19
|
-
self.parser = QueryParser(queries_dir)
|
|
20
|
-
self.all_queries = self.parser.load_all_queries()
|
|
21
|
-
|
|
22
|
-
def extract_parameters(self, alert_text: str) -> Dict[str, str]:
|
|
23
|
-
"""Extract username, hostname, and other parameters from alert text.
|
|
24
|
-
|
|
25
|
-
Patterns supported:
|
|
26
|
-
- "user: john.doe" or "username: john.doe" or "User: john.doe"
|
|
27
|
-
- "host: LAPTOP-123" or "hostname: LAPTOP-123" or "Host: LAPTOP-123"
|
|
28
|
-
- "process: powershell.exe" or "Process: powershell.exe"
|
|
29
|
-
- "ip: 8.8.8.8" or "IP: 8.8.8.8"
|
|
30
|
-
- "organization: acme-corp" or "org: acme-corp"
|
|
31
|
-
|
|
32
|
-
Args:
|
|
33
|
-
alert_text: Alert text to analyze
|
|
34
|
-
|
|
35
|
-
Returns:
|
|
36
|
-
Dictionary of extracted parameters
|
|
37
|
-
"""
|
|
38
|
-
parameters = {}
|
|
39
|
-
|
|
40
|
-
# Username patterns
|
|
41
|
-
username_patterns = [
|
|
42
|
-
r"(?:user|username|User|USERNAME):\s*([a-zA-Z0-9._@-]+)",
|
|
43
|
-
r"user\s+([a-zA-Z0-9._@-]+)",
|
|
44
|
-
r"for\s+user\s+([a-zA-Z0-9._@-]+)",
|
|
45
|
-
]
|
|
46
|
-
for pattern in username_patterns:
|
|
47
|
-
if match := re.search(pattern, alert_text):
|
|
48
|
-
parameters["username"] = match.group(1)
|
|
49
|
-
break
|
|
50
|
-
|
|
51
|
-
# Hostname patterns
|
|
52
|
-
hostname_patterns = [
|
|
53
|
-
r"(?:host|hostname|Host|HOSTNAME):\s*([a-zA-Z0-9._-]+)",
|
|
54
|
-
r"(?:on|from)\s+host\s+([a-zA-Z0-9._-]+)",
|
|
55
|
-
r"endpoint\s+([a-zA-Z0-9._-]+)",
|
|
56
|
-
]
|
|
57
|
-
for pattern in hostname_patterns:
|
|
58
|
-
if match := re.search(pattern, alert_text):
|
|
59
|
-
parameters["hostname"] = match.group(1)
|
|
60
|
-
break
|
|
61
|
-
|
|
62
|
-
# Process name patterns
|
|
63
|
-
process_patterns = [
|
|
64
|
-
r"(?:process|Process|PROCESS):\s*([a-zA-Z0-9._-]+\.exe)",
|
|
65
|
-
r"(?:process|Process)\s+([a-zA-Z0-9._-]+\.exe)",
|
|
66
|
-
r"execution\s+of\s+([a-zA-Z0-9._-]+\.exe)",
|
|
67
|
-
]
|
|
68
|
-
for pattern in process_patterns:
|
|
69
|
-
if match := re.search(pattern, alert_text):
|
|
70
|
-
parameters["process_name"] = match.group(1)
|
|
71
|
-
break
|
|
72
|
-
|
|
73
|
-
# IP address patterns
|
|
74
|
-
ip_pattern = r"(?:ip|IP|address):\s*(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})"
|
|
75
|
-
if match := re.search(ip_pattern, alert_text):
|
|
76
|
-
parameters["ip_address"] = match.group(1)
|
|
77
|
-
|
|
78
|
-
# Organization patterns
|
|
79
|
-
org_patterns = [
|
|
80
|
-
r"(?:organization|org):\s*([a-zA-Z0-9_-]+)",
|
|
81
|
-
r"customer\s+([a-zA-Z0-9_-]+)",
|
|
82
|
-
]
|
|
83
|
-
for pattern in org_patterns:
|
|
84
|
-
if match := re.search(pattern, alert_text):
|
|
85
|
-
parameters["organization_id"] = match.group(1)
|
|
86
|
-
break
|
|
87
|
-
|
|
88
|
-
return parameters
|
|
89
|
-
|
|
90
|
-
def suggest_queries(self, alert_text: str, max_suggestions: int = 5) -> List[Tuple[int, Dict[str, Any]]]:
|
|
91
|
-
"""Analyze alert text and suggest relevant queries with scores.
|
|
92
|
-
|
|
93
|
-
Scoring algorithm:
|
|
94
|
-
- Keyword in query name: +3 points
|
|
95
|
-
- Keyword in query description: +2 points
|
|
96
|
-
- Keyword in query tags: +1 point
|
|
97
|
-
- Extracted parameter matches query placeholder: +5 points
|
|
98
|
-
|
|
99
|
-
Args:
|
|
100
|
-
alert_text: Alert text to analyze
|
|
101
|
-
max_suggestions: Maximum number of queries to suggest
|
|
102
|
-
|
|
103
|
-
Returns:
|
|
104
|
-
List of tuples (score, query) sorted by relevance
|
|
105
|
-
"""
|
|
106
|
-
# Extract parameters
|
|
107
|
-
parameters = self.extract_parameters(alert_text)
|
|
108
|
-
|
|
109
|
-
# Extract keywords
|
|
110
|
-
keywords = self._extract_keywords(alert_text)
|
|
111
|
-
|
|
112
|
-
# Score queries by relevance
|
|
113
|
-
scored_queries = []
|
|
114
|
-
for query in self.all_queries:
|
|
115
|
-
score = 0
|
|
116
|
-
|
|
117
|
-
# Match keywords against query metadata
|
|
118
|
-
for keyword in keywords:
|
|
119
|
-
keyword_lower = keyword.lower()
|
|
120
|
-
|
|
121
|
-
# Check name
|
|
122
|
-
if keyword_lower in query.get("name", "").lower():
|
|
123
|
-
score += 3
|
|
124
|
-
|
|
125
|
-
# Check description
|
|
126
|
-
if keyword_lower in query.get("description", "").lower():
|
|
127
|
-
score += 2
|
|
128
|
-
|
|
129
|
-
# Check tags
|
|
130
|
-
query_tags = " ".join(query.get("tags", [])).lower()
|
|
131
|
-
if keyword_lower in query_tags:
|
|
132
|
-
score += 1
|
|
133
|
-
|
|
134
|
-
# Boost score if extracted parameters match query placeholders
|
|
135
|
-
query_placeholders = set(query.get("placeholders", {}).keys())
|
|
136
|
-
for param in parameters.keys():
|
|
137
|
-
if param in query_placeholders:
|
|
138
|
-
score += 5
|
|
139
|
-
|
|
140
|
-
if score > 0:
|
|
141
|
-
scored_queries.append((score, query))
|
|
142
|
-
|
|
143
|
-
# Sort by score (descending) and return top N
|
|
144
|
-
scored_queries.sort(key=lambda x: x[0], reverse=True)
|
|
145
|
-
return scored_queries[:max_suggestions]
|
|
146
|
-
|
|
147
|
-
def _extract_keywords(self, text: str) -> List[str]:
|
|
148
|
-
"""Extract relevant security keywords from alert text.
|
|
149
|
-
|
|
150
|
-
Args:
|
|
151
|
-
text: Alert text to analyze
|
|
152
|
-
|
|
153
|
-
Returns:
|
|
154
|
-
List of security-related keywords found
|
|
155
|
-
"""
|
|
156
|
-
keywords = []
|
|
157
|
-
text_lower = text.lower()
|
|
158
|
-
|
|
159
|
-
# Security keywords organized by category
|
|
160
|
-
security_terms = {
|
|
161
|
-
# Execution
|
|
162
|
-
"powershell",
|
|
163
|
-
"cmd",
|
|
164
|
-
"bash",
|
|
165
|
-
"shell",
|
|
166
|
-
"script",
|
|
167
|
-
"execution",
|
|
168
|
-
"process",
|
|
169
|
-
# Credential Access
|
|
170
|
-
"credential",
|
|
171
|
-
"password",
|
|
172
|
-
"mimikatz",
|
|
173
|
-
"lsass",
|
|
174
|
-
"procdump",
|
|
175
|
-
"sam",
|
|
176
|
-
"ntds",
|
|
177
|
-
# Network
|
|
178
|
-
"network",
|
|
179
|
-
"connection",
|
|
180
|
-
"lateral",
|
|
181
|
-
"smb",
|
|
182
|
-
"rdp",
|
|
183
|
-
"winrm",
|
|
184
|
-
"exfiltration",
|
|
185
|
-
"c2",
|
|
186
|
-
"command-and-control",
|
|
187
|
-
# General
|
|
188
|
-
"suspicious",
|
|
189
|
-
"malicious",
|
|
190
|
-
"privilege",
|
|
191
|
-
"escalation",
|
|
192
|
-
"file",
|
|
193
|
-
"access",
|
|
194
|
-
}
|
|
195
|
-
|
|
196
|
-
for term in security_terms:
|
|
197
|
-
if term in text_lower:
|
|
198
|
-
keywords.append(term)
|
|
199
|
-
|
|
200
|
-
# Extract MITRE ATT&CK technique IDs
|
|
201
|
-
technique_pattern = r"T\d{4}(?:\.\d{3})?"
|
|
202
|
-
techniques = re.findall(technique_pattern, text, re.IGNORECASE)
|
|
203
|
-
keywords.extend(techniques)
|
|
204
|
-
|
|
205
|
-
return keywords
|
|
206
|
-
|
|
207
|
-
def get_parameter_coverage(self, query: Dict[str, Any], extracted_params: Dict[str, str]) -> float:
|
|
208
|
-
"""Calculate what percentage of query placeholders can be filled.
|
|
209
|
-
|
|
210
|
-
Args:
|
|
211
|
-
query: Query dictionary
|
|
212
|
-
extracted_params: Extracted parameters from alert
|
|
213
|
-
|
|
214
|
-
Returns:
|
|
215
|
-
Coverage percentage (0.0 to 1.0)
|
|
216
|
-
"""
|
|
217
|
-
placeholders = query.get("placeholders", {})
|
|
218
|
-
if not placeholders:
|
|
219
|
-
return 1.0 # No placeholders, fully covered
|
|
220
|
-
|
|
221
|
-
# Count how many placeholders can be filled
|
|
222
|
-
required = [name for name, info in placeholders.items() if "default" not in info]
|
|
223
|
-
optional = [name for name, info in placeholders.items() if "default" in info]
|
|
224
|
-
|
|
225
|
-
filled_required = sum(1 for r in required if r in extracted_params)
|
|
226
|
-
filled_optional = sum(1 for o in optional if o in extracted_params)
|
|
227
|
-
|
|
228
|
-
# Weight required higher than optional
|
|
229
|
-
total_score = len(required) * 2 + len(optional)
|
|
230
|
-
filled_score = filled_required * 2 + filled_optional
|
|
231
|
-
|
|
232
|
-
if total_score == 0:
|
|
233
|
-
return 1.0
|
|
234
|
-
|
|
235
|
-
return filled_score / total_score
|
athf/core/query_validator.py
DELETED
|
@@ -1,240 +0,0 @@
|
|
|
1
|
-
"""Query validation for ClickHouse SQL safety checks."""
|
|
2
|
-
|
|
3
|
-
import re
|
|
4
|
-
from dataclasses import dataclass
|
|
5
|
-
from typing import List, Optional
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
@dataclass
|
|
9
|
-
class ValidationResult:
|
|
10
|
-
"""Result of query validation."""
|
|
11
|
-
|
|
12
|
-
is_valid: bool
|
|
13
|
-
errors: List[str]
|
|
14
|
-
warnings: List[str]
|
|
15
|
-
info: List[str]
|
|
16
|
-
estimated_row_count: Optional[int] = None
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class QueryValidator:
|
|
20
|
-
"""Validates SQL queries for ClickHouse MCP safety requirements.
|
|
21
|
-
|
|
22
|
-
Enforces rules from integrations/clickhouse/README.md:
|
|
23
|
-
- Time bounds required (7 days max initially)
|
|
24
|
-
- LIMIT clause required (start with 100)
|
|
25
|
-
- No SQL comments (breaks MCP)
|
|
26
|
-
- Early filtering recommended
|
|
27
|
-
- Avoid multiple DISTINCT
|
|
28
|
-
"""
|
|
29
|
-
|
|
30
|
-
# Time-related patterns
|
|
31
|
-
TIME_PATTERNS = [
|
|
32
|
-
r"toDate\(.+\)\s*>=\s*", # toDate(timestamp) >= ...
|
|
33
|
-
r"toDateTime\(.+\)\s*>=\s*", # toDateTime(timestamp) >= ...
|
|
34
|
-
r"timestamp\s*>=\s*", # timestamp >= ...
|
|
35
|
-
r"time\s*>=\s*", # time >= ...
|
|
36
|
-
r"WHERE.+>=\s*now\(\)\s*-\s*INTERVAL", # WHERE ... >= now() - INTERVAL
|
|
37
|
-
r"WHERE.+>=\s*subtractDays", # WHERE ... >= subtractDays()
|
|
38
|
-
r"WHERE.+>=\s*subtractHours", # WHERE ... >= subtractHours()
|
|
39
|
-
r">=\s*toDateTime\(", # >= toDateTime('...')
|
|
40
|
-
r">=\s*'[\d\-:\s]+'", # >= '2025-01-15 00:00:00' or >= '2025-01-15'
|
|
41
|
-
r"<=\s*'[\d\-:\s]+'", # <= '2025-01-15 23:59:59'
|
|
42
|
-
r"BETWEEN\s+['\"]\d{4}-\d{2}-\d{2}", # BETWEEN '2025-01-15' AND '2025-01-22'
|
|
43
|
-
r"=\s*'[\d\-:\s]+'", # = '2025-01-15' (exact date match)
|
|
44
|
-
]
|
|
45
|
-
|
|
46
|
-
# LIMIT patterns
|
|
47
|
-
LIMIT_PATTERNS = [
|
|
48
|
-
r"LIMIT\s+\d+", # LIMIT 100
|
|
49
|
-
r"TOP\s+\d+", # TOP 100 (alternative syntax)
|
|
50
|
-
]
|
|
51
|
-
|
|
52
|
-
# Comment patterns (these break MCP)
|
|
53
|
-
COMMENT_PATTERNS = [
|
|
54
|
-
r"--", # Single-line comment
|
|
55
|
-
r"/\*.*?\*/", # Multi-line comment
|
|
56
|
-
]
|
|
57
|
-
|
|
58
|
-
# Expensive operations
|
|
59
|
-
EXPENSIVE_PATTERNS = [
|
|
60
|
-
(r"DISTINCT.*DISTINCT", "Multiple DISTINCT clauses detected (expensive operation)"),
|
|
61
|
-
(r"SELECT\s+\*\s+FROM", "SELECT * detected (prefer explicit columns for performance)"),
|
|
62
|
-
(r"GROUP BY.+ORDER BY.+LIMIT", "GROUP BY + ORDER BY without early filtering may be slow"),
|
|
63
|
-
]
|
|
64
|
-
|
|
65
|
-
def __init__(self) -> None:
|
|
66
|
-
"""Initialize the query validator."""
|
|
67
|
-
pass
|
|
68
|
-
|
|
69
|
-
def validate(self, query: str, target: str = "clickhouse") -> ValidationResult:
|
|
70
|
-
"""Validate a SQL query against safety requirements.
|
|
71
|
-
|
|
72
|
-
Args:
|
|
73
|
-
query: The SQL query to validate
|
|
74
|
-
target: The target database (currently only 'clickhouse' supported)
|
|
75
|
-
|
|
76
|
-
Returns:
|
|
77
|
-
ValidationResult with validation status and messages
|
|
78
|
-
"""
|
|
79
|
-
if target != "clickhouse":
|
|
80
|
-
return ValidationResult(
|
|
81
|
-
is_valid=False,
|
|
82
|
-
errors=[f"Unsupported target: {target} (only 'clickhouse' is currently supported)"],
|
|
83
|
-
warnings=[],
|
|
84
|
-
info=[],
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
errors: List[str] = []
|
|
88
|
-
warnings: List[str] = []
|
|
89
|
-
info: List[str] = []
|
|
90
|
-
|
|
91
|
-
# Normalize query (remove extra whitespace)
|
|
92
|
-
normalized = " ".join(query.split())
|
|
93
|
-
|
|
94
|
-
# 1. CRITICAL: Check for SQL comments (breaks MCP)
|
|
95
|
-
if self._has_comments(query):
|
|
96
|
-
errors.append("SQL comments detected (-- or /* */). Comments break MCP - remove them before execution.")
|
|
97
|
-
|
|
98
|
-
# 2. CRITICAL: Check for time bounds
|
|
99
|
-
if not self._has_time_bounds(normalized):
|
|
100
|
-
errors.append(
|
|
101
|
-
"No time bounds detected. Add time constraints (e.g., WHERE timestamp >= now() - INTERVAL 7 DAY "
|
|
102
|
-
"or WHERE time >= '2025-01-15 00:00:00') to prevent timeouts."
|
|
103
|
-
)
|
|
104
|
-
else:
|
|
105
|
-
info.append("Time bounds detected")
|
|
106
|
-
|
|
107
|
-
# 3. CRITICAL: Check for LIMIT clause
|
|
108
|
-
limit_value = self._extract_limit(normalized)
|
|
109
|
-
if limit_value is None:
|
|
110
|
-
errors.append("No LIMIT clause detected. Add LIMIT (start with 100, max 1000) to prevent large result sets.")
|
|
111
|
-
else:
|
|
112
|
-
if limit_value <= 100:
|
|
113
|
-
info.append(f"LIMIT {limit_value} (safe - good starting point)")
|
|
114
|
-
elif limit_value <= 1000:
|
|
115
|
-
info.append(f"LIMIT {limit_value} (moderate - may need refinement if exactly {limit_value} rows returned)")
|
|
116
|
-
else:
|
|
117
|
-
warnings.append(
|
|
118
|
-
f"LIMIT {limit_value} is high. Consider starting with LIMIT 100 and increasing if needed "
|
|
119
|
-
"(progressive strategy)."
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
# 4. Check for early filtering (WHERE clause before aggregations)
|
|
123
|
-
if not self._has_early_filtering(normalized):
|
|
124
|
-
warnings.append(
|
|
125
|
-
"No early filtering detected. Add WHERE clause before aggregations (GROUP BY, DISTINCT) "
|
|
126
|
-
"for better performance."
|
|
127
|
-
)
|
|
128
|
-
else:
|
|
129
|
-
info.append("Early filtering detected (WHERE clause present)")
|
|
130
|
-
|
|
131
|
-
# 5. Check for expensive operations
|
|
132
|
-
for pattern, message in self.EXPENSIVE_PATTERNS:
|
|
133
|
-
if re.search(pattern, normalized, re.IGNORECASE):
|
|
134
|
-
warnings.append(message)
|
|
135
|
-
|
|
136
|
-
# 6. Check for query structure (FROM clause required)
|
|
137
|
-
if not re.search(r"FROM\s+\w+", normalized, re.IGNORECASE):
|
|
138
|
-
errors.append("No FROM clause detected. Query must select from a table.")
|
|
139
|
-
|
|
140
|
-
# Determine overall validity
|
|
141
|
-
is_valid = len(errors) == 0
|
|
142
|
-
|
|
143
|
-
return ValidationResult(
|
|
144
|
-
is_valid=is_valid,
|
|
145
|
-
errors=errors,
|
|
146
|
-
warnings=warnings,
|
|
147
|
-
info=info,
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
def _has_comments(self, query: str) -> bool:
|
|
151
|
-
"""Check if query contains SQL comments."""
|
|
152
|
-
# Check for single-line comments
|
|
153
|
-
if "--" in query:
|
|
154
|
-
return True
|
|
155
|
-
|
|
156
|
-
# Check for multi-line comments
|
|
157
|
-
if re.search(r"/\*.*?\*/", query, re.DOTALL):
|
|
158
|
-
return True
|
|
159
|
-
|
|
160
|
-
return False
|
|
161
|
-
|
|
162
|
-
def _has_time_bounds(self, normalized_query: str) -> bool:
|
|
163
|
-
"""Check if query has time constraints."""
|
|
164
|
-
for pattern in self.TIME_PATTERNS:
|
|
165
|
-
if re.search(pattern, normalized_query, re.IGNORECASE):
|
|
166
|
-
return True
|
|
167
|
-
return False
|
|
168
|
-
|
|
169
|
-
def _extract_limit(self, normalized_query: str) -> Optional[int]:
|
|
170
|
-
"""Extract LIMIT value from query."""
|
|
171
|
-
# Check for LIMIT clause
|
|
172
|
-
limit_match = re.search(r"LIMIT\s+(\d+)", normalized_query, re.IGNORECASE)
|
|
173
|
-
if limit_match:
|
|
174
|
-
return int(limit_match.group(1))
|
|
175
|
-
|
|
176
|
-
# Check for TOP clause (alternative syntax)
|
|
177
|
-
top_match = re.search(r"TOP\s+(\d+)", normalized_query, re.IGNORECASE)
|
|
178
|
-
if top_match:
|
|
179
|
-
return int(top_match.group(1))
|
|
180
|
-
|
|
181
|
-
return None
|
|
182
|
-
|
|
183
|
-
def _has_early_filtering(self, normalized_query: str) -> bool:
|
|
184
|
-
"""Check if query has WHERE clause before aggregations."""
|
|
185
|
-
# Look for WHERE clause
|
|
186
|
-
where_match = re.search(r"WHERE", normalized_query, re.IGNORECASE)
|
|
187
|
-
if not where_match:
|
|
188
|
-
return False
|
|
189
|
-
|
|
190
|
-
# Check if WHERE comes before GROUP BY or DISTINCT
|
|
191
|
-
where_pos = where_match.start()
|
|
192
|
-
|
|
193
|
-
group_match = re.search(r"GROUP\s+BY", normalized_query, re.IGNORECASE)
|
|
194
|
-
if group_match and group_match.start() < where_pos:
|
|
195
|
-
return False
|
|
196
|
-
|
|
197
|
-
distinct_match = re.search(r"DISTINCT", normalized_query, re.IGNORECASE)
|
|
198
|
-
if distinct_match and distinct_match.start() < where_pos:
|
|
199
|
-
return False
|
|
200
|
-
|
|
201
|
-
return True
|
|
202
|
-
|
|
203
|
-
def suggest_improvements(self, query: str, validation: ValidationResult) -> List[str]:
|
|
204
|
-
"""Suggest improvements based on validation results.
|
|
205
|
-
|
|
206
|
-
Args:
|
|
207
|
-
query: The original query
|
|
208
|
-
validation: The validation result
|
|
209
|
-
|
|
210
|
-
Returns:
|
|
211
|
-
List of suggested improvements
|
|
212
|
-
"""
|
|
213
|
-
suggestions: List[str] = []
|
|
214
|
-
|
|
215
|
-
# If missing time bounds, suggest adding them
|
|
216
|
-
if any("time bounds" in error.lower() for error in validation.errors):
|
|
217
|
-
suggestions.append(
|
|
218
|
-
"Add time bounds: WHERE timestamp >= now() - INTERVAL 7 DAY "
|
|
219
|
-
"or WHERE time >= '2025-01-15 00:00:00' AND time <= '2025-01-22 23:59:59'"
|
|
220
|
-
)
|
|
221
|
-
|
|
222
|
-
# If missing LIMIT, suggest adding it
|
|
223
|
-
if any("limit" in error.lower() for error in validation.errors):
|
|
224
|
-
suggestions.append("Add LIMIT clause: LIMIT 100 (progressive strategy)")
|
|
225
|
-
|
|
226
|
-
# If has comments, suggest removing them
|
|
227
|
-
if any("comment" in error.lower() for error in validation.errors):
|
|
228
|
-
suggestions.append("Remove SQL comments (-- or /* */) - they break ClickHouse MCP")
|
|
229
|
-
|
|
230
|
-
# If missing early filtering
|
|
231
|
-
if any("early filtering" in warning.lower() for warning in validation.warnings):
|
|
232
|
-
suggestions.append(
|
|
233
|
-
"Add WHERE clause before aggregations for better performance: " "WHERE timestamp >= ... AND condition"
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
# If SELECT *
|
|
237
|
-
if any("SELECT \\*" in warning for warning in validation.warnings):
|
|
238
|
-
suggestions.append("Replace SELECT * with explicit columns: SELECT column1, column2, ...")
|
|
239
|
-
|
|
240
|
-
return suggestions
|