altimate-code 0.5.2 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. package/CHANGELOG.md +12 -0
  2. package/bin/altimate +6 -0
  3. package/bin/altimate-code +6 -0
  4. package/dbt-tools/bin/altimate-dbt +2 -0
  5. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/__init__.py +0 -0
  6. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/fetch_schema.py +35 -0
  7. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/utils.py +353 -0
  8. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/validate_sql.py +114 -0
  9. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__init__.py +178 -0
  10. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__main__.py +96 -0
  11. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/_typing.py +17 -0
  12. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/__init__.py +3 -0
  13. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/__init__.py +18 -0
  14. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/_typing.py +18 -0
  15. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/column.py +332 -0
  16. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/dataframe.py +866 -0
  17. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/functions.py +1267 -0
  18. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/group.py +59 -0
  19. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/normalize.py +78 -0
  20. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/operations.py +53 -0
  21. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/readwriter.py +108 -0
  22. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/session.py +190 -0
  23. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/transforms.py +9 -0
  24. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/types.py +212 -0
  25. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/util.py +32 -0
  26. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/window.py +134 -0
  27. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/__init__.py +118 -0
  28. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/athena.py +166 -0
  29. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/bigquery.py +1331 -0
  30. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/clickhouse.py +1393 -0
  31. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/databricks.py +131 -0
  32. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dialect.py +1915 -0
  33. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/doris.py +561 -0
  34. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/drill.py +157 -0
  35. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/druid.py +20 -0
  36. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/duckdb.py +1159 -0
  37. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dune.py +16 -0
  38. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/hive.py +787 -0
  39. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/materialize.py +94 -0
  40. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/mysql.py +1324 -0
  41. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/oracle.py +378 -0
  42. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/postgres.py +778 -0
  43. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/presto.py +788 -0
  44. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/prql.py +203 -0
  45. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/redshift.py +448 -0
  46. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/risingwave.py +78 -0
  47. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/snowflake.py +1464 -0
  48. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark.py +202 -0
  49. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark2.py +349 -0
  50. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/sqlite.py +320 -0
  51. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/starrocks.py +343 -0
  52. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tableau.py +61 -0
  53. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/teradata.py +356 -0
  54. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/trino.py +115 -0
  55. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tsql.py +1403 -0
  56. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/diff.py +456 -0
  57. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/errors.py +93 -0
  58. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/__init__.py +95 -0
  59. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/context.py +101 -0
  60. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/env.py +246 -0
  61. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/python.py +460 -0
  62. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/table.py +155 -0
  63. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/expressions.py +8870 -0
  64. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/generator.py +4993 -0
  65. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/helper.py +582 -0
  66. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/jsonpath.py +227 -0
  67. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/lineage.py +423 -0
  68. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/__init__.py +11 -0
  69. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/annotate_types.py +589 -0
  70. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/canonicalize.py +222 -0
  71. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_ctes.py +43 -0
  72. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_joins.py +181 -0
  73. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py +189 -0
  74. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/isolate_table_selects.py +50 -0
  75. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/merge_subqueries.py +415 -0
  76. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize.py +200 -0
  77. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize_identifiers.py +64 -0
  78. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimize_joins.py +91 -0
  79. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimizer.py +94 -0
  80. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_predicates.py +222 -0
  81. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_projections.py +172 -0
  82. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify.py +104 -0
  83. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_columns.py +1024 -0
  84. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_tables.py +155 -0
  85. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/scope.py +904 -0
  86. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/simplify.py +1587 -0
  87. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/unnest_subqueries.py +302 -0
  88. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/parser.py +8501 -0
  89. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/planner.py +463 -0
  90. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/schema.py +588 -0
  91. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/serde.py +68 -0
  92. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/time.py +687 -0
  93. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/tokens.py +1520 -0
  94. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/transforms.py +1020 -0
  95. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/trie.py +81 -0
  96. package/dbt-tools/dist/altimate_python_packages/dbt_core_integration.py +825 -0
  97. package/dbt-tools/dist/altimate_python_packages/dbt_utils.py +157 -0
  98. package/dbt-tools/dist/index.js +23859 -0
  99. package/package.json +13 -13
  100. package/postinstall.mjs +42 -0
  101. package/skills/altimate-setup/SKILL.md +31 -0
