sqlglot 27.27.0__py3-none-any.whl → 28.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sqlglot/__init__.py +1 -0
  2. sqlglot/__main__.py +6 -4
  3. sqlglot/_version.py +2 -2
  4. sqlglot/dialects/bigquery.py +118 -279
  5. sqlglot/dialects/clickhouse.py +73 -5
  6. sqlglot/dialects/databricks.py +38 -1
  7. sqlglot/dialects/dialect.py +354 -275
  8. sqlglot/dialects/dremio.py +4 -1
  9. sqlglot/dialects/duckdb.py +754 -25
  10. sqlglot/dialects/exasol.py +243 -10
  11. sqlglot/dialects/hive.py +8 -8
  12. sqlglot/dialects/mysql.py +14 -4
  13. sqlglot/dialects/oracle.py +29 -0
  14. sqlglot/dialects/postgres.py +60 -26
  15. sqlglot/dialects/presto.py +47 -16
  16. sqlglot/dialects/redshift.py +16 -0
  17. sqlglot/dialects/risingwave.py +3 -0
  18. sqlglot/dialects/singlestore.py +12 -3
  19. sqlglot/dialects/snowflake.py +239 -218
  20. sqlglot/dialects/spark.py +15 -4
  21. sqlglot/dialects/spark2.py +11 -48
  22. sqlglot/dialects/sqlite.py +10 -0
  23. sqlglot/dialects/starrocks.py +3 -0
  24. sqlglot/dialects/teradata.py +5 -8
  25. sqlglot/dialects/trino.py +6 -0
  26. sqlglot/dialects/tsql.py +61 -22
  27. sqlglot/diff.py +4 -2
  28. sqlglot/errors.py +69 -0
  29. sqlglot/executor/__init__.py +5 -10
  30. sqlglot/executor/python.py +1 -29
  31. sqlglot/expressions.py +637 -100
  32. sqlglot/generator.py +160 -43
  33. sqlglot/helper.py +2 -44
  34. sqlglot/lineage.py +10 -4
  35. sqlglot/optimizer/annotate_types.py +247 -140
  36. sqlglot/optimizer/canonicalize.py +6 -1
  37. sqlglot/optimizer/eliminate_joins.py +1 -1
  38. sqlglot/optimizer/eliminate_subqueries.py +2 -2
  39. sqlglot/optimizer/merge_subqueries.py +5 -5
  40. sqlglot/optimizer/normalize.py +20 -13
  41. sqlglot/optimizer/normalize_identifiers.py +17 -3
  42. sqlglot/optimizer/optimizer.py +4 -0
  43. sqlglot/optimizer/pushdown_predicates.py +1 -1
  44. sqlglot/optimizer/qualify.py +18 -10
  45. sqlglot/optimizer/qualify_columns.py +122 -275
  46. sqlglot/optimizer/qualify_tables.py +128 -76
  47. sqlglot/optimizer/resolver.py +374 -0
  48. sqlglot/optimizer/scope.py +27 -16
  49. sqlglot/optimizer/simplify.py +1075 -959
  50. sqlglot/optimizer/unnest_subqueries.py +12 -2
  51. sqlglot/parser.py +296 -170
  52. sqlglot/planner.py +2 -2
  53. sqlglot/schema.py +15 -4
  54. sqlglot/tokens.py +42 -7
  55. sqlglot/transforms.py +77 -22
  56. sqlglot/typing/__init__.py +316 -0
  57. sqlglot/typing/bigquery.py +376 -0
  58. sqlglot/typing/hive.py +12 -0
  59. sqlglot/typing/presto.py +24 -0
  60. sqlglot/typing/snowflake.py +505 -0
  61. sqlglot/typing/spark2.py +58 -0
  62. sqlglot/typing/tsql.py +9 -0
  63. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
  64. sqlglot-28.4.0.dist-info/RECORD +92 -0
  65. sqlglot-27.27.0.dist-info/RECORD +0 -84
  66. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
  67. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
  68. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
sqlglot/dialects/spark.py CHANGED
@@ -4,8 +4,8 @@ import typing as t
4
4
 
5
5
  from sqlglot import exp
