sqlglot 27.29.0__py3-none-any.whl → 28.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sqlglot/__main__.py +6 -4
  2. sqlglot/_version.py +2 -2
  3. sqlglot/dialects/bigquery.py +116 -295
  4. sqlglot/dialects/clickhouse.py +67 -2
  5. sqlglot/dialects/databricks.py +38 -1
  6. sqlglot/dialects/dialect.py +327 -286
  7. sqlglot/dialects/dremio.py +4 -1
  8. sqlglot/dialects/duckdb.py +718 -22
  9. sqlglot/dialects/exasol.py +243 -10
  10. sqlglot/dialects/hive.py +8 -8
  11. sqlglot/dialects/mysql.py +11 -2
  12. sqlglot/dialects/oracle.py +29 -0
  13. sqlglot/dialects/postgres.py +46 -24
  14. sqlglot/dialects/presto.py +47 -16
  15. sqlglot/dialects/redshift.py +16 -0
  16. sqlglot/dialects/risingwave.py +3 -0
  17. sqlglot/dialects/singlestore.py +12 -3
  18. sqlglot/dialects/snowflake.py +199 -271
  19. sqlglot/dialects/spark.py +2 -2
  20. sqlglot/dialects/spark2.py +11 -48
  21. sqlglot/dialects/sqlite.py +9 -0
  22. sqlglot/dialects/teradata.py +5 -8
  23. sqlglot/dialects/trino.py +6 -0
  24. sqlglot/dialects/tsql.py +61 -25
  25. sqlglot/diff.py +4 -2
  26. sqlglot/errors.py +69 -0
  27. sqlglot/expressions.py +484 -84
  28. sqlglot/generator.py +143 -41
  29. sqlglot/helper.py +2 -2
  30. sqlglot/optimizer/annotate_types.py +247 -140
  31. sqlglot/optimizer/canonicalize.py +6 -1
  32. sqlglot/optimizer/eliminate_joins.py +1 -1
  33. sqlglot/optimizer/eliminate_subqueries.py +2 -2
  34. sqlglot/optimizer/merge_subqueries.py +5 -5
  35. sqlglot/optimizer/normalize.py +20 -13
  36. sqlglot/optimizer/normalize_identifiers.py +17 -3
  37. sqlglot/optimizer/optimizer.py +4 -0
  38. sqlglot/optimizer/pushdown_predicates.py +1 -1
  39. sqlglot/optimizer/qualify.py +14 -6
  40. sqlglot/optimizer/qualify_columns.py +113 -352
  41. sqlglot/optimizer/qualify_tables.py +112 -70
  42. sqlglot/optimizer/resolver.py +374 -0
  43. sqlglot/optimizer/scope.py +27 -16
  44. sqlglot/optimizer/simplify.py +1074 -964
  45. sqlglot/optimizer/unnest_subqueries.py +12 -2
  46. sqlglot/parser.py +276 -160
  47. sqlglot/planner.py +2 -2
  48. sqlglot/schema.py +15 -4
  49. sqlglot/tokens.py +42 -7
  50. sqlglot/transforms.py +77 -22
  51. sqlglot/typing/__init__.py +316 -0
  52. sqlglot/typing/bigquery.py +376 -0
  53. sqlglot/typing/hive.py +12 -0
  54. sqlglot/typing/presto.py +24 -0
  55. sqlglot/typing/snowflake.py +505 -0
  56. sqlglot/typing/spark2.py +58 -0
  57. sqlglot/typing/tsql.py +9 -0
  58. {sqlglot-27.29.0.dist-info → sqlglot-28.4.1.dist-info}/METADATA +2 -2
  59. sqlglot-28.4.1.dist-info/RECORD +92 -0
  60. sqlglot-27.29.0.dist-info/RECORD +0 -84
  61. {sqlglot-27.29.0.dist-info → sqlglot-28.4.1.dist-info}/WHEEL +0 -0
  62. {sqlglot-27.29.0.dist-info → sqlglot-28.4.1.dist-info}/licenses/LICENSE +0 -0
  63. {sqlglot-27.29.0.dist-info → sqlglot-28.4.1.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,6 @@ from sqlglot.helper import (
17
17
  flatten,
18
18
  is_int,
19
19
  seq_get,
20
- subclasses,
21
20
  suggest_closest_match_and_fail,
22
21
  to_bool,
23
22
  )
@@ -26,6 +25,7 @@ from sqlglot.parser import Parser
26
25
  from sqlglot.time import TIMEZONES, format_time, subsecond_precision
27
26
  from sqlglot.tokens import Token, Tokenizer, TokenType
28
27
  from sqlglot.trie import new_trie
28
+ from sqlglot.typing import EXPRESSION_METADATA
29
29
 
30
30
  DATE_ADD_OR_DIFF = t.Union[
31
31
  exp.DateAdd,
@@ -53,10 +53,6 @@ DATETIME_ADD = (exp.DateAdd, exp.TimeAdd, exp.DatetimeAdd, exp.TsOrDsAdd, exp.Ti
53
53
  if t.TYPE_CHECKING:
54
54
  from sqlglot._typing import B, E, F
55
55
 
56
- from sqlglot.optimizer.annotate_types import TypeAnnotator
57
-
58
- AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]
59
-
60
56
  logger = logging.getLogger("sqlglot")
61
57
 
62
58
  UNESCAPED_SEQUENCES = {
@@ -71,10 +67,6 @@ UNESCAPED_SEQUENCES = {
71
67
  }
72
68
 
73
69
 
74
- def annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
75
- return lambda self, e: self._annotate_with_type(e, data_type)
76
-
77
-
78
70
  class Dialects(str, Enum):
79
71
  """Dialects supported by SQLGLot."""
80
72
 
@@ -132,20 +124,6 @@ class NormalizationStrategy(str, AutoName):
132
124
  """Always case-insensitive (uppercase), regardless of quotes."""
133
125
 
134
126
 
135
- class Version(int):
136
- def __new__(cls, version_str: t.Optional[str], *args, **kwargs):
137
- if version_str:
138
- parts = version_str.split(".")
139
- parts.extend(["0"] * (3 - len(parts)))
140
- v = int("".join([p.zfill(3) for p in parts]))
141
- else:
142
- # No version defined means we should support the latest engine semantics, so
143
- # the comparison to any specific version should yield that latest is greater
144
- v = sys.maxsize
145
-
146
- return super(Version, cls).__new__(cls, v)
147
-
148
-
149
127
  class _Dialect(type):
150
128
  _classes: t.Dict[str, t.Type[Dialect]] = {}
151
129
 
@@ -207,7 +185,11 @@ class _Dialect(type):
207
185
  klass.FORMAT_TRIE = (
208
186
  new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
209
187
  )
210
- klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
188
+ # Merge class-defined INVERSE_TIME_MAPPING with auto-generated mappings
189
+ # This allows dialects to define custom inverse mappings for roundtrip correctness
190
+ klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} | (
191
+ klass.__dict__.get("INVERSE_TIME_MAPPING") or {}
192
+ )
211
193
  klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
212
194
  klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()}
213
195
  klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING)
