sqlglot 27.27.0__py3-none-any.whl → 28.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sqlglot/__init__.py +1 -0
  2. sqlglot/__main__.py +6 -4
  3. sqlglot/_version.py +2 -2
  4. sqlglot/dialects/bigquery.py +118 -279
  5. sqlglot/dialects/clickhouse.py +73 -5
  6. sqlglot/dialects/databricks.py +38 -1
  7. sqlglot/dialects/dialect.py +354 -275
  8. sqlglot/dialects/dremio.py +4 -1
  9. sqlglot/dialects/duckdb.py +754 -25
  10. sqlglot/dialects/exasol.py +243 -10
  11. sqlglot/dialects/hive.py +8 -8
  12. sqlglot/dialects/mysql.py +14 -4
  13. sqlglot/dialects/oracle.py +29 -0
  14. sqlglot/dialects/postgres.py +60 -26
  15. sqlglot/dialects/presto.py +47 -16
  16. sqlglot/dialects/redshift.py +16 -0
  17. sqlglot/dialects/risingwave.py +3 -0
  18. sqlglot/dialects/singlestore.py +12 -3
  19. sqlglot/dialects/snowflake.py +239 -218
  20. sqlglot/dialects/spark.py +15 -4
  21. sqlglot/dialects/spark2.py +11 -48
  22. sqlglot/dialects/sqlite.py +10 -0
  23. sqlglot/dialects/starrocks.py +3 -0
  24. sqlglot/dialects/teradata.py +5 -8
  25. sqlglot/dialects/trino.py +6 -0
  26. sqlglot/dialects/tsql.py +61 -22
  27. sqlglot/diff.py +4 -2
  28. sqlglot/errors.py +69 -0
  29. sqlglot/executor/__init__.py +5 -10
  30. sqlglot/executor/python.py +1 -29
  31. sqlglot/expressions.py +637 -100
  32. sqlglot/generator.py +160 -43
  33. sqlglot/helper.py +2 -44
  34. sqlglot/lineage.py +10 -4
  35. sqlglot/optimizer/annotate_types.py +247 -140
  36. sqlglot/optimizer/canonicalize.py +6 -1
  37. sqlglot/optimizer/eliminate_joins.py +1 -1
  38. sqlglot/optimizer/eliminate_subqueries.py +2 -2
  39. sqlglot/optimizer/merge_subqueries.py +5 -5
  40. sqlglot/optimizer/normalize.py +20 -13
  41. sqlglot/optimizer/normalize_identifiers.py +17 -3
  42. sqlglot/optimizer/optimizer.py +4 -0
  43. sqlglot/optimizer/pushdown_predicates.py +1 -1
  44. sqlglot/optimizer/qualify.py +18 -10
  45. sqlglot/optimizer/qualify_columns.py +122 -275
  46. sqlglot/optimizer/qualify_tables.py +128 -76
  47. sqlglot/optimizer/resolver.py +374 -0
  48. sqlglot/optimizer/scope.py +27 -16
  49. sqlglot/optimizer/simplify.py +1075 -959
  50. sqlglot/optimizer/unnest_subqueries.py +12 -2
  51. sqlglot/parser.py +296 -170
  52. sqlglot/planner.py +2 -2
  53. sqlglot/schema.py +15 -4
  54. sqlglot/tokens.py +42 -7
  55. sqlglot/transforms.py +77 -22
  56. sqlglot/typing/__init__.py +316 -0
  57. sqlglot/typing/bigquery.py +376 -0
  58. sqlglot/typing/hive.py +12 -0
  59. sqlglot/typing/presto.py +24 -0
  60. sqlglot/typing/snowflake.py +505 -0
  61. sqlglot/typing/spark2.py +58 -0
  62. sqlglot/typing/tsql.py +9 -0
  63. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
  64. sqlglot-28.4.0.dist-info/RECORD +92 -0
  65. sqlglot-27.27.0.dist-info/RECORD +0 -84
  66. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
  67. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
  68. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import typing as t
4
4
 