6
6
  from sqlglot.dialects.dialect import (
7
- Version,
8
7
  rename_func,
8
+ build_like,
9
9
  unit_to_var,
10
10
  timestampdiff_sql,
11
11
  build_date_delta,
@@ -99,7 +99,7 @@ def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.Timestam
99
99
 
100
100
 
101
101
  def _groupconcat_sql(self: Spark.Generator, expression: exp.GroupConcat) -> str:
102
- if self.dialect.version < Version("4.0.0"):
102
+ if self.dialect.version < (4,):
103
103
  expr = exp.ArrayToString(
104
104
  this=exp.ArrayAgg(this=expression.this),
105
105
  expression=expression.args.get("separator") or exp.Literal.string(""),
@@ -111,6 +111,7 @@ def _groupconcat_sql(self: Spark.Generator, expression: exp.GroupConcat) -> str:
111
111
 
112
112
  class Spark(Spark2):
113
113
  SUPPORTS_ORDER_BY_ALL = True
114
+ SUPPORTS_NULL_TYPE = True
114
115
 
115
116
  class Tokenizer(Spark2.Tokenizer):
116
117
  STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
@@ -128,7 +129,7 @@ class Spark(Spark2):
128
129
  "BIT_AND": exp.BitwiseAndAgg.from_arg_list,
129
130
  "BIT_OR": exp.BitwiseOrAgg.from_arg_list,
130
131
  "BIT_XOR": exp.BitwiseXorAgg.from_arg_list,
131
- "BIT_COUNT": exp.BitwiseCountAgg.from_arg_list,
132
+ "BIT_COUNT": exp.BitwiseCount.from_arg_list,
132
133
  "DATE_ADD": _build_dateadd,
133
134
  "DATEADD": _build_dateadd,
134
135
  "TIMESTAMPADD": _build_dateadd,
@@ -147,6 +148,8 @@ class Spark(Spark2):
147
148
  offset=1,
148
149
  safe=True,
149
150
  ),
151
+ "LIKE": build_like(exp.Like),
152
+ "ILIKE": build_like(exp.ILike),
150
153
  }
151
154
 
