altimate-code 0.5.2 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. package/CHANGELOG.md +12 -0
  2. package/bin/altimate +6 -0
  3. package/bin/altimate-code +6 -0
  4. package/dbt-tools/bin/altimate-dbt +2 -0
  5. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/__init__.py +0 -0
  6. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/fetch_schema.py +35 -0
  7. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/utils.py +353 -0
  8. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/validate_sql.py +114 -0
  9. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__init__.py +178 -0
  10. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__main__.py +96 -0
  11. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/_typing.py +17 -0
  12. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/__init__.py +3 -0
  13. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/__init__.py +18 -0
  14. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/_typing.py +18 -0
  15. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/column.py +332 -0
  16. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/dataframe.py +866 -0
  17. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/functions.py +1267 -0
  18. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/group.py +59 -0
  19. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/normalize.py +78 -0
  20. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/operations.py +53 -0
  21. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/readwriter.py +108 -0
  22. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/session.py +190 -0
  23. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/transforms.py +9 -0
  24. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/types.py +212 -0
  25. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/util.py +32 -0
  26. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/window.py +134 -0
  27. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/__init__.py +118 -0
  28. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/athena.py +166 -0
  29. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/bigquery.py +1331 -0
  30. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/clickhouse.py +1393 -0
  31. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/databricks.py +131 -0
  32. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dialect.py +1915 -0
  33. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/doris.py +561 -0
  34. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/drill.py +157 -0
  35. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/druid.py +20 -0
  36. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/duckdb.py +1159 -0
  37. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dune.py +16 -0
  38. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/hive.py +787 -0
  39. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/materialize.py +94 -0
  40. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/mysql.py +1324 -0
  41. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/oracle.py +378 -0
  42. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/postgres.py +778 -0
  43. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/presto.py +788 -0
  44. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/prql.py +203 -0
  45. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/redshift.py +448 -0
  46. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/risingwave.py +78 -0
  47. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/snowflake.py +1464 -0
  48. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark.py +202 -0
  49. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark2.py +349 -0
  50. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/sqlite.py +320 -0
  51. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/starrocks.py +343 -0
  52. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tableau.py +61 -0
  53. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/teradata.py +356 -0
  54. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/trino.py +115 -0
  55. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tsql.py +1403 -0
  56. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/diff.py +456 -0
  57. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/errors.py +93 -0
  58. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/__init__.py +95 -0
  59. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/context.py +101 -0
  60. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/env.py +246 -0
  61. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/python.py +460 -0
  62. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/table.py +155 -0
  63. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/expressions.py +8870 -0
  64. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/generator.py +4993 -0
  65. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/helper.py +582 -0
  66. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/jsonpath.py +227 -0
  67. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/lineage.py +423 -0
  68. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/__init__.py +11 -0
  69. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/annotate_types.py +589 -0
  70. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/canonicalize.py +222 -0
  71. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_ctes.py +43 -0
  72. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_joins.py +181 -0
  73. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py +189 -0
  74. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/isolate_table_selects.py +50 -0
  75. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/merge_subqueries.py +415 -0
  76. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize.py +200 -0
  77. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize_identifiers.py +64 -0
  78. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimize_joins.py +91 -0
  79. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimizer.py +94 -0
  80. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_predicates.py +222 -0
  81. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_projections.py +172 -0
  82. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify.py +104 -0
  83. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_columns.py +1024 -0
  84. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_tables.py +155 -0
  85. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/scope.py +904 -0
  86. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/simplify.py +1587 -0
  87. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/unnest_subqueries.py +302 -0
  88. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/parser.py +8501 -0
  89. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/planner.py +463 -0
  90. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/schema.py +588 -0
  91. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/serde.py +68 -0
  92. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/time.py +687 -0
  93. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/tokens.py +1520 -0
  94. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/transforms.py +1020 -0
  95. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/trie.py +81 -0
  96. package/dbt-tools/dist/altimate_python_packages/dbt_core_integration.py +825 -0
  97. package/dbt-tools/dist/altimate_python_packages/dbt_utils.py +157 -0
  98. package/dbt-tools/dist/index.js +23859 -0
  99. package/package.json +13 -13
  100. package/postinstall.mjs +42 -0
  101. package/skills/altimate-setup/SKILL.md +31 -0