5
- from sqlglot import exp, generator, parser, tokens
5
+ from sqlglot import exp, generator, parser, tokens, transforms
6
6
  from sqlglot.dialects.dialect import (
7
7
  Dialect,
8
8
  NormalizationStrategy,
@@ -12,13 +12,15 @@ from sqlglot.dialects.dialect import (
12
12
  rename_func,
13
13
  strposition_sql,
14
14
  timestrtotime_sql,
15
- unit_to_str,
16
15
  timestamptrunc_sql,
17
16
  build_date_delta,
17
+ no_last_day_sql,
18
+ DATE_ADD_OR_SUB,
18
19
  )
19
20
  from sqlglot.generator import unsupported_args
20
21
  from sqlglot.helper import seq_get
21
22
  from sqlglot.tokens import TokenType
23
+ from sqlglot.optimizer.scope import build_scope
22
24
 
23
25
  if t.TYPE_CHECKING:
24
26
  from sqlglot.dialects.dialect import DialectType
@@ -71,6 +73,200 @@ def _build_nullifzero(args: t.List) -> exp.If:
71
73
  return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
72
74
 
73
75
 
76
+ # https://docs.exasol.com/db/latest/sql/select.htm#:~:text=If%20you%20have,local.x%3E10
77
+ def _add_local_prefix_for_aliases(expression: exp.Expression) -> exp.Expression:
78
+ if isinstance(expression, exp.Select):
79
+ aliases: dict[str, bool] = {
80
+ alias.name: bool(alias.args.get("quoted"))
81
+ for sel in expression.selects
82
+ if isinstance(sel, exp.Alias) and (alias := sel.args.get("alias"))
83
+ }
84
+
85
+ table = expression.find(exp.Table)
86
+ table_ident = table.this if table else None
87
+
88
+ if (
89
+ table_ident
90
+ and table_ident.name.upper() == "LOCAL"
91
+ and not bool(table_ident.args.get("quoted"))
92
+ ):
93
+ table_ident.replace(exp.to_identifier(table_ident.name.upper(), quoted=True))
94
+
95
+ def prefix_local(node, visible_aliases: dict[str, bool]) -> exp.Expression:
96
+ if isinstance(node, exp.Column) and not node.table:
97
+ if node.name in visible_aliases:
98
+ return exp.Column(
99
+ this=exp.to_identifier(node.name, quoted=visible_aliases[node.name]),
100
+ table=exp.to_identifier("LOCAL", quoted=False),
101
+ )
102
+ return node
103
+
104
+ for key in ("where", "group", "having"):
105
+ if arg := expression.args.get(key):
106
+ expression.set(key, arg.transform(lambda node: prefix_local(node, aliases)))
107
+
108
+ seen_aliases: dict[str, bool] = {}
109
+ new_selects: list[exp.Expression] = []
110
+ for sel in expression.selects:
111
+ if isinstance(sel, exp.Alias):
112
+ inner = sel.this.transform(lambda node: prefix_local(node, seen_aliases))
113
+ sel.set("this", inner)
114
+
115
+ alias_node = sel.args.get("alias")
116
+
117
+ seen_aliases[sel.alias] = bool(alias_node and getattr(alias_node, "quoted", False))
118
+ new_selects.append(sel)
119
+ else:
120
+ new_selects.append(sel.transform(lambda node: prefix_local(node, seen_aliases)))
121
+ expression.set("expressions", new_selects)
122
+
123
+ return expression
124
+
125
+
126
+ def _trunc_sql(self: Exasol.Generator, kind: str, expression: exp.DateTrunc) -> str:
127
+ unit = expression.text("unit")
128
+ node = expression.this.this if isinstance(expression.this, exp.Cast) else expression.this
129
+ expr_sql = self.sql(node)
130
+ if isinstance(node, exp.Literal) and node.is_string:
131
+ expr_sql = (
132
+ f"{kind} '{node.this.replace('T', ' ')}'"
133
+ if kind == "TIMESTAMP"
134
+ else f"DATE '{node.this}'"
135
+ )
136
+ return f"DATE_TRUNC('{unit}', {expr_sql})"
137
+
138
+
139
+ def _date_trunc_sql(self: Exasol.Generator, expression: exp.DateTrunc) -> str:
140
+ return _trunc_sql(self, "DATE", expression)
141
+
142
+
143
+ def _timestamp_trunc_sql(self: Exasol.Generator, expression: exp.DateTrunc) -> str:
144
+ return _trunc_sql(self, "TIMESTAMP", expression)
145
+
146
+
147
+ def is_case_insensitive(node: exp.Expression) -> bool:
148
+ return isinstance(node, exp.Collate) and node.text("expression").upper() == "UTF8_LCASE"
149
+
150
+
151
+ def _substring_index_sql(self: Exasol.Generator, expression: exp.SubstringIndex) -> str:
152
+ this = expression.this
153
+ delimiter = expression.args["delimiter"]
154
+ count_node = expression.args["count"]
155
+ count_sql = self.sql(expression, "count")
156
+ num = count_node.to_py() if count_node.is_number else 0
157
+
158
+ haystack_sql = self.sql(this)
159
+ if num == 0:
160
+ return self.func("SUBSTR", haystack_sql, "1", "0")
161
+
162
+ from_right = num < 0
163
+ direction = "-1" if from_right else "1"
164
+ occur = self.func("ABS", count_sql) if from_right else count_sql
165
+
166
+ delimiter_sql = self.sql(delimiter)
167
+
168
+ position = self.func(
169
+ "INSTR",
170
+ self.func("LOWER", haystack_sql) if is_case_insensitive(this) else haystack_sql,
171
+ self.func("LOWER", delimiter_sql) if is_case_insensitive(delimiter) else delimiter_sql,
172
+ direction,
173
+ occur,
174
+ )
175
+ nullable_pos = self.func("NULLIF", position, "0")
176
+
177
+ if from_right:
178
+ start = self.func(
179
+ "NVL", f"{nullable_pos} + {self.func('LENGTH', delimiter_sql)}", direction
180
+ )
181
+ return self.func("SUBSTR", haystack_sql, start)
182
+
183
+ length = self.func("NVL", f"{nullable_pos} - 1", self.func("LENGTH", haystack_sql))
184
+ return self.func("SUBSTR", haystack_sql, direction, length)
185
+
186
+
187
+ # https://docs.exasol.com/db/latest/sql/select.htm#:~:text=The%20select_list%20defines%20the%20columns%20of%20the%20result%20table.%20If%20*%20is%20used%2C%20all%20columns%20are%20listed.%20You%20can%20use%20an%20expression%20like%20t.*%20to%20list%20all%20columns%20of%20the%20table%20t%2C%20the%20view%20t%2C%20or%20the%20object%20with%20the%20table%20alias%20t.
188
+ def _qualify_unscoped_star(expression: exp.Expression) -> exp.Expression:
189
+ """
190
+ Exasol doesn't support a bare * alongside other select items, so we rewrite it
191
+ Rewrite: SELECT *, <other> FROM <Table>
192
+ Into: SELECT T.*, <other> FROM <Table> AS T
193
+ """
194
+
195
+ if not isinstance(expression, exp.Select):
196
+ return expression
197
+
198
+ select_expressions = expression.expressions or []
199
+
200
+ def is_bare_star(expr: exp.Expression) -> bool:
201
+ return isinstance(expr, exp.Star) and expr.this is None
202
+
203
+ has_other_expression = False
204
+ bare_star_expr: exp.Expression | None = None
205
+ for expr in select_expressions:
206
+ has_bare_star = is_bare_star(expr)
207
+ if has_bare_star and bare_star_expr is None:
208
+ bare_star_expr = expr
209
+ elif not has_bare_star:
210
+ has_other_expression = True
211
+ if bare_star_expr and has_other_expression:
212
+ break
213
+
214
+ if not (bare_star_expr and has_other_expression):
215
+ return expression
216
+
217
+ scope = build_scope(expression)
218
+
219
+ if not scope or not scope.selected_sources:
220
+ return expression
221
+
222
+ table_identifiers: list[exp.Identifier] = []
223
+
224
+ for source_name, (source_expr, _) in scope.selected_sources.items():
225
+ ident = (
226
+ source_expr.this.copy()
227
+ if isinstance(source_expr, exp.Table) and isinstance(source_expr.this, exp.Identifier)
228
+ else exp.to_identifier(source_name)
229
+ )
230
+ table_identifiers.append(ident)
231
+
232
+ qualified_star_columns = [
233
+ exp.Column(this=bare_star_expr.copy(), table=ident) for ident in table_identifiers
234
+ ]
235
+
236
+ new_select_expressions: list[exp.Expression] = []
237
+
238
+ for select_expr in select_expressions:
239
+ new_select_expressions.extend(qualified_star_columns) if is_bare_star(
240
+ select_expr
241
+ ) else new_select_expressions.append(select_expr)
242
+
243
+ expression.set("expressions", new_select_expressions)
244
+ return expression
245
+
246
+
247
+ def _add_date_sql(self: Exasol.Generator, expression: DATE_ADD_OR_SUB) -> str:
248
+ interval = expression.expression if isinstance(expression.expression, exp.Interval) else None
249
+
250
+ unit = (
251
+ (interval.text("unit") or "DAY").upper()
252
+ if interval is not None
253
+ else (expression.text("unit") or "DAY").upper()
254
+ )
255
+
256
+ if unit not in DATE_UNITS:
257
+ self.unsupported(f"'{unit}' is not supported in Exasol.")
258
+ return self.function_fallback_sql(expression)
259
+
260
+ offset_expr: exp.Expression = expression.expression
261
+ if interval is not None:
262
+ offset_expr = interval.this
263
+
264
+ if isinstance(expression, exp.DateSub):
265
+ offset_expr = exp.Neg(this=offset_expr)
266
+
267
+ return self.func(f"ADD_{unit}S", expression.this, offset_expr)
268
+
269
+
74
270
  DATE_UNITS = {"DAY", "WEEK", "MONTH", "YEAR", "HOUR", "MINUTE", "SECOND"}
75
271
 
76
272
 
@@ -115,6 +311,7 @@ class Exasol(Dialect):
115
311
  }
