bizteamai-smcp-biz 1.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
smcp/filters.py ADDED
@@ -0,0 +1,176 @@
1
+ """
2
+ Input filtering and sanitization for prompts and parameters.
3
+ """
4
+
5
+ import re
6
+ from typing import Any, Dict, Union
7
+
8
+
9
+ class InputValidationError(Exception):
10
+ """Raised when input fails validation."""
11
+ pass
12
+
13
+
14
+ def sanitize_prompt(prompt: str, cfg: Dict[str, Union[str, int]]) -> None:
15
+ """
16
+ Sanitize and validate prompt input against configured rules.
17
+
18
+ Args:
19
+ prompt: The prompt text to validate
20
+ cfg: Configuration dictionary containing validation rules
21
+
22
+ Raises:
23
+ InputValidationError: If the prompt fails validation
24
+ """
25
+ # Length validation
26
+ max_len = cfg.get("MAX_LEN")
27
+ if max_len and len(prompt) > max_len:
28
+ raise InputValidationError(f"Prompt exceeds maximum length of {max_len}")
29
+
30
+ # Pattern validation
31
+ safe_re = cfg.get("SAFE_RE")
32
+ if safe_re:
33
+ if not re.match(safe_re, prompt, re.DOTALL):
34
+ raise InputValidationError("Prompt contains invalid characters or patterns")
35
+
36
+ # Blocked patterns check
37
+ blocked_patterns = cfg.get("BLOCKED_PATTERNS", [])
38
+ for pattern in blocked_patterns:
39
+ if re.search(pattern, prompt, re.IGNORECASE):
40
+ raise InputValidationError(f"Prompt contains blocked pattern")
41
+
42
+
43
+ def sanitize_parameter(param_name: str, param_value: Any, cfg: Dict[str, Any]) -> Any:
44
+ """
45
+ Sanitize and validate a function parameter.
46
+
47
+ Args:
48
+ param_name: Name of the parameter
49
+ param_value: Value of the parameter
50
+ cfg: Configuration dictionary containing validation rules
51
+
52
+ Returns:
53
+ Sanitized parameter value
54
+
55
+ Raises:
56
+ InputValidationError: If the parameter fails validation
57
+ """
58
+ # String parameter validation
59
+ if isinstance(param_value, str):
60
+ _validate_string_parameter(param_name, param_value, cfg)
61
+
62
+ # Numeric parameter validation
63
+ elif isinstance(param_value, (int, float)):
64
+ _validate_numeric_parameter(param_name, param_value, cfg)
65
+
66
+ # List parameter validation
67
+ elif isinstance(param_value, list):
68
+ _validate_list_parameter(param_name, param_value, cfg)
69
+
70
+ return param_value
71
+
72
+
73
+ def _validate_string_parameter(param_name: str, value: str, cfg: Dict[str, Any]) -> None:
74
+ """Validate string parameters."""
75
+ param_rules = cfg.get("PARAM_RULES", {}).get(param_name, {})
76
+
77
+ # Length validation
78
+ max_len = param_rules.get("max_length", cfg.get("MAX_PARAM_LEN"))
79
+ if max_len and len(value) > max_len:
80
+ raise InputValidationError(f"Parameter '{param_name}' exceeds maximum length")
81
+
82
+ # Pattern validation
83
+ pattern = param_rules.get("pattern", cfg.get("PARAM_SAFE_RE"))
84
+ if pattern and not re.match(pattern, value):
85
+ raise InputValidationError(f"Parameter '{param_name}' contains invalid characters")
86
+
87
+ # Blocked content check
88
+ blocked = param_rules.get("blocked", cfg.get("BLOCKED_PARAM_PATTERNS", []))
89
+ for block_pattern in blocked:
90
+ if re.search(block_pattern, value, re.IGNORECASE):
91
+ raise InputValidationError(f"Parameter '{param_name}' contains blocked content")
92
+
93
+
94
+ def _validate_numeric_parameter(param_name: str, value: Union[int, float], cfg: Dict[str, Any]) -> None:
95
+ """Validate numeric parameters."""
96
+ param_rules = cfg.get("PARAM_RULES", {}).get(param_name, {})
97
+
98
+ # Range validation
99
+ min_val = param_rules.get("min_value")
100
+ max_val = param_rules.get("max_value")
101
+
102
+ if min_val is not None and value < min_val:
103
+ raise InputValidationError(f"Parameter '{param_name}' below minimum value {min_val}")
104
+
105
+ if max_val is not None and value > max_val:
106
+ raise InputValidationError(f"Parameter '{param_name}' exceeds maximum value {max_val}")
107
+
108
+
109
+ def _validate_list_parameter(param_name: str, value: list, cfg: Dict[str, Any]) -> None:
110
+ """Validate list parameters."""
111
+ param_rules = cfg.get("PARAM_RULES", {}).get(param_name, {})
112
+
113
+ # Length validation
114
+ max_items = param_rules.get("max_items", cfg.get("MAX_LIST_ITEMS"))
115
+ if max_items and len(value) > max_items:
116
+ raise InputValidationError(f"Parameter '{param_name}' has too many items")
117
+
118
+ # Validate each item in the list
119
+ for i, item in enumerate(value):
120
+ try:
121
+ if isinstance(item, str):
122
+ _validate_string_parameter(f"{param_name}[{i}]", item, cfg)
123
+ elif isinstance(item, (int, float)):
124
+ _validate_numeric_parameter(f"{param_name}[{i}]", item, cfg)
125
+ except InputValidationError as e:
126
+ raise InputValidationError(f"List item validation failed: {e}")
127
+
128
+
129
+ def create_safe_regex(allowed_chars: str = None) -> str:
130
+ """
131
+ Create a safe regex pattern for input validation.
132
+
133
+ Args:
134
+ allowed_chars: Additional characters to allow beyond basic alphanumeric
135
+
136
+ Returns:
137
+ Regex pattern string for safe input validation
138
+ """
139
+ base_chars = r"\w\s" # alphanumeric and whitespace
140
+
141
+ if allowed_chars:
142
+ # Escape special regex characters
143
+ escaped_chars = re.escape(allowed_chars)
144
+ base_chars += escaped_chars
145
+
146
+ return f"^[{base_chars}]*$"
147
+
148
+
149
+ def strip_dangerous_content(text: str, cfg: Dict[str, Any]) -> str:
150
+ """
151
+ Strip potentially dangerous content from text.
152
+
153
+ Args:
154
+ text: Text to sanitize
155
+ cfg: Configuration dictionary
156
+
157
+ Returns:
158
+ Sanitized text with dangerous content removed
159
+ """
160
+ # Remove common injection patterns
161
+ dangerous_patterns = [
162
+ r'<script[^>]*>.*?</script>', # Script tags
163
+ r'javascript:', # JavaScript URLs
164
+ r'data:', # Data URLs
165
+ r'vbscript:', # VBScript URLs
166
+ ]
167
+
168
+ # Add custom dangerous patterns from config
169
+ custom_patterns = cfg.get("DANGEROUS_PATTERNS", [])
170
+ dangerous_patterns.extend(custom_patterns)
171
+
172
+ sanitized = text
173
+ for pattern in dangerous_patterns:
174
+ sanitized = re.sub(pattern, '', sanitized, flags=re.IGNORECASE)
175
+
176
+ return sanitized
smcp/license.py ADDED
@@ -0,0 +1,232 @@
1
+ """
2
+ License parsing, verification, and management for SMCP Business Edition.
3
+ """
4
+ import os
5
+ import hmac
6
+ import hashlib
7
+ import logging
8
+ import requests
9
+ from datetime import datetime, timezone
10
+ from typing import Optional, Dict, Any
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Global license state
15
+ _license_info: Optional[Dict[str, Any]] = None
16
+ _server_secret = "default-dev-secret-change-in-production" # Should be from env in production
17
+
18
+ class LicenseError(Exception):
19
+ """License validation error."""
20
+ pass
21
+
22
+ def parse_license_key(key: str) -> Dict[str, Any]:
23
+ """
24
+ Parse a license key in format: BZT.<custID>.<cores>.<expiryUTC>.<nonce>.<sig>
25
+
26
+ Args:
27
+ key: License key string
28
+
29
+ Returns:
30
+ Dictionary with parsed components
31
+
32
+ Raises:
33
+ LicenseError: If key format is invalid
34
+ """
35
+ parts = key.strip().split('.')
36
+ if len(parts) != 6 or parts[0] != 'BZT':
37
+ raise LicenseError("Invalid license key format")
38
+
39
+ try:
40
+ return {
41
+ 'prefix': parts[0],
42
+ 'customer_id': parts[1],
43
+ 'cores': int(parts[2]),
44
+ 'expiry': parts[3],
45
+ 'nonce': parts[4],
46
+ 'signature': parts[5]
47
+ }
48
+ except ValueError as e:
49
+ raise LicenseError(f"Invalid license key data: {e}")
50
+
51
+ def verify_signature(license_data: Dict[str, Any], secret: str) -> bool:
52
+ """
53
+ Verify the HMAC signature of a license key.
54
+
55
+ Args:
56
+ license_data: Parsed license data
57
+ secret: Server secret for verification
58
+
59
+ Returns:
60
+ True if signature is valid
61
+ """
62
+ payload = f"BZT.{license_data['customer_id']}.{license_data['cores']}.{license_data['expiry']}.{license_data['nonce']}"
63
+
64
+ expected_sig = hmac.new(
65
+ secret.encode('utf-8'),
66
+ payload.encode('utf-8'),
67
+ hashlib.sha256
68
+ ).hexdigest()
69
+
70
+ return hmac.compare_digest(expected_sig, license_data['signature'])
71
+
72
+ def check_expiry(license_data: Dict[str, Any]) -> bool:
73
+ """
74
+ Check if license has expired (with 24h grace period).
75
+
76
+ Args:
77
+ license_data: Parsed license data
78
+
79
+ Returns:
80
+ True if license is still valid
81
+ """
82
+ try:
83
+ expiry_date = datetime.strptime(license_data['expiry'], '%Y%m%d')
84
+ expiry_date = expiry_date.replace(tzinfo=timezone.utc)
85
+
86
+ # Add 24 hour grace period
87
+ from datetime import timedelta
88
+ grace_expiry = expiry_date + timedelta(hours=24)
89
+
90
+ now = datetime.now(timezone.utc)
91
+ return now <= grace_expiry
92
+ except ValueError:
93
+ return False
94
+
95
+ def check_revocation(license_key: str) -> bool:
96
+ """
97
+ Check if license key is revoked by querying remote revocation list.
98
+
99
+ Args:
100
+ license_key: Full license key
101
+
102
+ Returns:
103
+ True if license is NOT revoked
104
+ """
105
+ try:
106
+ # Calculate key hash
107
+ key_hash = hashlib.sha256(license_key.encode()).hexdigest()
108
+
109
+ # Check if revocation checking is disabled
110
+ if os.getenv('BIZTEAM_SKIP_REVOCATION') == '1':
111
+ logger.debug("Revocation checking disabled")
112
+ return True
113
+
114
+ # Fetch revocation list (with timeout)
115
+ revocation_url = os.getenv('BIZTEAM_REVOCATION_URL', 'https://revocation.bizteam.com/v1/list')
116
+ response = requests.get(revocation_url, timeout=5)
117
+
118
+ if response.status_code == 200:
119
+ revoked_hashes = response.json().get('revoked', [])
120
+ if key_hash in revoked_hashes:
121
+ logger.warning(f"License key is revoked: {key_hash[:16]}...")
122
+ return False
123
+ else:
124
+ logger.warning(f"Could not check revocation list: HTTP {response.status_code}")
125
+
126
+ return True
127
+
128
+ except requests.RequestException as e:
129
+ logger.warning(f"Revocation check failed: {e}")
130
+ # Fail open - allow license if revocation check fails
131
+ return True
132
+ except Exception as e:
133
+ logger.error(f"Unexpected error during revocation check: {e}")
134
+ return True
135
+
136
+ def load_license_key() -> Optional[str]:
137
+ """
138
+ Load license key from environment or file.
139
+
140
+ Returns:
141
+ License key string or None if not found
142
+ """
143
+ # Try environment variable first
144
+ key = os.getenv('BIZTEAM_LICENSE_KEY')
145
+ if key:
146
+ return key.strip()
147
+
148
+ # Try license file
149
+ license_file = os.getenv('BIZTEAM_LICENSE_FILE', '/etc/bizteam/license.txt')
150
+ try:
151
+ with open(license_file, 'r') as f:
152
+ return f.read().strip()
153
+ except FileNotFoundError:
154
+ logger.error(f"License file not found: {license_file}")
155
+ return None
156
+ except Exception as e:
157
+ logger.error(f"Error reading license file: {e}")
158
+ return None
159
+
160
+ def verify_license() -> Optional[Dict[str, Any]]:
161
+ """
162
+ Load and verify the license key.
163
+
164
+ Returns:
165
+ Verified license data or None if invalid
166
+ """
167
+ global _license_info
168
+
169
+ # Load license key
170
+ key = load_license_key()
171
+ if not key:
172
+ logger.error("No license key found")
173
+ return None
174
+
175
+ try:
176
+ # Parse license
177
+ license_data = parse_license_key(key)
178
+
179
+ # Get server secret
180
+ secret = os.getenv('BIZTEAM_SERVER_SECRET', _server_secret)
181
+
182
+ # Verify signature
183
+ if not verify_signature(license_data, secret):
184
+ logger.error("License signature verification failed")
185
+ return None
186
+
187
+ # Check expiry
188
+ if not check_expiry(license_data):
189
+ logger.error("License has expired")
190
+ return None
191
+
192
+ # Check revocation
193
+ if not check_revocation(key):
194
+ logger.error("License has been revoked")
195
+ return None
196
+
197
+ # Cache license info
198
+ _license_info = license_data
199
+ logger.info(f"License verified for customer {license_data['customer_id']}")
200
+ return license_data
201
+
202
+ except LicenseError as e:
203
+ logger.error(f"License validation failed: {e}")
204
+ return None
205
+
206
+ def get_licensed_cores() -> int:
207
+ """
208
+ Get the number of cores licensed.
209
+
210
+ Returns:
211
+ Number of licensed cores, or 0 if no valid license
212
+ """
213
+ if _license_info:
214
+ return _license_info['cores']
215
+
216
+ # Try to verify license if not cached
217
+ license_data = verify_license()
218
+ return license_data['cores'] if license_data else 0
219
+
220
+ def get_customer_id() -> Optional[str]:
221
+ """
222
+ Get the customer ID from the license.
223
+
224
+ Returns:
225
+ Customer ID or None if no valid license
226
+ """
227
+ if _license_info:
228
+ return _license_info['customer_id']
229
+
230
+ # Try to verify license if not cached
231
+ license_data = verify_license()
232
+ return license_data['customer_id'] if license_data else None
smcp/logchain.py ADDED
@@ -0,0 +1,270 @@
1
+ """
2
+ Tamper-proof logging with SHA-256 chaining for audit trails.
3
+ """
4
+
5
+ import hashlib
6
+ import json
7
+ import time
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional, Tuple
10
+
11
+
12
+ class LogChain:
13
+ """
14
+ Append-only logger with SHA-256 chaining for tamper detection.
15
+
16
+ Each log entry contains a hash of the previous entry, creating
17
+ a chain that can detect tampering anywhere in the log.
18
+ """
19
+
20
+ def __init__(self, log_path: str):
21
+ """
22
+ Initialize the log chain.
23
+
24
+ Args:
25
+ log_path: Path to the log file
26
+ """
27
+ self.log_path = Path(log_path)
28
+ self.log_path.parent.mkdir(parents=True, exist_ok=True)
29
+ self._initialize_log()
30
+
31
+ def _initialize_log(self) -> None:
32
+ """Initialize the log file if it doesn't exist."""
33
+ if not self.log_path.exists():
34
+ # Create genesis entry
35
+ genesis_entry = {
36
+ "sequence": 0,
37
+ "timestamp": time.time(),
38
+ "event_type": "genesis",
39
+ "data": "Log chain initialized",
40
+ "previous_hash": "0" * 64, # Genesis has no previous hash
41
+ }
42
+ genesis_entry["hash"] = self._calculate_hash(genesis_entry)
43
+
44
+ with open(self.log_path, 'w') as f:
45
+ json.dump([genesis_entry], f, indent=2)
46
+
47
+ def _calculate_hash(self, entry: Dict[str, Any]) -> str:
48
+ """Calculate SHA-256 hash for a log entry."""
49
+ # Create a copy without the hash field for calculation
50
+ entry_copy = {k: v for k, v in entry.items() if k != "hash"}
51
+ entry_json = json.dumps(entry_copy, sort_keys=True, separators=(',', ':'))
52
+ return hashlib.sha256(entry_json.encode()).hexdigest()
53
+
54
+ def _load_log(self) -> List[Dict[str, Any]]:
55
+ """Load the current log entries."""
56
+ try:
57
+ with open(self.log_path, 'r') as f:
58
+ return json.load(f)
59
+ except (FileNotFoundError, json.JSONDecodeError):
60
+ return []
61
+
62
+ def _save_log(self, entries: List[Dict[str, Any]]) -> None:
63
+ """Save log entries to disk."""
64
+ with open(self.log_path, 'w') as f:
65
+ json.dump(entries, f, indent=2, default=str)
66
+
67
+ def _get_last_hash(self) -> str:
68
+ """Get the hash of the last log entry."""
69
+ entries = self._load_log()
70
+ return entries[-1]["hash"] if entries else "0" * 64
71
+
72
+ def append(self, event_type: str, data: Any) -> str:
73
+ """
74
+ Append a new entry to the log chain.
75
+
76
+ Args:
77
+ event_type: Type of event being logged
78
+ data: Event data to log
79
+
80
+ Returns:
81
+ Hash of the new entry
82
+ """
83
+ entries = self._load_log()
84
+ sequence = len(entries)
85
+ previous_hash = self._get_last_hash()
86
+
87
+ entry = {
88
+ "sequence": sequence,
89
+ "timestamp": time.time(),
90
+ "event_type": event_type,
91
+ "data": data,
92
+ "previous_hash": previous_hash,
93
+ }
94
+ entry["hash"] = self._calculate_hash(entry)
95
+
96
+ entries.append(entry)
97
+ self._save_log(entries)
98
+
99
+ return entry["hash"]
100
+
101
+ def verify_chain(self) -> Tuple[bool, Optional[int]]:
102
+ """
103
+ Verify the integrity of the log chain.
104
+
105
+ Returns:
106
+ Tuple of (is_valid, first_invalid_sequence)
107
+ is_valid is False if tampering is detected
108
+ first_invalid_sequence is the sequence number of the first invalid entry
109
+ """
110
+ entries = self._load_log()
111
+
112
+ if not entries:
113
+ return True, None
114
+
115
+ # Verify genesis entry
116
+ if entries[0]["previous_hash"] != "0" * 64:
117
+ return False, 0
118
+
119
+ # Verify each entry's hash and chain
120
+ for i, entry in enumerate(entries):
121
+ # Verify the entry's own hash
122
+ calculated_hash = self._calculate_hash(entry)
123
+ if calculated_hash != entry["hash"]:
124
+ return False, i
125
+
126
+ # Verify the chain (except for genesis)
127
+ if i > 0:
128
+ if entry["previous_hash"] != entries[i-1]["hash"]:
129
+ return False, i
130
+
131
+ return True, None
132
+
133
+ def get_entries(self, start_sequence: int = 0, end_sequence: Optional[int] = None) -> List[Dict[str, Any]]:
134
+ """
135
+ Get log entries in a range.
136
+
137
+ Args:
138
+ start_sequence: Starting sequence number (inclusive)
139
+ end_sequence: Ending sequence number (inclusive, None for all)
140
+
141
+ Returns:
142
+ List of log entries in the specified range
143
+ """
144
+ entries = self._load_log()
145
+
146
+ if end_sequence is None:
147
+ end_sequence = len(entries) - 1
148
+
149
+ return [entry for entry in entries
150
+ if start_sequence <= entry["sequence"] <= end_sequence]
151
+
152
+ def get_stats(self) -> Dict[str, Any]:
153
+ """Get statistics about the log chain."""
154
+ entries = self._load_log()
155
+
156
+ if not entries:
157
+ return {"total_entries": 0, "chain_valid": True}
158
+
159
+ is_valid, invalid_seq = self.verify_chain()
160
+
161
+ event_types = {}
162
+ for entry in entries:
163
+ event_type = entry["event_type"]
164
+ event_types[event_type] = event_types.get(event_type, 0) + 1
165
+
166
+ return {
167
+ "total_entries": len(entries),
168
+ "chain_valid": is_valid,
169
+ "first_invalid_sequence": invalid_seq,
170
+ "first_entry_time": entries[0]["timestamp"] if entries else None,
171
+ "last_entry_time": entries[-1]["timestamp"] if entries else None,
172
+ "event_type_counts": event_types,
173
+ }
174
+
175
+
176
+ # Global logger instance
177
+ _logger = None
178
+
179
+
180
+ def get_logger(cfg: Dict[str, Any]) -> Optional[LogChain]:
181
+ """Get or create the global logger if logging is enabled."""
182
+ global _logger
183
+
184
+ log_path = cfg.get("LOG_PATH")
185
+ if not log_path:
186
+ return None
187
+
188
+ if _logger is None:
189
+ _logger = LogChain(log_path)
190
+
191
+ return _logger
192
+
193
+
194
+ def log_event(function_name: str, args: Tuple, kwargs: Dict[str, Any], result: Any, cfg: Dict[str, Any]) -> None:
195
+ """
196
+ Log a function execution event.
197
+
198
+ Args:
199
+ function_name: Name of the executed function
200
+ args: Function arguments
201
+ kwargs: Function keyword arguments
202
+ result: Function result
203
+ cfg: Configuration dictionary
204
+ """
205
+ logger = get_logger(cfg)
206
+ if not logger:
207
+ return
208
+
209
+ # Remove sensitive data from kwargs for logging
210
+ safe_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
211
+
212
+ event_data = {
213
+ "function": function_name,
214
+ "args": args,
215
+ "kwargs": safe_kwargs,
216
+ "result_type": type(result).__name__,
217
+ "result_size": len(str(result)) if result is not None else 0,
218
+ "success": True,
219
+ }
220
+
221
+ logger.append("function_execution", event_data)
222
+
223
+
224
+ def log_security_event(event_type: str, details: Dict[str, Any], cfg: Dict[str, Any]) -> None:
225
+ """
226
+ Log a security-related event.
227
+
228
+ Args:
229
+ event_type: Type of security event
230
+ details: Event details
231
+ cfg: Configuration dictionary
232
+ """
233
+ logger = get_logger(cfg)
234
+ if not logger:
235
+ return
236
+
237
+ event_data = {
238
+ "security_event_type": event_type,
239
+ "details": details,
240
+ }
241
+
242
+ logger.append("security_event", event_data)
243
+
244
+
245
+ def log_error(function_name: str, error: Exception, args: Tuple, kwargs: Dict[str, Any], cfg: Dict[str, Any]) -> None:
246
+ """
247
+ Log an error event.
248
+
249
+ Args:
250
+ function_name: Name of the function that errored
251
+ error: The exception that occurred
252
+ args: Function arguments
253
+ kwargs: Function keyword arguments
254
+ cfg: Configuration dictionary
255
+ """
256
+ logger = get_logger(cfg)
257
+ if not logger:
258
+ return
259
+
260
+ safe_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
261
+
262
+ event_data = {
263
+ "function": function_name,
264
+ "error_type": type(error).__name__,
265
+ "error_message": str(error),
266
+ "args": args,
267
+ "kwargs": safe_kwargs,
268
+ }
269
+
270
+ logger.append("function_error", event_data)