@@ -0,0 +1,202 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+
5
+ from sqlglot import exp
6
+ from sqlglot.dialects.dialect import rename_func, unit_to_var, timestampdiff_sql, build_date_delta
7
+ from sqlglot.dialects.hive import _build_with_ignore_nulls
8
+ from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider, _build_as_cast
9
+ from sqlglot.helper import ensure_list, seq_get
10
+ from sqlglot.transforms import (
11
+ ctas_with_tmp_tables_to_create_tmp_view,
12
+ remove_unique_constraints,
13
+ preprocess,
14
+ move_partitioned_by_to_schema_columns,
15
+ )
16
+
17
+
18
+ def _build_datediff(args: t.List) -> exp.Expression:
19
+ """
20
+ Although Spark docs don't mention the "unit" argument, Spark3 added support for
21
+ it at some point. Databricks also supports this variant (see below).
22
+
23
+ For example, in spark-sql (v3.3.1):
24
+ - SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
25
+ - SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4
26
+
27
+ See also:
28
+ - https://docs.databricks.com/sql/language-manual/functions/datediff3.html
29
+ - https://docs.databricks.com/sql/language-manual/functions/datediff.html
30
+ """
31
+ unit = None
32
+ this = seq_get(args, 0)
33
+ expression = seq_get(args, 1)
34
+
35
+ if len(args) == 3:
36
+ unit = exp.var(t.cast(exp.Expression, this).name)
37
+ this = args[2]
38
+
39
+ return exp.DateDiff(
40
+ this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit
41
+ )
42
+
43
+
44
+ def _build_dateadd(args: t.List) -> exp.Expression:
45
+ expression = seq_get(args, 1)
46
+
47
+ if len(args) == 2:
48
+ # DATE_ADD(startDate, numDays INTEGER)
49
+ # https://docs.databricks.com/en/sql/language-manual/functions/date_add.html
50
+ return exp.TsOrDsAdd(
51
+ this=seq_get(args, 0), expression=expression, unit=exp.Literal.string("DAY")
52
+ )
53
+
54
+ # DATE_ADD / DATEADD / TIMESTAMPADD(unit, value integer, expr)
55
+ # https://docs.databricks.com/en/sql/language-manual/functions/date_add3.html
56
+ return exp.TimestampAdd(this=seq_get(args, 2), expression=expression, unit=seq_get(args, 0))
57
+
58
+
59
+ def _normalize_partition(e: exp.Expression) -> exp.Expression:
60
+ """Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
61
+ if isinstance(e, str):
62
+ return exp.to_identifier(e)
63
+ if isinstance(e, exp.Literal):
64
+ return exp.to_identifier(e.name)
65
+ return e
66
+
67
+
68
+ def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.TimestampAdd) -> str:
69
+ if not expression.unit or (
70
+ isinstance(expression, exp.TsOrDsAdd) and expression.text("unit").upper() == "DAY"
71
+ ):
72
+ # Coming from Hive/Spark2 DATE_ADD or roundtripping the 2-arg version of Spark3/DB
73
+ return self.func("DATE_ADD", expression.this, expression.expression)
74
+
75
+ this = self.func(
76
+ "DATE_ADD",
77
+ unit_to_var(expression),
78
+ expression.expression,
79
+ expression.this,
80
+ )
81
+
82
+ if isinstance(expression, exp.TsOrDsAdd):
83
+ # The 3 arg version of DATE_ADD produces a timestamp in Spark3/DB but possibly not
84
+ # in other dialects
85
+ return_type = expression.return_type
86
+ if not return_type.is_type(exp.DataType.Type.TIMESTAMP, exp.DataType.Type.DATETIME):
87
+ this = f"CAST({this} AS {return_type})"
88
+
89
+ return this
90
+
91
+
92
+ class Spark(Spark2):
93
+ SUPPORTS_ORDER_BY_ALL = True
94
+
95
+ class Tokenizer(Spark2.Tokenizer):
96
+ STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
97
+
98
+ RAW_STRINGS = [
99
+ (prefix + q, q)
100
+ for q in t.cast(t.List[str], Spark2.Tokenizer.QUOTES)
101
+ for prefix in ("r", "R")
102
+ ]
103
+
104
+ class Parser(Spark2.Parser):
105
+ FUNCTIONS = {
106
+ **Spark2.Parser.FUNCTIONS,
107
+ "ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
108
+ "DATE_ADD": _build_dateadd,
109
+ "DATEADD": _build_dateadd,
110
+ "TIMESTAMPADD": _build_dateadd,
111
+ "TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff),
112
+ "DATEDIFF": _build_datediff,
113
+ "DATE_DIFF": _build_datediff,
114
+ "TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
115
+ "TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
116
+ "TRY_ELEMENT_AT": lambda args: exp.Bracket(
117
+ this=seq_get(args, 0),
118
+ expressions=ensure_list(seq_get(args, 1)),
119
+ offset=1,
120
+ safe=True,
121
+ ),
122
+ }
123
+
124
+ def _parse_generated_as_identity(
125
+ self,
126
+ ) -> (
127
+ exp.GeneratedAsIdentityColumnConstraint
128
+ | exp.ComputedColumnConstraint
129
+ | exp.GeneratedAsRowColumnConstraint
130
+ ):
131
+ this = super()._parse_generated_as_identity()
132
+ if this.expression:
133
+ return self.expression(exp.ComputedColumnConstraint, this=this.expression)
134
+ return this
135
+
136
+ class Generator(Spark2.Generator):
137
+ SUPPORTS_TO_NUMBER = True
138
+ PAD_FILL_PATTERN_IS_REQUIRED = False
139
+ SUPPORTS_CONVERT_TIMEZONE = True
140
+ SUPPORTS_MEDIAN = True
141
+ SUPPORTS_UNIX_SECONDS = True
142
+
143
+ TYPE_MAPPING = {
144
+ **Spark2.Generator.TYPE_MAPPING,
145
+ exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
146
+ exp.DataType.Type.SMALLMONEY: "DECIMAL(6, 4)",
147
+ exp.DataType.Type.UUID: "STRING",
148
+ exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP_LTZ",
149
+ exp.DataType.Type.TIMESTAMPNTZ: "TIMESTAMP_NTZ",
150
+ }
151
+
152
+ TRANSFORMS = {
153
+ **Spark2.Generator.TRANSFORMS,
154
+ exp.ArrayConstructCompact: lambda self, e: self.func(
155
+ "ARRAY_COMPACT", self.func("ARRAY", *e.expressions)
156
+ ),
157
+ exp.Create: preprocess(
158
+ [
159
+ remove_unique_constraints,
160
+ lambda e: ctas_with_tmp_tables_to_create_tmp_view(
161
+ e, temporary_storage_provider
162
+ ),
163
+ move_partitioned_by_to_schema_columns,
164
+ ]
165
+ ),
166
+ exp.EndsWith: rename_func("ENDSWITH"),
167
+ exp.PartitionedByProperty: lambda self,
168
+ e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
169
+ exp.StartsWith: rename_func("STARTSWITH"),
170
+ exp.TsOrDsAdd: _dateadd_sql,
171
+ exp.TimestampAdd: _dateadd_sql,
172
+ exp.DatetimeDiff: timestampdiff_sql,
173
+ exp.TimestampDiff: timestampdiff_sql,
174
+ exp.TryCast: lambda self, e: (
175
+ self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
176
+ ),
177
+ }
178
+ TRANSFORMS.pop(exp.AnyValue)
179
+ TRANSFORMS.pop(exp.DateDiff)
180
+ TRANSFORMS.pop(exp.Group)
181
+
182
+ def bracket_sql(self, expression: exp.Bracket) -> str:
183
+ if expression.args.get("safe"):
184
+ key = seq_get(self.bracket_offset_expressions(expression, index_offset=1), 0)
185
+ return self.func("TRY_ELEMENT_AT", expression.this, key)
186
+
187
+ return super().bracket_sql(expression)
188
+
189
+ def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
190
+ return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})"
191
+
192
+ def anyvalue_sql(self, expression: exp.AnyValue) -> str:
193
+ return self.function_fallback_sql(expression)
194
+
195
+ def datediff_sql(self, expression: exp.DateDiff) -> str:
196
+ end = self.sql(expression, "this")
197
+ start = self.sql(expression, "expression")
198
+
199
+ if expression.unit:
200
+ return self.func("DATEDIFF", unit_to_var(expression), start, end)
201
+
202
+ return self.func("DATEDIFF", end, start)
@@ -0,0 +1,349 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+
5
+ from sqlglot import exp, transforms
6
+ from sqlglot.dialects.dialect import (
7
+ binary_from_function,
8
+ build_formatted_time,
9
+ is_parse_json,
10
+ pivot_column_names,
11
+ rename_func,
12
+ trim_sql,
13
+ unit_to_str,
14
+ )
15
+ from sqlglot.dialects.hive import Hive
16
+ from sqlglot.helper import seq_get, ensure_list
17
+ from sqlglot.tokens import TokenType
18
+ from sqlglot.transforms import (
19
+ preprocess,
20
+ remove_unique_constraints,
21
+ ctas_with_tmp_tables_to_create_tmp_view,
22
+ move_schema_columns_to_partitioned_by,
23
+ )
24
+
25
+ if t.TYPE_CHECKING:
26
+ from sqlglot._typing import E
27
+
28
+ from sqlglot.optimizer.annotate_types import TypeAnnotator
29
+
30
+
31
+ def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
32
+ keys = expression.args.get("keys")
33
+ values = expression.args.get("values")
34
+
35
+ if not keys or not values:
36
+ return self.func("MAP")
37
+
38
+ return self.func("MAP_FROM_ARRAYS", keys, values)
39
+
40
+
41
+ def _build_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
42
+ return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
43
+
44
+
45
+ def _str_to_date(self: Spark2.Generator, expression: exp.StrToDate) -> str:
46
+ time_format = self.format_time(expression)
47
+ if time_format == Hive.DATE_FORMAT:
48
+ return self.func("TO_DATE", expression.this)
49
+ return self.func("TO_DATE", expression.this, time_format)
50
+
51
+
52
+ def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str:
53
+ scale = expression.args.get("scale")
54
+ timestamp = expression.this
55
+
56
+ if scale is None:
57
+ return self.sql(exp.cast(exp.func("from_unixtime", timestamp), exp.DataType.Type.TIMESTAMP))
58
+ if scale == exp.UnixToTime.SECONDS:
59
+ return self.func("TIMESTAMP_SECONDS", timestamp)
60
+ if scale == exp.UnixToTime.MILLIS:
61
+ return self.func("TIMESTAMP_MILLIS", timestamp)
62
+ if scale == exp.UnixToTime.MICROS:
63
+ return self.func("TIMESTAMP_MICROS", timestamp)
64
+
65
+ unix_seconds = exp.Div(this=timestamp, expression=exp.func("POW", 10, scale))
66
+ return self.func("TIMESTAMP_SECONDS", unix_seconds)
67
+
68
+
69
+ def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
70
+ """
71
+ Spark doesn't allow PIVOT aliases, so we need to remove them and possibly wrap a
72
+ pivoted source in a subquery with the same alias to preserve the query's semantics.
73
+
74
+ Example:
75
+ >>> from sqlglot import parse_one
76
+ >>> expr = parse_one("SELECT piv.x FROM tbl PIVOT (SUM(a) FOR b IN ('x')) piv")
77
+ >>> print(_unalias_pivot(expr).sql(dialect="spark"))
78
+ SELECT piv.x FROM (SELECT * FROM tbl PIVOT(SUM(a) FOR b IN ('x'))) AS piv
79
+ """
80
+ if isinstance(expression, exp.From) and expression.this.args.get("pivots"):
81
+ pivot = expression.this.args["pivots"][0]
82
+ if pivot.alias:
83
+ alias = pivot.args["alias"].pop()
84
+ return exp.From(
85
+ this=expression.this.replace(
86
+ exp.select("*")
87
+ .from_(expression.this.copy(), copy=False)
88
+ .subquery(alias=alias, copy=False)
89
+ )
90
+ )
91
+
92
+ return expression
93
+
94
+
95
+ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
96
+ """
97
+ Spark doesn't allow the column referenced in the PIVOT's field to be qualified,
98
+ so we need to unqualify it.
99
+
100
+ Example:
101
+ >>> from sqlglot import parse_one
102
+ >>> expr = parse_one("SELECT * FROM tbl PIVOT (SUM(tbl.sales) FOR tbl.quarter IN ('Q1', 'Q2'))")
103
+ >>> print(_unqualify_pivot_columns(expr).sql(dialect="spark"))
104
+ SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1'))
105
+ """
106
+ if isinstance(expression, exp.Pivot):
107
+ expression.set(
108
+ "fields", [transforms.unqualify_columns(field) for field in expression.fields]
109
+ )
110
+
111
+ return expression
112
+
113
+
114
+ def temporary_storage_provider(expression: exp.Expression) -> exp.Expression:
115
+ # spark2, spark, Databricks require a storage provider for temporary tables
116
+ provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
117
+ expression.args["properties"].append("expressions", provider)
118
+ return expression
119
+
120
+
121
+ def _annotate_by_similar_args(
122
+ self: TypeAnnotator, expression: E, *args: str, target_type: exp.DataType | exp.DataType.Type
123
+ ) -> E:
124
+ """
125
+ Infers the type of the expression according to the following rules:
126
+ - If all args are of the same type OR any arg is of target_type, the expr is inferred as such
127
+ - If any arg is of UNKNOWN type and none of target_type, the expr is inferred as UNKNOWN
128
+ """
129
+ self._annotate_args(expression)
130
+
131
+ expressions: t.List[exp.Expression] = []
132
+ for arg in args:
133
+ arg_expr = expression.args.get(arg)
134
+ expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
135
+
136
+ last_datatype = None
137
+
138
+ has_unknown = False
139
+ for expr in expressions:
140
+ if expr.is_type(exp.DataType.Type.UNKNOWN):
141
+ has_unknown = True
142
+ elif expr.is_type(target_type):
143
+ has_unknown = False
144
+ last_datatype = target_type
145
+ break
146
+ else:
147
+ last_datatype = expr.type
148
+
149
+ self._set_type(expression, exp.DataType.Type.UNKNOWN if has_unknown else last_datatype)
150
+ return expression
151
+
152
+
153
+ class Spark2(Hive):
154
+ ANNOTATORS = {
155
+ **Hive.ANNOTATORS,
156
+ exp.Substring: lambda self, e: self._annotate_by_args(e, "this"),
157
+ exp.Concat: lambda self, e: _annotate_by_similar_args(
158
+ self, e, "expressions", target_type=exp.DataType.Type.TEXT
159
+ ),
160
+ exp.Pad: lambda self, e: _annotate_by_similar_args(
161
+ self, e, "this", "fill_pattern", target_type=exp.DataType.Type.TEXT
162
+ ),
163
+ }
164
+
165
+ class Tokenizer(Hive.Tokenizer):
166
+ HEX_STRINGS = [("X'", "'"), ("x'", "'")]
167
+
168
+ KEYWORDS = {
169
+ **Hive.Tokenizer.KEYWORDS,
170
+ "TIMESTAMP": TokenType.TIMESTAMPTZ,
171
+ }
172
+
173
+ class Parser(Hive.Parser):
174
+ TRIM_PATTERN_FIRST = True
175
+
176
+ FUNCTIONS = {
177
+ **Hive.Parser.FUNCTIONS,
178
+ "AGGREGATE": exp.Reduce.from_arg_list,
179
+ "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
180
+ "BOOLEAN": _build_as_cast("boolean"),
181
+ "DATE": _build_as_cast("date"),
182
+ "DATE_TRUNC": lambda args: exp.TimestampTrunc(
183
+ this=seq_get(args, 1), unit=exp.var(seq_get(args, 0))
184
+ ),
185
+ "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
186
+ "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
187
+ "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
188
+ "DOUBLE": _build_as_cast("double"),
189
+ "FLOAT": _build_as_cast("float"),
190
+ "FROM_UTC_TIMESTAMP": lambda args, dialect: exp.AtTimeZone(
191
+ this=exp.cast(
192
+ seq_get(args, 0) or exp.Var(this=""),
193
+ exp.DataType.Type.TIMESTAMP,
194
+ dialect=dialect,
195
+ ),
196
+ zone=seq_get(args, 1),
197
+ ),
198
+ "INT": _build_as_cast("int"),
199
+ "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
200
+ "RLIKE": exp.RegexpLike.from_arg_list,
201
+ "SHIFTLEFT": binary_from_function(exp.BitwiseLeftShift),
202
+ "SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift),
203
+ "STRING": _build_as_cast("string"),
204
+ "TIMESTAMP": _build_as_cast("timestamp"),
205
+ "TO_TIMESTAMP": lambda args: (
206
+ _build_as_cast("timestamp")(args)
207
+ if len(args) == 1
208
+ else build_formatted_time(exp.StrToTime, "spark")(args)
209
+ ),
210
+ "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
211
+ "TO_UTC_TIMESTAMP": lambda args, dialect: exp.FromTimeZone(
212
+ this=exp.cast(
213
+ seq_get(args, 0) or exp.Var(this=""),
214
+ exp.DataType.Type.TIMESTAMP,
215
+ dialect=dialect,
216
+ ),
217
+ zone=seq_get(args, 1),
218
+ ),
219
+ "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
220
+ "WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
221
+ }
222
+
223
+ FUNCTION_PARSERS = {
224
+ **Hive.Parser.FUNCTION_PARSERS,
225
+ "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
226
+ "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
227
+ "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
228
+ "MERGE": lambda self: self._parse_join_hint("MERGE"),
229
+ "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
230
+ "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
231
+ "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
232
+ "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
233
+ }
234
+
235
+ def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]:
236
+ return self._match_text_seq("DROP", "COLUMNS") and self.expression(
237
+ exp.Drop, this=self._parse_schema(), kind="COLUMNS"
238
+ )
239
+
240
+ def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
241
+ if len(aggregations) == 1:
242
+ return []
243
+ return pivot_column_names(aggregations, dialect="spark")
244
+
245
+ class Generator(Hive.Generator):
246
+ QUERY_HINTS = True
247
+ NVL2_SUPPORTED = True
248
+ CAN_IMPLEMENT_ARRAY_ANY = True
249
+
250
+ PROPERTIES_LOCATION = {
251
+ **Hive.Generator.PROPERTIES_LOCATION,
252
+ exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
253
+ exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
254
+ exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
255
+ exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
256
+ }
257
+
258
+ TRANSFORMS = {
259
+ **Hive.Generator.TRANSFORMS,
260
+ exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
261
+ exp.ArraySum: lambda self,
262
+ e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
263
+ exp.ArrayToString: rename_func("ARRAY_JOIN"),
264
+ exp.AtTimeZone: lambda self, e: self.func(
265
+ "FROM_UTC_TIMESTAMP", e.this, e.args.get("zone")
266
+ ),
267
+ exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
268
+ exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
269
+ exp.Create: preprocess(
270
+ [
271
+ remove_unique_constraints,
272
+ lambda e: ctas_with_tmp_tables_to_create_tmp_view(
273
+ e, temporary_storage_provider
274
+ ),
275
+ move_schema_columns_to_partitioned_by,
276
+ ]
277
+ ),
278
+ exp.DateFromParts: rename_func("MAKE_DATE"),
279
+ exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, unit_to_str(e)),
280
+ exp.DayOfMonth: rename_func("DAYOFMONTH"),
281
+ exp.DayOfWeek: rename_func("DAYOFWEEK"),
282
+ # (DAY_OF_WEEK(datetime) % 7) + 1 is equivalent to DAYOFWEEK_ISO(datetime)
283
+ exp.DayOfWeekIso: lambda self, e: f"(({self.func('DAYOFWEEK', e.this)} % 7) + 1)",
284
+ exp.DayOfYear: rename_func("DAYOFYEAR"),
285
+ exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
286
+ exp.From: transforms.preprocess([_unalias_pivot]),
287
+ exp.FromTimeZone: lambda self, e: self.func(
288
+ "TO_UTC_TIMESTAMP", e.this, e.args.get("zone")
289
+ ),
290
+ exp.LogicalAnd: rename_func("BOOL_AND"),
291
+ exp.LogicalOr: rename_func("BOOL_OR"),
292
+ exp.Map: _map_sql,
293
+ exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]),
294
+ exp.Reduce: rename_func("AGGREGATE"),
295
+ exp.RegexpReplace: lambda self, e: self.func(
296
+ "REGEXP_REPLACE",
297
+ e.this,
298
+ e.expression,
299
+ e.args["replacement"],
300
+ e.args.get("position"),
301
+ ),
302
+ exp.Select: transforms.preprocess(
303
+ [
304
+ transforms.eliminate_qualify,
305
+ transforms.eliminate_distinct_on,
306
+ transforms.unnest_to_explode,
307
+ transforms.any_to_exists,
308
+ ]
309
+ ),
310
+ exp.StrToDate: _str_to_date,
311
+ exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
312
+ exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this),
313
+ exp.Trim: trim_sql,
314
+ exp.UnixToTime: _unix_to_time_sql,
315
+ exp.VariancePop: rename_func("VAR_POP"),
316
+ exp.WeekOfYear: rename_func("WEEKOFYEAR"),
317
+ exp.WithinGroup: transforms.preprocess(
318
+ [transforms.remove_within_group_for_percentiles]
319
+ ),
320
+ }
321
+ TRANSFORMS.pop(exp.ArraySort)
322
+ TRANSFORMS.pop(exp.ILike)
323
+ TRANSFORMS.pop(exp.Left)
324
+ TRANSFORMS.pop(exp.MonthsBetween)
325
+ TRANSFORMS.pop(exp.Right)
326
+
327
+ WRAP_DERIVED_VALUES = False
328
+ CREATE_FUNCTION_RETURN_AS = False
329
+
330
+ def struct_sql(self, expression: exp.Struct) -> str:
331
+ from sqlglot.generator import Generator
332
+
333
+ return Generator.struct_sql(self, expression)
334
+
335
+ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
336
+ arg = expression.this
337
+ is_json_extract = isinstance(
338
+ arg, (exp.JSONExtract, exp.JSONExtractScalar)
339
+ ) and not arg.args.get("variant_extract")
340
+
341
+ # We can't use a non-nested type (eg. STRING) as a schema
342
+ if expression.to.args.get("nested") and (is_parse_json(arg) or is_json_extract):
343
+ schema = f"'{self.sql(expression, 'to')}'"
344
+ return self.func("FROM_JSON", arg if is_json_extract else arg.this, schema)
345
+
346
+ if is_parse_json(expression):
347
+ return self.func("TO_JSON", arg)
348
+
349
+ return super(Hive.Generator, self).cast_sql(expression, safe_prefix=safe_prefix)