sqlglot 27.29.0__py3-none-any.whl → 28.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sqlglot/__main__.py +6 -4
  2. sqlglot/_version.py +2 -2
  3. sqlglot/dialects/bigquery.py +116 -295
  4. sqlglot/dialects/clickhouse.py +67 -2
  5. sqlglot/dialects/databricks.py +38 -1
  6. sqlglot/dialects/dialect.py +327 -286
  7. sqlglot/dialects/dremio.py +4 -1
  8. sqlglot/dialects/duckdb.py +718 -22
  9. sqlglot/dialects/exasol.py +243 -10
  10. sqlglot/dialects/hive.py +8 -8
  11. sqlglot/dialects/mysql.py +11 -2
  12. sqlglot/dialects/oracle.py +29 -0
  13. sqlglot/dialects/postgres.py +46 -24
  14. sqlglot/dialects/presto.py +47 -16
  15. sqlglot/dialects/redshift.py +16 -0
  16. sqlglot/dialects/risingwave.py +3 -0
  17. sqlglot/dialects/singlestore.py +12 -3
  18. sqlglot/dialects/snowflake.py +199 -271
  19. sqlglot/dialects/spark.py +2 -2
  20. sqlglot/dialects/spark2.py +11 -48
  21. sqlglot/dialects/sqlite.py +9 -0
  22. sqlglot/dialects/teradata.py +5 -8
  23. sqlglot/dialects/trino.py +6 -0
  24. sqlglot/dialects/tsql.py +61 -25
  25. sqlglot/diff.py +4 -2
  26. sqlglot/errors.py +69 -0
  27. sqlglot/expressions.py +484 -84
  28. sqlglot/generator.py +143 -41
  29. sqlglot/helper.py +2 -2
  30. sqlglot/optimizer/annotate_types.py +247 -140
  31. sqlglot/optimizer/canonicalize.py +6 -1
  32. sqlglot/optimizer/eliminate_joins.py +1 -1
  33. sqlglot/optimizer/eliminate_subqueries.py +2 -2
  34. sqlglot/optimizer/merge_subqueries.py +5 -5
  35. sqlglot/optimizer/normalize.py +20 -13
  36. sqlglot/optimizer/normalize_identifiers.py +17 -3
  37. sqlglot/optimizer/optimizer.py +4 -0
  38. sqlglot/optimizer/pushdown_predicates.py +1 -1
  39. sqlglot/optimizer/qualify.py +14 -6
  40. sqlglot/optimizer/qualify_columns.py +113 -352
  41. sqlglot/optimizer/qualify_tables.py +112 -70
  42. sqlglot/optimizer/resolver.py +374 -0
  43. sqlglot/optimizer/scope.py +27 -16
  44. sqlglot/optimizer/simplify.py +1074 -964
  45. sqlglot/optimizer/unnest_subqueries.py +12 -2
  46. sqlglot/parser.py +276 -160
  47. sqlglot/planner.py +2 -2
  48. sqlglot/schema.py +15 -4
  49. sqlglot/tokens.py +42 -7
  50. sqlglot/transforms.py +77 -22
  51. sqlglot/typing/__init__.py +316 -0
  52. sqlglot/typing/bigquery.py +376 -0
  53. sqlglot/typing/hive.py +12 -0
  54. sqlglot/typing/presto.py +24 -0
  55. sqlglot/typing/snowflake.py +505 -0
  56. sqlglot/typing/spark2.py +58 -0
  57. sqlglot/typing/tsql.py +9 -0
  58. {sqlglot-27.29.0.dist-info → sqlglot-28.4.1.dist-info}/METADATA +2 -2
  59. sqlglot-28.4.1.dist-info/RECORD +92 -0
  60. sqlglot-27.29.0.dist-info/RECORD +0 -84
  61. {sqlglot-27.29.0.dist-info → sqlglot-28.4.1.dist-info}/WHEEL +0 -0
  62. {sqlglot-27.29.0.dist-info → sqlglot-28.4.1.dist-info}/licenses/LICENSE +0 -0
  63. {sqlglot-27.29.0.dist-info → sqlglot-28.4.1.dist-info}/top_level.txt +0 -0
sqlglot/dialects/spark.py CHANGED
@@ -4,7 +4,6 @@ import typing as t
4
4
 
