sqlscope 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlscope/__init__.py +4 -0
- sqlscope/catalog/__init__.py +12 -0
- sqlscope/catalog/builder/__init__.py +8 -0
- sqlscope/catalog/builder/postgres.py +207 -0
- sqlscope/catalog/builder/sql.py +219 -0
- sqlscope/catalog/catalog.py +147 -0
- sqlscope/catalog/column.py +68 -0
- sqlscope/catalog/constraint.py +83 -0
- sqlscope/catalog/schema.py +60 -0
- sqlscope/catalog/table.py +112 -0
- sqlscope/query/__init__.py +5 -0
- sqlscope/query/extractors.py +118 -0
- sqlscope/query/query.py +191 -0
- sqlscope/query/set_operations/__init__.py +181 -0
- sqlscope/query/set_operations/binary_set_operation.py +162 -0
- sqlscope/query/set_operations/select.py +664 -0
- sqlscope/query/set_operations/set_operation.py +59 -0
- sqlscope/query/smt.py +334 -0
- sqlscope/query/tokenized_sql.py +70 -0
- sqlscope/query/typechecking/__init__.py +21 -0
- sqlscope/query/typechecking/base.py +9 -0
- sqlscope/query/typechecking/binary_ops.py +57 -0
- sqlscope/query/typechecking/functions.py +81 -0
- sqlscope/query/typechecking/predicates.py +123 -0
- sqlscope/query/typechecking/primitives.py +80 -0
- sqlscope/query/typechecking/queries.py +35 -0
- sqlscope/query/typechecking/types.py +59 -0
- sqlscope/query/typechecking/unary_ops.py +51 -0
- sqlscope/query/typechecking/util.py +51 -0
- sqlscope/util/__init__.py +18 -0
- sqlscope/util/ast/__init__.py +55 -0
- sqlscope/util/ast/column.py +55 -0
- sqlscope/util/ast/function.py +10 -0
- sqlscope/util/ast/subquery.py +23 -0
- sqlscope/util/ast/table.py +36 -0
- sqlscope/util/sql.py +27 -0
- sqlscope/util/tokens.py +17 -0
- sqlscope-1.0.0.dist-info/METADATA +52 -0
- sqlscope-1.0.0.dist-info/RECORD +41 -0
- sqlscope-1.0.0.dist-info/WHEEL +4 -0
- sqlscope-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import sqlglot
|
|
2
|
+
from sqlglot import exp
|
|
3
|
+
from ...catalog import Table, Column
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from .select import Select
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SetOperation(ABC):
|
|
13
|
+
'''
|
|
14
|
+
Abstract base class for SQL set operations (i.e., SELECT, UNION, INTERSECT, EXCEPT).
|
|
15
|
+
'''
|
|
16
|
+
|
|
17
|
+
def __init__(self, sql: str, parent_query: 'Select | None' = None) -> None:
|
|
18
|
+
self.sql = sql
|
|
19
|
+
'''The SQL string representing the operation.'''
|
|
20
|
+
|
|
21
|
+
self.parent_query = parent_query
|
|
22
|
+
'''The parent Select if this is a subquery.'''
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def output(self) -> Table:
|
|
27
|
+
'''Returns the output table schema of the set operation.'''
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def referenced_tables(self) -> list[Table]:
|
|
33
|
+
'''Returns a list of tables that are referenced in the SQL query.'''
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
def __repr__(self, pre: str = '') -> str:
|
|
37
|
+
return f'{pre}{self.__class__.__name__}'
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def print_tree(self, pre: str = '') -> None:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def main_selects(self) -> list['Select']:
|
|
47
|
+
'''Returns a list of selects that are part of a set operation.'''
|
|
48
|
+
return []
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def selects(self) -> list['Select']:
|
|
53
|
+
'''Returns a list of all Select nodes in the tree.'''
|
|
54
|
+
return []
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
|
sqlscope/query/smt.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
'''Convert SQL expressions to Z3 expressions for logical reasoning.'''
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable
|
|
4
|
+
from sqlglot import exp
|
|
5
|
+
from z3 import (
|
|
6
|
+
Int, IntVal,
|
|
7
|
+
Real, RealVal,
|
|
8
|
+
Bool, BoolVal,
|
|
9
|
+
String, StringVal,
|
|
10
|
+
And, Or, Not,
|
|
11
|
+
Solver,
|
|
12
|
+
unsat,
|
|
13
|
+
is_expr,
|
|
14
|
+
BoolSort,
|
|
15
|
+
ExprRef,
|
|
16
|
+
Re,
|
|
17
|
+
AllChar,
|
|
18
|
+
Concat,
|
|
19
|
+
InRe,
|
|
20
|
+
PrefixOf,
|
|
21
|
+
SuffixOf,
|
|
22
|
+
Contains,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
from ..catalog import Table
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# ----------------------------------------------------------------------
|
|
29
|
+
# Z3 variable creation
|
|
30
|
+
# ----------------------------------------------------------------------
|
|
31
|
+
|
|
32
|
+
def create_z3_var(variables: dict[str, Any], table_name: str | None,
|
|
33
|
+
col_name: str, col_type: Callable[[str], ExprRef] | None = None) -> None:
|
|
34
|
+
'''
|
|
35
|
+
Create a Z3 variable for the given column name and type, and add it to the
|
|
36
|
+
variables dictionary. If col_type is None, default to Int.
|
|
37
|
+
'''
|
|
38
|
+
if col_type is None:
|
|
39
|
+
col_type = Int # default type
|
|
40
|
+
|
|
41
|
+
# unqualified
|
|
42
|
+
variables[col_name] = col_type(col_name)
|
|
43
|
+
variables[f'{col_name}_isnull'] = Bool(f'{col_name}_isnull')
|
|
44
|
+
|
|
45
|
+
# qualified
|
|
46
|
+
if table_name:
|
|
47
|
+
variables[f'{table_name}.{col_name}'] = col_type(f'{table_name}.{col_name}')
|
|
48
|
+
variables[f'{table_name}.{col_name}_isnull'] = Bool(f'{table_name}.{col_name}_isnull')
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def fresh_symbol(prefix: str, sort: str):
|
|
52
|
+
'''Generate a fresh Z3 symbol with the given prefix and sort.'''
|
|
53
|
+
if sort == 'int':
|
|
54
|
+
return Int(f'{prefix}_{id(prefix)}')
|
|
55
|
+
if sort == 'real':
|
|
56
|
+
return Real(f'{prefix}_{id(prefix)}')
|
|
57
|
+
if sort == 'bool':
|
|
58
|
+
return Bool(f'{prefix}_{id(prefix)}')
|
|
59
|
+
return String(f'{prefix}_{id(prefix)}')
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# ----------------------------------------------------------------------
|
|
63
|
+
# Infer expected type of a subquery based on parent expression
|
|
64
|
+
# ----------------------------------------------------------------------
|
|
65
|
+
|
|
66
|
+
def infer_subquery_sort_from_parent(expr) -> str:
|
|
67
|
+
'''
|
|
68
|
+
Infer the expected Z3 sort of a subquery based on its parent expression.
|
|
69
|
+
'''
|
|
70
|
+
parent = expr.parent
|
|
71
|
+
|
|
72
|
+
# Arithmetic context → numeric
|
|
73
|
+
if isinstance(parent, (exp.Add, exp.Sub, exp.Mul, exp.Div, exp.Mod, exp.Pow)):
|
|
74
|
+
return 'real'
|
|
75
|
+
|
|
76
|
+
# Comparison context → numeric
|
|
77
|
+
if isinstance(parent, (exp.GT, exp.GTE, exp.LT, exp.LTE)):
|
|
78
|
+
return 'real'
|
|
79
|
+
|
|
80
|
+
# BETWEEN → numeric
|
|
81
|
+
if isinstance(parent, exp.Between):
|
|
82
|
+
return 'real'
|
|
83
|
+
|
|
84
|
+
# LIKE → string
|
|
85
|
+
if isinstance(parent, exp.Like):
|
|
86
|
+
return 'string'
|
|
87
|
+
|
|
88
|
+
# String concatenation (|| operator)
|
|
89
|
+
if isinstance(parent, exp.Concat):
|
|
90
|
+
return 'string'
|
|
91
|
+
|
|
92
|
+
# Default: boolean (EXISTS, WHERE (...))
|
|
93
|
+
return 'bool'
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# ----------------------------------------------------------------------
|
|
97
|
+
# Catalog → Z3 vars
|
|
98
|
+
# ----------------------------------------------------------------------
|
|
99
|
+
|
|
100
|
+
def catalog_table_to_z3_vars(table: Table) -> dict[str, ExprRef]:
|
|
101
|
+
'''Convert catalog table columns to Z3 variables.'''
|
|
102
|
+
variables = {}
|
|
103
|
+
for column in table.columns:
|
|
104
|
+
col_name = column.name
|
|
105
|
+
col_type = column.column_type.upper()
|
|
106
|
+
|
|
107
|
+
if col_type in ('INT', 'INTEGER', 'BIGINT', 'SMALLINT'):
|
|
108
|
+
create_z3_var(variables, table.name, col_name, Int)
|
|
109
|
+
elif col_type in ('FLOAT', 'REAL', 'DOUBLE'):
|
|
110
|
+
create_z3_var(variables, table.name, col_name, Real)
|
|
111
|
+
elif col_type in ('BOOLEAN', 'BOOL'):
|
|
112
|
+
create_z3_var(variables, table.name, col_name, Bool)
|
|
113
|
+
elif col_type in ('VARCHAR', 'CHAR', 'TEXT', 'CHARACTER VARYING'):
|
|
114
|
+
create_z3_var(variables, table.name, col_name, String)
|
|
115
|
+
else:
|
|
116
|
+
create_z3_var(variables, table.name, col_name)
|
|
117
|
+
return variables
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
# ----------------------------------------------------------------------
|
|
121
|
+
# SQL → Z3 conversion
|
|
122
|
+
# ----------------------------------------------------------------------
|
|
123
|
+
|
|
124
|
+
def sql_to_z3(expr, variables: dict[str, ExprRef] = {}) -> Any:
|
|
125
|
+
'''Convert a SQLGlot expression to a Z3 expression.'''
|
|
126
|
+
|
|
127
|
+
# --- Columns ---
|
|
128
|
+
if isinstance(expr, exp.Column):
|
|
129
|
+
name = expr.name.lower()
|
|
130
|
+
if name not in variables:
|
|
131
|
+
create_z3_var(variables, None, name)
|
|
132
|
+
return variables[name]
|
|
133
|
+
|
|
134
|
+
# --- Literals ---
|
|
135
|
+
elif isinstance(expr, exp.Literal):
|
|
136
|
+
val = expr.this
|
|
137
|
+
if expr.is_int:
|
|
138
|
+
return IntVal(int(val))
|
|
139
|
+
elif expr.is_number:
|
|
140
|
+
return RealVal(float(val))
|
|
141
|
+
elif expr.is_string:
|
|
142
|
+
return StringVal(val.strip("'"))
|
|
143
|
+
elif val.upper() in ('TRUE', 'FALSE'):
|
|
144
|
+
return BoolVal(val.upper() == 'TRUE')
|
|
145
|
+
elif val.upper() == 'NULL':
|
|
146
|
+
return None
|
|
147
|
+
else:
|
|
148
|
+
raise NotImplementedError(f"Unsupported literal: {val}")
|
|
149
|
+
|
|
150
|
+
elif isinstance(expr, exp.Null):
|
|
151
|
+
return None
|
|
152
|
+
|
|
153
|
+
# --- Boolean comparisons ---
|
|
154
|
+
elif isinstance(expr, exp.EQ):
|
|
155
|
+
return sql_to_z3(expr.left, variables) == sql_to_z3(expr.right, variables)
|
|
156
|
+
elif isinstance(expr, exp.NEQ):
|
|
157
|
+
return sql_to_z3(expr.left, variables) != sql_to_z3(expr.right, variables)
|
|
158
|
+
elif isinstance(expr, exp.GT):
|
|
159
|
+
return sql_to_z3(expr.left, variables) > sql_to_z3(expr.right, variables)
|
|
160
|
+
elif isinstance(expr, exp.GTE):
|
|
161
|
+
return sql_to_z3(expr.left, variables) >= sql_to_z3(expr.right, variables)
|
|
162
|
+
elif isinstance(expr, exp.LT):
|
|
163
|
+
return sql_to_z3(expr.left, variables) < sql_to_z3(expr.right, variables)
|
|
164
|
+
elif isinstance(expr, exp.LTE):
|
|
165
|
+
return sql_to_z3(expr.left, variables) <= sql_to_z3(expr.right, variables)
|
|
166
|
+
|
|
167
|
+
# --- Logical connectives ---
|
|
168
|
+
elif isinstance(expr, exp.And):
|
|
169
|
+
return And(sql_to_z3(expr.left, variables), sql_to_z3(expr.right, variables))
|
|
170
|
+
elif isinstance(expr, exp.Or):
|
|
171
|
+
return Or(sql_to_z3(expr.left, variables), sql_to_z3(expr.right, variables))
|
|
172
|
+
elif isinstance(expr, exp.Not):
|
|
173
|
+
return Not(sql_to_z3(expr.this, variables))
|
|
174
|
+
elif isinstance(expr, exp.Paren):
|
|
175
|
+
return sql_to_z3(expr.this, variables)
|
|
176
|
+
|
|
177
|
+
# --- Arithmetic ---
|
|
178
|
+
elif isinstance(expr, exp.Add):
|
|
179
|
+
return sql_to_z3(expr.left, variables) + sql_to_z3(expr.right, variables)
|
|
180
|
+
elif isinstance(expr, exp.Sub):
|
|
181
|
+
return sql_to_z3(expr.left, variables) - sql_to_z3(expr.right, variables)
|
|
182
|
+
elif isinstance(expr, exp.Mul):
|
|
183
|
+
return sql_to_z3(expr.left, variables) * sql_to_z3(expr.right, variables)
|
|
184
|
+
elif isinstance(expr, exp.Div):
|
|
185
|
+
return sql_to_z3(expr.left, variables) / sql_to_z3(expr.right, variables)
|
|
186
|
+
elif isinstance(expr, exp.Mod):
|
|
187
|
+
return sql_to_z3(expr.left, variables) % sql_to_z3(expr.right, variables)
|
|
188
|
+
elif isinstance(expr, exp.Pow):
|
|
189
|
+
return sql_to_z3(expr.left, variables) ** sql_to_z3(expr.right, variables)
|
|
190
|
+
|
|
191
|
+
# --- BETWEEN ---
|
|
192
|
+
elif isinstance(expr, exp.Between):
|
|
193
|
+
target = sql_to_z3(expr.this, variables)
|
|
194
|
+
low = sql_to_z3(expr.args['low'], variables)
|
|
195
|
+
high = sql_to_z3(expr.args['high'], variables)
|
|
196
|
+
return And(target >= low, target <= high)
|
|
197
|
+
|
|
198
|
+
# --- IN ---
|
|
199
|
+
elif isinstance(expr, exp.In):
|
|
200
|
+
target = sql_to_z3(expr.this, variables)
|
|
201
|
+
|
|
202
|
+
if isinstance(expr.args.get('query'), exp.Subquery):
|
|
203
|
+
# subquery → symbolic value
|
|
204
|
+
sym = fresh_symbol('subq_in', 'string')
|
|
205
|
+
return target == sym
|
|
206
|
+
|
|
207
|
+
options = [sql_to_z3(e, variables) for e in expr.expressions]
|
|
208
|
+
|
|
209
|
+
return Or(*[target == o for o in options])
|
|
210
|
+
|
|
211
|
+
# --- IS / IS NOT ---
|
|
212
|
+
elif isinstance(expr, exp.Is):
|
|
213
|
+
target_expr = expr.this
|
|
214
|
+
right_expr = expr.args.get('expression')
|
|
215
|
+
|
|
216
|
+
if isinstance(right_expr, exp.Null):
|
|
217
|
+
if isinstance(target_expr, exp.Column):
|
|
218
|
+
name = target_expr.name.lower()
|
|
219
|
+
flag = variables.setdefault(f'{name}_isnull', Bool(f'{name}_isnull'))
|
|
220
|
+
return flag
|
|
221
|
+
return BoolVal(False)
|
|
222
|
+
|
|
223
|
+
if isinstance(right_expr, exp.Not) and isinstance(right_expr.this, exp.Null):
|
|
224
|
+
if isinstance(target_expr, exp.Column):
|
|
225
|
+
name = target_expr.name.lower()
|
|
226
|
+
flag = variables.setdefault(f'{name}_isnull', Bool(f'{name}_isnull'))
|
|
227
|
+
return Not(flag)
|
|
228
|
+
return BoolVal(True)
|
|
229
|
+
|
|
230
|
+
return sql_to_z3(target_expr, variables) == sql_to_z3(right_expr, variables)
|
|
231
|
+
|
|
232
|
+
# --- LIKE ---
|
|
233
|
+
elif isinstance(expr, exp.Like):
|
|
234
|
+
target = sql_to_z3(expr.this, variables)
|
|
235
|
+
pattern_expr = sql_to_z3(expr.expression, variables)
|
|
236
|
+
|
|
237
|
+
# If pattern is a variable → fallback
|
|
238
|
+
if not isinstance(expr.expression, exp.Literal):
|
|
239
|
+
return target == pattern_expr
|
|
240
|
+
|
|
241
|
+
pattern = expr.expression.this.strip("'")
|
|
242
|
+
wildcard_count = pattern.count('%') + pattern.count('_')
|
|
243
|
+
|
|
244
|
+
if wildcard_count > 2:
|
|
245
|
+
return target == StringVal(pattern)
|
|
246
|
+
|
|
247
|
+
if '%' in pattern and '_' not in pattern:
|
|
248
|
+
# PREFIX pattern: abc%
|
|
249
|
+
if pattern.endswith('%') and pattern.count('%') == 1:
|
|
250
|
+
prefix = pattern[:-1]
|
|
251
|
+
return PrefixOf(StringVal(prefix), target)
|
|
252
|
+
|
|
253
|
+
# CONTAINS: %abc%
|
|
254
|
+
if pattern.startswith('%') and pattern.endswith('%') and pattern.count('%') == 2:
|
|
255
|
+
mid = pattern[1:-1]
|
|
256
|
+
return Contains(target, StringVal(mid))
|
|
257
|
+
|
|
258
|
+
# SUFFIX: %abc
|
|
259
|
+
if pattern.startswith('%') and pattern.count('%') == 1:
|
|
260
|
+
suffix = pattern[1:]
|
|
261
|
+
return SuffixOf(StringVal(suffix), target)
|
|
262
|
+
|
|
263
|
+
# EXACTLY ONE '_' wildcard
|
|
264
|
+
if '_' in pattern and '%' not in pattern and wildcard_count == 1:
|
|
265
|
+
parts = pattern.split('_')
|
|
266
|
+
regex = None
|
|
267
|
+
for i, p in enumerate(parts):
|
|
268
|
+
r = Re(StringVal(p))
|
|
269
|
+
regex = r if regex is None else Concat(regex, r)
|
|
270
|
+
if i < len(parts) - 1:
|
|
271
|
+
regex = Concat(regex, AllChar(r.sort()))
|
|
272
|
+
return InRe(target, regex)
|
|
273
|
+
|
|
274
|
+
return target == StringVal(pattern)
|
|
275
|
+
|
|
276
|
+
# --- EXISTS ---
|
|
277
|
+
elif isinstance(expr, exp.Exists):
|
|
278
|
+
return fresh_symbol('subq_exists', 'bool')
|
|
279
|
+
|
|
280
|
+
# --- SUBQUERY ---
|
|
281
|
+
elif isinstance(expr, exp.Subquery):
|
|
282
|
+
sort = infer_subquery_sort_from_parent(expr)
|
|
283
|
+
if sort == 'int':
|
|
284
|
+
return fresh_symbol('subq_val', 'int')
|
|
285
|
+
elif sort == 'real':
|
|
286
|
+
return fresh_symbol('subq_val', 'real')
|
|
287
|
+
elif sort == 'string':
|
|
288
|
+
return fresh_symbol('subq_val', 'string')
|
|
289
|
+
else:
|
|
290
|
+
return fresh_symbol('subq_bool', 'bool')
|
|
291
|
+
|
|
292
|
+
# --- Fallback ---
|
|
293
|
+
return BoolVal(True)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# ----------------------------------------------------------------------
|
|
297
|
+
# Formula checking
|
|
298
|
+
# ----------------------------------------------------------------------
|
|
299
|
+
|
|
300
|
+
def check_formula(expr) -> str:
|
|
301
|
+
'''Check if the given SQLGlot expression is a tautology, contradiction, or contingent.'''
|
|
302
|
+
|
|
303
|
+
formula = sql_to_z3(expr, {})
|
|
304
|
+
|
|
305
|
+
if formula is None:
|
|
306
|
+
return 'unknown'
|
|
307
|
+
|
|
308
|
+
solver = Solver()
|
|
309
|
+
|
|
310
|
+
solver.push()
|
|
311
|
+
solver.add(formula)
|
|
312
|
+
|
|
313
|
+
if solver.check() == unsat:
|
|
314
|
+
return 'contradiction'
|
|
315
|
+
|
|
316
|
+
solver.pop()
|
|
317
|
+
solver.push()
|
|
318
|
+
solver.add(Not(formula))
|
|
319
|
+
|
|
320
|
+
if solver.check() == unsat:
|
|
321
|
+
return 'tautology'
|
|
322
|
+
|
|
323
|
+
return 'contingent'
|
|
324
|
+
|
|
325
|
+
def is_satisfiable(expr_z3) -> bool:
|
|
326
|
+
|
|
327
|
+
solver = Solver()
|
|
328
|
+
solver.add(expr_z3)
|
|
329
|
+
result = solver.check() != unsat
|
|
330
|
+
|
|
331
|
+
return result
|
|
332
|
+
|
|
333
|
+
def is_bool_expr(e) -> bool:
|
|
334
|
+
return is_expr(e) and e.sort().kind() == BoolSort().kind()
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
'''Query representation as a list of tokens. Works even for invalid SQL.'''
|
|
2
|
+
|
|
3
|
+
import sqlparse
|
|
4
|
+
import sqlparse.tokens
|
|
5
|
+
from sqlparse.tokens import Whitespace, Newline
|
|
6
|
+
|
|
7
|
+
from . import extractors
|
|
8
|
+
|
|
9
|
+
class TokenizedSQL:
|
|
10
|
+
'''Base class for tokenizing SQL queries.'''
|
|
11
|
+
|
|
12
|
+
def __init__(self, sql: str) -> None:
|
|
13
|
+
self.sql = sql
|
|
14
|
+
'''The full SQL query string.'''
|
|
15
|
+
|
|
16
|
+
# Lazy properties
|
|
17
|
+
self._tokens = None
|
|
18
|
+
self._functions = None
|
|
19
|
+
self._comparisons = None
|
|
20
|
+
# End of lazy properties
|
|
21
|
+
|
|
22
|
+
parsed_statements = sqlparse.parse(self.sql)
|
|
23
|
+
if not parsed_statements:
|
|
24
|
+
self.all_statements: list[sqlparse.sql.Statement] = []
|
|
25
|
+
self.parsed = sqlparse.sql.Statement()
|
|
26
|
+
else:
|
|
27
|
+
self.all_statements = list(parsed_statements)
|
|
28
|
+
self.parsed = parsed_statements[0]
|
|
29
|
+
|
|
30
|
+
# region Properties
|
|
31
|
+
@property
|
|
32
|
+
def tokens(self) -> list[tuple[sqlparse.tokens._TokenType, str]]:
|
|
33
|
+
'''Returns a flattened list of tokens as (ttype, value) tuples, excluding whitespace and newlines.'''
|
|
34
|
+
if not self._tokens:
|
|
35
|
+
self._tokens = self._flatten()
|
|
36
|
+
return self._tokens
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def functions(self) -> list[tuple[sqlparse.sql.Function, str]]:
|
|
40
|
+
'''Returns a list of (function, clause) tuples found in the SQL query.'''
|
|
41
|
+
|
|
42
|
+
if self._functions is None:
|
|
43
|
+
self._functions = extractors.extract_functions(self.parsed.tokens)
|
|
44
|
+
return self._functions
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def comparisons(self) -> list[tuple[sqlparse.sql.Comparison, str]]:
|
|
48
|
+
'''Returns a list of (comparison, clause) tuples found in the SQL query.'''
|
|
49
|
+
if self._comparisons is None:
|
|
50
|
+
self._comparisons = extractors.extract_comparisons(self.parsed.tokens)
|
|
51
|
+
return self._comparisons
|
|
52
|
+
|
|
53
|
+
# endregion
|
|
54
|
+
|
|
55
|
+
def _flatten(self) -> list[tuple[sqlparse.tokens._TokenType, str]]:
|
|
56
|
+
'''Flattens the parsed SQL statement into a list of (ttype, value) tuples. Ignores whitespace and newlines.'''
|
|
57
|
+
|
|
58
|
+
if not self.parsed:
|
|
59
|
+
return []
|
|
60
|
+
|
|
61
|
+
# Flatten tokens into (ttype, value)
|
|
62
|
+
return [
|
|
63
|
+
(tok.ttype, tok.value) for tok in self.parsed.flatten()
|
|
64
|
+
if tok.ttype not in (Whitespace, Newline)
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
def print_tree(self) -> None:
|
|
68
|
+
for stmt in self.all_statements:
|
|
69
|
+
stmt._pprint_tree()
|
|
70
|
+
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from .base import get_type
|
|
2
|
+
from . import primitives, functions, queries, unary_ops, binary_ops, predicates
|
|
3
|
+
from ...catalog import Catalog
|
|
4
|
+
from sqlglot import exp
|
|
5
|
+
from sqlglot.optimizer.annotate_types import annotate_types
|
|
6
|
+
from sqlglot.optimizer.qualify import qualify
|
|
7
|
+
|
|
8
|
+
__all__ = ["get_type"]
|
|
9
|
+
|
|
10
|
+
def rewrite_expression(expression: exp.Expression, catalog: Catalog, search_path: str = 'public') -> exp.Expression:
|
|
11
|
+
'''
|
|
12
|
+
Rewrites the expression by annotating types to its nodes based on the catalog.
|
|
13
|
+
'''
|
|
14
|
+
|
|
15
|
+
schema = catalog.to_sqlglot_schema()
|
|
16
|
+
|
|
17
|
+
return annotate_types(qualify(expression, schema=schema, db=search_path, validate_qualify_columns=False), schema)
|
|
18
|
+
|
|
19
|
+
# This function needs to be called on a typed expression
|
|
20
|
+
def collect_errors(expression: exp.Expression, catalog: Catalog, search_path: str = 'public') -> list[tuple[str, str, str | None]]:
|
|
21
|
+
return get_type(expression, catalog, search_path).messages
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
import sqlglot.expressions as exp
|
|
2
|
+
from .types import AtomicType, ResultType
|
|
3
|
+
from ...catalog import Catalog
|
|
4
|
+
from functools import singledispatch
|
|
5
|
+
|
|
6
|
+
@singledispatch
|
|
7
|
+
def get_type(expression: exp.Expression, catalog: Catalog, search_path: str) -> ResultType:
|
|
8
|
+
'''Returns the type of the given SQL expression.'''
|
|
9
|
+
return AtomicType() # Default to unhandled expression
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from .base import get_type
|
|
2
|
+
from ...catalog import Catalog
|
|
3
|
+
from sqlglot import exp
|
|
4
|
+
from .types import ResultType, AtomicType
|
|
5
|
+
from sqlglot.expressions import DataType
|
|
6
|
+
from .util import is_number, to_number, to_date, error_message
|
|
7
|
+
|
|
8
|
+
@get_type.register
|
|
9
|
+
def _(expression: exp.Binary, catalog: Catalog, search_path: str) -> ResultType:
|
|
10
|
+
left_type = get_type(expression.this, catalog, search_path)
|
|
11
|
+
right_type = get_type(expression.expression, catalog, search_path)
|
|
12
|
+
|
|
13
|
+
old_messages = left_type.messages + right_type.messages
|
|
14
|
+
|
|
15
|
+
# handle comparison operators
|
|
16
|
+
if isinstance(expression, exp.Predicate):
|
|
17
|
+
return typecheck_comparisons(left_type, right_type, expression, old_messages)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
if left_type != right_type:
|
|
21
|
+
|
|
22
|
+
if left_type.data_type != DataType.Type.UNKNOWN and not to_number(left_type) and left_type.data_type != DataType.Type.NULL:
|
|
23
|
+
old_messages.append(error_message(expression, left_type, "numeric"))
|
|
24
|
+
|
|
25
|
+
if right_type.data_type != DataType.Type.UNKNOWN and not to_number(right_type) and right_type.data_type != DataType.Type.NULL:
|
|
26
|
+
old_messages.append(error_message(expression, right_type, "numeric"))
|
|
27
|
+
|
|
28
|
+
elif DataType.Type.UNKNOWN != left_type.data_type and not is_number(left_type.data_type) and not is_number(right_type.data_type):
|
|
29
|
+
if left_type.data_type != DataType.Type.NULL or right_type.data_type != DataType.Type.NULL:
|
|
30
|
+
old_messages.append(error_message(expression, left_type, "numeric"))
|
|
31
|
+
|
|
32
|
+
return AtomicType(data_type=expression.type.this, nullable=left_type.nullable or right_type.nullable, constant=left_type.constant and right_type.constant, messages=old_messages)
|
|
33
|
+
|
|
34
|
+
# handle comparison typechecking (e.g =, <, >, etc.)
|
|
35
|
+
def typecheck_comparisons(left_type: ResultType, right_type: ResultType, expression: exp.Binary, old_messages: list) -> ResultType:
|
|
36
|
+
|
|
37
|
+
if DataType.Type.UNKNOWN in (left_type.data_type, right_type.data_type):
|
|
38
|
+
return AtomicType(data_type=expression.type.this,messages=old_messages)
|
|
39
|
+
|
|
40
|
+
# for boolean comparisons we can have only equality/inequality
|
|
41
|
+
if DataType.Type.BOOLEAN == left_type.data_type == right_type.data_type:
|
|
42
|
+
if not isinstance(expression, (exp.EQ, exp.NEQ)):
|
|
43
|
+
old_messages.append(error_message(expression, left_type, "boolean"))
|
|
44
|
+
|
|
45
|
+
if left_type != right_type and left_type.data_type != DataType.Type.NULL and right_type.data_type != DataType.Type.NULL:
|
|
46
|
+
|
|
47
|
+
# handle implicit casts
|
|
48
|
+
if to_number(left_type) and to_number(right_type):
|
|
49
|
+
return AtomicType(data_type=expression.type.this, nullable=False, constant=True, messages=old_messages)
|
|
50
|
+
|
|
51
|
+
if to_date(left_type) and to_date(right_type):
|
|
52
|
+
return AtomicType(data_type=expression.type.this, nullable=False, constant=True, messages=old_messages)
|
|
53
|
+
|
|
54
|
+
old_messages.append(error_message(expression, left_type.data_type_str + " & " + right_type.data_type_str))
|
|
55
|
+
|
|
56
|
+
# Always returns boolean
|
|
57
|
+
return AtomicType(data_type=expression.type.this, nullable=False, constant=True, messages=old_messages)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from .base import get_type
|
|
2
|
+
from ...catalog import Catalog
|
|
3
|
+
from sqlglot import exp
|
|
4
|
+
from .types import ResultType, AtomicType
|
|
5
|
+
from sqlglot.expressions import DataType
|
|
6
|
+
from .util import is_number, error_message
|
|
7
|
+
|
|
8
|
+
@get_type.register
|
|
9
|
+
def _(expression: exp.Count, catalog: Catalog, search_path: str) -> ResultType:
|
|
10
|
+
old_messages = get_type(expression.this, catalog, search_path).messages
|
|
11
|
+
|
|
12
|
+
return AtomicType(data_type=expression.type.this, nullable=False, constant=True, messages=old_messages)
|
|
13
|
+
|
|
14
|
+
@get_type.register
|
|
15
|
+
def _(expression: exp.Avg, catalog: Catalog, search_path: str) -> ResultType:
|
|
16
|
+
inner_type = get_type(expression.this, catalog, search_path)
|
|
17
|
+
|
|
18
|
+
old_messages = inner_type.messages
|
|
19
|
+
|
|
20
|
+
if inner_type.data_type != DataType.Type.UNKNOWN and not is_number(inner_type.data_type):
|
|
21
|
+
old_messages.append(error_message(expression, inner_type, "NUMERIC"))
|
|
22
|
+
|
|
23
|
+
return AtomicType(data_type=expression.type.this, nullable=True, constant=True, messages=old_messages)
|
|
24
|
+
|
|
25
|
+
@get_type.register
|
|
26
|
+
def _(expression: exp.Sum, catalog: Catalog, search_path: str) -> ResultType:
|
|
27
|
+
inner_type = get_type(expression.this, catalog, search_path)
|
|
28
|
+
|
|
29
|
+
old_messages = inner_type.messages
|
|
30
|
+
|
|
31
|
+
if inner_type.data_type != DataType.Type.UNKNOWN and not is_number(inner_type.data_type):
|
|
32
|
+
old_messages.append(error_message(expression, inner_type, "NUMERIC"))
|
|
33
|
+
|
|
34
|
+
return AtomicType(data_type=expression.type.this, nullable=True, constant=True, messages=old_messages)
|
|
35
|
+
|
|
36
|
+
@get_type.register
|
|
37
|
+
def _(expression: exp.Min, catalog: Catalog, search_path: str) -> ResultType:
|
|
38
|
+
inner_type = get_type(expression.this, catalog, search_path)
|
|
39
|
+
|
|
40
|
+
old_messages = inner_type.messages
|
|
41
|
+
|
|
42
|
+
if inner_type.data_type != DataType.Type.UNKNOWN and inner_type.data_type == DataType.Type.BOOLEAN:
|
|
43
|
+
old_messages.append(error_message(expression, inner_type))
|
|
44
|
+
|
|
45
|
+
return AtomicType(data_type=inner_type.data_type, nullable=inner_type.nullable, constant=True, messages=old_messages)
|
|
46
|
+
|
|
47
|
+
@get_type.register
|
|
48
|
+
def _(expression: exp.Max, catalog: Catalog, search_path: str) -> ResultType:
|
|
49
|
+
inner_type = get_type(expression.this, catalog, search_path)
|
|
50
|
+
|
|
51
|
+
old_messages = inner_type.messages
|
|
52
|
+
|
|
53
|
+
if inner_type.data_type != DataType.Type.UNKNOWN and inner_type.data_type == DataType.Type.BOOLEAN:
|
|
54
|
+
old_messages.append(error_message(expression, inner_type))
|
|
55
|
+
|
|
56
|
+
return AtomicType(data_type=inner_type.data_type, nullable=inner_type.nullable, constant=True, messages=old_messages)
|
|
57
|
+
|
|
58
|
+
@get_type.register
|
|
59
|
+
def _(expression: exp.Concat, catalog: Catalog, search_path: str) -> ResultType:
|
|
60
|
+
old_messages = []
|
|
61
|
+
args_type = []
|
|
62
|
+
|
|
63
|
+
for arg in expression.expressions:
|
|
64
|
+
arg_type = get_type(arg, catalog, search_path)
|
|
65
|
+
if arg_type.messages:
|
|
66
|
+
old_messages.extend(arg_type.messages)
|
|
67
|
+
args_type.append(arg_type)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
if not args_type:
|
|
71
|
+
old_messages.append(error_message(expression, "Empty arguments"))
|
|
72
|
+
|
|
73
|
+
# if all args are NULL, result is NULL
|
|
74
|
+
if all(target_type.data_type == DataType.Type.NULL for target_type in args_type):
|
|
75
|
+
return AtomicType(data_type=DataType.Type.NULL, constant=True, messages=old_messages)
|
|
76
|
+
|
|
77
|
+
constant = all(target_type.constant for target_type in args_type)
|
|
78
|
+
nullable = any(target_type.nullable for target_type in args_type)
|
|
79
|
+
|
|
80
|
+
# Always returns VARCHAR
|
|
81
|
+
return AtomicType(data_type=expression.type.this, nullable=nullable, constant=constant, messages=old_messages)
|