@@ -263,6 +245,9 @@ class _Dialect(type):
263
245
 
264
246
  klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS
265
247
 
248
+ if enum not in ("", "bigquery", "snowflake"):
249
+ klass.INITCAP_SUPPORTS_CUSTOM_DELIMITERS = False
250
+
266
251
  if enum not in ("", "bigquery"):
267
252
  klass.generator_class.SELECT_KINDS = ()
268
253
 
@@ -294,6 +279,54 @@ class _Dialect(type):
294
279
  TokenType.SEMI,
295
280
  }
296
281
 
282
+ if enum not in (
283
+ "",
284
+ "postgres",
285
+ "duckdb",
286
+ "redshift",
287
+ "snowflake",
288
+ "presto",
289
+ "trino",
290
+ "mysql",
291
+ "singlestore",
292
+ ):
293
+ no_paren_functions = klass.parser_class.NO_PAREN_FUNCTIONS.copy()
294
+ no_paren_functions.pop(TokenType.LOCALTIME, None)
295
+ if enum != "oracle":
296
+ no_paren_functions.pop(TokenType.LOCALTIMESTAMP, None)
297
+ klass.parser_class.NO_PAREN_FUNCTIONS = no_paren_functions
298
+
299
+ if enum in (
300
+ "",
301
+ "postgres",
302
+ "duckdb",
303
+ "trino",
304
+ ):
305
+ no_paren_functions = klass.parser_class.NO_PAREN_FUNCTIONS.copy()
306
+ no_paren_functions[TokenType.CURRENT_CATALOG] = exp.CurrentCatalog
307
+ klass.parser_class.NO_PAREN_FUNCTIONS = no_paren_functions
308
+ else:
309
+ # For dialects that don't support this keyword, treat it as a regular identifier
310
+ # This fixes the "Unexpected token" error in BQ, Spark, etc.
311
+ klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | {
312
+ TokenType.CURRENT_CATALOG,
313
+ }
314
+
315
+ if enum in (
316
+ "",
317
+ "duckdb",
318
+ "spark",
319
+ "postgres",
320
+ "tsql",
321
+ ):
322
+ no_paren_functions = klass.parser_class.NO_PAREN_FUNCTIONS.copy()
323
+ no_paren_functions[TokenType.SESSION_USER] = exp.SessionUser
324
+ klass.parser_class.NO_PAREN_FUNCTIONS = no_paren_functions
325
+ else:
326
+ klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | {
327
+ TokenType.SESSION_USER,
328
+ }
329
+
297
330
  klass.VALID_INTERVAL_UNITS = {
298
331
  *klass.VALID_INTERVAL_UNITS,
299
332
  *klass.DATE_PART_MAPPING.keys(),
@@ -462,14 +495,139 @@ class Dialect(metaclass=_Dialect):
462
495
  to "WHERE id = 1 GROUP BY id HAVING id = 1"
463
496
  """
464
497
 
465
- EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False
498
+ EXPAND_ONLY_GROUP_ALIAS_REF = False
466
499
  """Whether alias reference expansion before qualification should only happen for the GROUP BY clause."""
467
500
 
501
+ ANNOTATE_ALL_SCOPES = False
502
+ """Whether to annotate all scopes during optimization. Used by BigQuery for UNNEST support."""
503
+
504
+ DISABLES_ALIAS_REF_EXPANSION = False
505
+ """
506
+ Whether alias reference expansion is disabled for this dialect.
507
+
508
+ Some dialects like Oracle do NOT support referencing aliases in projections or WHERE clauses.
509
+ The original expression must be repeated instead.
510
+
511
+ For example, in Oracle:
512
+ SELECT y.foo AS bar, bar * 2 AS baz FROM y -- INVALID
513
+ SELECT y.foo AS bar, y.foo * 2 AS baz FROM y -- VALID
514
+ """
515
+
516
+ SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS = False
517
+ """
518
+ Whether alias references are allowed in JOIN ... ON clauses.
519
+
520
+ Most dialects do not support this, but Snowflake allows alias expansion in the JOIN ... ON
521
+ clause (and almost everywhere else)
522
+
523
+ For example, in Snowflake:
524
+ SELECT a.id AS user_id FROM a JOIN b ON user_id = b.id -- VALID
525
+
526
+ Reference: https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
527
+ """
528
+
468
529
  SUPPORTS_ORDER_BY_ALL = False
469
530
  """
470
531
  Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
471
532
  """
472
533
 
534
+ PROJECTION_ALIASES_SHADOW_SOURCE_NAMES = False
535
+ """
536
+ Whether projection alias names can shadow table/source names in GROUP BY and HAVING clauses.
537
+
538
+ In BigQuery, when a projection alias has the same name as a source table, the alias takes
539
+ precedence in GROUP BY and HAVING clauses, and the table becomes inaccessible by that name.
540
+
541
+ For example, in BigQuery:
542
+ SELECT id, ARRAY_AGG(col) AS custom_fields
543
+ FROM custom_fields
544
+ GROUP BY id
545
+ HAVING id >= 1
546
+
547
+ The "custom_fields" source is shadowed by the projection alias, so we cannot qualify "id"
548
+ with "custom_fields" in GROUP BY/HAVING.
549
+ """
550
+
551
+ TABLES_REFERENCEABLE_AS_COLUMNS = False
552
+ """
553
+ Whether table names can be referenced as columns (treated as structs).
554
+
555
+ BigQuery allows tables to be referenced as columns in queries, automatically treating
556
+ them as struct values containing all the table's columns.
557
+
558
+ For example, in BigQuery:
559
+ SELECT t FROM my_table AS t -- Returns entire row as a struct
560
+ """
561
+
562
+ SUPPORTS_STRUCT_STAR_EXPANSION = False
563
+ """
564
+ Whether the dialect supports expanding struct fields using star notation (e.g., struct_col.*).
565
+
566
+ BigQuery allows struct fields to be expanded with the star operator:
567
+ SELECT t.struct_col.* FROM table t
568
+ RisingWave also allows struct field expansion with the star operator using parentheses:
569
+ SELECT (t.struct_col).* FROM table t
570
+
571
+ This expands to all fields within the struct.
572
+ """
573
+
574
+ EXCLUDES_PSEUDOCOLUMNS_FROM_STAR = False
575
+ """
576
+ Whether pseudocolumns should be excluded from star expansion (SELECT *).
577
+
578
+ Pseudocolumns are special dialect-specific columns (e.g., Oracle's ROWNUM, ROWID, LEVEL,
579
+ or BigQuery's _PARTITIONTIME, _PARTITIONDATE) that are implicitly available but not part
580
+ of the table schema. When this is True, SELECT * will not include these pseudocolumns;
581
+ they must be explicitly selected.
582
+ """
583
+
584
+ QUERY_RESULTS_ARE_STRUCTS = False
585
+ """
586
+ Whether query results are typed as structs in metadata for type inference.
587
+
588
+ In BigQuery, subqueries store their column types as a STRUCT in metadata,
589
+ enabling special type inference for ARRAY(SELECT ...) expressions:
590
+ ARRAY(SELECT x, y FROM t) → ARRAY<STRUCT<...>>
591
+
592
+ For single column subqueries, BigQuery unwraps the struct:
593
+ ARRAY(SELECT x FROM t) → ARRAY<type_of_x>
594
+
595
+ This is metadata-only for type inference.
596
+ """
597
+
598
+ REQUIRES_PARENTHESIZED_STRUCT_ACCESS = False
599
+ """
600
+ Whether struct field access requires parentheses around the expression.
601
+
602
+ RisingWave requires parentheses for struct field access in certain contexts:
603
+ SELECT (col.field).subfield FROM table -- Parentheses required
604
+
605
+ Without parentheses, the parser may not correctly interpret nested struct access.
606
+
607
+ Reference: https://docs.risingwave.com/sql/data-types/struct#retrieve-data-in-a-struct
608
+ """
609
+
610
+ SUPPORTS_NULL_TYPE = False
611
+ """
612
+ Whether NULL/VOID is supported as a valid data type (not just a value).
613
+
614
+ Databricks and Spark v3+ support NULL as an actual type, allowing expressions like:
615
+ SELECT NULL AS col -- Has type NULL, not just value NULL
616
+ CAST(x AS VOID) -- Valid type cast
617
+ """
618
+
619
+ COALESCE_COMPARISON_NON_STANDARD = False
620
+ """
621
+ Whether COALESCE in comparisons has non-standard NULL semantics.
622
+
623
+ We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift,
624
+ because they are not always equivalent. For example, if `x` is `NULL` and it comes
625
+ from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE`.
626
+
627
+ In standard SQL and most dialects, these expressions are equivalent, but Redshift treats
628
+ table NULLs differently in this context.
629
+ """
630
+
473
631
  HAS_DISTINCT_ARRAY_CONSTRUCTORS = False
474
632
  """
475
633
  Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3)
@@ -511,6 +669,9 @@ class Dialect(metaclass=_Dialect):
511
669
  REGEXP_EXTRACT_DEFAULT_GROUP = 0
512
670
  """The default value for the capturing group."""
513
671
 
672
+ REGEXP_EXTRACT_POSITION_OVERFLOW_RETURNS_NULL = True
673
+ """Whether REGEXP_EXTRACT returns NULL when the position arg exceeds the string length."""
674
+
514
675
  SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = {
515
676
  exp.Except: True,
516
677
  exp.Intersect: True,
@@ -545,6 +706,41 @@ class Dialect(metaclass=_Dialect):
545
706
  # Not safe with MySQL and SQLite due to type coercion (may not return boolean)
546
707
  SAFE_TO_ELIMINATE_DOUBLE_NEGATION = True
547
708
 
709
+ # Whether the INITCAP function supports custom delimiter characters as the second argument
710
+ # Default delimiter characters for INITCAP function: whitespace and non-alphanumeric characters
711
+ INITCAP_SUPPORTS_CUSTOM_DELIMITERS = True
712
+ INITCAP_DEFAULT_DELIMITER_CHARS = " \t\n\r\f\v!\"#$%&'()*+,\\-./:;<=>?@\\[\\]^_`{|}~"
713
+
714
+ BYTE_STRING_IS_BYTES_TYPE: bool = False
715
+ """
716
+ Whether byte string literals (ex: BigQuery's b'...') are typed as BYTES/BINARY
717
+ """
718
+
719
+ UUID_IS_STRING_TYPE: bool = False
720
+ """
721
+ Whether a UUID is considered a string or a UUID type.
722
+ """
723
+
724
+ JSON_EXTRACT_SCALAR_SCALAR_ONLY = False
725
+ """
726
+ Whether JSON_EXTRACT_SCALAR returns null if a non-scalar value is selected.
727
+ """
728
+
729
+ DEFAULT_FUNCTIONS_COLUMN_NAMES: t.Dict[t.Type[exp.Func], t.Union[str, t.Tuple[str, ...]]] = {}
730
+ """
731
+ Maps function expressions to their default output column name(s).
732
+
733
+ For example, in Postgres, generate_series function outputs a column named "generate_series" by default,
734
+ so we map the ExplodingGenerateSeries expression to "generate_series" string.
735
+ """
736
+
737
+ DEFAULT_NULL_TYPE = exp.DataType.Type.UNKNOWN
738
+ """
739
+ The default type of NULL for producing the correct projection type.
740
+
741
+ For example, in BigQuery the default type of the NULL value is INT64.
742
+ """
743
+
548
744
  # --- Autofilled ---
549
745
 
550
746
  tokenizer_class = Tokenizer
@@ -606,6 +802,7 @@ class Dialect(metaclass=_Dialect):
606
802
  "WEEKDAY_ISO": "DAYOFWEEKISO",
607
803
  "DOW_ISO": "DAYOFWEEKISO",
608
804
  "DW_ISO": "DAYOFWEEKISO",
805
+ "DAYOFWEEK_ISO": "DAYOFWEEKISO",
609
806
  "DAY OF YEAR": "DAYOFYEAR",
610
807
  "DOY": "DAYOFYEAR",
611
808
  "DY": "DAYOFYEAR",
@@ -668,243 +865,21 @@ class Dialect(metaclass=_Dialect):
668
865
  "DEC": "DECADE",
669
866
  "DECS": "DECADE",
670
867
  "DECADES": "DECADE",
671
- "MIL": "MILLENIUM",
672
- "MILS": "MILLENIUM",
673
- "MILLENIA": "MILLENIUM",
868
+ "MIL": "MILLENNIUM",
869
+ "MILS": "MILLENNIUM",
870
+ "MILLENIA": "MILLENNIUM",
674
871
  "C": "CENTURY",
675
872
  "CENT": "CENTURY",
676
873
  "CENTS": "CENTURY",
677
874
  "CENTURIES": "CENTURY",
678
875
  }
679
876
 
680
- TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
681
- exp.DataType.Type.BIGINT: {
682
- exp.ApproxDistinct,
683
- exp.ArraySize,
684
- exp.CountIf,
685
- exp.Int64,
686
- exp.Length,
687
- exp.UnixDate,
688
- exp.UnixSeconds,
689
- exp.UnixMicros,
690
- exp.UnixMillis,
691
- },
692
- exp.DataType.Type.BINARY: {
693
- exp.FromBase32,
694
- exp.FromBase64,
695
- },
696
- exp.DataType.Type.BOOLEAN: {
697
- exp.Between,
698
- exp.Boolean,
699
- exp.Contains,
700
- exp.EndsWith,
701
- exp.In,
702
- exp.LogicalAnd,
703
- exp.LogicalOr,
704
- exp.RegexpLike,
705
- exp.StartsWith,
706
- },
707
- exp.DataType.Type.DATE: {
708
- exp.CurrentDate,
709
- exp.Date,
710
- exp.DateFromParts,
711
- exp.DateStrToDate,
712
- exp.DiToDate,
713
- exp.LastDay,
714
- exp.StrToDate,
715
- exp.TimeStrToDate,
716
- exp.TsOrDsToDate,
717
- },
718
- exp.DataType.Type.DATETIME: {
719
- exp.CurrentDatetime,
720
- exp.Datetime,
721
- exp.DatetimeAdd,
722
- exp.DatetimeSub,
723
- },
724
- exp.DataType.Type.DOUBLE: {
725
- exp.ApproxQuantile,
726
- exp.Avg,
727
- exp.Exp,
728
- exp.Ln,
729
- exp.Log,
730
- exp.Pi,
731
- exp.Pow,
732
- exp.Quantile,
733
- exp.Radians,
734
- exp.Round,
735
- exp.SafeDivide,
736
- exp.Sqrt,
737
- exp.Stddev,
738
- exp.StddevPop,
739
- exp.StddevSamp,
740
- exp.ToDouble,
741
- exp.Variance,
742
- exp.VariancePop,
743
- },
744
- exp.DataType.Type.INT: {
745
- exp.Ascii,
746
- exp.Ceil,
747
- exp.DatetimeDiff,
748
- exp.DateDiff,
749
- exp.TimestampDiff,
750
- exp.TimeDiff,
751
- exp.Unicode,
752
- exp.DateToDi,
753
- exp.Levenshtein,
754
- exp.Sign,
755
- exp.StrPosition,
756
- exp.TsOrDiToDi,
757
- },
758
- exp.DataType.Type.INTERVAL: {
759
- exp.Interval,
760
- exp.JustifyDays,
761
- exp.JustifyHours,
762
- exp.JustifyInterval,
763
- exp.MakeInterval,
764
- },
765
- exp.DataType.Type.JSON: {
766
- exp.ParseJSON,
767
- },
768
- exp.DataType.Type.TIME: {
769
- exp.CurrentTime,
770
- exp.Time,
771
- exp.TimeAdd,
772
- exp.TimeSub,
773
- },
774
- exp.DataType.Type.TIMESTAMPLTZ: {
775
- exp.TimestampLtzFromParts,
776
- },
777
- exp.DataType.Type.TIMESTAMPTZ: {
778
- exp.CurrentTimestampLTZ,
779
- exp.TimestampTzFromParts,
780
- },
781
- exp.DataType.Type.TIMESTAMP: {
782
- exp.CurrentTimestamp,
783
- exp.StrToTime,
784
- exp.TimeStrToTime,
785
- exp.TimestampAdd,
786
- exp.TimestampSub,
787
- exp.UnixToTime,
788
- },
789
- exp.DataType.Type.TINYINT: {
790
- exp.Day,
791
- exp.DayOfWeek,
792
- exp.DayOfWeekIso,
793
- exp.DayOfMonth,
794
- exp.DayOfYear,
795
- exp.Week,
796
- exp.WeekOfYear,
797
- exp.Month,
798
- exp.Quarter,
799
- exp.Year,
800
- exp.YearOfWeek,
801
- exp.YearOfWeekIso,
802
- },
803
- exp.DataType.Type.VARCHAR: {
804
- exp.ArrayConcat,
805
- exp.ArrayToString,
806
- exp.Concat,
807
- exp.ConcatWs,
808
- exp.Chr,
809
- exp.DateToDateStr,
810
- exp.DPipe,
811
- exp.GroupConcat,
812
- exp.Initcap,
813
- exp.Lower,
814
- exp.Substring,
815
- exp.String,
816
- exp.TimeToStr,
817
- exp.TimeToTimeStr,
818
- exp.Trim,
819
- exp.ToBase32,
820
- exp.ToBase64,
821
- exp.TsOrDsToDateStr,
822
- exp.UnixToStr,
823
- exp.UnixToTimeStr,
824
- exp.Upper,
825
- },
826
- }
827
-
828
- ANNOTATORS: AnnotatorsType = {
829
- **{
830
- expr_type: lambda self, e: self._annotate_unary(e)
831
- for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
832
- },
833
- **{
834
- expr_type: lambda self, e: self._annotate_binary(e)
835
- for expr_type in subclasses(exp.__name__, exp.Binary)
836
- },
837
- **{
838
- expr_type: annotate_with_type_lambda(data_type)
839
- for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
840
- for expr_type in expressions
841
- },
842
- exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
843
- exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
844
- exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
845
- exp.AnyValue: lambda self, e: self._annotate_by_args(e, "this"),
846
- exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
847
- exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
848
- exp.ArrayConcatAgg: lambda self, e: self._annotate_by_args(e, "this"),
849
- exp.ArrayFirst: lambda self, e: self._annotate_by_array_element(e),
850
- exp.ArrayLast: lambda self, e: self._annotate_by_array_element(e),
851
- exp.ArrayReverse: lambda self, e: self._annotate_by_args(e, "this"),
852
- exp.ArraySlice: lambda self, e: self._annotate_by_args(e, "this"),
853
- exp.Bracket: lambda self, e: self._annotate_bracket(e),
854
- exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
855
- exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
856
- exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
857
- exp.Count: lambda self, e: self._annotate_with_type(
858
- e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT
859
- ),
860
- exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
861
- exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
862
- exp.DateSub: lambda self, e: self._annotate_timeunit(e),
863
- exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
864
- exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
865
- exp.Div: lambda self, e: self._annotate_div(e),
866
- exp.Dot: lambda self, e: self._annotate_dot(e),
867
- exp.Explode: lambda self, e: self._annotate_explode(e),
868
- exp.Extract: lambda self, e: self._annotate_extract(e),
869
- exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
870
- exp.GenerateSeries: lambda self, e: self._annotate_by_args(
871
- e, "start", "end", "step", array=True
872
- ),
873
- exp.GenerateDateArray: lambda self, e: self._annotate_with_type(
874
- e, exp.DataType.build("ARRAY<DATE>")
875
- ),
876
- exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type(
877
- e, exp.DataType.build("ARRAY<TIMESTAMP>")
878
- ),
879
- exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
880
- exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
881
- exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
882
- exp.Literal: lambda self, e: self._annotate_literal(e),
883
- exp.LastValue: lambda self, e: self._annotate_by_args(e, "this"),
884
- exp.Map: lambda self, e: self._annotate_map(e),
885
- exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
886
- exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
887
- exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
888
- exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
889
- exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
890
- exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
891
- exp.Struct: lambda self, e: self._annotate_struct(e),
892
- exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
893
- exp.SortArray: lambda self, e: self._annotate_by_args(e, "this"),
894
- exp.Timestamp: lambda self, e: self._annotate_with_type(
895
- e,
896
- exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
897
- ),
898
- exp.ToMap: lambda self, e: self._annotate_to_map(e),
899
- exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
900
- exp.Unnest: lambda self, e: self._annotate_unnest(e),
901
- exp.VarMap: lambda self, e: self._annotate_map(e),
902
- exp.Window: lambda self, e: self._annotate_by_args(e, "this"),
903
- }
904
-
905
877
  # Specifies what types a given type can be coerced into