116
312
 
117
313
  class Tokenizer(tokens.Tokenizer):
314
+ IDENTIFIERS = ['"', ("[", "]")]
118
315
  KEYWORDS = {
119
316
  **tokens.Tokenizer.KEYWORDS,
120
317
  "USER": TokenType.CURRENT_USER,
@@ -197,6 +394,24 @@ class Exasol(Dialect):
197
394
  **dict.fromkeys(("GROUP_CONCAT", "LISTAGG"), lambda self: self._parse_group_concat()),
198
395
  }
199
396
 
397
+ def _parse_column(self) -> t.Optional[exp.Expression]:
398
+ column = super()._parse_column()
399
+ if not isinstance(column, exp.Column):
400
+ return column
401
+ table_ident = column.args.get("table")
402
+ if (
403
+ isinstance(table_ident, exp.Identifier)
404
+ and table_ident.name.upper() == "LOCAL"
405
+ and not bool(table_ident.args.get("quoted"))
406
+ ):
407
+ column.set("table", None)
408
+ return column
409
+
410
+ ODBC_DATETIME_LITERALS = {
411
+ "d": exp.Date,
412
+ "ts": exp.Timestamp,
413
+ }
414
+
200
415
  class Generator(generator.Generator):
201
416
  # https://docs.exasol.com/db/latest/sql_references/data_types/datatypedetails.htm#StringDataType