5
5
  from sqlglot import exp
6
6
  from sqlglot.dialects.dialect import (
7
- Version,
8
7
  rename_func,
9
8
  build_like,
10
9
  unit_to_var,
@@ -100,7 +99,7 @@ def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.Timestam
100
99
 
101
100
 
102
101
  def _groupconcat_sql(self: Spark.Generator, expression: exp.GroupConcat) -> str:
103
- if self.dialect.version < Version("4.0.0"):
102
+ if self.dialect.version < (4,):
104
103
  expr = exp.ArrayToString(
105
104
  this=exp.ArrayAgg(this=expression.this),
106
105
  expression=expression.args.get("separator") or exp.Literal.string(""),
@@ -112,6 +111,7 @@ def _groupconcat_sql(self: Spark.Generator, expression: exp.GroupConcat) -> str:
112
111
 
113
112
  class Spark(Spark2):
114
113
  SUPPORTS_ORDER_BY_ALL = True
114
+ SUPPORTS_NULL_TYPE = True
115
115
 
116
116
  class Tokenizer(Spark2.Tokenizer):
117
117
  STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
@@ -13,7 +13,7 @@ from sqlglot.dialects.dialect import (
13
13
  unit_to_str,
14
14
  )
15
15
  from sqlglot.dialects.hive import Hive
16
- from sqlglot.helper import seq_get, ensure_list
16
+ from sqlglot.helper import seq_get
17
17
  from sqlglot.tokens import TokenType
18
18
  from sqlglot.transforms import (
19
19
  preprocess,
@@ -21,11 +21,7 @@ from sqlglot.transforms import (
21
21
  ctas_with_tmp_tables_to_create_tmp_view,
22
22
  move_schema_columns_to_partitioned_by,
23
23
  )
24
-
25
- if t.TYPE_CHECKING:
26
- from sqlglot._typing import E
27
-
28
- from sqlglot.optimizer.annotate_types import TypeAnnotator
24
+ from sqlglot.typing.spark2 import EXPRESSION_METADATA
29
25
 
30
26
 
31
27
  def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
@@ -118,51 +114,15 @@ def temporary_storage_provider(expression: exp.Expression) -> exp.Expression:
118
114
  return expression
119
115
 
120
116
 
121
- def _annotate_by_similar_args(
122
- self: TypeAnnotator, expression: E, *args: str, target_type: exp.DataType | exp.DataType.Type
123
- ) -> E:
124
- """
125
- Infers the type of the expression according to the following rules:
126
- - If all args are of the same type OR any arg is of target_type, the expr is inferred as such
127
- - If any arg is of UNKNOWN type and none of target_type, the expr is inferred as UNKNOWN
128
- """
129
- self._annotate_args(expression)
130
-
131
- expressions: t.List[exp.Expression] = []
132
- for arg in args:
133
- arg_expr = expression.args.get(arg)
134
- expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
135
-
136
- last_datatype = None
137
-
138
- has_unknown = False
139
- for expr in expressions:
140
- if expr.is_type(exp.DataType.Type.UNKNOWN):
141
- has_unknown = True
142
- elif expr.is_type(target_type):
143
- has_unknown = False
144
- last_datatype = target_type
145
- break
146
- else:
147
- last_datatype = expr.type
148
-
149
- self._set_type(expression, exp.DataType.Type.UNKNOWN if has_unknown else last_datatype)
150
- return expression
151
-
152
-
153
117
  class Spark2(Hive):
154
118
  ALTER_TABLE_SUPPORTS_CASCADE = False
155
119
 
156
- ANNOTATORS = {
157
- **Hive.ANNOTATORS,
158
- exp.Substring: lambda self, e: self._annotate_by_args(e, "this"),
159
- exp.Concat: lambda self, e: _annotate_by_similar_args(
160
- self, e, "expressions", target_type=exp.DataType.Type.TEXT
161
- ),
162
- exp.Pad: lambda self, e: _annotate_by_similar_args(
163
- self, e, "this", "fill_pattern", target_type=exp.DataType.Type.TEXT
164
- ),
165
- }
120
+ EXPRESSION_METADATA = EXPRESSION_METADATA.copy()
121
+
122
+ # https://spark.apache.org/docs/latest/api/sql/index.html#initcap
123
+ # https://docs.databricks.com/aws/en/sql/language-manual/functions/initcap
124
+ # https://github.com/apache/spark/blob/master/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java#L859-L905
125
+ INITCAP_DEFAULT_DELIMITER_CHARS = " "
166
126
 
167
127
  class Tokenizer(Hive.Tokenizer):
168
128
  HEX_STRINGS = [("X'", "'"), ("x'", "'")]
@@ -322,6 +282,9 @@ class Spark2(Hive):
322
282
  transforms.any_to_exists,
323
283
  ]
324
284
  ),
285
+ exp.SHA2Digest: lambda self, e: self.func(
286
+ "SHA2", e.this, e.args.get("length") or exp.Literal.number(256)
287
+ ),
325
288
  exp.StrToDate: _str_to_date,
326
289
  exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
327
290
  exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this),