906
878
  COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
907
879
 
880
+ # Specifies type inference & validation rules for expressions
881
+ EXPRESSION_METADATA = EXPRESSION_METADATA.copy()
882
+
908
883
  # Determines the supported Dialect instance settings
909
884
  SUPPORTED_SETTINGS = {
910
885
  "normalization_strategy",
@@ -984,7 +959,9 @@ class Dialect(metaclass=_Dialect):
984
959
  return expression
985
960
 
986
961
  def __init__(self, **kwargs) -> None:
987
- self.version = Version(kwargs.pop("version", None))
962
+ parts = str(kwargs.pop("version", sys.maxsize)).split(".")
963
+ parts.extend(["0"] * (3 - len(parts)))
964
+ self.version = tuple(int(p) for p in parts[:3])
988
965
 
989
966
  normalization_strategy = kwargs.pop("normalization_strategy", None)
990
967
  if normalization_strategy is None:
@@ -1061,42 +1038,50 @@ class Dialect(metaclass=_Dialect):
1061
1038
  )
1062
1039
  return any(unsafe(char) for char in text)
1063
1040
 
1064
- def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
1065
- """Checks if text can be identified given an identify option.
1041
+ def can_quote(self, identifier: exp.Identifier, identify: str | bool = "safe") -> bool:
1042
+ """Checks if an identifier can be quoted
1066
1043
 