152
155
  PLACEHOLDER_PARSERS = {
@@ -196,7 +199,7 @@ class Spark(Spark2):
196
199
  exp.BitwiseAndAgg: rename_func("BIT_AND"),
197
200
  exp.BitwiseOrAgg: rename_func("BIT_OR"),
198
201
  exp.BitwiseXorAgg: rename_func("BIT_XOR"),
199
- exp.BitwiseCountAgg: rename_func("BIT_COUNT"),
202
+ exp.BitwiseCount: rename_func("BIT_COUNT"),
200
203
  exp.Create: preprocess(
201
204
  [
202
205
  remove_unique_constraints,
@@ -258,3 +261,11 @@ class Spark(Spark2):
258
261
  return super().placeholder_sql(expression)
259
262
 
260
263
  return f"{{{expression.name}}}"
264
+
265
+ def readparquet_sql(self, expression: exp.ReadParquet) -> str:
266
+ if len(expression.expressions) != 1:
267
+ self.unsupported("READ_PARQUET with multiple arguments is not supported")
268
+ return ""
269
+
270
+ parquet_file = expression.expressions[0]
271
+ return f"parquet.`{parquet_file.name}`"
@@ -13,7 +13,7 @@ from sqlglot.dialects.dialect import (
13
13
  unit_to_str,
14
14
  )
15
15
  from sqlglot.dialects.hive import Hive
16
- from sqlglot.helper import seq_get, ensure_list
16
+ from sqlglot.helper import seq_get
17
17
  from sqlglot.tokens import TokenType
18
18
  from sqlglot.transforms import (
19
19
  preprocess,
@@ -21,11 +21,7 @@ from sqlglot.transforms import (
21
21
  ctas_with_tmp_tables_to_create_tmp_view,
22
22
  move_schema_columns_to_partitioned_by,
23
23
  )
24
-
25
- if t.TYPE_CHECKING:
26
- from sqlglot._typing import E
27
-
28
- from sqlglot.optimizer.annotate_types import TypeAnnotator
24
+ from sqlglot.typing.spark2 import EXPRESSION_METADATA
29
25
 
30
26
 
31
27
  def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
@@ -118,51 +114,15 @@ def temporary_storage_provider(expression: exp.Expression) -> exp.Expression:
118
114
  return expression
119
115
 
120
116
 
121
- def _annotate_by_similar_args(
122
- self: TypeAnnotator, expression: E, *args: str, target_type: exp.DataType | exp.DataType.Type
123
- ) -> E:
124
- """
125
- Infers the type of the expression according to the following rules:
126
- - If all args are of the same type OR any arg is of target_type, the expr is inferred as such
127
- - If any arg is of UNKNOWN type and none of target_type, the expr is inferred as UNKNOWN
128
- """
129
- self._annotate_args(expression)
130
-
131
- expressions: t.List[exp.Expression] = []
132
- for arg in args:
133
- arg_expr = expression.args.get(arg)
134
- expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
135
-
136
- last_datatype = None
137
-
138
- has_unknown = False
139
- for expr in expressions:
140
- if expr.is_type(exp.DataType.Type.UNKNOWN):
141
- has_unknown = True
142
- elif expr.is_type(target_type):
143
- has_unknown = False
144
- last_datatype = target_type
145
- break
146
- else:
147
- last_datatype = expr.type
148
-
149
- self._set_type(expression, exp.DataType.Type.UNKNOWN if has_unknown else last_datatype)
150
- return expression
151
-
152
-
153
117
  class Spark2(Hive):
154
118
  ALTER_TABLE_SUPPORTS_CASCADE = False
155
119
 
156
- ANNOTATORS = {
157
- **Hive.ANNOTATORS,
158
- exp.Substring: lambda self, e: self._annotate_by_args(e, "this"),
159
- exp.Concat: lambda self, e: _annotate_by_similar_args(
160
- self, e, "expressions", target_type=exp.DataType.Type.TEXT
161
- ),
162
- exp.Pad: lambda self, e: _annotate_by_similar_args(
163
- self, e, "this", "fill_pattern", target_type=exp.DataType.Type.TEXT
164
- ),
165
- }
120
+ EXPRESSION_METADATA = EXPRESSION_METADATA.copy()
121
+
122
+ # https://spark.apache.org/docs/latest/api/sql/index.html#initcap
123
+ # https://docs.databricks.com/aws/en/sql/language-manual/functions/initcap
124
+ # https://github.com/apache/spark/blob/master/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java#L859-L905
125
+ INITCAP_DEFAULT_DELIMITER_CHARS = " "
166
126
 
167
127
  class Tokenizer(Hive.Tokenizer):
168
128
  HEX_STRINGS = [("X'", "'"), ("x'", "'")]
@@ -322,6 +282,9 @@ class Spark2(Hive):
322
282
  transforms.any_to_exists,
323
283
  ]
324
284
  ),
285
+ exp.SHA2Digest: lambda self, e: self.func(
286
+ "SHA2", e.this, e.args.get("length") or exp.Literal.number(256)
287
+ ),
325
288
  exp.StrToDate: _str_to_date,
326
289
  exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
327
290
  exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this),
@@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import (
18
18
  strposition_sql,
19
19
  )
20
20
  from sqlglot.generator import unsupported_args
21
+ from sqlglot.parser import binary_range_parser
21
22
  from sqlglot.tokens import TokenType
22
23
 
23
24
 
@@ -89,6 +90,7 @@ class SQLite(Dialect):
89
90
  SUPPORTS_SEMI_ANTI_JOIN = False
90
91
  TYPED_DIVISION = True
91
92
  SAFE_DIVISION = True
93
+ SAFE_TO_ELIMINATE_DOUBLE_NEGATION = False
92
94
 
93
95
  class Tokenizer(tokens.Tokenizer):
94
96
  IDENTIFIERS = ['"', ("[", "]"), "`"]
@@ -100,6 +102,8 @@ class SQLite(Dialect):
100
102
  **tokens.Tokenizer.KEYWORDS,
101
103
  "ATTACH": TokenType.ATTACH,
102
104
  "DETACH": TokenType.DETACH,
105
+ "INDEXED BY": TokenType.INDEXED_BY,
106
+ "MATCH": TokenType.MATCH,
103
107
  }
104
108
 
105
109
  KEYWORDS.pop("/*+")
@@ -126,6 +130,12 @@ class SQLite(Dialect):
126
130
  TokenType.DETACH: lambda self: self._parse_attach_detach(is_attach=False),
127
131
  }
128
132
 