202
417
  STRING_TYPE_MAPPING = {
@@ -250,10 +465,14 @@ class Exasol(Dialect):
250
465
  # https://docs.exasol.com/db/latest/sql_references/functions/alphabeticallistfunctions/bit_xor.htm
251
466
  exp.BitwiseXor: rename_func("BIT_XOR"),
252
467
  exp.DateDiff: _date_diff_sql,
468
+ exp.DateAdd: _add_date_sql,
469
+ exp.TsOrDsAdd: _add_date_sql,
470
+ exp.DateSub: _add_date_sql,
253
471
  # https://docs.exasol.com/db/latest/sql_references/functions/alphabeticallistfunctions/div.htm#DIV
254
472
  exp.IntDiv: rename_func("DIV"),
255
473
  exp.TsOrDsDiff: _date_diff_sql,
256
- exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, unit_to_str(e)),
474
+ exp.DateTrunc: _date_trunc_sql,
475
+ exp.DayOfWeek: lambda self, e: f"CAST(TO_CHAR({self.sql(e, 'this')}, 'D') AS INTEGER)",
257
476
  exp.DatetimeTrunc: timestamptrunc_sql(),
258
477
  exp.GroupConcat: lambda self, e: groupconcat_sql(
259
478
  self, e, func_name="LISTAGG", within_group=True
@@ -282,7 +501,7 @@ class Exasol(Dialect):
282
501
  exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)),
283
502
  exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
284
503
  exp.TimeStrToTime: timestrtotime_sql,
285
- exp.TimestampTrunc: timestamptrunc_sql(),
504
+ exp.TimestampTrunc: _timestamp_trunc_sql,
286
505
  exp.StrToTime: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)),
287
506
  exp.CurrentUser: lambda *_: "CURRENT_USER",
288
507
  exp.AtTimeZone: lambda self, e: self.func(
@@ -307,7 +526,20 @@ class Exasol(Dialect):
307
526
  exp.MD5Digest: rename_func("HASHTYPE_MD5"),
308
527
  # https://docs.exasol.com/db/latest/sql/create_view.htm
309
528
  exp.CommentColumnConstraint: lambda self, e: f"COMMENT IS {self.sql(e, 'this')}",
529
+ exp.Select: transforms.preprocess(
530
+ [
531
+ _qualify_unscoped_star,
532
+ _add_local_prefix_for_aliases,
533
+ ]
534
+ ),
535
+ exp.SubstringIndex: _substring_index_sql,
310
536
  exp.WeekOfYear: rename_func("WEEK"),
537
+ # https://docs.exasol.com/db/latest/sql_references/functions/alphabeticallistfunctions/to_date.htm
538
+ exp.Date: rename_func("TO_DATE"),
539
+ # https://docs.exasol.com/db/latest/sql_references/functions/alphabeticallistfunctions/to_timestamp.htm
540
+ exp.Timestamp: rename_func("TO_TIMESTAMP"),
541
+ exp.Quarter: lambda self, e: f"CEIL(MONTH(TO_DATE({self.sql(e, 'this')}))/3)",
542
+ exp.LastDay: no_last_day_sql,
311
543
  }
312
544
 
313
545
  def converttimezone_sql(self, expression: exp.ConvertTimezone) -> str:
@@ -324,10 +556,11 @@ class Exasol(Dialect):
324
556
  false = self.sql(expression, "false")
325
557
  return f"IF {this} THEN {true} ELSE {false} ENDIF"
326
558
 
327
- def dateadd_sql(self, expression: exp.DateAdd) -> str:
328
- unit = expression.text("unit").upper() or "DAY"
329
- if unit not in DATE_UNITS:
330
- self.unsupported(f"'{unit}' is not supported in Exasol.")
331
- return self.function_fallback_sql(expression)
559
+ def collate_sql(self, expression: exp.Collate) -> str:
560
+ return self.sql(expression.this)
332
561
 