1067
1044
  Args:
1068
- text: The text to check.
1045
+ identifier: The identifier to check.
1069
1046
  identify:
1070
- `"always"` or `True`: Always returns `True`.
1047
+ `True`: Always returns `True` except for certain cases.
1071
1048
  `"safe"`: Only returns `True` if the identifier is case-insensitive.
1049
+ `"unsafe"`: Only returns `True` if the identifier is case-sensitive.
1072
1050
 
1073
1051
  Returns:
1074
1052
  Whether the given text can be identified.
1075
1053
  """
1076
- if identify is True or identify == "always":
1054
+ if identifier.quoted:
1077
1055
  return True
1056
+ if not identify:
1057
+ return False
1058
+ if isinstance(identifier.parent, exp.Func):
1059
+ return False
1060
+ if identify is True:
1061
+ return True
1062
+
1063
+ is_safe = not self.case_sensitive(identifier.this) and bool(
1064
+ exp.SAFE_IDENTIFIER_RE.match(identifier.this)
1065
+ )
1078
1066
 
1079
1067
  if identify == "safe":
1080
- return not self.case_sensitive(text)
1068
+ return is_safe
1069
+ if identify == "unsafe":
1070
+ return not is_safe
1081
1071
 
1082
- return False
1072
+ raise ValueError(f"Unexpected argument for identify: '{identify}'")
1083
1073
 
1084
1074
  def quote_identifier(self, expression: E, identify: bool = True) -> E:
1085
1075
  """