133
+ RANGE_PARSERS = {
134
+ **parser.Parser.RANGE_PARSERS,
135
+ # https://www.sqlite.org/lang_expr.html
136
+ TokenType.MATCH: binary_range_parser(exp.Match),
137
+ }
138
+
129
139
  def _parse_unique(self) -> exp.UniqueColumnConstraint:
130
140
  # Do not consume more tokens if UNIQUE is used as a standalone constraint, e.g:
131
141
  # CREATE TABLE foo (bar TEXT UNIQUE REFERENCES baz ...)
@@ -127,6 +127,9 @@ class StarRocks(MySQL):
127
127
  PARSE_JSON_NAME: t.Optional[str] = "PARSE_JSON"
128
128
  WITH_PROPERTIES_PREFIX = "PROPERTIES"
129
129
 
130
+ # StarRocks doesn't support "IS TRUE/FALSE" syntax.
131
+ IS_BOOL_ALLOWED = False
132
+
130
133
  CAST_MAPPING = {}
131
134
 
132
135
  TYPE_MAPPING = {
@@ -213,13 +213,10 @@ class Teradata(Dialect):
213
213
  def _parse_update(self) -> exp.Update:
214
214
  return self.expression(
215
215
  exp.Update,
216
- **{ # type: ignore
217
- "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
218
- "from": self._parse_from(joins=True),
219
- "expressions": self._match(TokenType.SET)
220
- and self._parse_csv(self._parse_equality),
221
- "where": self._parse_where(),
222
- },
216
+ this=self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
217
+ from_=self._parse_from(joins=True),
218
+ expressions=self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
219
+ where=self._parse_where(),
223
220
  )
224
221
 
225
222
  def _parse_rangen(self):
@@ -387,7 +384,7 @@ class Teradata(Dialect):
387
384
  # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
388
385
  def update_sql(self, expression: exp.Update) -> str:
389
386
  this = self.sql(expression, "this")
390
- from_sql = self.sql(expression, "from")
387
+ from_sql = self.sql(expression, "from_")
391
388
  set_sql = self.expressions(expression, flat=True)
392
389
  where_sql = self.sql(expression, "where")
393
390
  sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}"
sqlglot/dialects/trino.py CHANGED
@@ -16,6 +16,12 @@ class Trino(Presto):
16
16
  SUPPORTS_USER_DEFINED_TYPES = False
17
17
  LOG_BASE_FIRST = True
18
18
 
19
+ class Tokenizer(Presto.Tokenizer):
20
+ KEYWORDS = {
21
+ **Presto.Tokenizer.KEYWORDS,
22
+ "REFRESH": TokenType.REFRESH,
23
+ }
24
+
19
25
  class Parser(Presto.Parser):
20
26
  FUNCTION_PARSERS = {
21
27
  **Presto.Parser.FUNCTION_PARSERS,
sqlglot/dialects/tsql.py CHANGED
@@ -20,11 +20,13 @@ from sqlglot.dialects.dialect import (
20
20
  strposition_sql,
21
21
  timestrtotime_sql,
22
22
  trim_sql,
23
+ map_date_part,
23
24
  )
24
25
  from sqlglot.helper import seq_get
25
26
  from sqlglot.parser import build_coalesce
26
27
  from sqlglot.time import format_time
27
28
  from sqlglot.tokens import TokenType
29
+ from sqlglot.typing.tsql import EXPRESSION_METADATA
28
30
 
29
31
  if t.TYPE_CHECKING:
30
32
  from sqlglot._typing import E
@@ -56,6 +58,11 @@ DATE_DELTA_INTERVAL = {
56
58
  "d": "day",
57
59
  }
58
60
 
61
+ DATE_PART_UNMAPPING = {
62
+ "WEEKISO": "ISO_WEEK",
63
+ "DAYOFWEEK": "WEEKDAY",
64
+ "TIMEZONE_MINUTE": "TZOFFSET",
65
+ }
59
66
 
60
67
  DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})")
61
68
 
@@ -200,20 +207,12 @@ def _build_hashbytes(args: t.List) -> exp.Expression:
200
207
  return exp.func("HASHBYTES", *args)
201
208
 
202
209
 
203
- DATEPART_ONLY_FORMATS = {"DW", "WK", "HOUR", "QUARTER"}
204
-
205
-
206
210
  def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
207
211
  fmt = expression.args["format"]
208
212
 
209
213
  if not isinstance(expression, exp.NumberToStr):
210
214
  if fmt.is_string:
