rootly-mcp-server 2.0.14__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,404 @@
1
+ """
2
+ Security utilities for the Rootly MCP Server.
3
+
4
+ This module provides security-related functionality including:
5
+ - Secure token handling
6
+ - HTTPS enforcement
7
+ - Input sanitization
8
+ - Rate limiting
9
+ - Security validation
10
+ """
11
+
12
+ import os
13
+ import re
14
+ import time
15
+ from collections import defaultdict
16
+ from functools import wraps
17
+ from threading import Lock
18
+ from typing import Any
19
+ from urllib.parse import urlparse
20
+
21
+ from .exceptions import (
22
+ RootlyConfigurationError,
23
+ RootlyRateLimitError,
24
+ RootlyValidationError,
25
+ )
26
+
27
+ # Token validation pattern (Bearer tokens typically start with a prefix)
28
+ TOKEN_PATTERN = re.compile(r"^[A-Za-z0-9_-]{20,}$")
29
+
30
+ # URL validation patterns
31
+ HTTPS_PATTERN = re.compile(r"^https://")
32
+ VALID_DOMAIN_PATTERN = re.compile(
33
+ r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$"
34
+ )
35
+
36
+ # SQL injection patterns to block
37
+ SQL_INJECTION_PATTERNS = [
38
+ r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC|EXECUTE|UNION|SCRIPT)\b)",
39
+ r"(--|;|\/\*|\*\/|xp_|sp_)",
40
+ r"(\bOR\b.*=.*)",
41
+ r"(\bAND\b.*=.*)",
42
+ ]
43
+
44
+ # XSS patterns to block
45
+ XSS_PATTERNS = [
46
+ r"<script[^>]*>.*?</script>",
47
+ r"javascript:",
48
+ r"onerror\s*=",
49
+ r"onload\s*=",
50
+ r"<iframe[^>]*>",
51
+ ]
52
+
53
+
54
+ class RateLimiter:
55
+ """
56
+ Token bucket rate limiter for API requests.
57
+
58
+ Implements a sliding window rate limiter to prevent API abuse.
59
+ """
60
+
61
+ def __init__(self, max_requests: int = 100, time_window: int = 60):
62
+ """
63
+ Initialize the rate limiter.
64
+
65
+ Args:
66
+ max_requests: Maximum number of requests allowed in the time window
67
+ time_window: Time window in seconds
68
+ """
69
+ self.max_requests = max_requests
70
+ self.time_window = time_window
71
+ self._requests = defaultdict(list)
72
+ self._lock = Lock()
73
+
74
+ def is_allowed(self, identifier: str) -> tuple[bool, int | None]:
75
+ """
76
+ Check if a request is allowed for the given identifier.
77
+
78
+ Args:
79
+ identifier: Unique identifier for the client (e.g., IP address, user ID)
80
+
81
+ Returns:
82
+ Tuple of (is_allowed, retry_after_seconds)
83
+ """
84
+ with self._lock:
85
+ current_time = time.time()
86
+ window_start = current_time - self.time_window
87
+
88
+ # Clean up old requests
89
+ self._requests[identifier] = [
90
+ req_time for req_time in self._requests[identifier] if req_time > window_start
91
+ ]
92
+
93
+ # Check if limit is exceeded
94
+ if len(self._requests[identifier]) >= self.max_requests:
95
+ # Calculate retry_after based on oldest request
96
+ oldest_request = min(self._requests[identifier])
97
+ retry_after = int(oldest_request + self.time_window - current_time) + 1
98
+ return False, retry_after
99
+
100
+ # Allow the request and record it
101
+ self._requests[identifier].append(current_time)
102
+ return True, None
103
+
104
+ def reset(self, identifier: str) -> None:
105
+ """Reset the rate limit for a specific identifier."""
106
+ with self._lock:
107
+ self._requests.pop(identifier, None)
108
+
109
+
110
+ # Global rate limiter instance
111
+ _rate_limiter = RateLimiter(max_requests=100, time_window=60)
112
+
113
+
114
+ def get_rate_limiter() -> RateLimiter:
115
+ """Get the global rate limiter instance."""
116
+ return _rate_limiter
117
+
118
+
119
+ def rate_limit(identifier_func=None):
120
+ """
121
+ Decorator to apply rate limiting to a function.
122
+
123
+ Args:
124
+ identifier_func: Optional function to extract identifier from function args.
125
+ If None, uses "default" as identifier.
126
+ """
127
+
128
+ def decorator(func):
129
+ @wraps(func)
130
+ async def async_wrapper(*args, **kwargs):
131
+ identifier = identifier_func(*args, **kwargs) if identifier_func else "default"
132
+ allowed, retry_after = _rate_limiter.is_allowed(identifier)
133
+
134
+ if not allowed:
135
+ raise RootlyRateLimitError(
136
+ f"Rate limit exceeded. Try again in {retry_after} seconds.",
137
+ retry_after=retry_after,
138
+ )
139
+
140
+ return await func(*args, **kwargs)
141
+
142
+ @wraps(func)
143
+ def sync_wrapper(*args, **kwargs):
144
+ identifier = identifier_func(*args, **kwargs) if identifier_func else "default"
145
+ allowed, retry_after = _rate_limiter.is_allowed(identifier)
146
+
147
+ if not allowed:
148
+ raise RootlyRateLimitError(
149
+ f"Rate limit exceeded. Try again in {retry_after} seconds.",
150
+ retry_after=retry_after,
151
+ )
152
+
153
+ return func(*args, **kwargs)
154
+
155
+ # Return appropriate wrapper based on function type
156
+ import asyncio
157
+
158
+ if asyncio.iscoroutinefunction(func):
159
+ return async_wrapper
160
+ return sync_wrapper
161
+
162
+ return decorator
163
+
164
+
165
+ def validate_api_token(token: str | None) -> str:
166
+ """
167
+ Validate that an API token is properly formatted and not empty.
168
+
169
+ Args:
170
+ token: The API token to validate
171
+
172
+ Returns:
173
+ The validated token
174
+
175
+ Raises:
176
+ RootlyConfigurationError: If token is missing or invalid
177
+ """
178
+ if not token:
179
+ raise RootlyConfigurationError(
180
+ "API token is required but not provided. Set the ROOTLY_API_TOKEN environment variable."
181
+ )
182
+
183
+ token = token.strip()
184
+
185
+ if len(token) < 20:
186
+ raise RootlyConfigurationError(
187
+ "API token appears to be invalid (too short). Please check your ROOTLY_API_TOKEN value."
188
+ )
189
+
190
+ # Don't log the actual token value for security
191
+ return token
192
+
193
+
194
+ def get_api_token_from_env() -> str:
195
+ """
196
+ Get and validate the API token from environment variables.
197
+
198
+ Returns:
199
+ The validated API token
200
+
201
+ Raises:
202
+ RootlyConfigurationError: If token is missing or invalid
203
+ """
204
+ token = os.getenv("ROOTLY_API_TOKEN")
205
+ return validate_api_token(token)
206
+
207
+
208
+ def enforce_https(url: str) -> str:
209
+ """
210
+ Ensure that a URL uses HTTPS.
211
+
212
+ Args:
213
+ url: The URL to validate
214
+
215
+ Returns:
216
+ The validated HTTPS URL
217
+
218
+ Raises:
219
+ RootlyValidationError: If URL doesn't use HTTPS
220
+ """
221
+ if not url:
222
+ raise RootlyValidationError("URL cannot be empty")
223
+
224
+ parsed = urlparse(url)
225
+
226
+ if not parsed.scheme:
227
+ # Assume HTTPS if no scheme provided
228
+ return f"https://{url}"
229
+
230
+ if parsed.scheme != "https":
231
+ raise RootlyValidationError(
232
+ f"Only HTTPS URLs are allowed for security reasons. Got: {parsed.scheme}://"
233
+ )
234
+
235
+ return url
236
+
237
+
238
+ def validate_url(url: str, allowed_domains: list[str] | None = None) -> str:
239
+ """
240
+ Validate a URL for security.
241
+
242
+ Args:
243
+ url: The URL to validate
244
+ allowed_domains: Optional list of allowed domains
245
+
246
+ Returns:
247
+ The validated URL
248
+
249
+ Raises:
250
+ RootlyValidationError: If URL is invalid or not allowed
251
+ """
252
+ if not url:
253
+ raise RootlyValidationError("URL cannot be empty")
254
+
255
+ # Enforce HTTPS
256
+ url = enforce_https(url)
257
+
258
+ parsed = urlparse(url)
259
+
260
+ # Validate domain
261
+ if not parsed.netloc:
262
+ raise RootlyValidationError(f"Invalid URL: missing domain in {url}")
263
+
264
+ # Check against allowed domains if provided
265
+ if allowed_domains:
266
+ domain_allowed = any(
267
+ parsed.netloc == domain or parsed.netloc.endswith(f".{domain}")
268
+ for domain in allowed_domains
269
+ )
270
+ if not domain_allowed:
271
+ raise RootlyValidationError(f"Domain {parsed.netloc} is not in the allowed list")
272
+
273
+ return url
274
+
275
+
276
+ def sanitize_input(value: Any, max_length: int = 10000) -> Any:
277
+ """
278
+ Sanitize user input to prevent injection attacks.
279
+
280
+ Args:
281
+ value: The value to sanitize
282
+ max_length: Maximum allowed length for strings
283
+
284
+ Returns:
285
+ The sanitized value
286
+
287
+ Raises:
288
+ RootlyValidationError: If input contains malicious patterns
289
+ """
290
+ if value is None:
291
+ return None
292
+
293
+ if isinstance(value, bool):
294
+ return value
295
+
296
+ if isinstance(value, int | float):
297
+ return value
298
+
299
+ if isinstance(value, str):
300
+ # Check length
301
+ if len(value) > max_length:
302
+ raise RootlyValidationError(
303
+ f"Input too long: {len(value)} characters (max: {max_length})"
304
+ )
305
+
306
+ # Check for SQL injection patterns
307
+ for pattern in SQL_INJECTION_PATTERNS:
308
+ if re.search(pattern, value, re.IGNORECASE):
309
+ raise RootlyValidationError("Input contains potentially malicious SQL patterns")
310
+
311
+ # Check for XSS patterns
312
+ for pattern in XSS_PATTERNS:
313
+ if re.search(pattern, value, re.IGNORECASE):
314
+ raise RootlyValidationError("Input contains potentially malicious XSS patterns")
315
+
316
+ return value
317
+
318
+ if isinstance(value, dict):
319
+ return {k: sanitize_input(v, max_length) for k, v in value.items()}
320
+
321
+ if isinstance(value, list | tuple):
322
+ return type(value)(sanitize_input(item, max_length) for item in value)
323
+
324
+ # For other types, convert to string and sanitize
325
+ return sanitize_input(str(value), max_length)
326
+
327
+
328
+ def sanitize_error_message(error_message: str, max_length: int = 500) -> str:
329
+ """
330
+ Sanitize error messages to prevent information leakage.
331
+
332
+ Removes file paths, stack traces, and other sensitive information.
333
+
334
+ Args:
335
+ error_message: The error message to sanitize
336
+ max_length: Maximum length for the sanitized message
337
+
338
+ Returns:
339
+ The sanitized error message
340
+ """
341
+ if not error_message:
342
+ return "An error occurred"
343
+
344
+ # Remove absolute file paths
345
+ error_message = re.sub(r"/[\w/.-]+\.py", "[file]", error_message)
346
+ error_message = re.sub(r"C:\\[\w\\.-]+\.py", "[file]", error_message)
347
+
348
+ # Remove line numbers
349
+ error_message = re.sub(r", line \d+", "", error_message)
350
+
351
+ # Remove "Traceback" and everything after it
352
+ if "Traceback" in error_message:
353
+ error_message = error_message.split("Traceback")[0].strip()
354
+
355
+ # Remove stack trace markers
356
+ error_message = re.sub(r"File \"[^\"]+\"", "", error_message)
357
+
358
+ # Truncate if too long
359
+ if len(error_message) > max_length:
360
+ error_message = error_message[:max_length] + "..."
361
+
362
+ return error_message.strip() or "An error occurred"
363
+
364
+
365
+ def mask_sensitive_data(
366
+ data: dict[str, Any], sensitive_keys: list[str] | None = None
367
+ ) -> dict[str, Any]:
368
+ """
369
+ Mask sensitive data in a dictionary for logging.
370
+
371
+ Args:
372
+ data: The data dictionary to mask
373
+ sensitive_keys: List of key patterns to mask (case-insensitive)
374
+
375
+ Returns:
376
+ Dictionary with sensitive values masked
377
+ """
378
+ if sensitive_keys is None:
379
+ sensitive_keys = ["token", "password", "secret", "api_key", "auth"]
380
+
381
+ def should_mask(key: str) -> bool:
382
+ key_lower = key.lower()
383
+ return any(sensitive in key_lower for sensitive in sensitive_keys)
384
+
385
+ def mask_value(value: Any) -> Any:
386
+ if isinstance(value, str) and len(value) > 0:
387
+ return "***REDACTED***"
388
+ return value
389
+
390
+ masked = {}
391
+ for key, value in data.items():
392
+ if should_mask(key):
393
+ masked[key] = mask_value(value)
394
+ elif isinstance(value, dict):
395
+ masked[key] = mask_sensitive_data(value, sensitive_keys)
396
+ elif isinstance(value, list):
397
+ masked[key] = [
398
+ mask_sensitive_data(item, sensitive_keys) if isinstance(item, dict) else item
399
+ for item in value
400
+ ]
401
+ else:
402
+ masked[key] = value
403
+
404
+ return masked