1086
- Adds quotes to a given identifier.
1076
+ Adds quotes to a given expression if it is an identifier.
1087
1077
 
1088
1078
  Args:
1089
1079
  expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
1090
1080
  identify: If set to `False`, the quotes will only be added if the identifier is deemed
1091
1081
  "unsafe", with respect to its characters and this dialect's normalization strategy.
1092
1082
  """
1093
- if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
1094
- name = expression.this
1095
- expression.set(
1096
- "quoted",
1097
- identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
1098
- )
1099
-
1083
+ if isinstance(expression, exp.Identifier):
1084
+ expression.set("quoted", self.can_quote(expression, identify or "unsafe"))
1100
1085
  return expression
1101
1086
 
1102
1087
  def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
@@ -1187,11 +1172,11 @@ def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> st
1187
1172
  return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1188
1173
 
1189
1174
 
1190
- def inline_array_sql(self: Generator, expression: exp.Array) -> str:
1175
+ def inline_array_sql(self: Generator, expression: exp.Expression) -> str:
1191
1176
  return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
1192
1177
 
1193
1178
 
1194
- def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
1179
+ def inline_array_unless_query(self: Generator, expression: exp.Expression) -> str:
1195
1180
  elem = seq_get(expression.expressions, 0)
1196
1181
  if isinstance(elem, exp.Expression) and elem.find(exp.Query):
1197
1182
  return self.func("ARRAY", elem)
@@ -1414,12 +1399,14 @@ def date_add_interval_sql(
1414
1399
  return func
1415
1400
 
1416
1401
 
1417
- def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
1402
+ def timestamptrunc_sql(
1403
+ func: str = "DATE_TRUNC", zone: bool = False
1404
+ ) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
1418
1405
  def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
1419
1406
  args = [unit_to_str(expression), expression.this]
1420
1407
  if zone:
1421
1408
  args.append(expression.args.get("zone"))
1422
- return self.func("DATE_TRUNC", *args)
1409
+ return self.func(func, *args)
1423
1410
 
1424
1411
  return _timestamptrunc_sql
1425
1412
 
@@ -1957,6 +1944,10 @@ def sha256_sql(self: Generator, expression: exp.SHA2) -> str:
1957
1944
  return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
1958
1945
 
1959
1946
 
1947
+ def sha2_digest_sql(self: Generator, expression: exp.SHA2Digest) -> str:
1948
+ return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
1949
+
1950
+
1960
1951
  def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str:
1961
1952
  start = expression.args.get("start")
1962
1953
  end = expression.args.get("end")
@@ -1969,31 +1960,76 @@ def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateD
1969
1960
  else:
1970
1961
  target_type = None
1971
1962
 
1972
- if start and end and target_type and target_type.is_type("date", "timestamp"):
1973
- if isinstance(start, exp.Cast) and target_type is start.to:
1974
- end = exp.cast(end, target_type)
1975
- else:
1976
- start = exp.cast(start, target_type)
1963
+ if start and end:
1964
+ if target_type and target_type.is_type("date", "timestamp"):
1965
+ if isinstance(start, exp.Cast) and target_type is start.to:
1966
+ end = exp.cast(end, target_type)
1967
+ else:
1968
+ start = exp.cast(start, target_type)
1969
+
1970
+ if expression.args.get("is_end_exclusive"):
1971
+ step_value = step or exp.Literal.number(1)
1972
+ end = exp.paren(exp.Sub(this=end, expression=step_value), copy=False)
1973
+
1974
+ sequence_call = exp.Anonymous(
1975
+ this="SEQUENCE", expressions=[e for e in (start, end, step) if e]
1976
+ )
1977
+ zero = exp.Literal.number(0)
1978
+ should_return_empty = exp.or_(
1979
+ exp.EQ(this=step_value.copy(), expression=zero.copy()),
1980
+ exp.and_(
1981
+ exp.GT(this=step_value.copy(), expression=zero.copy()),
1982
+ exp.GTE(this=start.copy(), expression=end.copy()),
1983
+ ),
1984
+ exp.and_(
1985
+ exp.LT(this=step_value.copy(), expression=zero.copy()),
1986
+ exp.LTE(this=start.copy(), expression=end.copy()),
1987
+ ),
1988
+ )
1989
+ empty_array_or_sequence = exp.If(
1990
+ this=should_return_empty,
1991
+ true=exp.Array(expressions=[]),
1992
+ false=sequence_call,
1993
+ )
1994
+ return self.sql(self._simplify_unless_literal(empty_array_or_sequence))
1977
1995
 
1978
1996
  return self.func("SEQUENCE", start, end, step)
1979
1997
 
1980
1998
 
1981
- def build_like(expr_type: t.Type[E]) -> t.Callable[[t.List], E | exp.Escape]:
1982
- def _builder(args: t.List) -> E | exp.Escape:
1983
- like_expr = expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
1984
- escape = seq_get(args, 2)
1985
- return exp.Escape(this=like_expr, expression=escape) if escape else like_expr
1999
+ def build_like(
2000
+ expr_type: t.Type[E], not_like: bool = False
2001
+ ) -> t.Callable[[t.List], exp.Expression]:
2002
+ def _builder(args: t.List) -> exp.Expression:
2003
+ like_expr: exp.Expression = expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
2004
+
2005
+ if escape := seq_get(args, 2):
2006
+ like_expr = exp.Escape(this=like_expr, expression=escape)
2007
+
2008
+ if not_like:
2009
+ like_expr = exp.Not(this=like_expr)
2010
+
2011
+ return like_expr
1986
2012
 
1987
2013
  return _builder
1988
2014
 
1989
2015
 
1990
2016
  def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]:
1991
2017
  def _builder(args: t.List, dialect: Dialect) -> E:
2018
+ # The "position" argument specifies the index of the string character to start matching from.
2019
+ # `null_if_pos_overflow` reflects the dialect's behavior when position is greater than the string
2020
+ # length. If true, returns NULL. If false, returns an empty string. `null_if_pos_overflow` is
2021
+ # only needed for exp.RegexpExtract - exp.RegexpExtractAll always returns an empty array if
2022
+ # position overflows.
1992
2023
  return expr_type(
1993
2024
  this=seq_get(args, 0),
1994
2025
  expression=seq_get(args, 1),
1995
2026
  group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP),
1996
2027
  parameters=seq_get(args, 3),
2028
+ **(
2029
+ {"null_if_pos_overflow": dialect.REGEXP_EXTRACT_POSITION_OVERFLOW_RETURNS_NULL}
2030
+ if expr_type is exp.RegexpExtract
2031
+ else {}
2032
+ ),
1997
2033
  )
1998
2034
 
1999
2035
  return _builder
@@ -2038,12 +2074,14 @@ def groupconcat_sql(
2038
2074
  self: Generator,
2039
2075
  expression: exp.GroupConcat,
2040
2076
  func_name="LISTAGG",
2041
- sep: str = ",",
2077
+ sep: t.Optional[str] = ",",
2042
2078
  within_group: bool = True,
2043
2079
  on_overflow: bool = False,
2044
2080
  ) -> str:
2045
2081
  this = expression.this
2046
- separator = self.sql(expression.args.get("separator") or exp.Literal.string(sep))
2082
+ separator = self.sql(
2083
+ expression.args.get("separator") or (exp.Literal.string(sep) if sep else None)
2084
+ )
2047
2085
 
2048
2086
  on_overflow_sql = self.sql(expression, "on_overflow")
2049
2087
  on_overflow_sql = f" ON OVERFLOW {on_overflow_sql}" if (on_overflow and on_overflow_sql) else ""
@@ -2059,7 +2097,10 @@ def groupconcat_sql(
2059
2097
  if order and order.this:
2060
2098
  this = order.this.pop()
2061
2099
 
2062
- args = self.format_args(this, f"{separator}{on_overflow_sql}")
2100
+ args = self.format_args(
2101
+ this, f"{separator}{on_overflow_sql}" if separator or on_overflow_sql else None
2102
+ )
2103
+
2063
2104
  listagg: exp.Expression = exp.Anonymous(this=func_name, expressions=[args])
2064
2105
 
2065
2106
  modifiers = self.sql(limit)