bizteamai-smcp 1.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
smcp/filters.py ADDED
@@ -0,0 +1,176 @@
1
+ """
2
+ Input filtering and sanitization for prompts and parameters.
3
+ """
4
+
5
+ import re
6
+ from typing import Any, Dict, Union
7
+
8
+
9
+ class InputValidationError(Exception):
10
+ """Raised when input fails validation."""
11
+ pass
12
+
13
+
14
+ def sanitize_prompt(prompt: str, cfg: Dict[str, Union[str, int]]) -> None:
15
+ """
16
+ Sanitize and validate prompt input against configured rules.
17
+
18
+ Args:
19
+ prompt: The prompt text to validate
20
+ cfg: Configuration dictionary containing validation rules
21
+
22
+ Raises:
23
+ InputValidationError: If the prompt fails validation
24
+ """
25
+ # Length validation
26
+ max_len = cfg.get("MAX_LEN")
27
+ if max_len and len(prompt) > max_len:
28
+ raise InputValidationError(f"Prompt exceeds maximum length of {max_len}")
29
+
30
+ # Pattern validation
31
+ safe_re = cfg.get("SAFE_RE")
32
+ if safe_re:
33
+ if not re.match(safe_re, prompt, re.DOTALL):
34
+ raise InputValidationError("Prompt contains invalid characters or patterns")
35
+
36
+ # Blocked patterns check
37
+ blocked_patterns = cfg.get("BLOCKED_PATTERNS", [])
38
+ for pattern in blocked_patterns:
39
+ if re.search(pattern, prompt, re.IGNORECASE):
40
+ raise InputValidationError(f"Prompt contains blocked pattern")
41
+
42
+
43
+ def sanitize_parameter(param_name: str, param_value: Any, cfg: Dict[str, Any]) -> Any:
44
+ """
45
+ Sanitize and validate a function parameter.
46
+
47
+ Args:
48
+ param_name: Name of the parameter
49
+ param_value: Value of the parameter
50
+ cfg: Configuration dictionary containing validation rules
51
+
52
+ Returns:
53
+ Sanitized parameter value
54
+
55
+ Raises:
56
+ InputValidationError: If the parameter fails validation
57
+ """
58
+ # String parameter validation
59
+ if isinstance(param_value, str):
60
+ _validate_string_parameter(param_name, param_value, cfg)
61
+
62
+ # Numeric parameter validation
63
+ elif isinstance(param_value, (int, float)):
64
+ _validate_numeric_parameter(param_name, param_value, cfg)
65
+
66
+ # List parameter validation
67
+ elif isinstance(param_value, list):
68
+ _validate_list_parameter(param_name, param_value, cfg)
69
+
70
+ return param_value
71
+
72
+
73
+ def _validate_string_parameter(param_name: str, value: str, cfg: Dict[str, Any]) -> None:
74
+ """Validate string parameters."""
75
+ param_rules = cfg.get("PARAM_RULES", {}).get(param_name, {})
76
+
77
+ # Length validation
78
+ max_len = param_rules.get("max_length", cfg.get("MAX_PARAM_LEN"))
79
+ if max_len and len(value) > max_len:
80
+ raise InputValidationError(f"Parameter '{param_name}' exceeds maximum length")
81
+
82
+ # Pattern validation
83
+ pattern = param_rules.get("pattern", cfg.get("PARAM_SAFE_RE"))
84
+ if pattern and not re.match(pattern, value):
85
+ raise InputValidationError(f"Parameter '{param_name}' contains invalid characters")
86
+
87
+ # Blocked content check
88
+ blocked = param_rules.get("blocked", cfg.get("BLOCKED_PARAM_PATTERNS", []))
89
+ for block_pattern in blocked:
90
+ if re.search(block_pattern, value, re.IGNORECASE):
91
+ raise InputValidationError(f"Parameter '{param_name}' contains blocked content")
92
+
93
+
94
+ def _validate_numeric_parameter(param_name: str, value: Union[int, float], cfg: Dict[str, Any]) -> None:
95
+ """Validate numeric parameters."""
96
+ param_rules = cfg.get("PARAM_RULES", {}).get(param_name, {})
97
+
98
+ # Range validation
99
+ min_val = param_rules.get("min_value")
100
+ max_val = param_rules.get("max_value")
101
+
102
+ if min_val is not None and value < min_val:
103
+ raise InputValidationError(f"Parameter '{param_name}' below minimum value {min_val}")
104
+
105
+ if max_val is not None and value > max_val:
106
+ raise InputValidationError(f"Parameter '{param_name}' exceeds maximum value {max_val}")
107
+
108
+
109
+ def _validate_list_parameter(param_name: str, value: list, cfg: Dict[str, Any]) -> None:
110
+ """Validate list parameters."""
111
+ param_rules = cfg.get("PARAM_RULES", {}).get(param_name, {})
112
+
113
+ # Length validation
114
+ max_items = param_rules.get("max_items", cfg.get("MAX_LIST_ITEMS"))
115
+ if max_items and len(value) > max_items:
116
+ raise InputValidationError(f"Parameter '{param_name}' has too many items")
117
+
118
+ # Validate each item in the list
119
+ for i, item in enumerate(value):
120
+ try:
121
+ if isinstance(item, str):
122
+ _validate_string_parameter(f"{param_name}[{i}]", item, cfg)
123
+ elif isinstance(item, (int, float)):
124
+ _validate_numeric_parameter(f"{param_name}[{i}]", item, cfg)
125
+ except InputValidationError as e:
126
+ raise InputValidationError(f"List item validation failed: {e}")
127
+
128
+
129
+ def create_safe_regex(allowed_chars: str = None) -> str:
130
+ """
131
+ Create a safe regex pattern for input validation.
132
+
133
+ Args:
134
+ allowed_chars: Additional characters to allow beyond basic alphanumeric
135
+
136
+ Returns:
137
+ Regex pattern string for safe input validation
138
+ """
139
+ base_chars = r"\w\s" # alphanumeric and whitespace
140
+
141
+ if allowed_chars:
142
+ # Escape special regex characters
143
+ escaped_chars = re.escape(allowed_chars)
144
+ base_chars += escaped_chars
145
+
146
+ return f"^[{base_chars}]*$"
147
+
148
+
149
+ def strip_dangerous_content(text: str, cfg: Dict[str, Any]) -> str:
150
+ """
151
+ Strip potentially dangerous content from text.
152
+
153
+ Args:
154
+ text: Text to sanitize
155
+ cfg: Configuration dictionary
156
+
157
+ Returns:
158
+ Sanitized text with dangerous content removed
159
+ """
160
+ # Remove common injection patterns
161
+ dangerous_patterns = [
162
+ r'<script[^>]*>.*?</script>', # Script tags
163
+ r'javascript:', # JavaScript URLs
164
+ r'data:', # Data URLs
165
+ r'vbscript:', # VBScript URLs
166
+ ]
167
+
168
+ # Add custom dangerous patterns from config
169
+ custom_patterns = cfg.get("DANGEROUS_PATTERNS", [])
170
+ dangerous_patterns.extend(custom_patterns)
171
+
172
+ sanitized = text
173
+ for pattern in dangerous_patterns:
174
+ sanitized = re.sub(pattern, '', sanitized, flags=re.IGNORECASE)
175
+
176
+ return sanitized
smcp/license.py ADDED
@@ -0,0 +1,113 @@
1
+ """
2
+ License verification and core limit enforcement for SMCP Business Edition.
3
+ """
4
+ import os
5
+ import hmac
6
+ import hashlib
7
+ import logging
8
+ from datetime import datetime, timezone
9
+ from typing import Optional, Tuple
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class LicenseError(Exception):
14
+ """Base exception for license-related errors."""
15
+ pass
16
+
17
+ class LicenseExpiredError(LicenseError):
18
+ """Raised when license has expired."""
19
+ pass
20
+
21
+ class LicenseInvalidError(LicenseError):
22
+ """Raised when license is invalid or corrupted."""
23
+ pass
24
+
25
+ class CoreLimitExceededError(LicenseError):
26
+ """Raised when CPU core limit is exceeded."""
27
+ pass
28
+
29
+ def parse_license_key(key: str) -> dict:
30
+ """Parse and validate license key format."""
31
+ parts = key.strip().split('.')
32
+ if len(parts) != 6 or parts[0] != 'BZT':
33
+ raise LicenseInvalidError("Invalid license key format")
34
+
35
+ return {
36
+ 'prefix': parts[0],
37
+ 'customer_id': parts[1],
38
+ 'cores': int(parts[2]),
39
+ 'expiry_utc': parts[3],
40
+ 'nonce': parts[4],
41
+ 'signature': parts[5]
42
+ }
43
+
44
+ def verify_license_signature(key_data: dict, secret: str) -> bool:
45
+ """Verify HMAC-SHA256 signature of license key."""
46
+ payload = f"{key_data['prefix']}.{key_data['customer_id']}.{key_data['cores']}.{key_data['expiry_utc']}.{key_data['nonce']}"
47
+ expected_sig = hmac.new(
48
+ secret.encode('utf-8'),
49
+ payload.encode('utf-8'),
50
+ hashlib.sha256
51
+ ).hexdigest()
52
+ return hmac.compare_digest(expected_sig, key_data['signature'])
53
+
54
+ def check_license_expiry(expiry_str: str) -> bool:
55
+ """Check if license has expired (24h grace period)."""
56
+ try:
57
+ expiry_date = datetime.strptime(expiry_str, '%Y%m%d').replace(tzinfo=timezone.utc)
58
+ # Add 24 hour grace period
59
+ grace_expiry = expiry_date.replace(hour=23, minute=59, second=59)
60
+ return datetime.now(timezone.utc) > grace_expiry
61
+ except ValueError:
62
+ raise LicenseInvalidError("Invalid expiry date format")
63
+
64
+ def load_license_key() -> Optional[str]:
65
+ """Load license key from file or environment variable."""
66
+ # Check environment variable first
67
+ key = os.getenv('BIZTEAM_LICENSE_KEY')
68
+ if key:
69
+ return key
70
+
71
+ # Check license file
72
+ license_file = os.getenv('BIZTEAM_LICENSE_FILE', '/etc/bizteam/license.txt')
73
+ try:
74
+ with open(license_file, 'r') as f:
75
+ return f.read().strip()
76
+ except FileNotFoundError:
77
+ return None
78
+
79
+ def verify_license(secret: str) -> Tuple[bool, dict]:
80
+ """
81
+ Verify license and return validity status and license data.
82
+
83
+ Returns:
84
+ Tuple of (is_valid, license_data)
85
+ """
86
+ license_key = load_license_key()
87
+ if not license_key:
88
+ return False, {}
89
+
90
+ try:
91
+ key_data = parse_license_key(license_key)
92
+
93
+ # Verify signature
94
+ if not verify_license_signature(key_data, secret):
95
+ raise LicenseInvalidError("Invalid license signature")
96
+
97
+ # Check expiry
98
+ if check_license_expiry(key_data['expiry_utc']):
99
+ raise LicenseExpiredError("License has expired")
100
+
101
+ logger.info(f"License verified for customer {key_data['customer_id']}, cores: {key_data['cores']}")
102
+ return True, key_data
103
+
104
+ except (LicenseError, ValueError) as e:
105
+ logger.error(f"License verification failed: {e}")
106
+ return False, {}
107
+
108
+ def get_licensed_cores() -> int:
109
+ """Get the number of licensed cores, or 0 if no valid license."""
110
+ # In production, this would use a proper server secret
111
+ server_secret = os.getenv('BIZTEAM_SERVER_SECRET', 'dev-secret-key')
112
+ is_valid, license_data = verify_license(server_secret)
113
+ return license_data.get('cores', 0) if is_valid else 0
smcp/logchain.py ADDED
@@ -0,0 +1,270 @@
1
+ """
2
+ Tamper-proof logging with SHA-256 chaining for audit trails.
3
+ """
4
+
5
+ import hashlib
6
+ import json
7
+ import time
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional, Tuple
10
+
11
+
12
+ class LogChain:
13
+ """
14
+ Append-only logger with SHA-256 chaining for tamper detection.
15
+
16
+ Each log entry contains a hash of the previous entry, creating
17
+ a chain that can detect tampering anywhere in the log.
18
+ """
19
+
20
+ def __init__(self, log_path: str):
21
+ """
22
+ Initialize the log chain.
23
+
24
+ Args:
25
+ log_path: Path to the log file
26
+ """
27
+ self.log_path = Path(log_path)
28
+ self.log_path.parent.mkdir(parents=True, exist_ok=True)
29
+ self._initialize_log()
30
+
31
+ def _initialize_log(self) -> None:
32
+ """Initialize the log file if it doesn't exist."""
33
+ if not self.log_path.exists():
34
+ # Create genesis entry
35
+ genesis_entry = {
36
+ "sequence": 0,
37
+ "timestamp": time.time(),
38
+ "event_type": "genesis",
39
+ "data": "Log chain initialized",
40
+ "previous_hash": "0" * 64, # Genesis has no previous hash
41
+ }
42
+ genesis_entry["hash"] = self._calculate_hash(genesis_entry)
43
+
44
+ with open(self.log_path, 'w') as f:
45
+ json.dump([genesis_entry], f, indent=2)
46
+
47
+ def _calculate_hash(self, entry: Dict[str, Any]) -> str:
48
+ """Calculate SHA-256 hash for a log entry."""
49
+ # Create a copy without the hash field for calculation
50
+ entry_copy = {k: v for k, v in entry.items() if k != "hash"}
51
+ entry_json = json.dumps(entry_copy, sort_keys=True, separators=(',', ':'))
52
+ return hashlib.sha256(entry_json.encode()).hexdigest()
53
+
54
+ def _load_log(self) -> List[Dict[str, Any]]:
55
+ """Load the current log entries."""
56
+ try:
57
+ with open(self.log_path, 'r') as f:
58
+ return json.load(f)
59
+ except (FileNotFoundError, json.JSONDecodeError):
60
+ return []
61
+
62
+ def _save_log(self, entries: List[Dict[str, Any]]) -> None:
63
+ """Save log entries to disk."""
64
+ with open(self.log_path, 'w') as f:
65
+ json.dump(entries, f, indent=2, default=str)
66
+
67
+ def _get_last_hash(self) -> str:
68
+ """Get the hash of the last log entry."""
69
+ entries = self._load_log()
70
+ return entries[-1]["hash"] if entries else "0" * 64
71
+
72
+ def append(self, event_type: str, data: Any) -> str:
73
+ """
74
+ Append a new entry to the log chain.
75
+
76
+ Args:
77
+ event_type: Type of event being logged
78
+ data: Event data to log
79
+
80
+ Returns:
81
+ Hash of the new entry
82
+ """
83
+ entries = self._load_log()
84
+ sequence = len(entries)
85
+ previous_hash = self._get_last_hash()
86
+
87
+ entry = {
88
+ "sequence": sequence,
89
+ "timestamp": time.time(),
90
+ "event_type": event_type,
91
+ "data": data,
92
+ "previous_hash": previous_hash,
93
+ }
94
+ entry["hash"] = self._calculate_hash(entry)
95
+
96
+ entries.append(entry)
97
+ self._save_log(entries)
98
+
99
+ return entry["hash"]
100
+
101
+ def verify_chain(self) -> Tuple[bool, Optional[int]]:
102
+ """
103
+ Verify the integrity of the log chain.
104
+
105
+ Returns:
106
+ Tuple of (is_valid, first_invalid_sequence)
107
+ is_valid is False if tampering is detected
108
+ first_invalid_sequence is the sequence number of the first invalid entry
109
+ """
110
+ entries = self._load_log()
111
+
112
+ if not entries:
113
+ return True, None
114
+
115
+ # Verify genesis entry
116
+ if entries[0]["previous_hash"] != "0" * 64:
117
+ return False, 0
118
+
119
+ # Verify each entry's hash and chain
120
+ for i, entry in enumerate(entries):
121
+ # Verify the entry's own hash
122
+ calculated_hash = self._calculate_hash(entry)
123
+ if calculated_hash != entry["hash"]:
124
+ return False, i
125
+
126
+ # Verify the chain (except for genesis)
127
+ if i > 0:
128
+ if entry["previous_hash"] != entries[i-1]["hash"]:
129
+ return False, i
130
+
131
+ return True, None
132
+
133
+ def get_entries(self, start_sequence: int = 0, end_sequence: Optional[int] = None) -> List[Dict[str, Any]]:
134
+ """
135
+ Get log entries in a range.
136
+
137
+ Args:
138
+ start_sequence: Starting sequence number (inclusive)
139
+ end_sequence: Ending sequence number (inclusive, None for all)
140
+
141
+ Returns:
142
+ List of log entries in the specified range
143
+ """
144
+ entries = self._load_log()
145
+
146
+ if end_sequence is None:
147
+ end_sequence = len(entries) - 1
148
+
149
+ return [entry for entry in entries
150
+ if start_sequence <= entry["sequence"] <= end_sequence]
151
+
152
+ def get_stats(self) -> Dict[str, Any]:
153
+ """Get statistics about the log chain."""
154
+ entries = self._load_log()
155
+
156
+ if not entries:
157
+ return {"total_entries": 0, "chain_valid": True}
158
+
159
+ is_valid, invalid_seq = self.verify_chain()
160
+
161
+ event_types = {}
162
+ for entry in entries:
163
+ event_type = entry["event_type"]
164
+ event_types[event_type] = event_types.get(event_type, 0) + 1
165
+
166
+ return {
167
+ "total_entries": len(entries),
168
+ "chain_valid": is_valid,
169
+ "first_invalid_sequence": invalid_seq,
170
+ "first_entry_time": entries[0]["timestamp"] if entries else None,
171
+ "last_entry_time": entries[-1]["timestamp"] if entries else None,
172
+ "event_type_counts": event_types,
173
+ }
174
+
175
+
176
+ # Global logger instance
177
+ _logger = None
178
+
179
+
180
+ def get_logger(cfg: Dict[str, Any]) -> Optional[LogChain]:
181
+ """Get or create the global logger if logging is enabled."""
182
+ global _logger
183
+
184
+ log_path = cfg.get("LOG_PATH")
185
+ if not log_path:
186
+ return None
187
+
188
+ if _logger is None:
189
+ _logger = LogChain(log_path)
190
+
191
+ return _logger
192
+
193
+
194
+ def log_event(function_name: str, args: Tuple, kwargs: Dict[str, Any], result: Any, cfg: Dict[str, Any]) -> None:
195
+ """
196
+ Log a function execution event.
197
+
198
+ Args:
199
+ function_name: Name of the executed function
200
+ args: Function arguments
201
+ kwargs: Function keyword arguments
202
+ result: Function result
203
+ cfg: Configuration dictionary
204
+ """
205
+ logger = get_logger(cfg)
206
+ if not logger:
207
+ return
208
+
209
+ # Remove sensitive data from kwargs for logging
210
+ safe_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
211
+
212
+ event_data = {
213
+ "function": function_name,
214
+ "args": args,
215
+ "kwargs": safe_kwargs,
216
+ "result_type": type(result).__name__,
217
+ "result_size": len(str(result)) if result is not None else 0,
218
+ "success": True,
219
+ }
220
+
221
+ logger.append("function_execution", event_data)
222
+
223
+
224
+ def log_security_event(event_type: str, details: Dict[str, Any], cfg: Dict[str, Any]) -> None:
225
+ """
226
+ Log a security-related event.
227
+
228
+ Args:
229
+ event_type: Type of security event
230
+ details: Event details
231
+ cfg: Configuration dictionary
232
+ """
233
+ logger = get_logger(cfg)
234
+ if not logger:
235
+ return
236
+
237
+ event_data = {
238
+ "security_event_type": event_type,
239
+ "details": details,
240
+ }
241
+
242
+ logger.append("security_event", event_data)
243
+
244
+
245
+ def log_error(function_name: str, error: Exception, args: Tuple, kwargs: Dict[str, Any], cfg: Dict[str, Any]) -> None:
246
+ """
247
+ Log an error event.
248
+
249
+ Args:
250
+ function_name: Name of the function that errored
251
+ error: The exception that occurred
252
+ args: Function arguments
253
+ kwargs: Function keyword arguments
254
+ cfg: Configuration dictionary
255
+ """
256
+ logger = get_logger(cfg)
257
+ if not logger:
258
+ return
259
+
260
+ safe_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
261
+
262
+ event_data = {
263
+ "function": function_name,
264
+ "error_type": type(error).__name__,
265
+ "error_message": str(error),
266
+ "args": args,
267
+ "kwargs": safe_kwargs,
268
+ }
269
+
270
+ logger.append("function_error", event_data)