@@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import (
18
18
  strposition_sql,
19
19
  )
20
20
  from sqlglot.generator import unsupported_args
21
+ from sqlglot.parser import binary_range_parser
21
22
  from sqlglot.tokens import TokenType
22
23
 
23
24
 
@@ -101,6 +102,8 @@ class SQLite(Dialect):
101
102
  **tokens.Tokenizer.KEYWORDS,
102
103
  "ATTACH": TokenType.ATTACH,
103
104
  "DETACH": TokenType.DETACH,
105
+ "INDEXED BY": TokenType.INDEXED_BY,
106
+ "MATCH": TokenType.MATCH,
104
107
  }
105
108
 
106
109
  KEYWORDS.pop("/*+")
@@ -127,6 +130,12 @@ class SQLite(Dialect):
127
130
  TokenType.DETACH: lambda self: self._parse_attach_detach(is_attach=False),
128
131
  }
129
132
 
133
+ RANGE_PARSERS = {
134
+ **parser.Parser.RANGE_PARSERS,
135
+ # https://www.sqlite.org/lang_expr.html
136
+ TokenType.MATCH: binary_range_parser(exp.Match),
137
+ }
138
+
130
139
  def _parse_unique(self) -> exp.UniqueColumnConstraint:
131
140
  # Do not consume more tokens if UNIQUE is used as a standalone constraint, e.g:
132
141
  # CREATE TABLE foo (bar TEXT UNIQUE REFERENCES baz ...)
@@ -213,13 +213,10 @@ class Teradata(Dialect):
213
213
  def _parse_update(self) -> exp.Update:
214
214
  return self.expression(
215
215
  exp.Update,
216
- **{ # type: ignore
217
- "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
218
- "from": self._parse_from(joins=True),
219
- "expressions": self._match(TokenType.SET)
220
- and self._parse_csv(self._parse_equality),
221
- "where": self._parse_where(),
222
- },
216
+ this=self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
217
+ from_=self._parse_from(joins=True),
218
+ expressions=self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
219
+ where=self._parse_where(),
223
220
  )
224
221
 
225
222
  def _parse_rangen(self):
@@ -387,7 +384,7 @@ class Teradata(Dialect):
387
384
  # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
388
385
  def update_sql(self, expression: exp.Update) -> str:
389
386
  this = self.sql(expression, "this")
390
- from_sql = self.sql(expression, "from")
387
+ from_sql = self.sql(expression, "from_")
391
388
  set_sql = self.expressions(expression, flat=True)
392
389
  where_sql = self.sql(expression, "where")
393
390
  sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}"
sqlglot/dialects/trino.py CHANGED
@@ -16,6 +16,12 @@ class Trino(Presto):
16
16
  SUPPORTS_USER_DEFINED_TYPES = False
17
17
  LOG_BASE_FIRST = True
18
18
 
19
+ class Tokenizer(Presto.Tokenizer):
20
+ KEYWORDS = {
21
+ **Presto.Tokenizer.KEYWORDS,
22
+ "REFRESH": TokenType.REFRESH,
23
+ }
24
+
19
25
  class Parser(Presto.Parser):
