safeshield 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of safeshield might be problematic. Click here for more details.

@@ -0,0 +1,162 @@
1
+ import contextlib
2
+ from typing import Dict, Any, List, Optional, Iterator
3
+ import mysql.connector
4
+ import psycopg2
5
+
6
+ class DatabaseManager:
7
+ def __init__(self, config: Dict[str, Any]):
8
+ self.config = config
9
+ self.connection = None
10
+ self._connection_params = self._normalize_config()
11
+
12
+ def _normalize_config(self) -> Dict[str, Any]:
13
+ """Normalize different config formats from various frameworks"""
14
+ db_type = self.config.get('type', 'mysql').lower()
15
+
16
+ # Handle SQLAlchemy-style URLs
17
+ if 'url' in self.config:
18
+ return self._parse_url(self.config['url'])
19
+
20
+ return {
21
+ 'type': db_type,
22
+ 'host': self.config.get('host', 'localhost'),
23
+ 'port': self.config.get('port', 3306 if db_type == 'mysql' else 5432),
24
+ 'user': self.config.get('user') or self.config.get('username', 'root'),
25
+ 'password': self.config.get('password', ''),
26
+ 'database': self.config.get('database', ''),
27
+ 'options': self.config.get('options', {})
28
+ }
29
+
30
+ def _parse_url(self, url: str) -> Dict[str, Any]:
31
+ """Parse SQLAlchemy-style database URLs"""
32
+ from urllib.parse import urlparse
33
+ parsed = urlparse(url)
34
+ return {
35
+ 'type': parsed.scheme.split('+')[0],
36
+ 'host': parsed.hostname,
37
+ 'port': parsed.port,
38
+ 'user': parsed.username,
39
+ 'password': parsed.password,
40
+ 'database': parsed.path[1:] if parsed.path else ''
41
+ }
42
+
43
+ @contextlib.contextmanager
44
+ def get_cursor(self) -> Iterator[Any]:
45
+ """Context manager for safe cursor handling"""
46
+ self.connect()
47
+ cursor = None
48
+ try:
49
+ cursor = self.connection.cursor(dictionary=True)
50
+ yield cursor
51
+ finally:
52
+ if cursor:
53
+ cursor.close()
54
+
55
+ def connect(self):
56
+ """Establish database connection with retry logic"""
57
+ if self.connection and self.connection.is_connected():
58
+ return
59
+
60
+ params = self._connection_params
61
+ try:
62
+ if params['type'] == 'mysql':
63
+ self.connection = mysql.connector.connect(
64
+ host=params['host'],
65
+ user=params['user'],
66
+ password=params['password'],
67
+ database=params['database'],
68
+ port=params['port'],
69
+ **params.get('options', {})
70
+ )
71
+ elif params['type'] == 'postgresql':
72
+ self.connection = psycopg2.connect(
73
+ host=params['host'],
74
+ user=params['user'],
75
+ password=params['password'],
76
+ dbname=params['database'],
77
+ port=params['port'],
78
+ **params.get('options', {})
79
+ )
80
+ elif params['type'] == 'sqlite':
81
+ import sqlite3
82
+ self.connection = sqlite3.connect(
83
+ params['database'] or ':memory:',
84
+ **params.get('options', {})
85
+ )
86
+ except Exception as e:
87
+ raise ConnectionError(f"Failed to connect to database: {str(e)}")
88
+
89
+ def disconnect(self):
90
+ """Safely disconnect with cleanup"""
91
+ if self.connection:
92
+ try:
93
+ if hasattr(self.connection, 'is_connected') and self.connection.is_connected():
94
+ self.connection.close()
95
+ elif hasattr(self.connection, 'closed') and not self.connection.closed:
96
+ self.connection.close()
97
+ except Exception:
98
+ pass
99
+ finally:
100
+ self.connection = None
101
+
102
+ def query(self, sql: str, params: tuple = ()) -> List[Dict[str, Any]]:
103
+ """Safe query execution with automatic reconnection"""
104
+ with self.get_cursor() as cursor:
105
+ cursor.execute(sql, params)
106
+ return cursor.fetchall()
107
+
108
+ def exists(self, table: str, column: str, value: Any) -> bool:
109
+ """Optimized exists check with parameterized query"""
110
+ sql = f"SELECT 1 FROM {table} WHERE {column} = %s LIMIT 1"
111
+ result = self.query(sql, (value,))
112
+ return bool(result)
113
+
114
+ def is_unique(self, table: str, column: str, value: Any,
115
+ ignore_id: Optional[int] = None) -> bool:
116
+ """Advanced uniqueness check with ignore condition"""
117
+ where_clause = f"{column} = %s"
118
+ params = [value]
119
+
120
+ if ignore_id is not None:
121
+ where_clause += " AND id != %s"
122
+ params.append(ignore_id)
123
+
124
+ sql = f"""
125
+ SELECT NOT EXISTS (
126
+ SELECT 1 FROM {table}
127
+ WHERE {where_clause}
128
+ LIMIT 1
129
+ ) AS is_unique
130
+ """
131
+
132
+ result = self.query(sql, tuple(params))
133
+ return result[0]['is_unique']
134
+
135
+ def table_exists(self, table_name: str) -> bool:
136
+ """Check if table exists in database"""
137
+ try:
138
+ if self._connection_params['type'] == 'mysql':
139
+ sql = """
140
+ SELECT COUNT(*) AS count
141
+ FROM information_schema.tables
142
+ WHERE table_schema = %s AND table_name = %s
143
+ """
144
+ result = self.query(sql, (self._connection_params['database'], table_name))
145
+ elif self._connection_params['type'] == 'postgresql':
146
+ sql = """
147
+ SELECT COUNT(*) AS count
148
+ FROM information_schema.tables
149
+ WHERE table_schema = 'public' AND table_name = %s
150
+ """
151
+ result = self.query(sql, (table_name,))
152
+ else: # SQLite
153
+ sql = """
154
+ SELECT COUNT(*) AS count
155
+ FROM sqlite_master
156
+ WHERE type='table' AND name=?
157
+ """
158
+ result = self.query(sql, (table_name,))
159
+
160
+ return result[0]['count'] > 0
161
+ except Exception:
162
+ return False
@@ -0,0 +1,10 @@
1
+ from typing import Dict, List, Tuple, Optional
2
+
3
+ class ValidationException(Exception):
4
+ def __init__(self, errors: dict):
5
+ self.errors = errors
6
+ super().__init__(self.errors)
7
+
8
+ class RuleNotFoundException(ValueError):
9
+ def __init__(self, rule_name: str):
10
+ super().__init__(f"Validation rule '{rule_name}' is not registered")
validator/factory.py ADDED
@@ -0,0 +1,26 @@
1
+ # factory.py
2
+ from typing import Type, Dict, List
3
+ from .rules import all_rules
4
+ from .rules.base import ValidationRule
5
+
6
+ class RuleFactory:
7
+ _rules: Dict[str, Type[ValidationRule]] = all_rules
8
+
9
+ @classmethod
10
+ def create_rule(cls, rule_name: str) -> ValidationRule:
11
+ try:
12
+ return cls._rules[rule_name]()
13
+ except KeyError:
14
+ raise ValueError(f"Unknown validation rule: {rule_name}")
15
+
16
+ @classmethod
17
+ def register_rule(cls, name: str, rule_class: Type[ValidationRule]):
18
+ cls._rules[name] = rule_class
19
+
20
+ @classmethod
21
+ def has_rule(cls, rule_name: str) -> bool:
22
+ return rule_name in cls._rules
23
+
24
+ @classmethod
25
+ def get_rule_names(cls) -> List[str]:
26
+ return list(cls._rules.keys())
@@ -0,0 +1,27 @@
1
+ import inspect
2
+ from .base import ValidationRule
3
+ from ..utils.string import pascal_to_snake
4
+ from . import array, basic, comparison, date, type, files, format, conditional, string, utilities
5
+
6
+ def _collect_rules():
7
+ modules = [array, basic, comparison, date, type, files, format, conditional, string, utilities]
8
+ rules = {}
9
+
10
+ for module in modules:
11
+ for name, obj in inspect.getmembers(module):
12
+ if (inspect.isclass(obj) and
13
+ issubclass(obj, ValidationRule) and
14
+ obj != ValidationRule):
15
+ rules[obj.rule_name] = obj
16
+ return rules
17
+
18
+ all_rules = _collect_rules()
19
+
20
+ for name, cls in all_rules.items():
21
+ globals()[cls.__name__.replace('Rule', '')] = cls # Export class name
22
+ globals()[name] = cls # Export rule name
23
+
24
+
25
+ __all__ = ['ValidationRule'] + \
26
+ list(all_rules.keys()) + \
27
+ [cls.__name__.replace('Rule', '') for cls in all_rules.values()]
@@ -0,0 +1,77 @@
1
+ from .base import ValidationRule
2
+ from typing import Any, Dict, List, Optional, Set, Union, Tuple, Type
3
+
4
+ class ArrayRule(ValidationRule):
5
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
6
+
7
+ self._missing_keys = params
8
+
9
+ if params:
10
+ params = self._parse_option_values(self.rule_name, params)
11
+ if not isinstance(value, dict):
12
+ return False
13
+
14
+ missing = [param for param in params if param not in value]
15
+ self._missing_keys = missing
16
+ return len(missing) == 0
17
+
18
+ return isinstance(value, (list, tuple, set))
19
+
20
+ def message(self, field: str, params: List[str]) -> str:
21
+ if params:
22
+ return f"The {field} must contain the keys: {', '.join(self._missing_keys)}."
23
+ return f"The {field} must be an array."
24
+
25
+ class ContainsRule(ValidationRule):
26
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
27
+ if not params:
28
+ return False
29
+
30
+ search_value = params[0]
31
+
32
+ # String contains substring
33
+ if isinstance(value, str):
34
+ return search_value in value
35
+
36
+ # Array contains element
37
+ if isinstance(value, (list, tuple, set)):
38
+ return search_value in value
39
+
40
+ # Dictionary contains key
41
+ if isinstance(value, dict):
42
+ return search_value in value.keys()
43
+
44
+ return False
45
+
46
+ def message(self, field: str, params: List[str]) -> str:
47
+ return f"The :name must contain {params[0]}"
48
+
49
+ class DistinctRule(ValidationRule):
50
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
51
+ if not isinstance(value, (list, tuple, set)):
52
+ return False
53
+
54
+ return len(value) == len(set(value))
55
+
56
+ def message(self, field: str, params: List[str]) -> str:
57
+ return f"The :name must contain unique values"
58
+
59
+ class InArrayRule(ValidationRule):
60
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
61
+ if not params:
62
+ return False
63
+
64
+ return str(value) in params
65
+
66
+ def message(self, field: str, params: List[str]) -> str:
67
+ return f"The :name must be one of: {', '.join(params)}"
68
+
69
+ class InArrayKeysRule(ValidationRule):
70
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
71
+ if not params or not isinstance(value, dict):
72
+ return False
73
+
74
+ return any(key in value for key in params)
75
+
76
+ def message(self, field: str, params: List[str]) -> str:
77
+ return f"The :name must contain at least one of these keys: {', '.join(params)}"
@@ -0,0 +1,84 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, List, Optional, Set, Union, Tuple, Type
3
+ from enum import Enum
4
+ from ..utils.string import pascal_to_snake
5
+ import inspect
6
+
7
+ class ValidationRule(ABC):
8
+ """Abstract base class for all validation rules"""
9
+
10
+ def __init__(self, *params: str):
11
+ self._validator: Optional['Validator'] = None
12
+ self._params: List[str] = list(params)
13
+
14
+ def __init_subclass__(cls):
15
+ cls.rule_name = pascal_to_snake(cls.__name__)
16
+
17
+ @property
18
+ def params(self) -> List[str]:
19
+ return self._params
20
+
21
+ @params.setter
22
+ def params(self, value: List[str]) -> None:
23
+ self._params = value
24
+
25
+ def set_validator(self, validator: 'Validator') -> None:
26
+ """Set the validator instance this rule belongs to."""
27
+ self._validator = validator
28
+
29
+ def set_field_exists(self, exists: bool):
30
+ self._field_exists = exists
31
+
32
+ @property
33
+ def validator(self) -> 'Validator':
34
+ """Get the validator instance."""
35
+ if self._validator is None:
36
+ raise RuntimeError("Validator not set for this rule!")
37
+ return self._validator
38
+
39
+ @property
40
+ def field_exists(self):
41
+ return self._field_exists
42
+
43
+ def get_field_value(self, field_name, default=''):
44
+ return str(self.validator.data.get(field_name, default))
45
+
46
+ @staticmethod
47
+ def is_empty(value):
48
+ return value in (None, '', [], {})
49
+
50
+ @abstractmethod
51
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
52
+ """Validate a field's value."""
53
+ pass
54
+
55
+ @abstractmethod
56
+ def message(self, field: str) -> str:
57
+ """Generate an error message if validation fails."""
58
+ pass
59
+
60
+ @property
61
+ @abstractmethod
62
+ def rule_name(self) -> str:
63
+ """Return the name of the rule for error messages."""
64
+ pass
65
+
66
+ def _parse_option_values(self, field: str, params: List[str]) -> List[Any]:
67
+ """Parse parameters into allowed values, supporting both Enum class and literal values"""
68
+ if not params:
69
+ raise ValueError(
70
+ f"{self.rule_name} rule requires parameters. "
71
+ f"Use '({self.rule_name}, EnumClass)' or '{self.rule_name}:val1,val2'"
72
+ )
73
+
74
+ enum_params = [param for param in params if inspect.isclass(param) and issubclass(param, Enum)]
75
+ params = [param for param in params if param not in enum_params]
76
+
77
+ for enum_param in enum_params:
78
+ params.extend([e.value for e in enum_param])
79
+
80
+ params = set([str(param) for param in params])
81
+
82
+ param_str = ' ,'.join(params)
83
+
84
+ return [v.strip() for v in param_str.split(',') if v.strip()]
@@ -0,0 +1,41 @@
1
+ from .base import ValidationRule
2
+ from typing import Any, Dict, List, Optional, Set, Union, Tuple, Type
3
+
4
+ # =============================================
5
+ # BASIC VALIDATION RULES
6
+ # =============================================
7
+
8
+ class RequiredRule(ValidationRule):
9
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
10
+ return not self.is_empty(value)
11
+
12
+ def message(self, field: str, params: List[str]) -> str:
13
+ return f"The :name field is required."
14
+
15
+ class NullableRule(ValidationRule):
16
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
17
+ return True
18
+
19
+ def message(self, field: str, params: List[str]) -> str:
20
+ return f"The :name may be null."
21
+
22
+ class FilledRule(ValidationRule):
23
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
24
+ return value not in ('', None)
25
+
26
+ def message(self, field: str, params: List[str]) -> str:
27
+ return f"The :name field must have a value."
28
+
29
+ class PresentRule(ValidationRule):
30
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
31
+ return field in self.validator.data
32
+
33
+ def message(self, field: str, params: List[str]) -> str:
34
+ return f"The :name field must be present."
35
+
36
+ class SometimesRule(ValidationRule):
37
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
38
+ return True
39
+
40
+ def message(self, field: str, params: List[str]) -> str:
41
+ return ""
@@ -0,0 +1,240 @@
1
+ from .base import ValidationRule
2
+ from typing import Any, Dict, List, Optional, Set, Union, Tuple, Type
3
+
4
+ # =============================================
5
+ # COMPARISON VALIDATION RULES
6
+ # =============================================
7
+
8
+ class MinRule(ValidationRule):
9
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
10
+ if not params:
11
+ return False
12
+
13
+ try:
14
+ min_val = float(params[0])
15
+ except ValueError:
16
+ return False
17
+
18
+ if isinstance(value, (int, float)):
19
+ return value >= min_val
20
+ elif isinstance(value, str):
21
+ try:
22
+ return len(value) >= min_val
23
+ except ValueError:
24
+ return len(value) >= min_val
25
+ elif isinstance(value, (list, dict, set)):
26
+ return len(value) >= min_val
27
+ return False
28
+
29
+ def message(self, field: str, params: List[str]) -> str:
30
+ return f"The :name must be at least {params[0]}."
31
+
32
+ class MaxRule(ValidationRule):
33
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
34
+ if not params or len(params) < 1:
35
+ return False
36
+
37
+ try:
38
+ max_val = float(params[0])
39
+
40
+ # Handle Werkzeug/Flask FileStorage
41
+ if hasattr(value, 'content_length'): # Flask/Werkzeug
42
+ file_size = value.content_length
43
+ print(f"File size (content_length): {file_size}")
44
+ return file_size <= max_val
45
+
46
+ # Handle generic file objects with size attribute
47
+ if hasattr(value, 'size'):
48
+ file_size = value.size
49
+ return file_size <= max_val
50
+
51
+ # Handle file-like objects with seek/read
52
+ if hasattr(value, 'seek') and hasattr(value, 'read'):
53
+ try:
54
+ current_pos = value.tell()
55
+ value.seek(0, 2) # Seek to end
56
+ file_size = value.tell()
57
+ value.seek(current_pos) # Return to original position
58
+ return file_size <= max_val
59
+ except (AttributeError, IOError):
60
+ pass
61
+
62
+ # Numeric validation
63
+ if isinstance(value, (int, float)):
64
+ return value <= max_val
65
+
66
+ # String/collection length validation
67
+ if isinstance(value, (str, list, dict, set, tuple)):
68
+ length = len(value)
69
+ return length <= max_val
70
+
71
+ # String numeric validation
72
+ if isinstance(value, str):
73
+ try:
74
+ num = float(value)
75
+ return num <= max_val
76
+ except ValueError:
77
+ length = len(value)
78
+ return length <= max_val
79
+
80
+ except (ValueError, TypeError) as e:
81
+ return False
82
+
83
+ return False
84
+
85
+ def message(self, field: str, params: List[str]) -> str:
86
+ value = self.get_field_value(field)
87
+ if value is None:
88
+ return f"The {field} must not exceed {params[0]}"
89
+
90
+ # Check all possible file size attributes
91
+ file_attrs = ['content_length', 'size', 'fileno']
92
+ if any(hasattr(value, attr) for attr in file_attrs):
93
+ return f"File {field} must not exceed {params[0]} bytes"
94
+ return f"The {field} must not exceed {params[0]}"
95
+
96
+ class BetweenRule(ValidationRule):
97
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
98
+ if len(params) != 2:
99
+ return False
100
+
101
+ try:
102
+ min_val = float(params[0])
103
+ max_val = float(params[1])
104
+
105
+ # File size validation (untuk Werkzeug/FileStorage)
106
+ if hasattr(value, 'content_length'): # Cek atribut Werkzeug
107
+ file_size = value.content_length
108
+ return min_val <= file_size <= max_val
109
+ elif hasattr(value, 'size'): # Cek atribut umum
110
+ file_size = value.size
111
+ return min_val <= file_size <= max_val
112
+
113
+ # Numeric validation
114
+ if isinstance(value, (int, float)):
115
+ return min_val <= value <= max_val
116
+
117
+ # String/collection length validation
118
+ if isinstance(value, (str, list, dict, set, tuple)):
119
+ length = len(value)
120
+ return min_val <= length <= max_val
121
+
122
+ # String numeric validation
123
+ if isinstance(value, str):
124
+ try:
125
+ num = float(value)
126
+ return min_val <= num <= max_val
127
+ except ValueError:
128
+ length = len(value)
129
+ return min_val <= length <= max_val
130
+
131
+ except (ValueError, TypeError) as e:
132
+ return False
133
+
134
+ return False
135
+
136
+ def message(self, field: str, params: List[str]) -> str:
137
+ value = self.get_field_value(field)
138
+ if hasattr(value, 'content_length') or hasattr(value, 'size'):
139
+ return f"File {field} must be between {params[0]} and {params[1]} bytes"
140
+ return f"The {field} must be between {params[0]} and {params[1]}"
141
+
142
+ class SizeRule(ValidationRule):
143
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
144
+ if not params or len(params) < 1:
145
+ return False
146
+
147
+ try:
148
+ target_size = float(params[0])
149
+
150
+ # 1. Handle file objects (Flask/Werkzeug/FastAPI)
151
+ if hasattr(value, 'content_length'): # Flask/Werkzeug
152
+ return value.content_length == target_size
153
+ elif hasattr(value, 'size'): # FastAPI or custom file objects
154
+ return value.size == target_size
155
+ elif hasattr(value, 'seek') and hasattr(value, 'tell'): # File-like objects
156
+ current_pos = value.tell()
157
+ value.seek(0, 2) # Seek to end
158
+ file_size = value.tell()
159
+ value.seek(current_pos) # Return to original position
160
+ return file_size == target_size
161
+
162
+ # 2. Handle numeric values
163
+ if isinstance(value, (int, float)):
164
+ return value == target_size
165
+
166
+ # 3. Handle strings and collections
167
+ if isinstance(value, (str, list, dict, set, tuple)):
168
+ return len(value) == target_size
169
+
170
+ # 4. Handle string representations of numbers
171
+ if isinstance(value, str):
172
+ try:
173
+ return float(value) == target_size
174
+ except ValueError:
175
+ return len(value) == target_size
176
+
177
+ except (ValueError, TypeError, AttributeError) as e:
178
+ return False
179
+
180
+ return False
181
+
182
+ def message(self, field: str, params: List[str]) -> str:
183
+ value = self.get_field_value(field)
184
+ if value is None:
185
+ return f"The {field} must be exactly {params[0]}"
186
+
187
+ # Check for file attributes
188
+ file_attrs = ['content_length', 'size', 'fileno']
189
+ if any(hasattr(value, attr) for attr in file_attrs):
190
+ return f"File {field} must be exactly {params[0]} bytes"
191
+
192
+ return f"The {field} must be exactly {params[0]}"
193
+
194
+ class DigitsRule(ValidationRule):
195
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
196
+ if not params or not isinstance(value, str):
197
+ return False
198
+
199
+ try:
200
+ digits = int(params[0])
201
+ except ValueError:
202
+ return False
203
+
204
+ return value.isdigit() and len(value) == digits
205
+
206
+ def message(self, field: str, params: List[str]) -> str:
207
+ return f"The :name must be {params[0]} digits."
208
+
209
+ class DigitsBetweenRule(ValidationRule):
210
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
211
+ if len(params) < 2 or not isinstance(value, str):
212
+ return False
213
+
214
+ try:
215
+ min_digits = int(params[0])
216
+ max_digits = int(params[1])
217
+ except ValueError:
218
+ return False
219
+
220
+ return value.isdigit() and min_digits <= len(value) <= max_digits
221
+
222
+ def message(self, field: str, params: List[str]) -> str:
223
+ return f"The :name must be between {params[0]} and {params[1]} digits."
224
+
225
+ class MultipleOfRule(ValidationRule):
226
+ def validate(self, field: str, value: Any, params: List[str]) -> bool:
227
+ if not params:
228
+ return False
229
+
230
+ try:
231
+ divisor = float(params[0])
232
+ if divisor == 0:
233
+ return False
234
+ num = float(value)
235
+ return num % divisor == 0
236
+ except (ValueError, TypeError):
237
+ return False
238
+
239
+ def message(self, field: str, params: List[str]) -> str:
240
+ return f"The :name must be a multiple of {params[0]}."