sqlglot 28.4.1__py3-none-any.whl → 28.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sqlglot/_version.py +2 -2
  2. sqlglot/dialects/bigquery.py +20 -23
  3. sqlglot/dialects/clickhouse.py +2 -0
  4. sqlglot/dialects/dialect.py +355 -18
  5. sqlglot/dialects/doris.py +38 -90
  6. sqlglot/dialects/druid.py +1 -0
  7. sqlglot/dialects/duckdb.py +1739 -163
  8. sqlglot/dialects/exasol.py +17 -1
  9. sqlglot/dialects/hive.py +27 -2
  10. sqlglot/dialects/mysql.py +103 -11
  11. sqlglot/dialects/oracle.py +38 -1
  12. sqlglot/dialects/postgres.py +142 -33
  13. sqlglot/dialects/presto.py +6 -2
  14. sqlglot/dialects/redshift.py +7 -1
  15. sqlglot/dialects/singlestore.py +13 -3
  16. sqlglot/dialects/snowflake.py +271 -21
  17. sqlglot/dialects/spark.py +25 -0
  18. sqlglot/dialects/spark2.py +4 -3
  19. sqlglot/dialects/starrocks.py +152 -17
  20. sqlglot/dialects/trino.py +1 -0
  21. sqlglot/dialects/tsql.py +5 -0
  22. sqlglot/diff.py +1 -1
  23. sqlglot/expressions.py +239 -47
  24. sqlglot/generator.py +173 -44
  25. sqlglot/optimizer/annotate_types.py +129 -60
  26. sqlglot/optimizer/merge_subqueries.py +13 -2
  27. sqlglot/optimizer/qualify_columns.py +7 -0
  28. sqlglot/optimizer/resolver.py +19 -0
  29. sqlglot/optimizer/scope.py +12 -0
  30. sqlglot/optimizer/unnest_subqueries.py +7 -0
  31. sqlglot/parser.py +251 -58
  32. sqlglot/schema.py +186 -14
  33. sqlglot/tokens.py +36 -6
  34. sqlglot/transforms.py +6 -5
  35. sqlglot/typing/__init__.py +29 -10
  36. sqlglot/typing/bigquery.py +5 -10
  37. sqlglot/typing/duckdb.py +39 -0
  38. sqlglot/typing/hive.py +50 -1
  39. sqlglot/typing/mysql.py +32 -0
  40. sqlglot/typing/presto.py +0 -1
  41. sqlglot/typing/snowflake.py +80 -17
  42. sqlglot/typing/spark.py +29 -0
  43. sqlglot/typing/spark2.py +9 -1
  44. sqlglot/typing/tsql.py +21 -0
  45. {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/METADATA +47 -2
  46. sqlglot-28.8.0.dist-info/RECORD +95 -0
  47. {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/WHEEL +1 -1
  48. sqlglot-28.4.1.dist-info/RECORD +0 -92
  49. {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/licenses/LICENSE +0 -0
  50. {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,15 @@ if t.TYPE_CHECKING:
30
30
 
31
31
  logger = logging.getLogger("sqlglot")
32
32
 
33
+ # EXTRACT/DATE_PART specifiers that return BIGINT instead of INT
34
+ BIGINT_EXTRACT_DATE_PARTS = {
35
+ "EPOCH_SECOND",
36
+ "EPOCH_MILLISECOND",
37
+ "EPOCH_MICROSECOND",
38
+ "EPOCH_NANOSECOND",
39
+ "NANOSECOND",
40
+ }
41
+
33
42
 
34
43
  def annotate_types(
35
44
  expression: E,
@@ -213,10 +222,14 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
213
222
  # When set to False, this enables partial annotation by skipping already-annotated nodes
214
223
  self._overwrite_types = overwrite_types
215
224
 
225
+ # Maps Scope to its corresponding selected sources
226
+ self._scope_selects: t.Dict[Scope, t.Dict[str, t.Dict[str, t.Any]]] = {}
227
+
216
228
  def clear(self) -> None:
217
229
  self._visited.clear()
218
230
  self._null_expressions.clear()
219
231
  self._setop_column_types.clear()
232
+ self._scope_selects.clear()
220
233
 
221
234
  def _set_type(
222
235
  self, expression: E, target_type: t.Optional[exp.DataType | exp.DataType.Type]
@@ -268,53 +281,58 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
268
281
 
269
282
  return expression
270
283
 
271
- def annotate_scope(self, scope: Scope) -> None:
272
- selects = {}
273
-
274
- for name, source in scope.sources.items():
275
- if not isinstance(source, Scope):
276
- continue
277
-
278
- expression = source.expression
279
- if isinstance(expression, exp.UDTF):
280
- values = []
284
+ def _get_scope_selects(self, scope: Scope) -> t.Dict[str, t.Dict[str, t.Any]]:
285
+ if scope not in self._scope_selects:
286
+ selects = {}
287
+ for name, source in scope.sources.items():
288
+ if not isinstance(source, Scope):
289
+ continue
281
290
 
282
- if isinstance(expression, exp.Lateral):
283
- if isinstance(expression.this, exp.Explode):
284
- values = [expression.this.this]
285
- elif isinstance(expression, exp.Unnest):
286
- values = [expression]
287
- elif not isinstance(expression, exp.TableFromRows):
288
- values = expression.expressions[0].expressions
291
+ expression = source.expression
292
+ if isinstance(expression, exp.UDTF):
293
+ values = []
289
294
 
290
- if not values:
291
- continue
295
+ if isinstance(expression, exp.Lateral):
296
+ if isinstance(expression.this, exp.Explode):
297
+ values = [expression.this.this]
298
+ elif isinstance(expression, exp.Unnest):
299
+ values = [expression]
300
+ elif not isinstance(expression, exp.TableFromRows):
301
+ values = expression.expressions[0].expressions
292
302
 
293
- alias_column_names = expression.alias_column_names
303
+ if not values:
304
+ continue
294
305
 
295
- if (
296
- isinstance(expression, exp.Unnest)
297
- and not alias_column_names
298
- and expression.type
299
- and expression.type.is_type(exp.DataType.Type.STRUCT)
300
- ):
301
- selects[name] = {
302
- col_def.name: t.cast(t.Union[exp.DataType, exp.DataType.Type], col_def.kind)
303
- for col_def in expression.type.expressions
304
- if isinstance(col_def, exp.ColumnDef) and col_def.kind
305
- }
306
+ alias_column_names = expression.alias_column_names
307
+
308
+ if (
309
+ isinstance(expression, exp.Unnest)
310
+ and expression.type
311
+ and expression.type.is_type(exp.DataType.Type.STRUCT)
312
+ ):
313
+ selects[name] = {
314
+ col_def.name: t.cast(
315
+ t.Union[exp.DataType, exp.DataType.Type], col_def.kind
316
+ )
317
+ for col_def in expression.type.expressions
318
+ if isinstance(col_def, exp.ColumnDef) and col_def.kind
319
+ }
320
+ else:
321
+ selects[name] = {
322
+ alias: column.type for alias, column in zip(alias_column_names, values)
323
+ }
324
+ elif isinstance(expression, exp.SetOperation) and len(
325
+ expression.left.selects
326
+ ) == len(expression.right.selects):
327
+ selects[name] = self._get_setop_column_types(expression)
306
328
  else:
307
- selects[name] = {
308
- alias: column.type for alias, column in zip(alias_column_names, values)
309
- }
310
- elif isinstance(expression, exp.SetOperation) and len(expression.left.selects) == len(
311
- expression.right.selects
312
- ):
313
- selects[name] = self._get_setop_column_types(expression)
329
+ selects[name] = {s.alias_or_name: s.type for s in expression.selects}
314
330
 
315
- else:
316
- selects[name] = {s.alias_or_name: s.type for s in expression.selects}
331
+ self._scope_selects[scope] = selects
332
+
333
+ return self._scope_selects[scope]
317
334
 
335
+ def annotate_scope(self, scope: Scope) -> None:
318
336
  if isinstance(self.schema, MappingSchema):
319
337
  for table_column in scope.table_columns:
320
338
  source = scope.sources.get(table_column.name)
@@ -345,7 +363,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
345
363
  self._set_type(table_column, source.expression.meta["query_type"])
346
364
 
347
365
  # Iterate through all the expressions of the current scope in post-order, and annotate
348
- self._annotate_expression(scope.expression, scope, selects)
366
+ self._annotate_expression(scope.expression, scope)
349
367
 
350
368
  if self.dialect.QUERY_RESULTS_ARE_STRUCTS and isinstance(scope.expression, exp.Query):
351
369
  struct_type = exp.DataType(
@@ -374,10 +392,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
374
392
  self,
375
393
  expression: exp.Expression,
376
394
  scope: t.Optional[Scope] = None,
377
- selects: t.Optional[t.Dict[str, t.Dict[str, t.Any]]] = None,
378
395
  ) -> None:
379
396
  stack = [(expression, False)]
380
- selects = selects or {}
381
397
 
382
398
  while stack:
383
399
  expr, children_annotated = stack.pop()
@@ -396,12 +412,21 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
396
412
  continue
397
413
 
398
414
  if scope and isinstance(expr, exp.Column) and expr.table:
399
- source = scope.sources.get(expr.table)
415
+ source = None
416
+ source_scope = scope
417
+ while source_scope and not source:
418
+ source = source_scope.sources.get(expr.table)
419
+ if not source:
420
+ source_scope = source_scope.parent
421
+
400
422
  if isinstance(source, exp.Table):
401
423
  self._set_type(expr, self.schema.get_column_type(source, expr))
402
424
  elif source:
403
- if expr.table in selects and expr.name in selects[expr.table]:
404
- self._set_type(expr, selects[expr.table][expr.name])
425
+ col_type = (
426
+ self._get_scope_selects(source_scope).get(expr.table, {}).get(expr.name)
427
+ )
428
+ if col_type:
429
+ self._set_type(expr, col_type)
405
430
  elif isinstance(source.expression, exp.Unnest):
406
431
  self._set_type(expr, source.expression.type)
407
432
  else:
@@ -536,7 +561,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
536
561
  elif (left_type, right_type) in self.binary_coercions:
537
562
  self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
538
563
  else:
539
- self._set_type(expression, self._maybe_coerce(left_type, right_type))
564
+ self._annotate_by_args(expression, left, right)
540
565
 
541
566
  if isinstance(expression, exp.Is) or (
542
567
  left.meta.get("nonnull") is True and right.meta.get("nonnull") is True
@@ -572,28 +597,64 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
572
597
  def _annotate_by_args(
573
598
  self,
574
599
  expression: E,
575
- *args: str,
600
+ *args: str | exp.Expression,
576
601
  promote: bool = False,
577
602
  array: bool = False,
578
603
  ) -> E:
579
- expressions: t.List[exp.Expression] = []
604
+ literal_type = None
605
+ non_literal_type = None
606
+ nested_type = None
607
+
580
608
  for arg in args:
581
- arg_expr = expression.args.get(arg)
582
- expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
609
+ if isinstance(arg, str):
610
+ expressions = expression.args.get(arg)
611
+ else:
612
+ expressions = arg
583
613
 
584
- last_datatype = None
585
- for expr in expressions:
586
- expr_type = expr.type
614
+ for expr in ensure_list(expressions):
615
+ expr_type = expr.type
587
616
 
588
- # Stop at the first nested data type found - we don't want to _maybe_coerce nested types
589
- if expr_type.args.get("nested"):
590
- last_datatype = expr_type
617
+ # Stop at the first nested data type found - we don't want to _maybe_coerce nested types
618
+ if expr_type.args.get("nested"):
619
+ nested_type = expr_type
620
+ break
621
+
622
+ if isinstance(expr, exp.Literal):
623
+ literal_type = self._maybe_coerce(literal_type or expr_type, expr_type)
624
+ else:
625
+ non_literal_type = self._maybe_coerce(non_literal_type or expr_type, expr_type)
626
+
627
+ if nested_type:
591
628
  break
592
629
 
593
- if not expr_type.is_type(exp.DataType.Type.UNKNOWN):
594
- last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
630
+ result_type = None
631
+
632
+ if nested_type:
633
+ result_type = nested_type
634
+ elif literal_type and non_literal_type:
635
+ if self.dialect.PRIORITIZE_NON_LITERAL_TYPES:
636
+ literal_this_type = (
637
+ literal_type.this if isinstance(literal_type, exp.DataType) else literal_type
638
+ )
639
+ non_literal_this_type = (
640
+ non_literal_type.this
641
+ if isinstance(non_literal_type, exp.DataType)
642
+ else non_literal_type
643
+ )
644
+ if (
645
+ literal_this_type in exp.DataType.INTEGER_TYPES
646
+ and non_literal_this_type in exp.DataType.INTEGER_TYPES
647
+ ) or (
648
+ literal_this_type in exp.DataType.REAL_TYPES
649
+ and non_literal_this_type in exp.DataType.REAL_TYPES
650
+ ):
651
+ result_type = non_literal_type
652
+ else:
653
+ result_type = literal_type or non_literal_type or exp.DataType.Type.UNKNOWN
595
654
 
596
- self._set_type(expression, last_datatype)
655
+ self._set_type(
656
+ expression, result_type or self._maybe_coerce(non_literal_type, literal_type)
657
+ )
597
658
 
598
659
  if promote:
599
660
  if expression.type.this in exp.DataType.INTEGER_TYPES:
@@ -661,6 +722,12 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
661
722
 
662
723
  def _annotate_dot(self, expression: exp.Dot) -> exp.Dot:
663
724
  self._set_type(expression, None)
725
+
726
+ # Propagate type from qualified UDF calls (e.g., db.my_udf(...))
727
+ if isinstance(expression.expression, exp.Anonymous):
728
+ self._set_type(expression, expression.expression.type)
729
+ return expression
730
+
664
731
  this_type = expression.this.type
665
732
 
666
733
  if this_type and this_type.is_type(exp.DataType.Type.STRUCT):
@@ -784,6 +851,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
784
851
  self._set_type(expression, exp.DataType.Type.TIME)
785
852
  elif part == "DATE":
786
853
  self._set_type(expression, exp.DataType.Type.DATE)
854
+ elif part in BIGINT_EXTRACT_DATE_PARTS:
855
+ self._set_type(expression, exp.DataType.Type.BIGINT)
787
856
  else:
788
857
  self._set_type(expression, exp.DataType.Type.INT)
789
858
  return expression
@@ -326,14 +326,25 @@ def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> No
326
326
  expression = expression.unalias()
327
327
  must_wrap_expression = not isinstance(expression, SAFE_TO_REPLACE_UNWRAPPED)
328
328
 
329
+ is_number = expression.is_number
330
+
329
331
  for column in columns_to_replace:
332
+ parent = column.parent
333
+
334
+ # Ensures that we don't merge literal numbers in GROUP BY as they have positional context
335
+ # e.g don't trasform `SELECT a FROM (SELECT 6 AS a) GROUP BY a` to `SELECT 6 AS a GROUP BY 6`,
336
+ # as this would attempt to GROUP BY the 6th projection instead of the column `a`
337
+ if is_number and isinstance(parent, exp.Group):
338
+ column.replace(exp.to_identifier(column.name))
339
+ continue
340
+
330
341
  # Ensures we don't alter the intended operator precedence if there's additional
331
342
  # context surrounding the outer expression (i.e. it's not a simple projection).
332
- if isinstance(column.parent, (exp.Unary, exp.Binary)) and must_wrap_expression:
343
+ if isinstance(parent, (exp.Unary, exp.Binary)) and must_wrap_expression:
333
344
  expression = exp.paren(expression, copy=False)
334
345
 
335
346
  # make sure we do not accidentally change the name of the column
336
- if isinstance(column.parent, exp.Select) and column.name != expression.name:
347
+ if isinstance(parent, exp.Select) and column.name != expression.name:
337
348
  expression = exp.alias_(expression, column.name)
338
349
 
339
350
  column.replace(expression.copy())
@@ -610,6 +610,13 @@ def _qualify_columns(
610
610
  # column_table can be a '' because bigquery unnest has no table alias
611
611
  column_table = resolver.get_table(column)
612
612
 
613
+ if (
614
+ column_table
615
+ and isinstance(source := scope.sources.get(column_table.name), Scope)
616
+ and id(column) in source.column_index
617
+ ):
618
+ continue
619
+
613
620
  if column_table:
614
621
  column.set("table", column_table)
615
622
  elif (
@@ -305,6 +305,21 @@ class Resolver:
305
305
  # Performance optimization - avoid copying first_columns if there is only one table.
306
306
  return SingleValuedMapping(first_columns, first_table)
307
307
 
308
+ # For BigQuery UNNEST_COLUMN_ONLY, build a mapping of original UNNEST aliases
309
+ # from alias.columns[0] to their source names. This is used to resolve shadowing
310
+ # where an UNNEST alias shadows a column name from another table.
311
+ unnest_original_aliases: t.Dict[str, str] = {}
312
+ if self.dialect.UNNEST_COLUMN_ONLY:
313
+ unnest_original_aliases = {
314
+ alias_arg.columns[0].name: source_name
315
+ for source_name, source in self.scope.sources.items()
316
+ if (
317
+ isinstance(source.expression, exp.Unnest)
318
+ and (alias_arg := source.expression.args.get("alias"))
319
+ and alias_arg.columns
320
+ )
321
+ }
322
+
308
323
  unambiguous_columns = {col: first_table for col in first_columns}
309
324
  all_columns = set(unambiguous_columns)
310
325
 
@@ -314,6 +329,10 @@ class Resolver:
314
329
  all_columns.update(columns)
315
330
 
316
331
  for column in ambiguous:
332
+ if column in unnest_original_aliases:
333
+ unambiguous_columns[column] = unnest_original_aliases[column]
334
+ continue
335
+
317
336
  unambiguous_columns.pop(column, None)
318
337
  for column in unique.difference(ambiguous):
319
338
  unambiguous_columns[column] = table
@@ -103,6 +103,7 @@ class Scope:
103
103
  self._pivots = None
104
104
  self._references = None
105
105
  self._semi_anti_join_tables = None
106
+ self._column_index = None
106
107
 
107
108
  def branch(
108
109
  self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
@@ -131,6 +132,7 @@ class Scope:
131
132
  self._stars = []
132
133
  self._join_hints = []
133
134
  self._semi_anti_join_tables = set()
135
+ self._column_index = set()
134
136
 
135
137
  for node in self.walk(bfs=False):
136
138
  if node is self.expression:
@@ -139,6 +141,8 @@ class Scope:
139
141
  if isinstance(node, exp.Dot) and node.is_star:
140
142
  self._stars.append(node)
141
143
  elif isinstance(node, exp.Column) and not isinstance(node, exp.Pseudocolumn):
144
+ self._column_index.add(id(node))
145
+
142
146
  if isinstance(node.this, exp.Star):
143
147
  self._stars.append(node)
144
148
  else:
@@ -259,6 +263,14 @@ class Scope:
259
263
  self._ensure_collected()
260
264
  return self._stars
261
265
 
266
+ @property
267
+ def column_index(self) -> t.Set[int]:
268
+ """
269
+ Set of column object IDs that belong to this scope's expression.
270
+ """
271
+ self._ensure_collected()
272
+ return self._column_index
273
+
262
274
  @property
263
275
  def columns(self):
264
276
  """
@@ -43,6 +43,12 @@ def unnest(select, parent_select, next_alias_name):
43
43
  predicate = select.find_ancestor(exp.Condition)
44
44
  if (
45
45
  not predicate
46
+ # Do not unnest subqueries inside table-valued functions such as
47
+ # FROM GENERATE_SERIES(...), FROM UNNEST(...) etc in order to preserve join order
48
+ or (
49
+ isinstance(predicate, exp.Func)
50
+ and isinstance(predicate.parent, (exp.Table, exp.From, exp.Join))
51
+ )
46
52
  or parent_select is not predicate.parent_select
47
53
  or not parent_select.args.get("from_")
48
54
  ):
@@ -83,6 +89,7 @@ def unnest(select, parent_select, next_alias_name):
83
89
 
84
90
  _replace(select.parent, column)
85
91
  parent_select.join(select, on=on_clause, join_type=join_type, join_alias=alias, copy=False)
92
+
86
93
  return
87
94
 
88
95
  if select.find(exp.Limit, exp.Offset):