sqlscope 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. sqlscope/__init__.py +4 -0
  2. sqlscope/catalog/__init__.py +12 -0
  3. sqlscope/catalog/builder/__init__.py +8 -0
  4. sqlscope/catalog/builder/postgres.py +207 -0
  5. sqlscope/catalog/builder/sql.py +219 -0
  6. sqlscope/catalog/catalog.py +147 -0
  7. sqlscope/catalog/column.py +68 -0
  8. sqlscope/catalog/constraint.py +83 -0
  9. sqlscope/catalog/schema.py +60 -0
  10. sqlscope/catalog/table.py +112 -0
  11. sqlscope/query/__init__.py +5 -0
  12. sqlscope/query/extractors.py +118 -0
  13. sqlscope/query/query.py +191 -0
  14. sqlscope/query/set_operations/__init__.py +181 -0
  15. sqlscope/query/set_operations/binary_set_operation.py +162 -0
  16. sqlscope/query/set_operations/select.py +664 -0
  17. sqlscope/query/set_operations/set_operation.py +59 -0
  18. sqlscope/query/smt.py +334 -0
  19. sqlscope/query/tokenized_sql.py +70 -0
  20. sqlscope/query/typechecking/__init__.py +21 -0
  21. sqlscope/query/typechecking/base.py +9 -0
  22. sqlscope/query/typechecking/binary_ops.py +57 -0
  23. sqlscope/query/typechecking/functions.py +81 -0
  24. sqlscope/query/typechecking/predicates.py +123 -0
  25. sqlscope/query/typechecking/primitives.py +80 -0
  26. sqlscope/query/typechecking/queries.py +35 -0
  27. sqlscope/query/typechecking/types.py +59 -0
  28. sqlscope/query/typechecking/unary_ops.py +51 -0
  29. sqlscope/query/typechecking/util.py +51 -0
  30. sqlscope/util/__init__.py +18 -0
  31. sqlscope/util/ast/__init__.py +55 -0
  32. sqlscope/util/ast/column.py +55 -0
  33. sqlscope/util/ast/function.py +10 -0
  34. sqlscope/util/ast/subquery.py +23 -0
  35. sqlscope/util/ast/table.py +36 -0
  36. sqlscope/util/sql.py +27 -0
  37. sqlscope/util/tokens.py +17 -0
  38. sqlscope-1.0.0.dist-info/METADATA +52 -0
  39. sqlscope-1.0.0.dist-info/RECORD +41 -0
  40. sqlscope-1.0.0.dist-info/WHEEL +4 -0
  41. sqlscope-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,59 @@