211
215
  mapped_fmt = format_time(fmt.name, TSQL.INVERSE_TIME_MAPPING)
212
-
213
- name = (mapped_fmt or "").upper()
214
- if name in DATEPART_ONLY_FORMATS:
215
- return self.func("DATEPART", name, expression.this)
216
-
217
216
  fmt_sql = self.sql(exp.Literal.string(mapped_fmt))
218
217
  else:
219
218
  fmt_sql = self.format_time(expression) or self.sql(fmt)
@@ -243,7 +242,7 @@ def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
243
242
 
244
243
 
245
244
  def _build_date_delta(
246
- exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
245
+ exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None, big_int: bool = False
247
246
  ) -> t.Callable[[t.List], E]:
248
247
  def _builder(args: t.List) -> E:
249
248
  unit = seq_get(args, 0)
@@ -259,12 +258,15 @@ def _build_date_delta(
259
258
  else:
260
259
  # We currently don't handle float values, i.e. they're not converted to equivalent DATETIMEs.
261
260
  # This is not a problem when generating T-SQL code, it is when transpiling to other dialects.
262
- return exp_class(this=seq_get(args, 2), expression=start_date, unit=unit)
261
+ return exp_class(
262
+ this=seq_get(args, 2), expression=start_date, unit=unit, big_int=big_int
263
+ )
263
264
 
264
265
  return exp_class(
265
266
  this=exp.TimeStrToTime(this=seq_get(args, 2)),
266
267
  expression=exp.TimeStrToTime(this=start_date),
267
268
  unit=unit,
269
+ big_int=big_int,
268
270
  )
269
271
 
270
272
  return _builder
@@ -412,9 +414,22 @@ class TSQL(Dialect):
412
414
 
413
415
  TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
414
416
 
415
- ANNOTATORS = {
416
- **Dialect.ANNOTATORS,
417
- exp.Radians: lambda self, e: self._annotate_by_args(e, "this"),
417
+ EXPRESSION_METADATA = EXPRESSION_METADATA.copy()
418
+
419
+ DATE_PART_MAPPING = {
420
+ **Dialect.DATE_PART_MAPPING,
421
+ "QQ": "QUARTER",
422
+ "M": "MONTH",
423
+ "Y": "DAYOFYEAR",
424
+ "WW": "WEEK",
425
+ "N": "MINUTE",
426
+ "SS": "SECOND",
427
+ "MCS": "MICROSECOND",
428
+ "TZOFFSET": "TIMEZONE_MINUTE",
429
+ "TZ": "TIMEZONE_MINUTE",
430
+ "ISO_WEEK": "WEEKISO",
431
+ "ISOWK": "WEEKISO",
432
+ "ISOWW": "WEEKISO",
418
433
  }
419
434
 
420
435
  TIME_MAPPING = {
@@ -426,6 +441,9 @@ class TSQL(Dialect):
426
441
  "week": "%W",
427
442
  "ww": "%W",
428
443
  "wk": "%W",
444
+ "isowk": "%V",
445
+ "isoww": "%V",
446
+ "iso_week": "%V",
429
447
  "hour": "%h",
430
448
  "hh": "%I",
431
449
  "minute": "%M",
@@ -571,7 +589,7 @@ class TSQL(Dialect):
571
589
  QUERY_MODIFIER_PARSERS = {
572
590
  **parser.Parser.QUERY_MODIFIER_PARSERS,
573
591
  TokenType.OPTION: lambda self: ("options", self._parse_options()),
574
- TokenType.FOR: lambda self: ("for", self._parse_for()),
592
+ TokenType.FOR: lambda self: ("for_", self._parse_for()),
575
593
  }
576
594
 
577
595
  # T-SQL does not allow BEGIN to be used as an identifier
@@ -596,8 +614,10 @@ class TSQL(Dialect):
596
614
  ),
597
615
  "DATEADD": build_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
598
616
  "DATEDIFF": _build_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
617
+ "DATEDIFF_BIG": _build_date_delta(
618
+ exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL, big_int=True
619
+ ),
599
620
  "DATENAME": _build_formatted_time(exp.TimeToStr, full_format_mapping=True),
600
- "DATEPART": _build_formatted_time(exp.TimeToStr),
601
621
  "DATETIMEFROMPARTS": _build_datetimefromparts,
602
622
  "EOMONTH": _build_eomonth,
603
623
  "FORMAT": _build_format,
@@ -663,6 +683,7 @@ class TSQL(Dialect):
663
683
  order=self._parse_order(),
664
684
  null_handling=self._parse_on_handling("NULL", "NULL", "ABSENT"),
665
685
  ),
686
+ "DATEPART": lambda self: self._parse_datepart(),
666
687
  }