20
26
  FUNCTION_PARSERS = {
21
27
  **Presto.Parser.FUNCTION_PARSERS,
sqlglot/dialects/tsql.py CHANGED
@@ -20,11 +20,13 @@ from sqlglot.dialects.dialect import (
20
20
  strposition_sql,
21
21
  timestrtotime_sql,
22
22
  trim_sql,
23
+ map_date_part,
23
24
  )
24
25
  from sqlglot.helper import seq_get
25
26
  from sqlglot.parser import build_coalesce
26
27
  from sqlglot.time import format_time
27
28
  from sqlglot.tokens import TokenType
29
+ from sqlglot.typing.tsql import EXPRESSION_METADATA
28
30
 
29
31
  if t.TYPE_CHECKING:
30
32
  from sqlglot._typing import E
@@ -56,6 +58,11 @@ DATE_DELTA_INTERVAL = {
56
58
  "d": "day",
57
59
  }
58
60
 
61
+ DATE_PART_UNMAPPING = {
62
+ "WEEKISO": "ISO_WEEK",
63
+ "DAYOFWEEK": "WEEKDAY",
64
+ "TIMEZONE_MINUTE": "TZOFFSET",
65
+ }
59
66
 
60
67
  DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})")
61
68
 
@@ -200,20 +207,12 @@ def _build_hashbytes(args: t.List) -> exp.Expression:
200
207
  return exp.func("HASHBYTES", *args)
201
208
 
202
209
 
203
- DATEPART_ONLY_FORMATS = {"DW", "WK", "HOUR", "QUARTER", "ISO_WEEK"}
204
-
205
-
206
210
  def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
207
211
  fmt = expression.args["format"]
208
212
 
209
213
  if not isinstance(expression, exp.NumberToStr):
210
214
  if fmt.is_string:
211
215
  mapped_fmt = format_time(fmt.name, TSQL.INVERSE_TIME_MAPPING)
212
-
213
- name = (mapped_fmt or "").upper()
214
- if name in DATEPART_ONLY_FORMATS:
215
- return self.func("DATEPART", name, expression.this)
216
-
217
216
  fmt_sql = self.sql(exp.Literal.string(mapped_fmt))
218
217
  else:
219
218
  fmt_sql = self.format_time(expression) or self.sql(fmt)
@@ -243,7 +242,7 @@ def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
243
242
 
244
243
 
245
244
  def _build_date_delta(
246
- exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
245
+ exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None, big_int: bool = False
247
246
  ) -> t.Callable[[t.List], E]:
248
247
  def _builder(args: t.List) -> E:
249
248
  unit = seq_get(args, 0)