333
- return self.func(f"ADD_{unit}S", expression.this, expression.expression)
562
+ # https://docs.exasol.com/db/latest/sql_references/functions/alphabeticallistfunctions/rank.htm
563
+ def rank_sql(self, expression: exp.Rank) -> str:
564
+ if expression.args.get("expressions"):
565
+ self.unsupported("Exasol does not support arguments in RANK")
566
+ return self.func("RANK")
sqlglot/dialects/hive.py CHANGED
@@ -46,6 +46,7 @@ from sqlglot.helper import seq_get
46
46
  from sqlglot.tokens import TokenType
47
47
  from sqlglot.generator import unsupported_args
48
48
  from sqlglot.optimizer.annotate_types import TypeAnnotator
49
+ from sqlglot.typing.hive import EXPRESSION_METADATA
49
50
 
50
51
  # (FuncType, Multiplier)
51
52
  DATE_DELTA_INTERVAL = {
@@ -216,13 +217,11 @@ class Hive(Dialect):
216
217
  # https://spark.apache.org/docs/latest/sql-ref-identifier.html#description
217
218
  NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
218
219
 
219
- ANNOTATORS = {
220
- **Dialect.ANNOTATORS,
221
- exp.If: lambda self, e: self._annotate_by_args(e, "true", "false", promote=True),
222
- exp.Coalesce: lambda self, e: self._annotate_by_args(
223
- e, "this", "expressions", promote=True
224
- ),
225
- }
220
+ EXPRESSION_METADATA = EXPRESSION_METADATA.copy()
221
+
222
+ # https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=27362046#LanguageManualUDF-StringFunctions
223
+ # https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java#L266-L269
224
+ INITCAP_DEFAULT_DELIMITER_CHARS = " \t\n\r\f\u000b\u001c\u001d\u001e\u001f"
226
225
 
227
226
  # Support only the non-ANSI mode (default for Hive, Spark2, Spark)
228
227
  COERCES_TO = defaultdict(set, deepcopy(TypeAnnotator.COERCES_TO))
@@ -576,6 +575,7 @@ class Hive(Dialect):
576
575
  exp.ApproxDistinct: approx_count_distinct_sql,
577
576
  exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
578
577
  exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
578
+ exp.Array: transforms.preprocess([transforms.inherit_struct_field_names]),
579
579
  exp.ArrayConcat: rename_func("CONCAT"),
580
580
  exp.ArrayToString: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
581
581
  exp.ArraySort: _array_sort_sql,
@@ -838,7 +838,7 @@ class Hive(Dialect):
838
838
  return f"SET{serde}{exprs}{location}{file_format}{tags}"
839
839
 
840
840
  def serdeproperties_sql(self, expression: exp.SerdeProperties) -> str:
841
- prefix = "WITH " if expression.args.get("with") else ""
841
+ prefix = "WITH " if expression.args.get("with_") else ""
842
842
  exprs = self.expressions(expression, flat=True)
843
843
 
844
844
  return f"{prefix}SERDEPROPERTIES ({exprs})"
sqlglot/dialects/mysql.py CHANGED
@@ -163,6 +163,7 @@ class MySQL(Dialect):
163
163
  SUPPORTS_USER_DEFINED_TYPES = False
164
164
  SUPPORTS_SEMI_ANTI_JOIN = False
165
165
  SAFE_DIVISION = True
166
+ SAFE_TO_ELIMINATE_DOUBLE_NEGATION = False
166
167
 
167
168
  # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
168
169
  TIME_MAPPING = {
@@ -201,6 +202,8 @@ class MySQL(Dialect):
201
202
  STRING_ESCAPES = ["'", '"', "\\"]
202
203
  BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
203
204
  HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
205
+ # https://dev.mysql.com/doc/refman/8.4/en/string-literals.html
206
+ ESCAPE_FOLLOW_CHARS = ["0", "b", "n", "r", "t", "Z", "%", "_"]
204
207
 
205
208
  NESTED_COMMENTS = False
206
209
 
@@ -325,7 +328,7 @@ class MySQL(Dialect):
325
328
  "BIT_AND": exp.BitwiseAndAgg.from_arg_list,
326
329
  "BIT_OR": exp.BitwiseOrAgg.from_arg_list,
327
330
  "BIT_XOR": exp.BitwiseXorAgg.from_arg_list,
328
- "BIT_COUNT": exp.BitwiseCountAgg.from_arg_list,
331
+ "BIT_COUNT": exp.BitwiseCount.from_arg_list,
329
332
  "CONVERT_TZ": lambda args: exp.ConvertTimezone(
330
333
  source_tz=seq_get(args, 1), target_tz=seq_get(args, 2), timestamp=seq_get(args, 0)
331
334
  ),
@@ -463,6 +466,7 @@ class MySQL(Dialect):
463
466
  "INDEX": lambda self: self._parse_index_constraint(),
464
467
  "KEY": lambda self: self._parse_index_constraint(),
465
468
  "SPATIAL": lambda self: self._parse_index_constraint(kind="SPATIAL"),
469
+ "ZEROFILL": lambda self: self.expression(exp.ZeroFillColumnConstraint),
466
470
  }
467
471
 
468
472
  ALTER_PARSERS = {
@@ -670,7 +674,7 @@ class MySQL(Dialect):
670
674
  for_role=for_role,
671
675
  into_outfile=into_outfile,
672
676
  json=json,
673
- **{"global": global_}, # type: ignore
677
+ global_=global_,
674
678
  )
675
679
 
676
680
  def _parse_oldstyle_limit(
@@ -755,7 +759,7 @@ class MySQL(Dialect):
755
759
  exp.BitwiseAndAgg: rename_func("BIT_AND"),
756
760
  exp.BitwiseOrAgg: rename_func("BIT_OR"),
757
761
  exp.BitwiseXorAgg: rename_func("BIT_XOR"),
758
- exp.BitwiseCountAgg: rename_func("BIT_COUNT"),
762
+ exp.BitwiseCount: rename_func("BIT_COUNT"),
759
763
  exp.CurrentDate: no_paren_current_date_sql,
760
764
  exp.DateDiff: _remove_ts_or_ds_to_date(
761
765
  lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression")
@@ -797,6 +801,7 @@ class MySQL(Dialect):
797
801
  exp.StrToDate: _str_to_date_sql,
798
802
  exp.StrToTime: _str_to_date_sql,
799
803
  exp.Stuff: rename_func("INSERT"),
804
+ exp.SessionUser: lambda *_: "SESSION_USER()",
800
805
  exp.TableSample: no_tablesample_sql,
801
806
  exp.TimeFromParts: rename_func("MAKETIME"),
802
807
  exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"),
@@ -1228,7 +1233,7 @@ class MySQL(Dialect):
1228
1233
  def show_sql(self, expression: exp.Show) -> str:
1229
1234
  this = f" {expression.name}"
1230
1235
  full = " FULL" if expression.args.get("full") else ""
1231
- global_ = " GLOBAL" if expression.args.get("global") else ""
1236
+ global_ = " GLOBAL" if expression.args.get("global_") else ""
1232
1237
 
1233
1238
  target = self.sql(expression, "target")
1234
1239
  target = f" {target}" if target else ""
@@ -1329,6 +1334,11 @@ class MySQL(Dialect):
1329
1334
  def isascii_sql(self, expression: exp.IsAscii) -> str:
1330
1335
  return f"REGEXP_LIKE({self.sql(expression.this)}, '^[[:ascii:]]*$')"
1331
1336
 
1337
+ def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str:
1338
+ # https://dev.mysql.com/doc/refman/8.4/en/window-function-descriptions.html
1339
+ self.unsupported("MySQL does not support IGNORE NULLS.")
1340
+ return self.sql(expression.this)
1341
+
1332
1342
  @unsupported_args("this")
1333
1343
  def currentschema_sql(self, expression: exp.CurrentSchema) -> str:
1334
1344
  return self.func("SCHEMA")
@@ -45,6 +45,7 @@ class Oracle(Dialect):
45
45
  NULL_ORDERING = "nulls_are_large"
46
46
  ON_CONDITION_EMPTY_BEFORE_ERROR = False
47
47
  ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
48
+ DISABLES_ALIAS_REF_EXPANSION = True
48
49
 
49
50
  # See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
50
51
  NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@@ -72,6 +73,15 @@ class Oracle(Dialect):
72
73
  "FF6": "%f", # only 6 digits are supported in python formats
73
74
  }
74
75
 
76
+ PSEUDOCOLUMNS = {"ROWNUM", "ROWID", "OBJECT_ID", "OBJECT_VALUE", "LEVEL"}
77
+
78
+ def can_quote(self, identifier: exp.Identifier, identify: str | bool = "safe") -> bool:
79
+ # Disable quoting for pseudocolumns as it may break queries e.g
80
+ # `WHERE "ROWNUM" = ...` does not work but `WHERE ROWNUM = ...` does
81
+ return (
82
+ identifier.quoted or not isinstance(identifier.parent, exp.Pseudocolumn)
83
+ ) and super().can_quote(identifier, identify=identify)
84
+
75
85
  class Tokenizer(tokens.Tokenizer):
76
86
  VAR_SINGLE_TOKENS = {"@", "$", "#"}
77
87
 
@@ -119,6 +129,7 @@ class Oracle(Dialect):
119
129
  unabbreviate=False,
120
130
  ),
121
131
  }