667
688
 
668
689
  # The DCOLON (::) operator serves as a scope resolution (exp.ScopeResolution) operator in T-SQL
@@ -681,6 +702,13 @@ class TSQL(Dialect):
681
702
  "ts": exp.Timestamp,
682
703
  }
683
704
 
705
+ def _parse_datepart(self) -> exp.Extract:
706
+ this = self._parse_var()
707
+ expression = self._match(TokenType.COMMA) and self._parse_bitwise()
708
+ name = map_date_part(this, self.dialect)
709
+
710
+ return self.expression(exp.Extract, this=name, expression=expression)
711
+
684
712
  def _parse_alter_table_set(self) -> exp.AlterSet:
685
713
  return self._parse_wrapped(super()._parse_alter_table_set)
686
714
 
@@ -818,7 +846,6 @@ class TSQL(Dialect):
818
846
  args = [this, *self._parse_csv(self._parse_assignment)]
819
847
  convert = exp.Convert.from_arg_list(args)
820
848
  convert.set("safe", safe)
821
- convert.set("strict", strict)
822
849
  return convert
823
850
 
824
851
  def _parse_column_def(
@@ -875,7 +902,7 @@ class TSQL(Dialect):
875
902
  this = super()._parse_id_var(any_token=any_token, tokens=tokens)
876
903
  if this:
877
904
  if is_global:
878
- this.set("global", True)
905
+ this.set("global_", True)
879
906
  elif is_temporary:
880
907
  this.set("temporary", True)
881
908
 
@@ -1030,15 +1057,14 @@ class TSQL(Dialect):
1030
1057
  exp.AnyValue: any_value_to_max_sql,
1031
1058
  exp.ArrayToString: rename_func("STRING_AGG"),
1032
1059
  exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY",
1060
+ exp.Ceil: rename_func("CEILING"),
1033
1061
  exp.Chr: rename_func("CHAR"),
1034
1062
  exp.DateAdd: date_delta_sql("DATEADD"),
1035
- exp.DateDiff: date_delta_sql("DATEDIFF"),
1036
1063
  exp.CTE: transforms.preprocess([qualify_derived_table_outputs]),
1037
1064
  exp.CurrentDate: rename_func("GETDATE"),
1038
1065
  exp.CurrentTimestamp: rename_func("GETDATE"),
1039
1066
  exp.CurrentTimestampLTZ: rename_func("SYSDATETIMEOFFSET"),
1040
1067
  exp.DateStrToDate: datestrtodate_sql,
1041
- exp.Extract: rename_func("DATEPART"),
1042
1068
  exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
1043
1069
  exp.GroupConcat: _string_agg_sql,
1044
1070
  exp.If: rename_func("IIF"),
@@ -1066,6 +1092,9 @@ class TSQL(Dialect):
1066
1092
  ),
1067
1093
  exp.Subquery: transforms.preprocess([qualify_derived_table_outputs]),
1068
1094
  exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
1095
+ exp.SHA1Digest: lambda self, e: self.func(
1096
+ "HASHBYTES", exp.Literal.string("SHA1"), e.this
1097
+ ),
1069
1098
  exp.SHA2: lambda self, e: self.func(
1070
1099
  "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
1071
1100
  ),
@@ -1160,6 +1189,12 @@ class TSQL(Dialect):
1160
1189
  "PARSENAME", this, exp.Literal.number(split_count + 1 - part_index.to_py())
1161
1190
  )
1162
1191
 
1192
+ def extract_sql(self, expression: exp.Extract) -> str:
1193
+ part = expression.this
1194
+ name = DATE_PART_UNMAPPING.get(part.name.upper()) or part
1195
+
1196
+ return self.func("DATEPART", name, expression.expression)
1197
+
1163
1198
  def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
1164
1199
  nano = expression.args.get("nano")
1165
1200
  if nano is not None:
@@ -1235,12 +1270,12 @@ class TSQL(Dialect):
1235
1270
 