@@ -259,12 +258,15 @@ def _build_date_delta(
259
258
  else:
260
259
  # We currently don't handle float values, i.e. they're not converted to equivalent DATETIMEs.
261
260
  # This is not a problem when generating T-SQL code, it is when transpiling to other dialects.
262
- return exp_class(this=seq_get(args, 2), expression=start_date, unit=unit)
261
+ return exp_class(
262
+ this=seq_get(args, 2), expression=start_date, unit=unit, big_int=big_int
263
+ )
263
264
 
264
265
  return exp_class(
265
266
  this=exp.TimeStrToTime(this=seq_get(args, 2)),
266
267
  expression=exp.TimeStrToTime(this=start_date),
267
268
  unit=unit,
269
+ big_int=big_int,
268
270
  )
269
271
 
270
272
  return _builder
@@ -412,9 +414,22 @@ class TSQL(Dialect):
412
414
 
413
415
  TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
414
416
 
415
- ANNOTATORS = {
416
- **Dialect.ANNOTATORS,
417
- exp.Radians: lambda self, e: self._annotate_by_args(e, "this"),
417
+ EXPRESSION_METADATA = EXPRESSION_METADATA.copy()
418
+
419
+ DATE_PART_MAPPING = {
420
+ **Dialect.DATE_PART_MAPPING,
421
+ "QQ": "QUARTER",
422
+ "M": "MONTH",
423
+ "Y": "DAYOFYEAR",
424
+ "WW": "WEEK",
425
+ "N": "MINUTE",
426
+ "SS": "SECOND",
427
+ "MCS": "MICROSECOND",
428
+ "TZOFFSET": "TIMEZONE_MINUTE",
429
+ "TZ": "TIMEZONE_MINUTE",
430
+ "ISO_WEEK": "WEEKISO",
431
+ "ISOWK": "WEEKISO",
432
+ "ISOWW": "WEEKISO",
418
433
  }
419
434
 
420
435
  TIME_MAPPING = {
@@ -426,9 +441,9 @@ class TSQL(Dialect):
426
441
  "week": "%W",
427
442
  "ww": "%W",
428
443
  "wk": "%W",
429
- "isowk": "%IW",
430
- "isoww": "%IW",
431
- "iso_week": "%IW",
444
+ "isowk": "%V",
445
+ "isoww": "%V",
446
+ "iso_week": "%V",
432
447
  "hour": "%h",
433
448
  "hh": "%I",
434
449
  "minute": "%M",
@@ -574,7 +589,7 @@ class TSQL(Dialect):
574
589
  QUERY_MODIFIER_PARSERS = {
575
590
  **parser.Parser.QUERY_MODIFIER_PARSERS,
576
591
  TokenType.OPTION: lambda self: ("options", self._parse_options()),
577
- TokenType.FOR: lambda self: ("for", self._parse_for()),
592
+ TokenType.FOR: lambda self: ("for_", self._parse_for()),
578
593
  }
579
594
 
580
595
  # T-SQL does not allow BEGIN to be used as an identifier
@@ -599,8 +614,10 @@ class TSQL(Dialect):
599
614
  ),
600
615
  "DATEADD": build_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
601
616
  "DATEDIFF": _build_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
617
+ "DATEDIFF_BIG": _build_date_delta(
618
+ exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL, big_int=True
619
+ ),
602
620
  "DATENAME": _build_formatted_time(exp.TimeToStr, full_format_mapping=True),
603
- "DATEPART": _build_formatted_time(exp.TimeToStr),
604
621
  "DATETIMEFROMPARTS": _build_datetimefromparts,
605
622
  "EOMONTH": _build_eomonth,
606
623
  "FORMAT": _build_format,
@@ -666,6 +683,7 @@ class TSQL(Dialect):
666
683
  order=self._parse_order(),
667
684
  null_handling=self._parse_on_handling("NULL", "NULL", "ABSENT"),
668
685
  ),
686
+ "DATEPART": lambda self: self._parse_datepart(),
669
687
  }
670
688
 
671
689
  # The DCOLON (::) operator serves as a scope resolution (exp.ScopeResolution) operator in T-SQL
@@ -684,6 +702,13 @@ class TSQL(Dialect):
684
702
  "ts": exp.Timestamp,
685
703
  }
686
704
 
705
+ def _parse_datepart(self) -> exp.Extract:
706
+ this = self._parse_var()
707
+ expression = self._match(TokenType.COMMA) and self._parse_bitwise()
708
+ name = map_date_part(this, self.dialect)
709
+
710
+ return self.expression(exp.Extract, this=name, expression=expression)
711
+
687
712
  def _parse_alter_table_set(self) -> exp.AlterSet:
688
713
  return self._parse_wrapped(super()._parse_alter_table_set)
689
714
 
@@ -821,7 +846,6 @@ class TSQL(Dialect):
821
846
  args = [this, *self._parse_csv(self._parse_assignment)]
822
847
  convert = exp.Convert.from_arg_list(args)
823
848
  convert.set("safe", safe)
824
- convert.set("strict", strict)
825
849
  return convert
826
850
 
827
851
  def _parse_column_def(
@@ -878,7 +902,7 @@ class TSQL(Dialect):
878
902
  this = super()._parse_id_var(any_token=any_token, tokens=tokens)
879
903
  if this:
880
904
  if is_global:
881
- this.set("global", True)
905
+ this.set("global_", True)
882
906
  elif is_temporary:
883
907
  this.set("temporary", True)
884
908
 
@@ -1033,15 +1057,14 @@ class TSQL(Dialect):
1033
1057
  exp.AnyValue: any_value_to_max_sql,
1034
1058
  exp.ArrayToString: rename_func("STRING_AGG"),
1035
1059
  exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY",
1060
+ exp.Ceil: rename_func("CEILING"),
1036
1061
  exp.Chr: rename_func("CHAR"),
1037
1062
  exp.DateAdd: date_delta_sql("DATEADD"),
1038
- exp.DateDiff: date_delta_sql("DATEDIFF"),
1039
1063
  exp.CTE: transforms.preprocess([qualify_derived_table_outputs]),
1040
1064
  exp.CurrentDate: rename_func("GETDATE"),
1041
1065
  exp.CurrentTimestamp: rename_func("GETDATE"),
1042
1066
  exp.CurrentTimestampLTZ: rename_func("SYSDATETIMEOFFSET"),
1043
1067
  exp.DateStrToDate: datestrtodate_sql,
1044
- exp.Extract: rename_func("DATEPART"),
1045
1068
  exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
1046
1069
  exp.GroupConcat: _string_agg_sql,
1047
1070
  exp.If: rename_func("IIF"),
@@ -1069,6 +1092,9 @@ class TSQL(Dialect):
1069
1092
  ),