132
+ FUNCTIONS.pop("TO_BOOLEAN")
122
133
 
123
134
  NO_PAREN_FUNCTION_PARSERS = {
124
135
  **parser.Parser.NO_PAREN_FUNCTION_PARSERS,
@@ -264,6 +275,24 @@ class Oracle(Dialect):
264
275
  def _parse_connect_with_prior(self):
265
276
  return self._parse_assignment()
266
277
 
278
+ def _parse_insert_table(self) -> t.Optional[exp.Expression]:
279
+ # Oracle does not use AS for INSERT INTO alias
280
+ # https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/INSERT.html
281
+ # Parse table parts without schema to avoid parsing the alias with its columns
282
+ this = self._parse_table_parts(schema=True)
283
+
284
+ if isinstance(this, exp.Table):
285
+ alias_name = self._parse_id_var(any_token=False)
286
+ if alias_name:
287
+ this.set("alias", exp.TableAlias(this=alias_name))
288
+
289
+ this.set("partition", self._parse_partition())
290
+
291
+ # Now parse the schema (column list) if present
292
+ return self._parse_schema(this=this)
293
+
294
+ return this
295
+
267
296
  class Generator(generator.Generator):
268
297
  LOCKING_READS_SUPPORTED = True
269
298
  JOIN_HINTS = False
@@ -36,7 +36,8 @@ from sqlglot.dialects.dialect import (
36
36
  strposition_sql,
37
37
  count_if_to_sum,
38
38
  groupconcat_sql,
39
- Version,
39
+ regexp_replace_global_modifier,
40
+ sha2_digest_sql,
40
41
  )
41
42
  from sqlglot.generator import unsupported_args
42
43
  from sqlglot.helper import is_int, seq_get
@@ -203,6 +204,7 @@ def _build_regexp_replace(args: t.List, dialect: DialectType = None) -> exp.Rege
203
204
  # Any one of `start`, `N` and `flags` can be column references, meaning that
204
205
  # unless we can statically see that the last argument is a non-integer string
205
206
  # (eg. not '0'), then it's not possible to construct the correct AST
207
+ regexp_replace = None
206
208
  if len(args) > 3:
207
209
  last = args[-1]
208
210
  if not is_int(last.name):
@@ -214,9 +216,10 @@ def _build_regexp_replace(args: t.List, dialect: DialectType = None) -> exp.Rege
214
216
  if last.is_type(*exp.DataType.TEXT_TYPES):
215
217
  regexp_replace = exp.RegexpReplace.from_arg_list(args[:-1])
216
218
  regexp_replace.set("modifiers", last)
217
- return regexp_replace
218
219
 
219
- return exp.RegexpReplace.from_arg_list(args)
220
+ regexp_replace = regexp_replace or exp.RegexpReplace.from_arg_list(args)
221
+ regexp_replace.set("single_replace", True)
222
+ return regexp_replace
220
223
 
221
224
 
222
225
  def _unix_to_time_sql(self: Postgres.Generator, expression: exp.UnixToTime) -> str:
@@ -259,12 +262,35 @@ def _levenshtein_sql(self: Postgres.Generator, expression: exp.Levenshtein) -> s
259
262
  def _versioned_anyvalue_sql(self: Postgres.Generator, expression: exp.AnyValue) -> str:
260
263
  # https://www.postgresql.org/docs/16/functions-aggregate.html
261
264
  # https://www.postgresql.org/about/featurematrix/
262
- if self.dialect.version < Version("16.0"):
265
+ if self.dialect.version < (16,):
263
266
  return any_value_to_max_sql(self, expression)
264
267
 
265
268
  return rename_func("ANY_VALUE")(self, expression)
266
269
 
267
270
 
271
+ def _round_sql(self: Postgres.Generator, expression: exp.Round) -> str:
272
+ this = self.sql(expression, "this")
273
+ decimals = self.sql(expression, "decimals")
274
+
275
+ if not decimals:
276
+ return self.func("ROUND", this)
277
+
278
+ if not expression.type:
279
+ from sqlglot.optimizer.annotate_types import annotate_types
280
+
281
+ expression = annotate_types(expression, dialect=self.dialect)
282
+
283
+ # ROUND(double precision, integer) is not permitted in Postgres
284
+ # so it's necessary to cast to decimal before rounding.
285
+ if expression.this.is_type(exp.DataType.Type.DOUBLE):
286
+ decimal_type = exp.DataType.build(
287
+ exp.DataType.Type.DECIMAL, expressions=expression.expressions
288
+ )
289
+ this = self.sql(exp.Cast(this=this, to=decimal_type))
290
+
291
+ return self.func("ROUND", this, decimals)
292
+
293
+
268
294
  class Postgres(Dialect):
269
295
  INDEX_OFFSET = 1
270
296
  TYPED_DIVISION = True
@@ -272,6 +298,11 @@ class Postgres(Dialect):
272
298
  NULL_ORDERING = "nulls_are_large"
273
299
  TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
274
300
  TABLESAMPLE_SIZE_IS_PERCENT = True
301
+ TABLES_REFERENCEABLE_AS_COLUMNS = True
302
+
303
+ DEFAULT_FUNCTIONS_COLUMN_NAMES = {
304
+ exp.ExplodingGenerateSeries: "generate_series",
305
+ }
275
306
 
276
307
  TIME_MAPPING = {
277
308
  "d": "%u", # 1-based day of week
@@ -327,6 +358,8 @@ class Postgres(Dialect):
327
358
  "<@": TokenType.LT_AT,
328
359
  "?&": TokenType.QMARK_AMP,
329
360
  "?|": TokenType.QMARK_PIPE,
361
+ "&<": TokenType.AMP_LT,
362
+ "&>": TokenType.AMP_GT,
330
363
  "#-": TokenType.HASH_DASH,
331
364
  "|/": TokenType.PIPE_SLASH,
332
365
  "||/": TokenType.DPIPE_SLASH,
@@ -446,6 +479,8 @@ class Postgres(Dialect):
446
479
  RANGE_PARSERS = {
447
480
  **parser.Parser.RANGE_PARSERS,
448
481
  TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps),
482
+ TokenType.AMP_LT: binary_range_parser(exp.ExtendsLeft),
483
+ TokenType.AMP_GT: binary_range_parser(exp.ExtendsRight),
449
484
  TokenType.DAT: lambda self, this: self.expression(
450
485
  exp.MatchAgainst, this=self._parse_bitwise(), expressions=[this]
451
486
  ),
@@ -651,6 +686,16 @@ class Postgres(Dialect):
651
686
  exp.Rand: rename_func("RANDOM"),
652
687
  exp.RegexpLike: lambda self, e: self.binary(e, "~"),
653
688
  exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
689
+ exp.RegexpReplace: lambda self, e: self.func(
690
+ "REGEXP_REPLACE",
691
+ e.this,
692
+ e.expression,
693
+ e.args.get("replacement"),
694
+ e.args.get("position"),
695
+ e.args.get("occurrence"),
696
+ regexp_replace_global_modifier(e),
697
+ ),
698
+ exp.Round: _round_sql,
654
699
  exp.Select: transforms.preprocess(
655
700
  [
656
701
  transforms.eliminate_semi_and_anti_joins,
@@ -658,6 +703,7 @@ class Postgres(Dialect):
658
703
  ]
659
704
  ),
660
705
  exp.SHA2: sha256_sql,
706
+ exp.SHA2Digest: sha2_digest_sql,
661
707
  exp.StrPosition: lambda self, e: strposition_sql(self, e, func_name="POSITION"),
662
708
  exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)),