1236
1271
  if kind == "VIEW":
1237
1272
  expression.this.set("catalog", None)
1238
- with_ = expression.args.get("with")
1273
+ with_ = expression.args.get("with_")
1239
1274
  if ctas_expression and with_:
1240
1275
  # We've already preprocessed the Create expression to bubble up any nested CTEs,
1241
1276
  # but CREATE VIEW actually requires the WITH clause to come after it so we need
1242
1277
  # to amend the AST by moving the CTEs to the CREATE VIEW statement's query.
1243
- ctas_expression.set("with", with_.pop())
1278
+ ctas_expression.set("with_", with_.pop())
1244
1279
 
1245
1280
  table = expression.find(exp.Table)
1246
1281
 
@@ -1298,6 +1333,10 @@ class TSQL(Dialect):
1298
1333
  func_name = "COUNT_BIG" if expression.args.get("big_int") else "COUNT"
1299
1334
  return rename_func(func_name)(self, expression)
1300
1335
 
1336
+ def datediff_sql(self, expression: exp.DateDiff) -> str:
1337
+ func_name = "DATEDIFF_BIG" if expression.args.get("big_int") else "DATEDIFF"
1338
+ return date_delta_sql(func_name)(self, expression)
1339
+
1301
1340
  def offset_sql(self, expression: exp.Offset) -> str:
1302
1341
  return f"{super().offset_sql(expression)} ROWS"
1303
1342
 
@@ -1352,7 +1391,7 @@ class TSQL(Dialect):
1352
1391
  def identifier_sql(self, expression: exp.Identifier) -> str:
1353
1392
  identifier = super().identifier_sql(expression)
1354
1393
 
1355
- if expression.args.get("global"):
1394
+ if expression.args.get("global_"):
1356
1395
  identifier = f"##{identifier}"
1357
1396
  elif expression.args.get("temporary"):
1358
1397
  identifier = f"#{identifier}"