@@ -0,0 +1,589 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import typing as t
5
+
6
+ from sqlglot import exp
7
+ from sqlglot.helper import (
8
+ ensure_list,
9
+ is_date_unit,
10
+ is_iso_date,
11
+ is_iso_datetime,
12
+ seq_get,
13
+ )
14
+ from sqlglot.optimizer.scope import Scope, traverse_scope
15
+ from sqlglot.schema import Schema, ensure_schema
16
+ from sqlglot.dialects.dialect import Dialect
17
+
18
+ if t.TYPE_CHECKING:
19
+ from sqlglot._typing import B, E
20
+
21
+ BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
22
+ BinaryCoercions = t.Dict[
23
+ t.Tuple[exp.DataType.Type, exp.DataType.Type],
24
+ BinaryCoercionFunc,
25
+ ]
26
+
27
+ from sqlglot.dialects.dialect import DialectType, AnnotatorsType
28
+
29
+
30
+ def annotate_types(
31
+ expression: E,
32
+ schema: t.Optional[t.Dict | Schema] = None,
33
+ annotators: t.Optional[AnnotatorsType] = None,
34
+ coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
35
+ dialect: DialectType = None,
36
+ ) -> E:
37
+ """
38
+ Infers the types of an expression, annotating its AST accordingly.
39
+
40
+ Example:
41
+ >>> import sqlglot
42
+ >>> schema = {"y": {"cola": "SMALLINT"}}
43
+ >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
44
+ >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
45
+ >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola"
46
+ <Type.DOUBLE: 'DOUBLE'>
47
+
48
+ Args:
49
+ expression: Expression to annotate.
50
+ schema: Database schema.
51
+ annotators: Maps expression type to corresponding annotation function.
52
+ coerces_to: Maps expression type to set of types that it can be coerced into.
53
+
54
+ Returns:
55
+ The expression annotated with types.
56
+ """
57
+
58
+ schema = ensure_schema(schema, dialect=dialect)
59
+
60
+ return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
61
+
62
+
63
+ def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
64
+ date_text = l.name
65
+ is_iso_date_ = is_iso_date(date_text)
66
+
67
+ if is_iso_date_ and is_date_unit(unit):
68
+ return exp.DataType.Type.DATE
69
+
70
+ # An ISO date is also an ISO datetime, but not vice versa
71
+ if is_iso_date_ or is_iso_datetime(date_text):
72
+ return exp.DataType.Type.DATETIME
73
+
74
+ return exp.DataType.Type.UNKNOWN
75
+
76
+
77
+ def _coerce_date(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
78
+ if not is_date_unit(unit):
79
+ return exp.DataType.Type.DATETIME
80
+ return l.type.this if l.type else exp.DataType.Type.UNKNOWN
81
+
82
+
83
+ def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc:
84
+ @functools.wraps(func)
85
+ def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
86
+ return func(r, l)
87
+
88
+ return _swapped
89
+
90
+
91
+ def swap_all(coercions: BinaryCoercions) -> BinaryCoercions:
92
+ return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}}
93
+
94
+
95
+ class _TypeAnnotator(type):
96
+ def __new__(cls, clsname, bases, attrs):
97
+ klass = super().__new__(cls, clsname, bases, attrs)
98
+
99
+ # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
100
+ # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
101
+ text_precedence = (
102
+ exp.DataType.Type.TEXT,
103
+ exp.DataType.Type.NVARCHAR,
104
+ exp.DataType.Type.VARCHAR,
105
+ exp.DataType.Type.NCHAR,
106
+ exp.DataType.Type.CHAR,
107
+ )
108
+ numeric_precedence = (
109
+ exp.DataType.Type.DOUBLE,
110
+ exp.DataType.Type.FLOAT,
111
+ exp.DataType.Type.DECIMAL,
112
+ exp.DataType.Type.BIGINT,
113
+ exp.DataType.Type.INT,
114
+ exp.DataType.Type.SMALLINT,
115
+ exp.DataType.Type.TINYINT,
116
+ )
117
+ timelike_precedence = (
118
+ exp.DataType.Type.TIMESTAMPLTZ,
119
+ exp.DataType.Type.TIMESTAMPTZ,
120
+ exp.DataType.Type.TIMESTAMP,
121
+ exp.DataType.Type.DATETIME,
122
+ exp.DataType.Type.DATE,
123
+ )
124
+
125
+ for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
126
+ coerces_to = set()
127
+ for data_type in type_precedence:
128
+ klass.COERCES_TO[data_type] = coerces_to.copy()
129
+ coerces_to |= {data_type}
130
+
131
+ # NULL can be coerced to any type, so e.g. NULL + 1 will have type INT
132
+ klass.COERCES_TO[exp.DataType.Type.NULL] = {
133
+ *text_precedence,
134
+ *numeric_precedence,
135
+ *timelike_precedence,
136
+ }
137
+
138
+ return klass
139
+
140
+
141
+ class TypeAnnotator(metaclass=_TypeAnnotator):
142
+ NESTED_TYPES = {
143
+ exp.DataType.Type.ARRAY,
144
+ }
145
+
146
+ # Specifies what types a given type can be coerced into (autofilled)
147
+ COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
148
+
149
+ # Coercion functions for binary operations.
150
+ # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type.
151
+ BINARY_COERCIONS: BinaryCoercions = {
152
+ **swap_all(
153
+ {
154
+ (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal(
155
+ l, r.args.get("unit")
156
+ )
157
+ for t in exp.DataType.TEXT_TYPES
158
+ }
159
+ ),
160
+ **swap_all(
161
+ {
162
+ # text + numeric will yield the numeric type to match most dialects' semantics
163
+ (text, numeric): lambda l, r: t.cast(
164
+ exp.DataType.Type, l.type if l.type in exp.DataType.NUMERIC_TYPES else r.type
165
+ )
166
+ for text in exp.DataType.TEXT_TYPES
167
+ for numeric in exp.DataType.NUMERIC_TYPES
168
+ }
169
+ ),
170
+ **swap_all(
171
+ {
172
+ (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date(
173
+ l, r.args.get("unit")
174
+ ),
175
+ }
176
+ ),
177
+ }
178
+
179
+ def __init__(
180
+ self,
181
+ schema: Schema,
182
+ annotators: t.Optional[AnnotatorsType] = None,
183
+ coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
184
+ binary_coercions: t.Optional[BinaryCoercions] = None,
185
+ ) -> None:
186
+ self.schema = schema
187
+ self.annotators = annotators or Dialect.get_or_raise(schema.dialect).ANNOTATORS
188
+ self.coerces_to = (
189
+ coerces_to or Dialect.get_or_raise(schema.dialect).COERCES_TO or self.COERCES_TO
190
+ )
191
+ self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
192
+
193
+ # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
194
+ self._visited: t.Set[int] = set()
195
+
196
+ # Maps an exp.SetOperation's id (e.g. UNION) to its projection types. This is computed if the
197
+ # exp.SetOperation is the expression of a scope source, as selecting from it multiple times
198
+ # would reprocess the entire subtree to coerce the types of its operands' projections
199
+ self._setop_column_types: t.Dict[int, t.Dict[str, exp.DataType | exp.DataType.Type]] = {}
200
+
201
+ def _set_type(
202
+ self, expression: exp.Expression, target_type: t.Optional[exp.DataType | exp.DataType.Type]
203
+ ) -> None:
204
+ expression.type = target_type or exp.DataType.Type.UNKNOWN # type: ignore
205
+ self._visited.add(id(expression))
206
+
207
+ def annotate(self, expression: E) -> E:
208
+ for scope in traverse_scope(expression):
209
+ self.annotate_scope(scope)
210
+ return self._maybe_annotate(expression) # This takes care of non-traversable expressions
211
+
212
+ def annotate_scope(self, scope: Scope) -> None:
213
+ selects = {}
214
+ for name, source in scope.sources.items():
215
+ if not isinstance(source, Scope):
216
+ continue
217
+
218
+ expression = source.expression
219
+ if isinstance(expression, exp.UDTF):
220
+ values = []
221
+
222
+ if isinstance(expression, exp.Lateral):
223
+ if isinstance(expression.this, exp.Explode):
224
+ values = [expression.this.this]
225
+ elif isinstance(expression, exp.Unnest):
226
+ values = [expression]
227
+ elif not isinstance(expression, exp.TableFromRows):
228
+ values = expression.expressions[0].expressions
229
+
230
+ if not values:
231
+ continue
232
+
233
+ selects[name] = {
234
+ alias: column.type
235
+ for alias, column in zip(expression.alias_column_names, values)
236
+ }
237
+ elif isinstance(expression, exp.SetOperation) and len(expression.left.selects) == len(
238
+ expression.right.selects
239
+ ):
240
+ selects[name] = col_types = self._setop_column_types.setdefault(id(expression), {})
241
+
242
+ if not col_types:
243
+ # Process a chain / sub-tree of set operations
244
+ for set_op in expression.walk(
245
+ prune=lambda n: not isinstance(n, (exp.SetOperation, exp.Subquery))
246
+ ):
247
+ if not isinstance(set_op, exp.SetOperation):
248
+ continue
249
+
250
+ if set_op.args.get("by_name"):
251
+ r_type_by_select = {
252
+ s.alias_or_name: s.type for s in set_op.right.selects
253
+ }
254
+ setop_cols = {
255
+ s.alias_or_name: self._maybe_coerce(
256
+ t.cast(exp.DataType, s.type),
257
+ r_type_by_select.get(s.alias_or_name)
258
+ or exp.DataType.Type.UNKNOWN,
259
+ )
260
+ for s in set_op.left.selects
261
+ }
262
+ else:
263
+ setop_cols = {
264
+ ls.alias_or_name: self._maybe_coerce(
265
+ t.cast(exp.DataType, ls.type), t.cast(exp.DataType, rs.type)
266
+ )
267
+ for ls, rs in zip(set_op.left.selects, set_op.right.selects)
268
+ }
269
+
270
+ # Coerce intermediate results with the previously registered types, if they exist
271
+ for col_name, col_type in setop_cols.items():
272
+ col_types[col_name] = self._maybe_coerce(
273
+ col_type, col_types.get(col_name, exp.DataType.Type.NULL)
274
+ )
275
+
276
+ else:
277
+ selects[name] = {s.alias_or_name: s.type for s in expression.selects}
278
+
279
+ # First annotate the current scope's column references
280
+ for col in scope.columns:
281
+ if not col.table:
282
+ continue
283
+
284
+ source = scope.sources.get(col.table)
285
+ if isinstance(source, exp.Table):
286
+ self._set_type(col, self.schema.get_column_type(source, col))
287
+ elif source:
288
+ if col.table in selects and col.name in selects[col.table]:
289
+ self._set_type(col, selects[col.table][col.name])
290
+ elif isinstance(source.expression, exp.Unnest):
291
+ self._set_type(col, source.expression.type)
292
+
293
+ # Then (possibly) annotate the remaining expressions in the scope
294
+ self._maybe_annotate(scope.expression)
295
+
296
+ def _maybe_annotate(self, expression: E) -> E:
297
+ if id(expression) in self._visited:
298
+ return expression # We've already inferred the expression's type
299
+
300
+ annotator = self.annotators.get(expression.__class__)
301
+
302
+ return (
303
+ annotator(self, expression)
304
+ if annotator
305
+ else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
306
+ )
307
+
308
+ def _annotate_args(self, expression: E) -> E:
309
+ for value in expression.iter_expressions():
310
+ self._maybe_annotate(value)
311
+
312
+ return expression
313
+
314
+ def _maybe_coerce(
315
+ self,
316
+ type1: exp.DataType | exp.DataType.Type,
317
+ type2: exp.DataType | exp.DataType.Type,
318
+ ) -> exp.DataType | exp.DataType.Type:
319
+ """
320
+ Returns type2 if type1 can be coerced into it, otherwise type1.
321
+
322
+ If either type is parameterized (e.g. DECIMAL(18, 2) contains two parameters),
323
+ we assume type1 does not coerce into type2, so we also return it in this case.
324
+ """
325
+ if isinstance(type1, exp.DataType):
326
+ if type1.expressions:
327
+ return type1
328
+ type1_value = type1.this
329
+ else:
330
+ type1_value = type1
331
+
332
+ if isinstance(type2, exp.DataType):
333
+ if type2.expressions:
334
+ return type2
335
+ type2_value = type2.this
336
+ else:
337
+ type2_value = type2
338
+
339
+ # We propagate the UNKNOWN type upwards if found
340
+ if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
341
+ return exp.DataType.Type.UNKNOWN
342
+
343
+ return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
344
+
345
+ def _annotate_binary(self, expression: B) -> B:
346
+ self._annotate_args(expression)
347
+
348
+ left, right = expression.left, expression.right
349
+ left_type, right_type = left.type.this, right.type.this # type: ignore
350
+
351
+ if isinstance(expression, (exp.Connector, exp.Predicate)):
352
+ self._set_type(expression, exp.DataType.Type.BOOLEAN)
353
+ elif (left_type, right_type) in self.binary_coercions:
354
+ self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
355
+ else:
356
+ self._set_type(expression, self._maybe_coerce(left_type, right_type))
357
+
358
+ return expression
359
+
360
+ def _annotate_unary(self, expression: E) -> E:
361
+ self._annotate_args(expression)
362
+
363
+ if isinstance(expression, exp.Not):
364
+ self._set_type(expression, exp.DataType.Type.BOOLEAN)
365
+ else:
366
+ self._set_type(expression, expression.this.type)
367
+
368
+ return expression
369
+
370
+ def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
371
+ if expression.is_string:
372
+ self._set_type(expression, exp.DataType.Type.VARCHAR)
373
+ elif expression.is_int:
374
+ self._set_type(expression, exp.DataType.Type.INT)
375
+ else:
376
+ self._set_type(expression, exp.DataType.Type.DOUBLE)
377
+
378
+ return expression
379
+
380
+ def _annotate_with_type(
381
+ self, expression: E, target_type: exp.DataType | exp.DataType.Type
382
+ ) -> E:
383
+ self._set_type(expression, target_type)
384
+ return self._annotate_args(expression)
385
+
386
+ @t.no_type_check
387
+ def _annotate_by_args(
388
+ self,
389
+ expression: E,
390
+ *args: str,
391
+ promote: bool = False,
392
+ array: bool = False,
393
+ ) -> E:
394
+ self._annotate_args(expression)
395
+
396
+ expressions: t.List[exp.Expression] = []
397
+ for arg in args:
398
+ arg_expr = expression.args.get(arg)
399
+ expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
400
+
401
+ last_datatype = None
402
+ for expr in expressions:
403
+ expr_type = expr.type
404
+
405
+ # Stop at the first nested data type found - we don't want to _maybe_coerce nested types
406
+ if expr_type.args.get("nested"):
407
+ last_datatype = expr_type
408
+ break
409
+
410
+ if not expr_type.is_type(exp.DataType.Type.UNKNOWN):
411
+ last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
412
+
413
+ self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
414
+
415
+ if promote:
416
+ if expression.type.this in exp.DataType.INTEGER_TYPES:
417
+ self._set_type(expression, exp.DataType.Type.BIGINT)
418
+ elif expression.type.this in exp.DataType.FLOAT_TYPES:
419
+ self._set_type(expression, exp.DataType.Type.DOUBLE)
420
+
421
+ if array:
422
+ self._set_type(
423
+ expression,
424
+ exp.DataType(
425
+ this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
426
+ ),
427
+ )
428
+
429
+ return expression
430
+
431
+ def _annotate_timeunit(
432
+ self, expression: exp.TimeUnit | exp.DateTrunc
433
+ ) -> exp.TimeUnit | exp.DateTrunc:
434
+ self._annotate_args(expression)
435
+
436
+ if expression.this.type.this in exp.DataType.TEXT_TYPES:
437
+ datatype = _coerce_date_literal(expression.this, expression.unit)
438
+ elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES:
439
+ datatype = _coerce_date(expression.this, expression.unit)
440
+ else:
441
+ datatype = exp.DataType.Type.UNKNOWN
442
+
443
+ self._set_type(expression, datatype)
444
+ return expression
445
+
446
+ def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket:
447
+ self._annotate_args(expression)
448
+
449
+ bracket_arg = expression.expressions[0]
450
+ this = expression.this
451
+
452
+ if isinstance(bracket_arg, exp.Slice):
453
+ self._set_type(expression, this.type)
454
+ elif this.type.is_type(exp.DataType.Type.ARRAY):
455
+ self._set_type(expression, seq_get(this.type.expressions, 0))
456
+ elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
457
+ index = this.keys.index(bracket_arg)
458
+ value = seq_get(this.values, index)
459
+ self._set_type(expression, value.type if value else None)
460
+ else:
461
+ self._set_type(expression, exp.DataType.Type.UNKNOWN)
462
+
463
+ return expression
464
+
465
+ def _annotate_div(self, expression: exp.Div) -> exp.Div:
466
+ self._annotate_args(expression)
467
+
468
+ left_type, right_type = expression.left.type.this, expression.right.type.this # type: ignore
469
+
470
+ if (
471
+ expression.args.get("typed")
472
+ and left_type in exp.DataType.INTEGER_TYPES
473
+ and right_type in exp.DataType.INTEGER_TYPES
474
+ ):
475
+ self._set_type(expression, exp.DataType.Type.BIGINT)
476
+ else:
477
+ self._set_type(expression, self._maybe_coerce(left_type, right_type))
478
+ if expression.type and expression.type.this not in exp.DataType.REAL_TYPES:
479
+ self._set_type(
480
+ expression, self._maybe_coerce(expression.type, exp.DataType.Type.DOUBLE)
481
+ )
482
+
483
+ return expression
484
+
485
+ def _annotate_dot(self, expression: exp.Dot) -> exp.Dot:
486
+ self._annotate_args(expression)
487
+ self._set_type(expression, None)
488
+ this_type = expression.this.type
489
+
490
+ if this_type and this_type.is_type(exp.DataType.Type.STRUCT):
491
+ for e in this_type.expressions:
492
+ if e.name == expression.expression.name:
493
+ self._set_type(expression, e.kind)
494
+ break
495
+
496
+ return expression
497
+
498
+ def _annotate_explode(self, expression: exp.Explode) -> exp.Explode:
499
+ self._annotate_args(expression)
500
+ self._set_type(expression, seq_get(expression.this.type.expressions, 0))
501
+ return expression
502
+
503
+ def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest:
504
+ self._annotate_args(expression)
505
+ child = seq_get(expression.expressions, 0)
506
+
507
+ if child and child.is_type(exp.DataType.Type.ARRAY):
508
+ expr_type = seq_get(child.type.expressions, 0)
509
+ else:
510
+ expr_type = None
511
+
512
+ self._set_type(expression, expr_type)
513
+ return expression
514
+
515
+ def _annotate_struct_value(
516
+ self, expression: exp.Expression
517
+ ) -> t.Optional[exp.DataType] | exp.ColumnDef:
518
+ alias = expression.args.get("alias")
519
+ if alias:
520
+ return exp.ColumnDef(this=alias.copy(), kind=expression.type)
521
+
522
+ # Case: key = value or key := value
523
+ if expression.expression:
524
+ return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
525
+
526
+ return expression.type
527
+
528
+ def _annotate_struct(self, expression: exp.Struct) -> exp.Struct:
529
+ self._annotate_args(expression)
530
+ self._set_type(
531
+ expression,
532
+ exp.DataType(
533
+ this=exp.DataType.Type.STRUCT,
534
+ expressions=[self._annotate_struct_value(expr) for expr in expression.expressions],
535
+ nested=True,
536
+ ),
537
+ )
538
+ return expression
539
+
540
+ @t.overload
541
+ def _annotate_map(self, expression: exp.Map) -> exp.Map: ...
542
+
543
+ @t.overload
544
+ def _annotate_map(self, expression: exp.VarMap) -> exp.VarMap: ...
545
+
546
+ def _annotate_map(self, expression):
547
+ self._annotate_args(expression)
548
+
549
+ keys = expression.args.get("keys")
550
+ values = expression.args.get("values")
551
+
552
+ map_type = exp.DataType(this=exp.DataType.Type.MAP)
553
+ if isinstance(keys, exp.Array) and isinstance(values, exp.Array):
554
+ key_type = seq_get(keys.type.expressions, 0) or exp.DataType.Type.UNKNOWN
555
+ value_type = seq_get(values.type.expressions, 0) or exp.DataType.Type.UNKNOWN
556
+
557
+ if key_type != exp.DataType.Type.UNKNOWN and value_type != exp.DataType.Type.UNKNOWN:
558
+ map_type.set("expressions", [key_type, value_type])
559
+ map_type.set("nested", True)
560
+
561
+ self._set_type(expression, map_type)
562
+ return expression
563
+
564
+ def _annotate_to_map(self, expression: exp.ToMap) -> exp.ToMap:
565
+ self._annotate_args(expression)
566
+
567
+ map_type = exp.DataType(this=exp.DataType.Type.MAP)
568
+ arg = expression.this
569
+ if arg.is_type(exp.DataType.Type.STRUCT):
570
+ for coldef in arg.type.expressions:
571
+ kind = coldef.kind
572
+ if kind != exp.DataType.Type.UNKNOWN:
573
+ map_type.set("expressions", [exp.DataType.build("varchar"), kind])
574
+ map_type.set("nested", True)
575
+ break
576
+
577
+ self._set_type(expression, map_type)
578
+ return expression
579
+
580
+ def _annotate_extract(self, expression: exp.Extract) -> exp.Extract:
581
+ self._annotate_args(expression)
582
+ part = expression.name
583
+ if part == "TIME":
584
+ self._set_type(expression, exp.DataType.Type.TIME)
585
+ elif part == "DATE":
586
+ self._set_type(expression, exp.DataType.Type.DATE)
587
+ else:
588
+ self._set_type(expression, exp.DataType.Type.INT)
589
+ return expression