prismiq 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prismiq/sql_utils.py ADDED
@@ -0,0 +1,147 @@
1
+ """SQL utilities for building safe parameterized queries.
2
+
3
+ This module provides generic SQL validation and formatting utilities
4
+ that can be used across different database drivers (asyncpg, SQLAlchemy,
5
+ etc.).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+
11
+ def validate_identifier(identifier: str, field_name: str) -> None:
12
+ """Validate SQL identifier to prevent injection via quote escaping.
13
+
14
+ Args:
15
+ identifier: The identifier to validate (table name, column name, alias)
16
+ field_name: Name of the field for error messages (e.g., "table name")
17
+
18
+ Raises:
19
+ ValueError: If identifier contains dangerous characters
20
+
21
+ Security:
22
+ Even though identifiers are quoted with double quotes, this function
23
+ prevents double-quote escaping and other dangerous patterns as a
24
+ defense-in-depth measure.
25
+
26
+ Example:
27
+ >>> validate_identifier("users", "table name") # OK
28
+ >>> validate_identifier("user'; DROP TABLE users--", "table name") # Raises
29
+ """
30
+ if not identifier:
31
+ raise ValueError(f"Invalid {field_name}: cannot be empty")
32
+
33
+ # Check for dangerous characters that could escape quoting
34
+ dangerous_chars = ['"', "'", ";", "--", "/*", "*/", "\\"]
35
+ for char in dangerous_chars:
36
+ if char in identifier:
37
+ raise ValueError(
38
+ f"Invalid {field_name} '{identifier}': contains forbidden character '{char}'"
39
+ )
40
+
41
+ # Additional validation: only allow safe characters for PostgreSQL identifiers
42
+ # - alphanumeric, underscore: standard identifier chars
43
+ # - dot: for schema.table references
44
+ # - space, parentheses, forward slash, hyphen, colon: common in PostgreSQL view column names
45
+ # - non-breaking space (\xa0): sometimes used in data imported from Excel/Word
46
+ allowed_special = ("_", ".", " ", "(", ")", "/", "-", ":", "\xa0")
47
+ if not all(c.isalnum() or c in allowed_special for c in identifier):
48
+ raise ValueError(f"Invalid {field_name} '{identifier}': contains invalid characters")
49
+
50
+
51
+ def quote_identifier(identifier: str) -> str:
52
+ """Quote a SQL identifier with double quotes for PostgreSQL.
53
+
54
+ Args:
55
+ identifier: Identifier to quote (table, column, alias)
56
+
57
+ Returns:
58
+ Quoted identifier safe for SQL interpolation
59
+
60
+ Example:
61
+ >>> quote_identifier("users")
62
+ '"users"'
63
+ >>> quote_identifier("first_name")
64
+ '"first_name"'
65
+ """
66
+ return f'"{identifier}"'
67
+
68
+
69
+ def convert_java_date_format_to_postgres(java_format: str) -> str:
70
+ """Convert Java SimpleDateFormat to PostgreSQL TO_CHAR format.
71
+
72
+ Args:
73
+ java_format: Date format using Java SimpleDateFormat patterns
74
+ (e.g., "MMM-yyyy", "MM/dd/yyyy")
75
+
76
+ Returns:
77
+ PostgreSQL TO_CHAR format string (e.g., "Mon-YYYY", "MM/DD/YYYY")
78
+
79
+ Example:
80
+ >>> convert_java_date_format_to_postgres("MMM-yyyy")
81
+ 'Mon-YYYY'
82
+ >>> convert_java_date_format_to_postgres("MM/dd/yyyy")
83
+ 'MM/DD/YYYY'
84
+ """
85
+ # Map Java SimpleDateFormat patterns to PostgreSQL patterns
86
+ format_map = {
87
+ "yyyy": "YYYY", # 4-digit year
88
+ "yy": "YY", # 2-digit year
89
+ "MMMM": "Month", # Full month name
90
+ "MMM": "Mon", # Abbreviated month name
91
+ "MM": "MM", # 2-digit month
92
+ "M": "MM", # 1-2 digit month (PostgreSQL doesn't have single digit, use MM)
93
+ "dd": "DD", # 2-digit day
94
+ "d": "DD", # 1-2 digit day (PostgreSQL doesn't have single digit, use DD)
95
+ "EEEE": "Day", # Full day name
96
+ "EEE": "Dy", # Abbreviated day name
97
+ "HH": "HH24", # 24-hour
98
+ "hh": "HH12", # 12-hour
99
+ "mm": "MI", # Minutes
100
+ "ss": "SS", # Seconds
101
+ }
102
+
103
+ # Use numeric placeholders to avoid substring matching issues
104
+ # e.g., "MMM-yyyy" -> "<<0>>-yyyy" -> "<<0>>-<<1>>" -> "Mon-YYYY"
105
+ placeholders: list[str] = []
106
+ result = java_format
107
+
108
+ # Replace patterns with numeric placeholders (longest first to avoid partial matches)
109
+ for java_pattern, pg_pattern in sorted(format_map.items(), key=lambda x: -len(x[0])):
110
+ while java_pattern in result:
111
+ placeholder_id = len(placeholders)
112
+ placeholders.append(pg_pattern)
113
+ # Use double angle brackets to avoid conflicts with format strings
114
+ result = result.replace(java_pattern, f"<<{placeholder_id}>>", 1)
115
+
116
+ # Replace numeric placeholders with actual PostgreSQL patterns
117
+ for i, pg_pattern in enumerate(placeholders):
118
+ result = result.replace(f"<<{i}>>", pg_pattern)
119
+
120
+ return result
121
+
122
+
123
+ # Constants for validation
124
+ ALLOWED_JOIN_TYPES = frozenset({"INNER", "LEFT", "RIGHT", "FULL"})
125
+ ALLOWED_OPERATORS = frozenset(
126
+ {
127
+ "eq",
128
+ "ne",
129
+ "gt",
130
+ "gte",
131
+ "lt",
132
+ "lte",
133
+ "in",
134
+ "in_", # Alias for "in" (React types use in_)
135
+ "in_or_null",
136
+ "in_subquery",
137
+ "like",
138
+ "ilike",
139
+ "not_like",
140
+ "not_ilike",
141
+ }
142
+ )
143
+ ALLOWED_AGGREGATIONS = frozenset({"none", "sum", "avg", "count", "count_distinct", "min", "max"})
144
+ ALLOWED_DATE_TRUNCS = frozenset(
145
+ {"year", "quarter", "month", "week", "day", "hour", "minute", "second"}
146
+ )
147
+ ALLOWED_ORDER_DIRECTIONS = frozenset({"ASC", "DESC"})
@@ -0,0 +1,219 @@
1
+ """SQL validation for custom SQL queries.
2
+
3
+ This module validates raw SQL queries to ensure they are safe to execute:
4
+ - Only SELECT statements allowed
5
+ - Only tables visible in the schema can be queried
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass
11
+
12
+ import sqlglot
13
+ import sqlglot.errors
14
+ from sqlglot import Expression, exp
15
+
16
+ from .types import DatabaseSchema, PrismiqError
17
+
18
+
19
+ class SQLValidationError(PrismiqError):
20
+ """Raised when SQL validation fails."""
21
+
22
+ def __init__(self, message: str, errors: list[str] | None = None) -> None:
23
+ super().__init__(message)
24
+ self.errors = errors or []
25
+
26
+
27
+ @dataclass
28
+ class SQLValidationResult:
29
+ """Result of SQL validation."""
30
+
31
+ valid: bool
32
+ """Whether the SQL is valid."""
33
+
34
+ errors: list[str]
35
+ """List of validation errors (empty if valid)."""
36
+
37
+ tables: list[str]
38
+ """List of tables referenced in the query."""
39
+
40
+ sanitized_sql: str | None
41
+ """The SQL if valid, None otherwise."""
42
+
43
+
44
+ class SQLValidator:
45
+ """Validates raw SQL queries for safety and schema compliance."""
46
+
47
+ ALLOWED_STATEMENT_TYPES = frozenset({"SELECT"})
48
+
49
+ def __init__(self, schema: DatabaseSchema) -> None:
50
+ """Initialize the validator.
51
+
52
+ Args:
53
+ schema: Database schema defining allowed tables.
54
+ """
55
+ self._schema = schema
56
+ self._allowed_tables = frozenset(t.name.lower() for t in schema.tables)
57
+
58
+ def validate(self, sql: str) -> SQLValidationResult:
59
+ """Validate a raw SQL query.
60
+
61
+ Args:
62
+ sql: The SQL query to validate.
63
+
64
+ Returns:
65
+ SQLValidationResult with validation status and details.
66
+ """
67
+ errors: list[str] = []
68
+ tables: list[str] = []
69
+
70
+ # Parse the SQL
71
+ try:
72
+ statements = sqlglot.parse(sql, dialect="postgres")
73
+ except sqlglot.errors.ParseError as e:
74
+ return SQLValidationResult(
75
+ valid=False,
76
+ errors=[f"SQL parse error: {e}"],
77
+ tables=[],
78
+ sanitized_sql=None,
79
+ )
80
+
81
+ if not statements:
82
+ return SQLValidationResult(
83
+ valid=False,
84
+ errors=["No SQL statement provided"],
85
+ tables=[],
86
+ sanitized_sql=None,
87
+ )
88
+
89
+ # Only allow single statements
90
+ if len(statements) > 1:
91
+ return SQLValidationResult(
92
+ valid=False,
93
+ errors=["Only single statements are allowed"],
94
+ tables=[],
95
+ sanitized_sql=None,
96
+ )
97
+
98
+ statement = statements[0]
99
+
100
+ # Some parse results can be None for empty strings
101
+ if statement is None:
102
+ return SQLValidationResult(
103
+ valid=False,
104
+ errors=["Empty SQL statement"],
105
+ tables=[],
106
+ sanitized_sql=None,
107
+ )
108
+
109
+ # Check statement type
110
+ statement_type = self._get_statement_type(statement)
111
+ if statement_type not in self.ALLOWED_STATEMENT_TYPES:
112
+ errors.append(
113
+ f"Statement type '{statement_type}' is not allowed. Only SELECT statements are permitted."
114
+ )
115
+
116
+ # Extract and validate tables
117
+ tables = self._extract_tables(statement)
118
+ invalid_tables = self._check_table_access(tables)
119
+ if invalid_tables:
120
+ errors.append(
121
+ f"Access denied to tables: {', '.join(sorted(invalid_tables))}. "
122
+ f"Allowed tables: {', '.join(sorted(t.name for t in self._schema.tables))}"
123
+ )
124
+
125
+ # Check for dangerous operations even within SELECT
126
+ dangerous = self._check_dangerous_operations(statement)
127
+ if dangerous:
128
+ errors.extend(dangerous)
129
+
130
+ if errors:
131
+ return SQLValidationResult(
132
+ valid=False,
133
+ errors=errors,
134
+ tables=tables,
135
+ sanitized_sql=None,
136
+ )
137
+
138
+ # Generate sanitized SQL
139
+ sanitized = statement.sql(dialect="postgres")
140
+
141
+ return SQLValidationResult(
142
+ valid=True,
143
+ errors=[],
144
+ tables=tables,
145
+ sanitized_sql=sanitized,
146
+ )
147
+
148
+ def _get_statement_type(self, statement: Expression) -> str:
149
+ """Get the type of SQL statement."""
150
+ if isinstance(statement, exp.Select):
151
+ return "SELECT"
152
+ if isinstance(statement, exp.Insert):
153
+ return "INSERT"
154
+ if isinstance(statement, exp.Update):
155
+ return "UPDATE"
156
+ if isinstance(statement, exp.Delete):
157
+ return "DELETE"
158
+ if isinstance(statement, exp.Create):
159
+ return "CREATE"
160
+ if isinstance(statement, exp.Drop):
161
+ return "DROP"
162
+ if isinstance(statement, exp.Alter):
163
+ return "ALTER"
164
+ if isinstance(statement, exp.Command):
165
+ return statement.this.upper() if statement.this else "COMMAND"
166
+ return type(statement).__name__.upper()
167
+
168
+ def _extract_tables(self, statement: Expression) -> list[str]:
169
+ """Extract all table names referenced in the statement."""
170
+ tables: set[str] = set()
171
+
172
+ # Find all Table expressions
173
+ for table in statement.find_all(exp.Table):
174
+ table_name = table.name
175
+ if table_name:
176
+ tables.add(table_name)
177
+
178
+ return sorted(tables)
179
+
180
+ def _check_table_access(self, tables: list[str]) -> set[str]:
181
+ """Check if all tables are allowed.
182
+
183
+ Returns set of invalid tables.
184
+ """
185
+ invalid: set[str] = set()
186
+ for table in tables:
187
+ if table.lower() not in self._allowed_tables:
188
+ invalid.add(table)
189
+ return invalid
190
+
191
+ def _check_dangerous_operations(self, statement: Expression) -> list[str]:
192
+ """Check for dangerous operations that could cause harm."""
193
+ errors: list[str] = []
194
+
195
+ # Check for INTO clause (SELECT INTO creates new table)
196
+ if statement.find(exp.Into):
197
+ errors.append("SELECT INTO is not allowed")
198
+
199
+ # Check for subqueries with modifications
200
+ for subquery in statement.find_all(exp.Subquery):
201
+ inner = subquery.this
202
+ if isinstance(inner, exp.Insert | exp.Update | exp.Delete):
203
+ errors.append("Modification statements in subqueries are not allowed")
204
+
205
+ return errors
206
+
207
+
208
+ def validate_sql(sql: str, schema: DatabaseSchema) -> SQLValidationResult:
209
+ """Convenience function to validate SQL.
210
+
211
+ Args:
212
+ sql: The SQL query to validate.
213
+ schema: Database schema defining allowed tables.
214
+
215
+ Returns:
216
+ SQLValidationResult with validation status and details.
217
+ """
218
+ validator = SQLValidator(schema)
219
+ return validator.validate(sql)