sqlglot/diff.py CHANGED
@@ -393,8 +393,10 @@ def _get_expression_leaves(expression: exp.Expression) -> t.Iterator[exp.Express
393
393
 
394
394
  def _get_non_expression_leaves(expression: exp.Expression) -> t.Iterator[t.Tuple[str, t.Any]]:
395
395
  for arg, value in expression.args.items():
396
- if isinstance(value, exp.Expression) or (
397
- isinstance(value, list) and isinstance(seq_get(value, 0), exp.Expression)
396
+ if (
397
+ value is None
398
+ or isinstance(value, exp.Expression)
399
+ or (isinstance(value, list) and isinstance(seq_get(value, 0), exp.Expression))
398
400
  ):
399
401
  continue
400
402
 
sqlglot/errors.py CHANGED
@@ -6,6 +6,12 @@ from enum import auto
6
6
  from sqlglot.helper import AutoName
7
7
 
8
8
 
9
+ # ANSI escape codes for error formatting
10
+ ANSI_UNDERLINE = "\033[4m"
11
+ ANSI_RESET = "\033[0m"
12
+ ERROR_MESSAGE_CONTEXT_DEFAULT = 100
13
+
14
+
9
15
  class ErrorLevel(AutoName):
10
16
  IGNORE = auto()
11
17
  """Ignore all errors."""
@@ -81,6 +87,69 @@ class ExecuteError(SqlglotError):
81
87
  pass
82
88
 
83
89
 
90
+ def highlight_sql(
91
+ sql: str,
92
+ positions: t.List[t.Tuple[int, int]],
93
+ context_length: int = ERROR_MESSAGE_CONTEXT_DEFAULT,
94
+ ) -> t.Tuple[str, str, str, str]:
95
+ """
96
+ Highlight a SQL string using ANSI codes at the given positions.
97
+
98
+ Args:
99
+ sql: The complete SQL string.
100
+ positions: List of (start, end) tuples where both start and end are inclusive 0-based
101
+ indexes. For example, to highlight "foo" in "SELECT foo", use (7, 9).
102
+ The positions will be sorted and de-duplicated if they overlap.
103
+ context_length: Number of characters to show before the first highlight and after
104
+ the last highlight.
105
+
106
+ Returns:
107
+ A tuple of (formatted_sql, start_context, highlight, end_context) where:
108
+ - formatted_sql: The SQL with ANSI underline codes applied to highlighted sections
109
+ - start_context: Plain text before the first highlight
110
+ - highlight: Plain text from the first highlight start to the last highlight end,
111
+ including any non-highlighted text in between (no ANSI)
112
+ - end_context: Plain text after the last highlight
113
+
114
+ Note:
115
+ If positions is empty, raises a ValueError.
116
+ """
117
+ if not positions:
118
+ raise ValueError("positions must contain at least one (start, end) tuple")
119
+
120
+ start_context = ""
121
+ end_context = ""
122
+ first_highlight_start = 0
123
+ formatted_parts = []
124
+ previous_part_end = 0
125
+ sorted_positions = sorted(positions, key=lambda pos: pos[0])
126
+
127
+ if sorted_positions[0][0] > 0:
128
+ first_highlight_start = sorted_positions[0][0]
129
+ start_context = sql[max(0, first_highlight_start - context_length) : first_highlight_start]
130
+ formatted_parts.append(start_context)
131
+ previous_part_end = first_highlight_start
132
+
133
+ for start, end in sorted_positions:
134
+ highlight_start = max(start, previous_part_end)
135
+ highlight_end = end + 1
136
+ if highlight_start >= highlight_end:
137
+ continue # Skip invalid or overlapping highlights
138
+ if highlight_start > previous_part_end:
139
+ formatted_parts.append(sql[previous_part_end:highlight_start])
140
+ formatted_parts.append(f"{ANSI_UNDERLINE}{sql[highlight_start:highlight_end]}{ANSI_RESET}")
141
+ previous_part_end = highlight_end
142
+
143
+ if previous_part_end < len(sql):
144
+ end_context = sql[previous_part_end : previous_part_end + context_length]
145
+ formatted_parts.append(end_context)
146
+
147
+ formatted_sql = "".join(formatted_parts)
148
+ highlight = sql[first_highlight_start:previous_part_end]
149
+
150
+ return formatted_sql, start_context, highlight, end_context
151
+
152
+
84
153
  def concat_messages(errors: t.Sequence[t.Any], maximum: int) -> str:
85
154
  msg = [str(e) for e in errors[:maximum]]
86
155
  remaining = len(errors) - maximum
@@ -31,7 +31,6 @@ if t.TYPE_CHECKING:
31
31
  def execute(
32
32
  sql: str | Expression,
33
33
  schema: t.Optional[t.Dict | Schema] = None,
34
- read: DialectType = None,
35
34
  dialect: DialectType = None,
36
35
  tables: t.Optional[t.Dict] = None,
37
36
  ) -> Table:
@@ -45,15 +44,13 @@ def execute(
45
44
  1. {table: {col: type}}
46
45
  2. {db: {table: {col: type}}}
47
46
  3. {catalog: {db: {table: {col: type}}}}
48
- read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
49
- dialect: the SQL dialect (alias for read).
47
+ dialect: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
50
48
  tables: additional tables to register.
51
49
 
52
50
  Returns:
53
51
  Simple columnar data structure.
54
52
  """
55
- read = read or dialect
56
- tables_ = ensure_tables(tables, dialect=read)
53
+ tables_ = ensure_tables(tables, dialect=dialect)
57
54
 
58
55
  if not schema:
59
56
  schema = {}
@@ -66,19 +63,17 @@ def execute(
66
63
  for column in table.columns:
67
64
  value = table[0][column]
68
65
  column_type = (
69
- annotate_types(exp.convert(value), dialect=read).type or type(value).__name__
66
+ annotate_types(exp.convert(value), dialect=dialect).type or type(value).__name__
70
67
  )
71
68
  nested_set(schema, [*keys, column], column_type)
72
69
 
73
- schema = ensure_schema(schema, dialect=read)
70
+ schema = ensure_schema(schema, dialect=dialect)
74
71
 
75
72
  if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args:
76
73
  raise ExecuteError("Tables must support the same table args as schema")
77
74
 
78
75
  now = time.time()
79
- expression = optimize(
80
- sql, schema, leave_tables_isolated=True, infer_csv_schemas=True, dialect=read
81
- )
76
+ expression = optimize(sql, schema, leave_tables_isolated=True, dialect=dialect)
82
77
 
83
78
  logger.debug("Optimization finished: %f", time.time() - now)
84
79
  logger.debug("Optimized SQL: %s", expression.sql(pretty=True))