1070
1093
  exp.Subquery: transforms.preprocess([qualify_derived_table_outputs]),
1071
1094
  exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
1095
+ exp.SHA1Digest: lambda self, e: self.func(
1096
+ "HASHBYTES", exp.Literal.string("SHA1"), e.this
1097
+ ),
1072
1098
  exp.SHA2: lambda self, e: self.func(
1073
1099
  "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
1074
1100
  ),
@@ -1163,6 +1189,12 @@ class TSQL(Dialect):
1163
1189
  "PARSENAME", this, exp.Literal.number(split_count + 1 - part_index.to_py())
1164
1190
  )
1165
1191
 
1192
+ def extract_sql(self, expression: exp.Extract) -> str:
1193
+ part = expression.this
1194
+ name = DATE_PART_UNMAPPING.get(part.name.upper()) or part
1195
+
1196
+ return self.func("DATEPART", name, expression.expression)
1197
+
1166
1198
  def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
1167
1199
  nano = expression.args.get("nano")
1168
1200
  if nano is not None:
@@ -1238,12 +1270,12 @@ class TSQL(Dialect):
1238
1270
 
1239
1271
  if kind == "VIEW":
1240
1272
  expression.this.set("catalog", None)
1241
- with_ = expression.args.get("with")
1273
+ with_ = expression.args.get("with_")
1242
1274
  if ctas_expression and with_:
1243
1275
  # We've already preprocessed the Create expression to bubble up any nested CTEs,
1244
1276
  # but CREATE VIEW actually requires the WITH clause to come after it so we need
1245
1277
  # to amend the AST by moving the CTEs to the CREATE VIEW statement's query.
1246
- ctas_expression.set("with", with_.pop())
1278
+ ctas_expression.set("with_", with_.pop())
1247
1279
 
1248
1280
  table = expression.find(exp.Table)
1249
1281
 
@@ -1301,6 +1333,10 @@ class TSQL(Dialect):
1301
1333
  func_name = "COUNT_BIG" if expression.args.get("big_int") else "COUNT"
1302
1334
  return rename_func(func_name)(self, expression)
1303
1335
 
1336
+ def datediff_sql(self, expression: exp.DateDiff) -> str:
1337
+ func_name = "DATEDIFF_BIG" if expression.args.get("big_int") else "DATEDIFF"
1338
+ return date_delta_sql(func_name)(self, expression)
1339
+
1304
1340
  def offset_sql(self, expression: exp.Offset) -> str:
1305
1341
  return f"{super().offset_sql(expression)} ROWS"
1306
1342
 
@@ -1355,7 +1391,7 @@ class TSQL(Dialect):
1355
1391
  def identifier_sql(self, expression: exp.Identifier) -> str:
1356
1392
  identifier = super().identifier_sql(expression)
1357
1393
 
1358
- if expression.args.get("global"):
1394
+ if expression.args.get("global_"):
1359
1395
  identifier = f"##{identifier}"
1360
1396
  elif expression.args.get("temporary"):
1361
1397
  identifier = f"#{identifier}"
