iflow-mcp_splunk_splunk-mcp-server 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
guardrails.py ADDED
@@ -0,0 +1,307 @@
1
+ """SPL query validation and output sanitization functions."""
2
+
3
+ import re
4
+ from typing import Any, Dict, List, Union, Tuple, Callable
5
+
6
+ # Helper functions for complex validation rules
7
+ def check_collect_params(query: str, context: dict, base_risk: int) -> int:
8
+ """Check for risky collect parameters."""
9
+ if re.search(r'\|\s*collect\b', context['query_lower']):
10
+ if 'override=true' in context['query_lower'] or 'addtime=false' in context['query_lower']:
11
+ return base_risk
12
+ return 0
13
+
14
+
15
+ def check_outputlookup_params(query: str, context: dict, base_risk: int) -> int:
16
+ """Check for risky outputlookup parameters."""
17
+ if re.search(r'\|\s*outputlookup\b', context['query_lower']):
18
+ if 'override=true' in context['query_lower']:
19
+ return base_risk
20
+ return 0
21
+
22
+
23
+ def parse_time_to_hours(time_str: str) -> float:
24
+ """Convert Splunk time string to hours."""
25
+ time_str = time_str.strip().lower()
26
+
27
+ # Remove leading minus sign if present
28
+ if time_str.startswith('-'):
29
+ time_str = time_str[1:]
30
+
31
+ # Handle relative time modifiers
32
+ match = re.match(r'^(\d+)([smhdwmonqy]+)?(@[smhdwmonqy]+)?$', time_str)
33
+ if match:
34
+ value = int(match.group(1))
35
+ unit = match.group(2) if match.group(2) else 's'
36
+
37
+ # Convert to hours
38
+ multipliers = {
39
+ 's': 1/3600, # seconds to hours
40
+ 'm': 1/60, # minutes to hours
41
+ 'h': 1, # hours
42
+ 'd': 24, # days to hours
43
+ 'w': 24*7, # weeks to hours
44
+ 'mon': 24*30, # months to hours (approximate)
45
+ 'q': 24*90, # quarters to hours (approximate)
46
+ 'y': 24*365 # years to hours
47
+ }
48
+
49
+ return value * multipliers.get(unit, 1)
50
+
51
+ # Handle special keywords
52
+ if time_str in ['0', 'all', 'alltime'] or time_str == '0':
53
+ return float('inf') # All time
54
+
55
+ # Default to 24 hours if unparseable
56
+ return 24
57
+
58
+
59
+ def check_time_range(query: str, context: dict, base_risk: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, str]]:
60
+ """Check time range issues.
61
+
62
+ base_risk can be:
63
+ - int: single risk score
64
+ - tuple: (no_risk, exceeds_safe_range, all_time/no_time)
65
+
66
+ Returns:
67
+ - int: risk score
68
+ - OR tuple: (risk_score, time_range_type) where time_range_type is 'all_time', 'exceeds_safe', or 'no_time'
69
+ """
70
+ if isinstance(base_risk, tuple):
71
+ no_risk, exceeds_safe, all_time = base_risk
72
+ else:
73
+ # Backward compatibility
74
+ no_risk = 0
75
+ exceeds_safe = int(base_risk * 0.5)
76
+ all_time = base_risk
77
+
78
+ query_lower = context['query_lower']
79
+ safe_timerange_str = context.get('safe_timerange', '24h')
80
+ safe_hours = parse_time_to_hours(safe_timerange_str)
81
+
82
+ has_earliest = 'earliest' in query_lower or 'earliest_time' in query_lower
83
+ has_latest = 'latest' in query_lower or 'latest_time' in query_lower
84
+ has_time_range = has_earliest or has_latest
85
+
86
+ if not has_time_range:
87
+ # Check if it's an all-time search
88
+ if re.search(r'all\s*time|alltime', query_lower):
89
+ return (all_time, 'all_time') # All time search
90
+ else:
91
+ # No time range specified, could default to all-time
92
+ return (all_time, 'no_time')
93
+ else:
94
+ # Extract time range from query
95
+ # Look for patterns like earliest=-30d or earliest=0
96
+ earliest_match = re.search(r'earliest(?:_time)?\s*=\s*([^\s,]+)', query_lower)
97
+
98
+ if earliest_match:
99
+ time_value = earliest_match.group(1)
100
+ query_hours = parse_time_to_hours(time_value)
101
+
102
+ # Check if it's all time (0 or inf)
103
+ if query_hours == float('inf') or time_value == '0':
104
+ return (all_time, 'all_time')
105
+
106
+ # Check if time range exceeds safe range
107
+ if query_hours > safe_hours:
108
+ return (exceeds_safe, 'exceeds_safe')
109
+
110
+ return no_risk
111
+
112
+
113
+ def check_index_usage(query: str, context: dict, base_risk: Union[int, Tuple[int, ...]]) -> int:
114
+ """Check for index usage patterns.
115
+
116
+ base_risk can be:
117
+ - int: single risk score
118
+ - tuple: (no_risk, no_index_with_constraints, index_star_unconstrained)
119
+ """
120
+ if isinstance(base_risk, tuple):
121
+ no_risk, no_index_constrained, index_star = base_risk
122
+ else:
123
+ # Backward compatibility
124
+ no_risk = 0
125
+ no_index_constrained = int(base_risk * 0.57) # ~20/35
126
+ index_star = base_risk
127
+
128
+ query_lower = context['query_lower']
129
+
130
+ if 'index=*' in query_lower:
131
+ # Check if there are constraining source/sourcetype
132
+ if not (re.search(r'source\s*=', query_lower) or re.search(r'sourcetype\s*=', query_lower)):
133
+ return index_star # Full risk for unconstrained index=*
134
+ elif not re.search(r'index\s*=', query_lower):
135
+ # No index specified
136
+ if re.search(r'source\s*=|sourcetype\s*=', query_lower):
137
+ return no_index_constrained
138
+ return no_risk
139
+
140
+
141
+ def check_subsearch_limits(query: str, context: dict, base_risk: int) -> int:
142
+ """Check for subsearches without limits."""
143
+ if '[' in query and ']' in query:
144
+ subsearch = query[query.find('['):query.find(']')+1]
145
+ if 'maxout' not in subsearch.lower() and 'maxresults' not in subsearch.lower():
146
+ return base_risk
147
+ return 0
148
+
149
+
150
+ def check_expensive_commands(query: str, context: dict, base_risk: int) -> int:
151
+ """Check for expensive commands and return appropriate score."""
152
+ query_lower = context['query_lower']
153
+ multiplier = 0
154
+
155
+ # Check each expensive command (each adds to the multiplier)
156
+ if re.search(r'\|\s*transaction\b', query_lower):
157
+ multiplier += 1
158
+ if re.search(r'\|\s*map\b', query_lower):
159
+ multiplier += 1
160
+ if re.search(r'\|\s*join\b', query_lower):
161
+ multiplier += 1
162
+
163
+ return int(base_risk * multiplier)
164
+
165
+
166
+ def check_append_operations(query: str, context: dict, base_risk: int) -> int:
167
+ """Check for append operations."""
168
+ if re.search(r'\|\s*(append|appendcols)\b', context['query_lower']):
169
+ return base_risk
170
+ return 0
171
+
172
+
173
+ # ===========================================================================
174
+ def validate_spl_query(query: str, safe_timerange: str) -> Tuple[int, str]:
175
+ """
176
+ Validate SPL query and calculate risk score using rule-based system.
177
+
178
+ Args:
179
+ query: The SPL query to validate
180
+ safe_timerange: Safe time range from configuration
181
+
182
+ Returns:
183
+ Tuple of (risk_score, risk_message)
184
+ """
185
+ # Import here to avoid circular dependency
186
+ from spl_risk_rules import SPL_RISK_RULES
187
+
188
+ risk_score = 0
189
+ issues = []
190
+ query_lower = query.lower()
191
+
192
+ # Context for function-based rules
193
+ context = {
194
+ 'safe_timerange': safe_timerange,
195
+ 'query_lower': query_lower
196
+ }
197
+
198
+ # Process all rules
199
+ for rule in SPL_RISK_RULES:
200
+ pattern_or_func, base_score, message = rule
201
+
202
+ if callable(pattern_or_func):
203
+ # It's a function - call it with base_score
204
+ result = pattern_or_func(query, context, base_score)
205
+
206
+ # Handle special case where function returns (score, type) tuple
207
+ if isinstance(result, tuple) and len(result) == 2 and isinstance(result[1], str):
208
+ score, time_type = result
209
+ if score > 0:
210
+ risk_score += score
211
+ # Special handling for time range messages
212
+ if pattern_or_func.__name__ == 'check_time_range':
213
+ if time_type == 'all_time':
214
+ formatted_message = f"All-time search detected (+{score}). This can be very resource intensive. Add time constraints like earliest=-24h latest=now to limit search scope."
215
+ elif time_type == 'exceeds_safe':
216
+ formatted_message = f"Time range exceeds safe limit (+{score}). Consider narrowing your search window for better performance."
217
+ elif time_type == 'no_time':
218
+ formatted_message = f"No time range specified (+{score}). Query may default to all-time. Add explicit time constraints like earliest=-24h latest=now."
219
+ else:
220
+ formatted_message = message.format(score=score)
221
+ else:
222
+ formatted_message = message.format(score=score)
223
+ issues.append(formatted_message)
224
+ else:
225
+ # Regular integer score
226
+ score = result if isinstance(result, int) else 0
227
+ if score > 0:
228
+ risk_score += score
229
+ # Format message with actual score
230
+ formatted_message = message.format(score=score)
231
+ issues.append(formatted_message)
232
+ else:
233
+ # It's a regex pattern
234
+ if re.search(pattern_or_func, query_lower):
235
+ risk_score += base_score
236
+ # Format message with base score
237
+ formatted_message = message.format(score=base_score)
238
+ issues.append(formatted_message)
239
+
240
+ # Cap risk score at 100
241
+ risk_score = min(risk_score, 100)
242
+
243
+ # Build final message
244
+ if not issues:
245
+ return risk_score, "Query appears safe."
246
+ else:
247
+ risk_message = "Risk factors found:\n" + "\n".join(f"- {issue}" for issue in issues)
248
+
249
+ # Add high-risk warning if needed
250
+ if risk_score >= 50:
251
+ risk_message += "\n\nConsider reviewing this query with your Splunk administrator."
252
+
253
+ return risk_score, risk_message
254
+ # ===========================================================================
255
+
256
+ # ===========================================================================
257
+ def sanitize_output(data: Any) -> Any:
258
+ """
259
+ Recursively sanitize sensitive data in output.
260
+
261
+ Masks:
262
+ - Credit card numbers (showing only last 4 digits)
263
+ - Social Security Numbers (complete masking)
264
+
265
+ Args:
266
+ data: Data to sanitize (can be dict, list, string, or other)
267
+
268
+ Returns:
269
+ Sanitized data with same structure
270
+ """
271
+ # Credit card pattern - matches 13-19 digit sequences with optional separators
272
+ cc_pattern = re.compile(r'\b(\d{4})[-\s]?(\d{4})[-\s]?(\d{4})[-\s]?(\d{3,6})\b')
273
+
274
+ # SSN pattern - matches XXX-XX-XXXX format
275
+ ssn_pattern = re.compile(r'\b\d{3}-\d{2}-\d{4}\b')
276
+
277
+ def sanitize_string(text: str) -> str:
278
+ """Sanitize a single string value."""
279
+ if not isinstance(text, str):
280
+ return text
281
+
282
+ # Replace credit cards, keeping last 4 digits
283
+ def cc_replacer(match):
284
+ last_four = match.group(4)
285
+ # Determine separator from original
286
+ separator = '-' if '-' in match.group(0) else ' ' if ' ' in match.group(0) else ''
287
+ masked = f"****{separator}****{separator}****{separator}{last_four}"
288
+ return masked
289
+
290
+ text = cc_pattern.sub(cc_replacer, text)
291
+
292
+ # Replace SSNs completely
293
+ text = ssn_pattern.sub('***-**-****', text)
294
+
295
+ return text
296
+
297
+ # Handle different data types
298
+ if isinstance(data, dict):
299
+ return {key: sanitize_output(value) for key, value in data.items()}
300
+ elif isinstance(data, list):
301
+ return [sanitize_output(item) for item in data]
302
+ elif isinstance(data, str):
303
+ return sanitize_string(data)
304
+ else:
305
+ # For other types (int, float, bool, None), return as-is
306
+ return data
307
+ # ===========================================================================
helpers.py ADDED
@@ -0,0 +1,117 @@
1
+ """Helper functions for formatting Splunk search results."""
2
+
3
+ from typing import List, Dict, Any
4
+
5
+
6
+ def format_events_as_markdown(events: List[Dict[str, Any]], query: str) -> str:
7
+ """Convert events to markdown table format."""
8
+ if not events:
9
+ return f"Query: {query}\nNo events found."
10
+
11
+ # Get all unique keys from events
12
+ all_keys = []
13
+ seen_keys = set()
14
+ for event in events:
15
+ for key in event.keys():
16
+ if key not in seen_keys:
17
+ all_keys.append(key)
18
+ seen_keys.add(key)
19
+
20
+ # Build markdown table
21
+ lines = [f"Query: {query}", f"Found: {len(events)} events", ""]
22
+
23
+ # Header
24
+ header = "| " + " | ".join(all_keys) + " |"
25
+ separator = "|" + "|".join(["-" * (len(key) + 2) for key in all_keys]) + "|"
26
+ lines.extend([header, separator])
27
+
28
+ # Rows
29
+ for event in events:
30
+ row_values = []
31
+ for key in all_keys:
32
+ value = str(event.get(key, ""))
33
+ # Escape pipe characters in values
34
+ value = value.replace("|", "\\|")
35
+ row_values.append(value)
36
+ row = "| " + " | ".join(row_values) + " |"
37
+ lines.append(row)
38
+
39
+ return "\n".join(lines)
40
+
41
+
42
+ def format_events_as_csv(events: List[Dict[str, Any]], query: str) -> str:
43
+ """Convert events to CSV format."""
44
+ if not events:
45
+ return f"# Query: {query}\n# No events found"
46
+
47
+ # Get all unique keys
48
+ all_keys = []
49
+ seen_keys = set()
50
+ for event in events:
51
+ for key in event.keys():
52
+ if key not in seen_keys:
53
+ all_keys.append(key)
54
+ seen_keys.add(key)
55
+
56
+ lines = [f"# Query: {query}", f"# Events: {len(events)}", ""]
57
+
58
+ # Header
59
+ lines.append(",".join(all_keys))
60
+
61
+ # Rows
62
+ for event in events:
63
+ row_values = []
64
+ for key in all_keys:
65
+ value = str(event.get(key, ""))
66
+ # Escape quotes and handle commas
67
+ if "," in value or '"' in value or "\n" in value:
68
+ value = '"' + value.replace('"', '""') + '"'
69
+ row_values.append(value)
70
+ lines.append(",".join(row_values))
71
+
72
+ return "\n".join(lines)
73
+
74
+
75
+ def format_events_as_summary(events: List[Dict[str, Any]], query: str, event_count: int) -> str:
76
+ """Create a natural language summary of events."""
77
+ lines = [f"Query: {query}", f"Total events: {event_count}"]
78
+
79
+ if not events:
80
+ lines.append("No events found.")
81
+ return "\n".join(lines)
82
+
83
+ # Analyze events
84
+ if len(events) < event_count:
85
+ lines.append(f"Showing: First {len(events)} events")
86
+
87
+ # Time range analysis if _time exists
88
+ if events and "_time" in events[0]:
89
+ times = [e.get("_time", "") for e in events if e.get("_time")]
90
+ if times:
91
+ lines.append(f"Time range: {times[-1]} to {times[0]}")
92
+
93
+ # Field analysis
94
+ all_fields = set()
95
+ for event in events:
96
+ all_fields.update(event.keys())
97
+
98
+ lines.append(f"Fields: {', '.join(sorted(all_fields))}")
99
+
100
+ # Value frequency analysis for common fields
101
+ for field in ["status", "sourcetype", "host", "source"]:
102
+ if field in all_fields:
103
+ values = [str(e.get(field, "")) for e in events if field in e]
104
+ if values:
105
+ value_counts = {}
106
+ for v in values:
107
+ value_counts[v] = value_counts.get(v, 0) + 1
108
+ top_values = sorted(value_counts.items(), key=lambda x: x[1], reverse=True)[:3]
109
+ summary = ", ".join([f"{v[0]} ({v[1]})" for v in top_values])
110
+ lines.append(f"{field.capitalize()} distribution: {summary}")
111
+
112
+ # Sample events
113
+ lines.append("\nFirst 3 events:")
114
+ for i, event in enumerate(events[:3], 1):
115
+ lines.append(f"Event {i}: " + " | ".join([f"{k}={v}" for k, v in event.items()]))
116
+
117
+ return "\n".join(lines)