663
709
  exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
@@ -698,28 +744,6 @@ class Postgres(Dialect):
698
744
  exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
699
745
  }
700
746
 
701
- def round_sql(self, expression: exp.Round) -> str:
702
- this = self.sql(expression, "this")
703
- decimals = self.sql(expression, "decimals")
704
-
705
- if not decimals:
706
- return self.func("ROUND", this)
707
-
708
- if not expression.type:
709
- from sqlglot.optimizer.annotate_types import annotate_types
710
-
711
- expression = annotate_types(expression, dialect=self.dialect)
712
-
713
- # ROUND(double precision, integer) is not permitted in Postgres
714
- # so it's necessary to cast to decimal before rounding.
715
- if expression.this.is_type(exp.DataType.Type.DOUBLE):
716
- decimal_type = exp.DataType.build(
717
- exp.DataType.Type.DECIMAL, expressions=expression.expressions
718
- )
719
- this = self.sql(exp.Cast(this=this, to=decimal_type))
720
-
721
- return self.func("ROUND", this, decimals)
722
-
723
747
  def schemacommentproperty_sql(self, expression: exp.SchemaCommentProperty) -> str:
724
748
  self.unsupported("Table comments are not supported in the CREATE statement")
725
749
  return ""
@@ -824,6 +848,16 @@ class Postgres(Dialect):
824
848
  def isascii_sql(self, expression: exp.IsAscii) -> str:
825
849
  return f"({self.sql(expression.this)} ~ '^[[:ascii:]]*$')"
826
850
 
851
+ def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str:
852
+ # https://www.postgresql.org/docs/current/functions-window.html
853
+ self.unsupported("PostgreSQL does not support IGNORE NULLS.")
854
+ return self.sql(expression.this)
855
+
856
+ def respectnulls_sql(self, expression: exp.RespectNulls) -> str:
857
+ # https://www.postgresql.org/docs/current/functions-window.html
858
+ self.unsupported("PostgreSQL does not support RESPECT NULLS.")
859
+ return self.sql(expression.this)
860
+
827
861
  @unsupported_args("this")
828
862
  def currentschema_sql(self, expression: exp.CurrentSchema) -> str:
829
863
  return "CURRENT_SCHEMA"