1
+ import sqlglot
2
+ from sqlglot import exp
3
+ from ...catalog import Table, Column
4
+
5
+ from abc import ABC, abstractmethod
6
+
7
+ from typing import TYPE_CHECKING
8
+ if TYPE_CHECKING:
9
+ from .select import Select
10
+
11
+
12
+ class SetOperation(ABC):
13
+ '''
14
+ Abstract base class for SQL set operations (i.e., SELECT, UNION, INTERSECT, EXCEPT).
15
+ '''
16
+
17
+ def __init__(self, sql: str, parent_query: 'Select | None' = None) -> None:
18
+ self.sql = sql
19
+ '''The SQL string representing the operation.'''
20
+
21
+ self.parent_query = parent_query
22
+ '''The parent Select if this is a subquery.'''
23
+
24
+ @property
25
+ @abstractmethod
26
+ def output(self) -> Table:
27
+ '''Returns the output table schema of the set operation.'''
28
+ pass
29
+
30
+ @property
31
+ @abstractmethod
32
+ def referenced_tables(self) -> list[Table]:
33
+ '''Returns a list of tables that are referenced in the SQL query.'''
34
+ pass
35
+
36
+ def __repr__(self, pre: str = '') -> str:
37
+ return f'{pre}{self.__class__.__name__}'
38
+
39
+
40
+ @abstractmethod
41
+ def print_tree(self, pre: str = '') -> None:
42
+ pass
43
+
44
+ @property
45
+ @abstractmethod
46
+ def main_selects(self) -> list['Select']:
47
+ '''Returns a list of selects that are part of a set operation.'''
48
+ return []
49
+
50
+ @property
51
+ @abstractmethod
52
+ def selects(self) -> list['Select']:
53
+ '''Returns a list of all Select nodes in the tree.'''
54
+ return []
55
+
56
+
57
+
58
+
59
+
sqlscope/query/smt.py ADDED
@@ -0,0 +1,334 @@
1
+ '''Convert SQL expressions to Z3 expressions for logical reasoning.'''
2
+
3
+ from typing import Any, Callable
4
+ from sqlglot import exp
5
+ from z3 import (
6
+ Int, IntVal,
7
+ Real, RealVal,
8
+ Bool, BoolVal,
9
+ String, StringVal,
10
+ And, Or, Not,
11
+ Solver,
12
+ unsat,
13
+ is_expr,
14
+ BoolSort,
15
+ ExprRef,
16
+ Re,
17
+ AllChar,
18
+ Concat,
19
+ InRe,
20
+ PrefixOf,
21
+ SuffixOf,
22
+ Contains,
23
+ )
24
+
25
+ from ..catalog import Table
26
+
27
+
28
+ # ----------------------------------------------------------------------
29
+ # Z3 variable creation
30
+ # ----------------------------------------------------------------------
31
+
32
+ def create_z3_var(variables: dict[str, Any], table_name: str | None,
33
+ col_name: str, col_type: Callable[[str], ExprRef] | None = None) -> None:
34
+ '''
35
+ Create a Z3 variable for the given column name and type, and add it to the
36
+ variables dictionary. If col_type is None, default to Int.
37
+ '''
38
+ if col_type is None:
39
+ col_type = Int # default type
40
+
41
+ # unqualified
42
+ variables[col_name] = col_type(col_name)
43
+ variables[f'{col_name}_isnull'] = Bool(f'{col_name}_isnull')
44
+
45
+ # qualified
46
+ if table_name:
47
+ variables[f'{table_name}.{col_name}'] = col_type(f'{table_name}.{col_name}')
48
+ variables[f'{table_name}.{col_name}_isnull'] = Bool(f'{table_name}.{col_name}_isnull')
49
+
50
+
51
+ def fresh_symbol(prefix: str, sort: str):
52
+ '''Generate a fresh Z3 symbol with the given prefix and sort.'''
53
+ if sort == 'int':
54
+ return Int(f'{prefix}_{id(prefix)}')
55
+ if sort == 'real':
56
+ return Real(f'{prefix}_{id(prefix)}')
57
+ if sort == 'bool':
58
+ return Bool(f'{prefix}_{id(prefix)}')
59
+ return String(f'{prefix}_{id(prefix)}')
60
+
61
+
62
+ # ----------------------------------------------------------------------
63
+ # Infer expected type of a subquery based on parent expression
64
+ # ----------------------------------------------------------------------
65
+
66
+ def infer_subquery_sort_from_parent(expr) -> str:
67
+ '''
68
+ Infer the expected Z3 sort of a subquery based on its parent expression.
69
+ '''
70
+ parent = expr.parent
71
+
72
+ # Arithmetic context → numeric
73
+ if isinstance(parent, (exp.Add, exp.Sub, exp.Mul, exp.Div, exp.Mod, exp.Pow)):
74
+ return 'real'
75
+
76
+ # Comparison context → numeric
77
+ if isinstance(parent, (exp.GT, exp.GTE, exp.LT, exp.LTE)):
78
+ return 'real'
79
+
80
+ # BETWEEN → numeric
81
+ if isinstance(parent, exp.Between):
82
+ return 'real'
83
+
84
+ # LIKE → string
85
+ if isinstance(parent, exp.Like):
86
+ return 'string'
87
+
88
+ # String concatenation (|| operator)
89
+ if isinstance(parent, exp.Concat):
90
+ return 'string'
91
+
92
+ # Default: boolean (EXISTS, WHERE (...))
93
+ return 'bool'
94
+
95
+
96
+ # ----------------------------------------------------------------------
97
+ # Catalog → Z3 vars
98
+ # ----------------------------------------------------------------------
99
+
100
+ def catalog_table_to_z3_vars(table: Table) -> dict[str, ExprRef]:
101
+ '''Convert catalog table columns to Z3 variables.'''
102
+ variables = {}
103
+ for column in table.columns:
104
+ col_name = column.name
105
+ col_type = column.column_type.upper()
106
+
107
+ if col_type in ('INT', 'INTEGER', 'BIGINT', 'SMALLINT'):
108
+ create_z3_var(variables, table.name, col_name, Int)
109
+ elif col_type in ('FLOAT', 'REAL', 'DOUBLE'):
110
+ create_z3_var(variables, table.name, col_name, Real)
111
+ elif col_type in ('BOOLEAN', 'BOOL'):
112
+ create_z3_var(variables, table.name, col_name, Bool)
113
+ elif col_type in ('VARCHAR', 'CHAR', 'TEXT', 'CHARACTER VARYING'):
114
+ create_z3_var(variables, table.name, col_name, String)
115
+ else:
116
+ create_z3_var(variables, table.name, col_name)
117
+ return variables
118
+
119
+
120
+ # ----------------------------------------------------------------------
121
+ # SQL → Z3 conversion
122
+ # ----------------------------------------------------------------------
123
+
124
+ def sql_to_z3(expr, variables: dict[str, ExprRef] = {}) -> Any:
125
+ '''Convert a SQLGlot expression to a Z3 expression.'''
126
+
127
+ # --- Columns ---
128
+ if isinstance(expr, exp.Column):
129
+ name = expr.name.lower()
130
+ if name not in variables:
131
+ create_z3_var(variables, None, name)
132
+ return variables[name]
133
+
134
+ # --- Literals ---
135
+ elif isinstance(expr, exp.Literal):
136
+ val = expr.this
137
+ if expr.is_int:
138
+ return IntVal(int(val))
139
+ elif expr.is_number:
140
+ return RealVal(float(val))
141
+ elif expr.is_string:
142
+ return StringVal(val.strip("'"))
143
+ elif val.upper() in ('TRUE', 'FALSE'):
144
+ return BoolVal(val.upper() == 'TRUE')
145
+ elif val.upper() == 'NULL':
146
+ return None
147
+ else:
148
+ raise NotImplementedError(f"Unsupported literal: {val}")
149
+
150
+ elif isinstance(expr, exp.Null):
151
+ return None
152
+
153
+ # --- Boolean comparisons ---
154
+ elif isinstance(expr, exp.EQ):
155
+ return sql_to_z3(expr.left, variables) == sql_to_z3(expr.right, variables)
156
+ elif isinstance(expr, exp.NEQ):
157
+ return sql_to_z3(expr.left, variables) != sql_to_z3(expr.right, variables)
158
+ elif isinstance(expr, exp.GT):
159
+ return sql_to_z3(expr.left, variables) > sql_to_z3(expr.right, variables)
160
+ elif isinstance(expr, exp.GTE):
161
+ return sql_to_z3(expr.left, variables) >= sql_to_z3(expr.right, variables)
162
+ elif isinstance(expr, exp.LT):
163
+ return sql_to_z3(expr.left, variables) < sql_to_z3(expr.right, variables)
164
+ elif isinstance(expr, exp.LTE):
165
+ return sql_to_z3(expr.left, variables) <= sql_to_z3(expr.right, variables)
166
+
167
+ # --- Logical connectives ---
168
+ elif isinstance(expr, exp.And):
169
+ return And(sql_to_z3(expr.left, variables), sql_to_z3(expr.right, variables))
170
+ elif isinstance(expr, exp.Or):
171
+ return Or(sql_to_z3(expr.left, variables), sql_to_z3(expr.right, variables))
172
+ elif isinstance(expr, exp.Not):
173
+ return Not(sql_to_z3(expr.this, variables))
174
+ elif isinstance(expr, exp.Paren):
175
+ return sql_to_z3(expr.this, variables)
176
+
177
+ # --- Arithmetic ---
178
+ elif isinstance(expr, exp.Add):
179
+ return sql_to_z3(expr.left, variables) + sql_to_z3(expr.right, variables)
180
+ elif isinstance(expr, exp.Sub):
181
+ return sql_to_z3(expr.left, variables) - sql_to_z3(expr.right, variables)
182
+ elif isinstance(expr, exp.Mul):
183
+ return sql_to_z3(expr.left, variables) * sql_to_z3(expr.right, variables)
184
+ elif isinstance(expr, exp.Div):
185
+ return sql_to_z3(expr.left, variables) / sql_to_z3(expr.right, variables)
186
+ elif isinstance(expr, exp.Mod):
187
+ return sql_to_z3(expr.left, variables) % sql_to_z3(expr.right, variables)
188
+ elif isinstance(expr, exp.Pow):
189
+ return sql_to_z3(expr.left, variables) ** sql_to_z3(expr.right, variables)
190
+
191
+ # --- BETWEEN ---
192
+ elif isinstance(expr, exp.Between):
193
+ target = sql_to_z3(expr.this, variables)
194
+ low = sql_to_z3(expr.args['low'], variables)
195
+ high = sql_to_z3(expr.args['high'], variables)
196
+ return And(target >= low, target <= high)
197
+
198
+ # --- IN ---
199
+ elif isinstance(expr, exp.In):
200
+ target = sql_to_z3(expr.this, variables)
201
+
202
+ if isinstance(expr.args.get('query'), exp.Subquery):
203
+ # subquery → symbolic value
204
+ sym = fresh_symbol('subq_in', 'string')
205
+ return target == sym
206
+
207
+ options = [sql_to_z3(e, variables) for e in expr.expressions]
208
+
209
+ return Or(*[target == o for o in options])
210
+
211
+ # --- IS / IS NOT ---
212
+ elif isinstance(expr, exp.Is):
213
+ target_expr = expr.this
214
+ right_expr = expr.args.get('expression')
215
+
216
+ if isinstance(right_expr, exp.Null):
217
+ if isinstance(target_expr, exp.Column):
218
+ name = target_expr.name.lower()
219
+ flag = variables.setdefault(f'{name}_isnull', Bool(f'{name}_isnull'))
220
+ return flag
221
+ return BoolVal(False)
222
+
223
+ if isinstance(right_expr, exp.Not) and isinstance(right_expr.this, exp.Null):
224
+ if isinstance(target_expr, exp.Column):
225
+ name = target_expr.name.lower()
226
+ flag = variables.setdefault(f'{name}_isnull', Bool(f'{name}_isnull'))
227
+ return Not(flag)
228
+ return BoolVal(True)
229
+
230
+ return sql_to_z3(target_expr, variables) == sql_to_z3(right_expr, variables)
231
+
232
+ # --- LIKE ---
233
+ elif isinstance(expr, exp.Like):
234
+ target = sql_to_z3(expr.this, variables)
235
+ pattern_expr = sql_to_z3(expr.expression, variables)
236
+
237
+ # If pattern is a variable → fallback
238
+ if not isinstance(expr.expression, exp.Literal):
239
+ return target == pattern_expr
240
+
241
+ pattern = expr.expression.this.strip("'")
242
+ wildcard_count = pattern.count('%') + pattern.count('_')
243
+
244
+ if wildcard_count > 2:
245
+ return target == StringVal(pattern)
246
+
247
+ if '%' in pattern and '_' not in pattern:
248
+ # PREFIX pattern: abc%
249
+ if pattern.endswith('%') and pattern.count('%') == 1:
250
+ prefix = pattern[:-1]
251
+ return PrefixOf(StringVal(prefix), target)
252
+
253
+ # CONTAINS: %abc%
254
+ if pattern.startswith('%') and pattern.endswith('%') and pattern.count('%') == 2:
255
+ mid = pattern[1:-1]
256
+ return Contains(target, StringVal(mid))
257
+
258
+ # SUFFIX: %abc
259
+ if pattern.startswith('%') and pattern.count('%') == 1:
260
+ suffix = pattern[1:]
261
+ return SuffixOf(StringVal(suffix), target)
262
+
263
+ # EXACTLY ONE '_' wildcard
264
+ if '_' in pattern and '%' not in pattern and wildcard_count == 1:
265
+ parts = pattern.split('_')
266
+ regex = None
267
+ for i, p in enumerate(parts):
268
+ r = Re(StringVal(p))
269
+ regex = r if regex is None else Concat(regex, r)
270
+ if i < len(parts) - 1:
271
+ regex = Concat(regex, AllChar(r.sort()))
272
+ return InRe(target, regex)
273
+
274
+ return target == StringVal(pattern)
275
+
276
+ # --- EXISTS ---
277
+ elif isinstance(expr, exp.Exists):
278
+ return fresh_symbol('subq_exists', 'bool')
279
+
280
+ # --- SUBQUERY ---
281
+ elif isinstance(expr, exp.Subquery):
282
+ sort = infer_subquery_sort_from_parent(expr)
283
+ if sort == 'int':
284
+ return fresh_symbol('subq_val', 'int')
285
+ elif sort == 'real':
286
+ return fresh_symbol('subq_val', 'real')
287
+ elif sort == 'string':
288
+ return fresh_symbol('subq_val', 'string')
289
+ else:
290
+ return fresh_symbol('subq_bool', 'bool')
291
+
292
+ # --- Fallback ---
293
+ return BoolVal(True)
294
+
295
+
296
+ # ----------------------------------------------------------------------
297
+ # Formula checking
298
+ # ----------------------------------------------------------------------
299
+
300
+ def check_formula(expr) -> str:
301
+ '''Check if the given SQLGlot expression is a tautology, contradiction, or contingent.'''
302
+
303
+ formula = sql_to_z3(expr, {})
304
+
305
+ if formula is None:
306
+ return 'unknown'
307
+
308
+ solver = Solver()
309
+
310
+ solver.push()
311
+ solver.add(formula)
312
+
313
+ if solver.check() == unsat:
314
+ return 'contradiction'
315
+
316
+ solver.pop()
317
+ solver.push()
318
+ solver.add(Not(formula))
319
+
320
+ if solver.check() == unsat:
321
+ return 'tautology'
322
+
323
+ return 'contingent'
324
+
325
+ def is_satisfiable(expr_z3) -> bool:
326
+
327
+ solver = Solver()
328
+ solver.add(expr_z3)
329
+ result = solver.check() != unsat
330
+
331
+ return result
332
+
333
+ def is_bool_expr(e) -> bool:
334
+ return is_expr(e) and e.sort().kind() == BoolSort().kind()
@@ -0,0 +1,70 @@
1
+ '''Query representation as a list of tokens. Works even for invalid SQL.'''
2
+
3
+ import sqlparse
4
+ import sqlparse.tokens
5
+ from sqlparse.tokens import Whitespace, Newline
6
+
7
+ from . import extractors
8
+
9
+ class TokenizedSQL:
10
+ '''Base class for tokenizing SQL queries.'''
11
+
12
+ def __init__(self, sql: str) -> None:
13
+ self.sql = sql
14
+ '''The full SQL query string.'''
15
+
16
+ # Lazy properties
17
+ self._tokens = None
18
+ self._functions = None
19
+ self._comparisons = None
20
+ # End of lazy properties
21
+
22
+ parsed_statements = sqlparse.parse(self.sql)
23
+ if not parsed_statements:
24
+ self.all_statements: list[sqlparse.sql.Statement] = []
25
+ self.parsed = sqlparse.sql.Statement()
26
+ else:
27
+ self.all_statements = list(parsed_statements)
28
+ self.parsed = parsed_statements[0]
29
+
30
+ # region Properties
31
+ @property
32
+ def tokens(self) -> list[tuple[sqlparse.tokens._TokenType, str]]:
33
+ '''Returns a flattened list of tokens as (ttype, value) tuples, excluding whitespace and newlines.'''
34
+ if not self._tokens:
35
+ self._tokens = self._flatten()
36
+ return self._tokens
37
+
38
+ @property
39
+ def functions(self) -> list[tuple[sqlparse.sql.Function, str]]:
40
+ '''Returns a list of (function, clause) tuples found in the SQL query.'''
41
+
42
+ if self._functions is None:
43
+ self._functions = extractors.extract_functions(self.parsed.tokens)
44
+ return self._functions
45
+
46
+ @property
47
+ def comparisons(self) -> list[tuple[sqlparse.sql.Comparison, str]]:
48
+ '''Returns a list of (comparison, clause) tuples found in the SQL query.'''
49
+ if self._comparisons is None:
50
+ self._comparisons = extractors.extract_comparisons(self.parsed.tokens)
51
+ return self._comparisons
52
+
53
+ # endregion
54
+
55
+ def _flatten(self) -> list[tuple[sqlparse.tokens._TokenType, str]]:
56
+ '''Flattens the parsed SQL statement into a list of (ttype, value) tuples. Ignores whitespace and newlines.'''
57
+
58
+ if not self.parsed:
59
+ return []
60
+
61
+ # Flatten tokens into (ttype, value)
62
+ return [
63
+ (tok.ttype, tok.value) for tok in self.parsed.flatten()
64
+ if tok.ttype not in (Whitespace, Newline)
65
+ ]
66
+
67
+ def print_tree(self) -> None:
68
+ for stmt in self.all_statements:
69
+ stmt._pprint_tree()
70
+
@@ -0,0 +1,21 @@
1
+ from .base import get_type
2
+ from . import primitives, functions, queries, unary_ops, binary_ops, predicates
3
+ from ...catalog import Catalog
4
+ from sqlglot import exp
5
+ from sqlglot.optimizer.annotate_types import annotate_types
6
+ from sqlglot.optimizer.qualify import qualify
7
+
8
+ __all__ = ["get_type"]
9
+
10
+ def rewrite_expression(expression: exp.Expression, catalog: Catalog, search_path: str = 'public') -> exp.Expression:
11
+ '''
12
+ Rewrites the expression by annotating types to its nodes based on the catalog.
13
+ '''
14
+
15
+ schema = catalog.to_sqlglot_schema()
16
+
17
+ return annotate_types(qualify(expression, schema=schema, db=search_path, validate_qualify_columns=False), schema)
18
+
19
+ # This function needs to be called on a typed expression
20
+ def collect_errors(expression: exp.Expression, catalog: Catalog, search_path: str = 'public') -> list[tuple[str, str, str | None]]:
21
+ return get_type(expression, catalog, search_path).messages
@@ -0,0 +1,9 @@
1
+ import sqlglot.expressions as exp
2
+ from .types import AtomicType, ResultType
3
+ from ...catalog import Catalog
4
+ from functools import singledispatch
5
+
6
+ @singledispatch
7
+ def get_type(expression: exp.Expression, catalog: Catalog, search_path: str) -> ResultType:
8
+ '''Returns the type of the given SQL expression.'''
9
+ return AtomicType() # Default to unhandled expression
@@ -0,0 +1,57 @@
1
+ from .base import get_type
2
+ from ...catalog import Catalog
3
+ from sqlglot import exp
4
+ from .types import ResultType, AtomicType
5
+ from sqlglot.expressions import DataType
6
+ from .util import is_number, to_number, to_date, error_message
7
+
8
+ @get_type.register
9
+ def _(expression: exp.Binary, catalog: Catalog, search_path: str) -> ResultType:
10
+ left_type = get_type(expression.this, catalog, search_path)
11
+ right_type = get_type(expression.expression, catalog, search_path)
12
+
13
+ old_messages = left_type.messages + right_type.messages
14
+
15
+ # handle comparison operators
16
+ if isinstance(expression, exp.Predicate):
17
+ return typecheck_comparisons(left_type, right_type, expression, old_messages)
18
+
19
+
20
+ if left_type != right_type:
21
+
22
+ if left_type.data_type != DataType.Type.UNKNOWN and not to_number(left_type) and left_type.data_type != DataType.Type.NULL:
23
+ old_messages.append(error_message(expression, left_type, "numeric"))
24
+
25
+ if right_type.data_type != DataType.Type.UNKNOWN and not to_number(right_type) and right_type.data_type != DataType.Type.NULL:
26
+ old_messages.append(error_message(expression, right_type, "numeric"))
27
+
28
+ elif DataType.Type.UNKNOWN != left_type.data_type and not is_number(left_type.data_type) and not is_number(right_type.data_type):
29
+ if left_type.data_type != DataType.Type.NULL or right_type.data_type != DataType.Type.NULL:
30
+ old_messages.append(error_message(expression, left_type, "numeric"))
31
+
32
+ return AtomicType(data_type=expression.type.this, nullable=left_type.nullable or right_type.nullable, constant=left_type.constant and right_type.constant, messages=old_messages)
33
+
34
+ # handle comparison typechecking (e.g =, <, >, etc.)
35
+ def typecheck_comparisons(left_type: ResultType, right_type: ResultType, expression: exp.Binary, old_messages: list) -> ResultType:
36
+
37
+ if DataType.Type.UNKNOWN in (left_type.data_type, right_type.data_type):
38
+ return AtomicType(data_type=expression.type.this,messages=old_messages)
39
+
40
+ # for boolean comparisons we can have only equality/inequality
41
+ if DataType.Type.BOOLEAN == left_type.data_type == right_type.data_type:
42
+ if not isinstance(expression, (exp.EQ, exp.NEQ)):
43
+ old_messages.append(error_message(expression, left_type, "boolean"))
44
+
45
+ if left_type != right_type and left_type.data_type != DataType.Type.NULL and right_type.data_type != DataType.Type.NULL:
46
+
47
+ # handle implicit casts
48
+ if to_number(left_type) and to_number(right_type):
49
+ return AtomicType(data_type=expression.type.this, nullable=False, constant=True, messages=old_messages)
50
+
51
+ if to_date(left_type) and to_date(right_type):
52
+ return AtomicType(data_type=expression.type.this, nullable=False, constant=True, messages=old_messages)
53
+
54
+ old_messages.append(error_message(expression, left_type.data_type_str + " & " + right_type.data_type_str))
55
+
56
+ # Always returns boolean
57
+ return AtomicType(data_type=expression.type.this, nullable=False, constant=True, messages=old_messages)
@@ -0,0 +1,81 @@
1
+ from .base import get_type
2
+ from ...catalog import Catalog
3
+ from sqlglot import exp
4
+ from .types import ResultType, AtomicType
5
+ from sqlglot.expressions import DataType
6
+ from .util import is_number, error_message
7
+
8
+ @get_type.register
9
+ def _(expression: exp.Count, catalog: Catalog, search_path: str) -> ResultType:
10
+ old_messages = get_type(expression.this, catalog, search_path).messages
11
+
12
+ return AtomicType(data_type=expression.type.this, nullable=False, constant=True, messages=old_messages)
13
+
14
+ @get_type.register
15
+ def _(expression: exp.Avg, catalog: Catalog, search_path: str) -> ResultType:
16
+ inner_type = get_type(expression.this, catalog, search_path)
17
+
18
+ old_messages = inner_type.messages
19
+
20
+ if inner_type.data_type != DataType.Type.UNKNOWN and not is_number(inner_type.data_type):
21
+ old_messages.append(error_message(expression, inner_type, "NUMERIC"))
22
+
23
+ return AtomicType(data_type=expression.type.this, nullable=True, constant=True, messages=old_messages)
24
+
25
+ @get_type.register
26
+ def _(expression: exp.Sum, catalog: Catalog, search_path: str) -> ResultType:
27
+ inner_type = get_type(expression.this, catalog, search_path)
28
+
29
+ old_messages = inner_type.messages
30
+
31
+ if inner_type.data_type != DataType.Type.UNKNOWN and not is_number(inner_type.data_type):
32
+ old_messages.append(error_message(expression, inner_type, "NUMERIC"))
33
+
34
+ return AtomicType(data_type=expression.type.this, nullable=True, constant=True, messages=old_messages)
35
+
36
+ @get_type.register
37
+ def _(expression: exp.Min, catalog: Catalog, search_path: str) -> ResultType:
38
+ inner_type = get_type(expression.this, catalog, search_path)
39
+
40
+ old_messages = inner_type.messages
41
+
42
+ if inner_type.data_type != DataType.Type.UNKNOWN and inner_type.data_type == DataType.Type.BOOLEAN:
43
+ old_messages.append(error_message(expression, inner_type))
44
+
45
+ return AtomicType(data_type=inner_type.data_type, nullable=inner_type.nullable, constant=True, messages=old_messages)
46
+
47
+ @get_type.register
48
+ def _(expression: exp.Max, catalog: Catalog, search_path: str) -> ResultType:
49
+ inner_type = get_type(expression.this, catalog, search_path)
50
+
51
+ old_messages = inner_type.messages
52
+
53
+ if inner_type.data_type != DataType.Type.UNKNOWN and inner_type.data_type == DataType.Type.BOOLEAN:
54
+ old_messages.append(error_message(expression, inner_type))
55
+
56
+ return AtomicType(data_type=inner_type.data_type, nullable=inner_type.nullable, constant=True, messages=old_messages)
57
+
58
+ @get_type.register
59
+ def _(expression: exp.Concat, catalog: Catalog, search_path: str) -> ResultType:
60
+ old_messages = []
61
+ args_type = []
62
+
63
+ for arg in expression.expressions:
64
+ arg_type = get_type(arg, catalog, search_path)
65
+ if arg_type.messages:
66
+ old_messages.extend(arg_type.messages)
67
+ args_type.append(arg_type)
68
+
69
+
70
+ if not args_type:
71
+ old_messages.append(error_message(expression, "Empty arguments"))
72
+
73
+ # if all args are NULL, result is NULL
74
+ if all(target_type.data_type == DataType.Type.NULL for target_type in args_type):
75
+ return AtomicType(data_type=DataType.Type.NULL, constant=True, messages=old_messages)
76
+
77
+ constant = all(target_type.constant for target_type in args_type)
78
+ nullable = any(target_type.nullable for target_type in args_type)
79
+
80
+ # Always returns VARCHAR
81
+ return AtomicType(data_type=expression.type.this, nullable=nullable, constant=constant, messages=old_messages)