sqlglot/diff.py CHANGED
@@ -393,8 +393,10 @@ def _get_expression_leaves(expression: exp.Expression) -> t.Iterator[exp.Express
393
393
 
394
394
  def _get_non_expression_leaves(expression: exp.Expression) -> t.Iterator[t.Tuple[str, t.Any]]:
395
395
  for arg, value in expression.args.items():
396
- if isinstance(value, exp.Expression) or (
397
- isinstance(value, list) and isinstance(seq_get(value, 0), exp.Expression)
396
+ if (
397
+ value is None
398
+ or isinstance(value, exp.Expression)
399
+ or (isinstance(value, list) and isinstance(seq_get(value, 0), exp.Expression))
398
400
  ):
399
401
  continue
400
402
 
sqlglot/errors.py CHANGED
@@ -6,6 +6,12 @@ from enum import auto
6
6
  from sqlglot.helper import AutoName
7
7
 
8
8
 
9
+ # ANSI escape codes for error formatting
10
+ ANSI_UNDERLINE = "\033[4m"
11
+ ANSI_RESET = "\033[0m"
12
+ ERROR_MESSAGE_CONTEXT_DEFAULT = 100
13
+
14
+
9
15
  class ErrorLevel(AutoName):
10
16
  IGNORE = auto()
11
17
  """Ignore all errors."""
@@ -81,6 +87,69 @@ class ExecuteError(SqlglotError):
81
87
  pass
82
88
 
83
89
 
90
+ def highlight_sql(
91
+ sql: str,
92
+ positions: t.List[t.Tuple[int, int]],
93
+ context_length: int = ERROR_MESSAGE_CONTEXT_DEFAULT,
94
+ ) -> t.Tuple[str, str, str, str]:
95
+ """
96
+ Highlight a SQL string using ANSI codes at the given positions.
97
+
98
+ Args:
99
+ sql: The complete SQL string.
100
+ positions: List of (start, end) tuples where both start and end are inclusive 0-based
101
+ indexes. For example, to highlight "foo" in "SELECT foo", use (7, 9).
102
+ The positions will be sorted and de-duplicated if they overlap.
103
+ context_length: Number of characters to show before the first highlight and after
104
+ the last highlight.
105
+
106
+ Returns:
107
+ A tuple of (formatted_sql, start_context, highlight, end_context) where:
108
+ - formatted_sql: The SQL with ANSI underline codes applied to highlighted sections
109
+ - start_context: Plain text before the first highlight
110
+ - highlight: Plain text from the first highlight start to the last highlight end,
111
+ including any non-highlighted text in between (no ANSI)
112
+ - end_context: Plain text after the last highlight
113
+
114
+ Note:
115
+ If positions is empty, raises a ValueError.
116
+ """
117
+ if not positions:
118
+ raise ValueError("positions must contain at least one (start, end) tuple")
119
+
120
+ start_context = ""
121
+ end_context = ""
122
+ first_highlight_start = 0
123
+ formatted_parts = []
124
+ previous_part_end = 0
125
+ sorted_positions = sorted(positions, key=lambda pos: pos[0])
126
+
127
+ if sorted_positions[0][0] > 0:
128
+ first_highlight_start = sorted_positions[0][0]
129
+ start_context = sql[max(0, first_highlight_start - context_length) : first_highlight_start]
130
+ formatted_parts.append(start_context)
131
+ previous_part_end = first_highlight_start
132
+
133
+ for start, end in sorted_positions:
134
+ highlight_start = max(start, previous_part_end)
135
+ highlight_end = end + 1
136
+ if highlight_start >= highlight_end:
137
+ continue # Skip invalid or overlapping highlights
138
+ if highlight_start > previous_part_end:
139
+ formatted_parts.append(sql[previous_part_end:highlight_start])
140
+ formatted_parts.append(f"{ANSI_UNDERLINE}{sql[highlight_start:highlight_end]}{ANSI_RESET}")
141
+ previous_part_end = highlight_end
142
+
143
+ if previous_part_end < len(sql):
144
+ end_context = sql[previous_part_end : previous_part_end + context_length]
145
+ formatted_parts.append(end_context)
146
+
147
+ formatted_sql = "".join(formatted_parts)
148
+ highlight = sql[first_highlight_start:previous_part_end]
149
+
150
+ return formatted_sql, start_context, highlight, end_context
151
+
152
+
84
153
  def concat_messages(errors: t.Sequence[t.Any], maximum: int) -> str:
85
154
  msg = [str(e) for e in errors[:maximum]]
86
155
  remaining = len(errors) - maximum