sqlglot 27.27.0__py3-none-any.whl → 28.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sqlglot/__init__.py +1 -0
  2. sqlglot/__main__.py +6 -4
  3. sqlglot/_version.py +2 -2
  4. sqlglot/dialects/bigquery.py +118 -279
  5. sqlglot/dialects/clickhouse.py +73 -5
  6. sqlglot/dialects/databricks.py +38 -1
  7. sqlglot/dialects/dialect.py +354 -275
  8. sqlglot/dialects/dremio.py +4 -1
  9. sqlglot/dialects/duckdb.py +754 -25
  10. sqlglot/dialects/exasol.py +243 -10
  11. sqlglot/dialects/hive.py +8 -8
  12. sqlglot/dialects/mysql.py +14 -4
  13. sqlglot/dialects/oracle.py +29 -0
  14. sqlglot/dialects/postgres.py +60 -26
  15. sqlglot/dialects/presto.py +47 -16
  16. sqlglot/dialects/redshift.py +16 -0
  17. sqlglot/dialects/risingwave.py +3 -0
  18. sqlglot/dialects/singlestore.py +12 -3
  19. sqlglot/dialects/snowflake.py +239 -218
  20. sqlglot/dialects/spark.py +15 -4
  21. sqlglot/dialects/spark2.py +11 -48
  22. sqlglot/dialects/sqlite.py +10 -0
  23. sqlglot/dialects/starrocks.py +3 -0
  24. sqlglot/dialects/teradata.py +5 -8
  25. sqlglot/dialects/trino.py +6 -0
  26. sqlglot/dialects/tsql.py +61 -22
  27. sqlglot/diff.py +4 -2
  28. sqlglot/errors.py +69 -0
  29. sqlglot/executor/__init__.py +5 -10
  30. sqlglot/executor/python.py +1 -29
  31. sqlglot/expressions.py +637 -100
  32. sqlglot/generator.py +160 -43
  33. sqlglot/helper.py +2 -44
  34. sqlglot/lineage.py +10 -4
  35. sqlglot/optimizer/annotate_types.py +247 -140
  36. sqlglot/optimizer/canonicalize.py +6 -1
  37. sqlglot/optimizer/eliminate_joins.py +1 -1
  38. sqlglot/optimizer/eliminate_subqueries.py +2 -2
  39. sqlglot/optimizer/merge_subqueries.py +5 -5
  40. sqlglot/optimizer/normalize.py +20 -13
  41. sqlglot/optimizer/normalize_identifiers.py +17 -3
  42. sqlglot/optimizer/optimizer.py +4 -0
  43. sqlglot/optimizer/pushdown_predicates.py +1 -1
  44. sqlglot/optimizer/qualify.py +18 -10
  45. sqlglot/optimizer/qualify_columns.py +122 -275
  46. sqlglot/optimizer/qualify_tables.py +128 -76
  47. sqlglot/optimizer/resolver.py +374 -0
  48. sqlglot/optimizer/scope.py +27 -16
  49. sqlglot/optimizer/simplify.py +1075 -959
  50. sqlglot/optimizer/unnest_subqueries.py +12 -2
  51. sqlglot/parser.py +296 -170
  52. sqlglot/planner.py +2 -2
  53. sqlglot/schema.py +15 -4
  54. sqlglot/tokens.py +42 -7
  55. sqlglot/transforms.py +77 -22
  56. sqlglot/typing/__init__.py +316 -0
  57. sqlglot/typing/bigquery.py +376 -0
  58. sqlglot/typing/hive.py +12 -0
  59. sqlglot/typing/presto.py +24 -0
  60. sqlglot/typing/snowflake.py +505 -0
  61. sqlglot/typing/spark2.py +58 -0
  62. sqlglot/typing/tsql.py +9 -0
  63. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
  64. sqlglot-28.4.0.dist-info/RECORD +92 -0
  65. sqlglot-27.27.0.dist-info/RECORD +0 -84
  66. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
  67. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
  68. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,6 @@ from sqlglot.helper import (
17
17
  flatten,
18
18
  is_int,
19
19
  seq_get,
20
- subclasses,
21
20
  suggest_closest_match_and_fail,
22
21
  to_bool,
23
22
  )
@@ -26,6 +25,7 @@ from sqlglot.parser import Parser
26
25
  from sqlglot.time import TIMEZONES, format_time, subsecond_precision
27
26
  from sqlglot.tokens import Token, Tokenizer, TokenType
28
27
  from sqlglot.trie import new_trie
28
+ from sqlglot.typing import EXPRESSION_METADATA
29
29
 
30
30
  DATE_ADD_OR_DIFF = t.Union[
31
31
  exp.DateAdd,
@@ -44,17 +44,15 @@ DATETIME_DELTA = t.Union[
44
44
  exp.DatetimeSub,
45
45
  exp.TimeAdd,
46
46
  exp.TimeSub,
47
+ exp.TimestampAdd,
47
48
  exp.TimestampSub,
48
49
  exp.TsOrDsAdd,
49
50
  ]
51
+ DATETIME_ADD = (exp.DateAdd, exp.TimeAdd, exp.DatetimeAdd, exp.TsOrDsAdd, exp.TimestampAdd)
50
52
 
51
53
  if t.TYPE_CHECKING:
52
54
  from sqlglot._typing import B, E, F
53
55
 
54
- from sqlglot.optimizer.annotate_types import TypeAnnotator
55
-
56
- AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]
57
-
58
56
  logger = logging.getLogger("sqlglot")
59
57
 
60
58
  UNESCAPED_SEQUENCES = {
@@ -69,10 +67,6 @@ UNESCAPED_SEQUENCES = {
69
67
  }
70
68
 
71
69
 
72
- def annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
73
- return lambda self, e: self._annotate_with_type(e, data_type)
74
-
75
-
76
70
  class Dialects(str, Enum):
77
71
  """Dialects supported by SQLGLot."""
78
72
 
@@ -130,20 +124,6 @@ class NormalizationStrategy(str, AutoName):
130
124
  """Always case-insensitive (uppercase), regardless of quotes."""
131
125
 
132
126
 
133
- class Version(int):
134
- def __new__(cls, version_str: t.Optional[str], *args, **kwargs):
135
- if version_str:
136
- parts = version_str.split(".")
137
- parts.extend(["0"] * (3 - len(parts)))
138
- v = int("".join([p.zfill(3) for p in parts]))
139
- else:
140
- # No version defined means we should support the latest engine semantics, so
141
- # the comparison to any specific version should yield that latest is greater
142
- v = sys.maxsize
143
-
144
- return super(Version, cls).__new__(cls, v)
145
-
146
-
147
127
  class _Dialect(type):
148
128
  _classes: t.Dict[str, t.Type[Dialect]] = {}
149
129
 
@@ -205,7 +185,11 @@ class _Dialect(type):
205
185
  klass.FORMAT_TRIE = (
206
186
  new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
207
187
  )
208
- klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
188
+ # Merge class-defined INVERSE_TIME_MAPPING with auto-generated mappings
189
+ # This allows dialects to define custom inverse mappings for roundtrip correctness
190
+ klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} | (
191
+ klass.__dict__.get("INVERSE_TIME_MAPPING") or {}
192
+ )
209
193
  klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
210
194
  klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()}
211
195
  klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING)
@@ -261,6 +245,9 @@ class _Dialect(type):
261
245
 
262
246
  klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS
263
247
 
248
+ if enum not in ("", "bigquery", "snowflake"):
249
+ klass.INITCAP_SUPPORTS_CUSTOM_DELIMITERS = False
250
+
264
251
  if enum not in ("", "bigquery"):
265
252
  klass.generator_class.SELECT_KINDS = ()
266
253
 
@@ -292,6 +279,54 @@ class _Dialect(type):
292
279
  TokenType.SEMI,
293
280
  }
294
281
 
282
+ if enum not in (
283
+ "",
284
+ "postgres",
285
+ "duckdb",
286
+ "redshift",
287
+ "snowflake",
288
+ "presto",
289
+ "trino",
290
+ "mysql",
291
+ "singlestore",
292
+ ):
293
+ no_paren_functions = klass.parser_class.NO_PAREN_FUNCTIONS.copy()
294
+ no_paren_functions.pop(TokenType.LOCALTIME, None)
295
+ if enum != "oracle":
296
+ no_paren_functions.pop(TokenType.LOCALTIMESTAMP, None)
297
+ klass.parser_class.NO_PAREN_FUNCTIONS = no_paren_functions
298
+
299
+ if enum in (
300
+ "",
301
+ "postgres",
302
+ "duckdb",
303
+ "trino",
304
+ ):
305
+ no_paren_functions = klass.parser_class.NO_PAREN_FUNCTIONS.copy()
306
+ no_paren_functions[TokenType.CURRENT_CATALOG] = exp.CurrentCatalog
307
+ klass.parser_class.NO_PAREN_FUNCTIONS = no_paren_functions
308
+ else:
309
+ # For dialects that don't support this keyword, treat it as a regular identifier
310
+ # This fixes the "Unexpected token" error in BQ, Spark, etc.
311
+ klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | {
312
+ TokenType.CURRENT_CATALOG,
313
+ }
314
+
315
+ if enum in (
316
+ "",
317
+ "duckdb",
318
+ "spark",
319
+ "postgres",
320
+ "tsql",
321
+ ):
322
+ no_paren_functions = klass.parser_class.NO_PAREN_FUNCTIONS.copy()
323
+ no_paren_functions[TokenType.SESSION_USER] = exp.SessionUser
324
+ klass.parser_class.NO_PAREN_FUNCTIONS = no_paren_functions
325
+ else:
326
+ klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | {
327
+ TokenType.SESSION_USER,
328
+ }
329
+
295
330
  klass.VALID_INTERVAL_UNITS = {
296
331
  *klass.VALID_INTERVAL_UNITS,
297
332
  *klass.DATE_PART_MAPPING.keys(),
@@ -460,14 +495,139 @@ class Dialect(metaclass=_Dialect):
460
495
  to "WHERE id = 1 GROUP BY id HAVING id = 1"
461
496
  """
462
497
 
463
- EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False
498
+ EXPAND_ONLY_GROUP_ALIAS_REF = False
464
499
  """Whether alias reference expansion before qualification should only happen for the GROUP BY clause."""
465
500
 
501
+ ANNOTATE_ALL_SCOPES = False
502
+ """Whether to annotate all scopes during optimization. Used by BigQuery for UNNEST support."""
503
+
504
+ DISABLES_ALIAS_REF_EXPANSION = False
505
+ """
506
+ Whether alias reference expansion is disabled for this dialect.
507
+
508
+ Some dialects like Oracle do NOT support referencing aliases in projections or WHERE clauses.
509
+ The original expression must be repeated instead.
510
+
511
+ For example, in Oracle:
512
+ SELECT y.foo AS bar, bar * 2 AS baz FROM y -- INVALID
513
+ SELECT y.foo AS bar, y.foo * 2 AS baz FROM y -- VALID
514
+ """
515
+
516
+ SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS = False
517
+ """
518
+ Whether alias references are allowed in JOIN ... ON clauses.
519
+
520
+ Most dialects do not support this, but Snowflake allows alias expansion in the JOIN ... ON
521
+ clause (and almost everywhere else)
522
+
523
+ For example, in Snowflake:
524
+ SELECT a.id AS user_id FROM a JOIN b ON user_id = b.id -- VALID
525
+
526
+ Reference: https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
527
+ """
528
+
466
529
  SUPPORTS_ORDER_BY_ALL = False
467
530
  """
468
531
  Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
469
532
  """
470
533
 
534
+ PROJECTION_ALIASES_SHADOW_SOURCE_NAMES = False
535
+ """
536
+ Whether projection alias names can shadow table/source names in GROUP BY and HAVING clauses.
537
+
538
+ In BigQuery, when a projection alias has the same name as a source table, the alias takes
539
+ precedence in GROUP BY and HAVING clauses, and the table becomes inaccessible by that name.
540
+
541
+ For example, in BigQuery:
542
+ SELECT id, ARRAY_AGG(col) AS custom_fields
543
+ FROM custom_fields
544
+ GROUP BY id
545
+ HAVING id >= 1
546
+
547
+ The "custom_fields" source is shadowed by the projection alias, so we cannot qualify "id"
548
+ with "custom_fields" in GROUP BY/HAVING.
549
+ """
550
+
551
+ TABLES_REFERENCEABLE_AS_COLUMNS = False
552
+ """
553
+ Whether table names can be referenced as columns (treated as structs).
554
+
555
+ BigQuery allows tables to be referenced as columns in queries, automatically treating
556
+ them as struct values containing all the table's columns.
557
+
558
+ For example, in BigQuery:
559
+ SELECT t FROM my_table AS t -- Returns entire row as a struct
560
+ """
561
+
562
+ SUPPORTS_STRUCT_STAR_EXPANSION = False
563
+ """
564
+ Whether the dialect supports expanding struct fields using star notation (e.g., struct_col.*).
565
+
566
+ BigQuery allows struct fields to be expanded with the star operator:
567
+ SELECT t.struct_col.* FROM table t
568
+ RisingWave also allows struct field expansion with the star operator using parentheses:
569
+ SELECT (t.struct_col).* FROM table t
570
+
571
+ This expands to all fields within the struct.
572
+ """
573
+
574
+ EXCLUDES_PSEUDOCOLUMNS_FROM_STAR = False
575
+ """
576
+ Whether pseudocolumns should be excluded from star expansion (SELECT *).
577
+
578
+ Pseudocolumns are special dialect-specific columns (e.g., Oracle's ROWNUM, ROWID, LEVEL,
579
+ or BigQuery's _PARTITIONTIME, _PARTITIONDATE) that are implicitly available but not part
580
+ of the table schema. When this is True, SELECT * will not include these pseudocolumns;
581
+ they must be explicitly selected.
582
+ """
583
+
584
+ QUERY_RESULTS_ARE_STRUCTS = False
585
+ """
586
+ Whether query results are typed as structs in metadata for type inference.
587
+
588
+ In BigQuery, subqueries store their column types as a STRUCT in metadata,
589
+ enabling special type inference for ARRAY(SELECT ...) expressions:
590
+ ARRAY(SELECT x, y FROM t) → ARRAY<STRUCT<...>>
591
+
592
+ For single column subqueries, BigQuery unwraps the struct:
593
+ ARRAY(SELECT x FROM t) → ARRAY<type_of_x>
594
+
595
+ This is metadata-only for type inference.
596
+ """
597
+
598
+ REQUIRES_PARENTHESIZED_STRUCT_ACCESS = False
599
+ """
600
+ Whether struct field access requires parentheses around the expression.
601
+
602
+ RisingWave requires parentheses for struct field access in certain contexts:
603
+ SELECT (col.field).subfield FROM table -- Parentheses required
604
+
605
+ Without parentheses, the parser may not correctly interpret nested struct access.
606
+
607
+ Reference: https://docs.risingwave.com/sql/data-types/struct#retrieve-data-in-a-struct
608
+ """
609
+
610
+ SUPPORTS_NULL_TYPE = False
611
+ """
612
+ Whether NULL/VOID is supported as a valid data type (not just a value).
613
+
614
+ Databricks and Spark v3+ support NULL as an actual type, allowing expressions like:
615
+ SELECT NULL AS col -- Has type NULL, not just value NULL
616
+ CAST(x AS VOID) -- Valid type cast
617
+ """
618
+
619
+ COALESCE_COMPARISON_NON_STANDARD = False
620
+ """
621
+ Whether COALESCE in comparisons has non-standard NULL semantics.
622
+
623
+ We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift,
624
+ because they are not always equivalent. For example, if `x` is `NULL` and it comes
625
+ from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE`.
626
+
627
+ In standard SQL and most dialects, these expressions are equivalent, but Redshift treats
628
+ table NULLs differently in this context.
629
+ """
630
+
471
631
  HAS_DISTINCT_ARRAY_CONSTRUCTORS = False
472
632
  """
473
633
  Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3)
@@ -509,6 +669,9 @@ class Dialect(metaclass=_Dialect):
509
669
  REGEXP_EXTRACT_DEFAULT_GROUP = 0
510
670
  """The default value for the capturing group."""
511
671
 
672
+ REGEXP_EXTRACT_POSITION_OVERFLOW_RETURNS_NULL = True
673
+ """Whether REGEXP_EXTRACT returns NULL when the position arg exceeds the string length."""
674
+
512
675
  SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = {
513
676
  exp.Except: True,
514
677
  exp.Intersect: True,
@@ -539,6 +702,45 @@ class Dialect(metaclass=_Dialect):
539
702
  # STRING type (Snowflake's case) or can be of any type
540
703
  TRY_CAST_REQUIRES_STRING: t.Optional[bool] = None
541
704
 
705
+ # Whether the double negation can be applied
706
+ # Not safe with MySQL and SQLite due to type coercion (may not return boolean)
707
+ SAFE_TO_ELIMINATE_DOUBLE_NEGATION = True
708
+
709
+ # Whether the INITCAP function supports custom delimiter characters as the second argument
710
+ # Default delimiter characters for INITCAP function: whitespace and non-alphanumeric characters
711
+ INITCAP_SUPPORTS_CUSTOM_DELIMITERS = True
712
+ INITCAP_DEFAULT_DELIMITER_CHARS = " \t\n\r\f\v!\"#$%&'()*+,\\-./:;<=>?@\\[\\]^_`{|}~"
713
+
714
+ BYTE_STRING_IS_BYTES_TYPE: bool = False
715
+ """
716
+ Whether byte string literals (ex: BigQuery's b'...') are typed as BYTES/BINARY
717
+ """
718
+
719
+ UUID_IS_STRING_TYPE: bool = False
720
+ """
721
+ Whether a UUID is considered a string or a UUID type.
722
+ """
723
+
724
+ JSON_EXTRACT_SCALAR_SCALAR_ONLY = False
725
+ """
726
+ Whether JSON_EXTRACT_SCALAR returns null if a non-scalar value is selected.
727
+ """
728
+
729
+ DEFAULT_FUNCTIONS_COLUMN_NAMES: t.Dict[t.Type[exp.Func], t.Union[str, t.Tuple[str, ...]]] = {}
730
+ """
731
+ Maps function expressions to their default output column name(s).
732
+
733
+ For example, in Postgres, generate_series function outputs a column named "generate_series" by default,
734
+ so we map the ExplodingGenerateSeries expression to "generate_series" string.
735
+ """
736
+
737
+ DEFAULT_NULL_TYPE = exp.DataType.Type.UNKNOWN
738
+ """
739
+ The default type of NULL for producing the correct projection type.
740
+
741
+ For example, in BigQuery the default type of the NULL value is INT64.
742
+ """
743
+
542
744
  # --- Autofilled ---
543
745
 
544
746
  tokenizer_class = Tokenizer
@@ -600,6 +802,7 @@ class Dialect(metaclass=_Dialect):
600
802
  "WEEKDAY_ISO": "DAYOFWEEKISO",
601
803
  "DOW_ISO": "DAYOFWEEKISO",
602
804
  "DW_ISO": "DAYOFWEEKISO",
805
+ "DAYOFWEEK_ISO": "DAYOFWEEKISO",
603
806
  "DAY OF YEAR": "DAYOFYEAR",
604
807
  "DOY": "DAYOFYEAR",
605
808
  "DY": "DAYOFYEAR",
@@ -662,232 +865,21 @@ class Dialect(metaclass=_Dialect):
662
865
  "DEC": "DECADE",
663
866
  "DECS": "DECADE",
664
867
  "DECADES": "DECADE",
665
- "MIL": "MILLENIUM",
666
- "MILS": "MILLENIUM",
667
- "MILLENIA": "MILLENIUM",
868
+ "MIL": "MILLENNIUM",
869
+ "MILS": "MILLENNIUM",
870
+ "MILLENIA": "MILLENNIUM",
668
871
  "C": "CENTURY",
669
872
  "CENT": "CENTURY",
670
873
  "CENTS": "CENTURY",
671
874
  "CENTURIES": "CENTURY",
672
875
  }
673
876
 
674
- TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
675
- exp.DataType.Type.BIGINT: {
676
- exp.ApproxDistinct,
677
- exp.ArraySize,
678
- exp.CountIf,
679
- exp.Int64,
680
- exp.Length,
681
- exp.UnixDate,
682
- exp.UnixSeconds,
683
- exp.UnixMicros,
684
- exp.UnixMillis,
685
- },
686
- exp.DataType.Type.BINARY: {
687
- exp.FromBase32,
688
- exp.FromBase64,
689
- },
690
- exp.DataType.Type.BOOLEAN: {
691
- exp.Between,
692
- exp.Boolean,
693
- exp.Contains,
694
- exp.EndsWith,
695
- exp.In,
696
- exp.LogicalAnd,
697
- exp.LogicalOr,
698
- exp.RegexpLike,
699
- exp.StartsWith,
700
- },
701
- exp.DataType.Type.DATE: {
702
- exp.CurrentDate,
703
- exp.Date,
704
- exp.DateFromParts,
705
- exp.DateStrToDate,
706
- exp.DiToDate,
707
- exp.LastDay,
708
- exp.StrToDate,
709
- exp.TimeStrToDate,
710
- exp.TsOrDsToDate,
711
- },
712
- exp.DataType.Type.DATETIME: {
713
- exp.CurrentDatetime,
714
- exp.Datetime,
715
- exp.DatetimeAdd,
716
- exp.DatetimeSub,
717
- },
718
- exp.DataType.Type.DOUBLE: {
719
- exp.ApproxQuantile,
720
- exp.Avg,
721
- exp.Exp,
722
- exp.Ln,
723
- exp.Log,
724
- exp.Pi,
725
- exp.Pow,
726
- exp.Quantile,
727
- exp.Radians,
728
- exp.Round,
729
- exp.SafeDivide,
730
- exp.Sqrt,
731
- exp.Stddev,
732
- exp.StddevPop,
733
- exp.StddevSamp,
734
- exp.ToDouble,
735
- exp.Variance,
736
- exp.VariancePop,
737
- },
738
- exp.DataType.Type.INT: {
739
- exp.Ascii,
740
- exp.Ceil,
741
- exp.DatetimeDiff,
742
- exp.DateDiff,
743
- exp.TimestampDiff,
744
- exp.TimeDiff,
745
- exp.Unicode,
746
- exp.DateToDi,
747
- exp.Levenshtein,
748
- exp.Sign,
749
- exp.StrPosition,
750
- exp.TsOrDiToDi,
751
- },
752
- exp.DataType.Type.INTERVAL: {
753
- exp.Interval,
754
- exp.JustifyDays,
755
- exp.JustifyHours,
756
- exp.JustifyInterval,
757
- exp.MakeInterval,
758
- },
759
- exp.DataType.Type.JSON: {
760
- exp.ParseJSON,
761
- },
762
- exp.DataType.Type.TIME: {
763
- exp.CurrentTime,
764
- exp.Time,
765
- exp.TimeAdd,
766
- exp.TimeSub,
767
- },
768
- exp.DataType.Type.TIMESTAMPTZ: {
769
- exp.CurrentTimestampLTZ,
770
- },
771
- exp.DataType.Type.TIMESTAMP: {
772
- exp.CurrentTimestamp,
773
- exp.StrToTime,
774
- exp.TimeStrToTime,
775
- exp.TimestampAdd,
776
- exp.TimestampSub,
777
- exp.UnixToTime,
778
- },
779
- exp.DataType.Type.TINYINT: {
780
- exp.Day,
781
- exp.Month,
782
- exp.Week,
783
- exp.Year,
784
- exp.Quarter,
785
- },
786
- exp.DataType.Type.VARCHAR: {
787
- exp.ArrayConcat,
788
- exp.ArrayToString,
789
- exp.Concat,
790
- exp.ConcatWs,
791
- exp.Chr,
792
- exp.DateToDateStr,
793
- exp.DPipe,
794
- exp.GroupConcat,
795
- exp.Initcap,
796
- exp.Lower,
797
- exp.Substring,
798
- exp.String,
799
- exp.TimeToStr,
800
- exp.TimeToTimeStr,
801
- exp.Trim,
802
- exp.ToBase32,
803
- exp.ToBase64,
804
- exp.TsOrDsToDateStr,
805
- exp.UnixToStr,
806
- exp.UnixToTimeStr,
807
- exp.Upper,
808
- },
809
- }
810
-
811
- ANNOTATORS: AnnotatorsType = {
812
- **{
813
- expr_type: lambda self, e: self._annotate_unary(e)
814
- for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
815
- },
816
- **{
817
- expr_type: lambda self, e: self._annotate_binary(e)
818
- for expr_type in subclasses(exp.__name__, exp.Binary)
819
- },
820
- **{
821
- expr_type: annotate_with_type_lambda(data_type)
822
- for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
823
- for expr_type in expressions
824
- },
825
- exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
826
- exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
827
- exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
828
- exp.AnyValue: lambda self, e: self._annotate_by_args(e, "this"),
829
- exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
830
- exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
831
- exp.ArrayConcatAgg: lambda self, e: self._annotate_by_args(e, "this"),
832
- exp.ArrayFirst: lambda self, e: self._annotate_by_array_element(e),
833
- exp.ArrayLast: lambda self, e: self._annotate_by_array_element(e),
834
- exp.ArrayReverse: lambda self, e: self._annotate_by_args(e, "this"),
835
- exp.ArraySlice: lambda self, e: self._annotate_by_args(e, "this"),
836
- exp.Bracket: lambda self, e: self._annotate_bracket(e),
837
- exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
838
- exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
839
- exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
840
- exp.Count: lambda self, e: self._annotate_with_type(
841
- e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT
842
- ),
843
- exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
844
- exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
845
- exp.DateSub: lambda self, e: self._annotate_timeunit(e),
846
- exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
847
- exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
848
- exp.Div: lambda self, e: self._annotate_div(e),
849
- exp.Dot: lambda self, e: self._annotate_dot(e),
850
- exp.Explode: lambda self, e: self._annotate_explode(e),
851
- exp.Extract: lambda self, e: self._annotate_extract(e),
852
- exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
853
- exp.GenerateSeries: lambda self, e: self._annotate_by_args(
854
- e, "start", "end", "step", array=True
855
- ),
856
- exp.GenerateDateArray: lambda self, e: self._annotate_with_type(
857
- e, exp.DataType.build("ARRAY<DATE>")
858
- ),
859
- exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type(
860
- e, exp.DataType.build("ARRAY<TIMESTAMP>")
861
- ),
862
- exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
863
- exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
864
- exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
865
- exp.Literal: lambda self, e: self._annotate_literal(e),
866
- exp.LastValue: lambda self, e: self._annotate_by_args(e, "this"),
867
- exp.Map: lambda self, e: self._annotate_map(e),
868
- exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
869
- exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
870
- exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
871
- exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
872
- exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
873
- exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
874
- exp.Struct: lambda self, e: self._annotate_struct(e),
875
- exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
876
- exp.SortArray: lambda self, e: self._annotate_by_args(e, "this"),
877
- exp.Timestamp: lambda self, e: self._annotate_with_type(
878
- e,
879
- exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
880
- ),
881
- exp.ToMap: lambda self, e: self._annotate_to_map(e),
882
- exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
883
- exp.Unnest: lambda self, e: self._annotate_unnest(e),
884
- exp.VarMap: lambda self, e: self._annotate_map(e),
885
- exp.Window: lambda self, e: self._annotate_by_args(e, "this"),
886
- }
887
-
888
877
  # Specifies what types a given type can be coerced into
889
878
  COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
890
879
 
880
+ # Specifies type inference & validation rules for expressions
881
+ EXPRESSION_METADATA = EXPRESSION_METADATA.copy()
882
+
891
883
  # Determines the supported Dialect instance settings
892
884
  SUPPORTED_SETTINGS = {
893
885
  "normalization_strategy",
@@ -967,7 +959,9 @@ class Dialect(metaclass=_Dialect):
967
959
  return expression
968
960
 
969
961
  def __init__(self, **kwargs) -> None:
970
- self.version = Version(kwargs.pop("version", None))
962
+ parts = str(kwargs.pop("version", sys.maxsize)).split(".")
963
+ parts.extend(["0"] * (3 - len(parts)))
964
+ self.version = tuple(int(p) for p in parts[:3])
971
965
 
972
966
  normalization_strategy = kwargs.pop("normalization_strategy", None)
973
967
  if normalization_strategy is None:
@@ -1044,42 +1038,50 @@ class Dialect(metaclass=_Dialect):
1044
1038
  )
1045
1039
  return any(unsafe(char) for char in text)
1046
1040
 
1047
- def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
1048
- """Checks if text can be identified given an identify option.
1041
+ def can_quote(self, identifier: exp.Identifier, identify: str | bool = "safe") -> bool:
1042
+ """Checks if an identifier can be quoted
1049
1043
 
1050
1044
  Args:
1051
- text: The text to check.
1045
+ identifier: The identifier to check.
1052
1046
  identify:
1053
- `"always"` or `True`: Always returns `True`.
1047
+ `True`: Always returns `True` except for certain cases.
1054
1048
  `"safe"`: Only returns `True` if the identifier is case-insensitive.
1049
+ `"unsafe"`: Only returns `True` if the identifier is case-sensitive.
1055
1050
 
1056
1051
  Returns:
1057
1052
  Whether the given text can be identified.
1058
1053
  """
1059
- if identify is True or identify == "always":
1054
+ if identifier.quoted:
1055
+ return True
1056
+ if not identify:
1057
+ return False
1058
+ if isinstance(identifier.parent, exp.Func):
1059
+ return False
1060
+ if identify is True:
1060
1061
  return True
1061
1062
 
1063
+ is_safe = not self.case_sensitive(identifier.this) and bool(
1064
+ exp.SAFE_IDENTIFIER_RE.match(identifier.this)
1065
+ )
1066
+
1062
1067
  if identify == "safe":
1063
- return not self.case_sensitive(text)
1068
+ return is_safe
1069
+ if identify == "unsafe":
1070
+ return not is_safe
1064
1071
 
1065
- return False
1072
+ raise ValueError(f"Unexpected argument for identify: '{identify}'")
1066
1073
 
1067
1074
  def quote_identifier(self, expression: E, identify: bool = True) -> E:
1068
1075
  """
1069
- Adds quotes to a given identifier.
1076
+ Adds quotes to a given expression if it is an identifier.
1070
1077
 
1071
1078
  Args:
1072
1079
  expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
1073
1080
  identify: If set to `False`, the quotes will only be added if the identifier is deemed
1074
1081
  "unsafe", with respect to its characters and this dialect's normalization strategy.
1075
1082
  """
1076
- if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
1077
- name = expression.this
1078
- expression.set(
1079
- "quoted",
1080
- identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
1081
- )
1082
-
1083
+ if isinstance(expression, exp.Identifier):
1084
+ expression.set("quoted", self.can_quote(expression, identify or "unsafe"))
1083
1085
  return expression
1084
1086
 
1085
1087
  def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
@@ -1170,11 +1172,11 @@ def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> st
1170
1172
  return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1171
1173
 
1172
1174
 
1173
- def inline_array_sql(self: Generator, expression: exp.Array) -> str:
1175
+ def inline_array_sql(self: Generator, expression: exp.Expression) -> str:
1174
1176
  return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
1175
1177
 
1176
1178
 
1177
- def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
1179
+ def inline_array_unless_query(self: Generator, expression: exp.Expression) -> str:
1178
1180
  elem = seq_get(expression.expressions, 0)
1179
1181
  if isinstance(elem, exp.Expression) and elem.find(exp.Query):
1180
1182
  return self.func("ARRAY", elem)
@@ -1397,12 +1399,14 @@ def date_add_interval_sql(
1397
1399
  return func
1398
1400
 
1399
1401
 
1400
- def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
1402
+ def timestamptrunc_sql(
1403
+ func: str = "DATE_TRUNC", zone: bool = False
1404
+ ) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
1401
1405
  def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
1402
1406
  args = [unit_to_str(expression), expression.this]
1403
1407
  if zone:
1404
1408
  args.append(expression.args.get("zone"))
1405
- return self.func("DATE_TRUNC", *args)
1409
+ return self.func(func, *args)
1406
1410
 
1407
1411
  return _timestamptrunc_sql
1408
1412
 
@@ -1682,11 +1686,7 @@ def date_delta_to_binary_interval_op(
1682
1686
  def date_delta_to_binary_interval_op_sql(self: Generator, expression: DATETIME_DELTA) -> str:
1683
1687
  this = expression.this
1684
1688
  unit = unit_to_var(expression)
1685
- op = (
1686
- "+"
1687
- if isinstance(expression, (exp.DateAdd, exp.TimeAdd, exp.DatetimeAdd, exp.TsOrDsAdd))
1688
- else "-"
1689
- )
1689
+ op = "+" if isinstance(expression, DATETIME_ADD) else "-"
1690
1690
 
1691
1691
  to_type: t.Optional[exp.DATA_TYPE] = None
1692
1692
  if cast:
@@ -1944,6 +1944,10 @@ def sha256_sql(self: Generator, expression: exp.SHA2) -> str:
1944
1944
  return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
1945
1945
 
1946
1946
 
1947
+ def sha2_digest_sql(self: Generator, expression: exp.SHA2Digest) -> str:
1948
+ return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
1949
+
1950
+
1947
1951
  def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str:
1948
1952
  start = expression.args.get("start")
1949
1953
  end = expression.args.get("end")
@@ -1956,22 +1960,76 @@ def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateD
1956
1960
  else:
1957
1961
  target_type = None
1958
1962
 
1959
- if start and end and target_type and target_type.is_type("date", "timestamp"):
1960
- if isinstance(start, exp.Cast) and target_type is start.to:
1961
- end = exp.cast(end, target_type)
1962
- else:
1963
- start = exp.cast(start, target_type)
1963
+ if start and end:
1964
+ if target_type and target_type.is_type("date", "timestamp"):
1965
+ if isinstance(start, exp.Cast) and target_type is start.to:
1966
+ end = exp.cast(end, target_type)
1967
+ else:
1968
+ start = exp.cast(start, target_type)
1969
+
1970
+ if expression.args.get("is_end_exclusive"):
1971
+ step_value = step or exp.Literal.number(1)
1972
+ end = exp.paren(exp.Sub(this=end, expression=step_value), copy=False)
1973
+
1974
+ sequence_call = exp.Anonymous(
1975
+ this="SEQUENCE", expressions=[e for e in (start, end, step) if e]
1976
+ )
1977
+ zero = exp.Literal.number(0)
1978
+ should_return_empty = exp.or_(
1979
+ exp.EQ(this=step_value.copy(), expression=zero.copy()),
1980
+ exp.and_(
1981
+ exp.GT(this=step_value.copy(), expression=zero.copy()),
1982
+ exp.GTE(this=start.copy(), expression=end.copy()),
1983
+ ),
1984
+ exp.and_(
1985
+ exp.LT(this=step_value.copy(), expression=zero.copy()),
1986
+ exp.LTE(this=start.copy(), expression=end.copy()),
1987
+ ),
1988
+ )
1989
+ empty_array_or_sequence = exp.If(
1990
+ this=should_return_empty,
1991
+ true=exp.Array(expressions=[]),
1992
+ false=sequence_call,
1993
+ )
1994
+ return self.sql(self._simplify_unless_literal(empty_array_or_sequence))
1964
1995
 
1965
1996
  return self.func("SEQUENCE", start, end, step)
1966
1997
 
1967
1998
 
1999
+ def build_like(
2000
+ expr_type: t.Type[E], not_like: bool = False
2001
+ ) -> t.Callable[[t.List], exp.Expression]:
2002
+ def _builder(args: t.List) -> exp.Expression:
2003
+ like_expr: exp.Expression = expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
2004
+
2005
+ if escape := seq_get(args, 2):
2006
+ like_expr = exp.Escape(this=like_expr, expression=escape)
2007
+
2008
+ if not_like:
2009
+ like_expr = exp.Not(this=like_expr)
2010
+
2011
+ return like_expr
2012
+
2013
+ return _builder
2014
+
2015
+
1968
2016
  def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]:
1969
2017
  def _builder(args: t.List, dialect: Dialect) -> E:
2018
+ # The "position" argument specifies the index of the string character to start matching from.
2019
+ # `null_if_pos_overflow` reflects the dialect's behavior when position is greater than the string
2020
+ # length. If true, returns NULL. If false, returns an empty string. `null_if_pos_overflow` is
2021
+ # only needed for exp.RegexpExtract - exp.RegexpExtractAll always returns an empty array if
2022
+ # position overflows.
1970
2023
  return expr_type(
1971
2024
  this=seq_get(args, 0),
1972
2025
  expression=seq_get(args, 1),
1973
2026
  group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP),
1974
2027
  parameters=seq_get(args, 3),
2028
+ **(
2029
+ {"null_if_pos_overflow": dialect.REGEXP_EXTRACT_POSITION_OVERFLOW_RETURNS_NULL}
2030
+ if expr_type is exp.RegexpExtract
2031
+ else {}
2032
+ ),
1975
2033
  )
1976
2034
 
1977
2035
  return _builder
@@ -2016,12 +2074,14 @@ def groupconcat_sql(
2016
2074
  self: Generator,
2017
2075
  expression: exp.GroupConcat,
2018
2076
  func_name="LISTAGG",
2019
- sep: str = ",",
2077
+ sep: t.Optional[str] = ",",
2020
2078
  within_group: bool = True,
2021
2079
  on_overflow: bool = False,
2022
2080
  ) -> str:
2023
2081
  this = expression.this
2024
- separator = self.sql(expression.args.get("separator") or exp.Literal.string(sep))
2082
+ separator = self.sql(
2083
+ expression.args.get("separator") or (exp.Literal.string(sep) if sep else None)
2084
+ )
2025
2085
 
2026
2086
  on_overflow_sql = self.sql(expression, "on_overflow")
2027
2087
  on_overflow_sql = f" ON OVERFLOW {on_overflow_sql}" if (on_overflow and on_overflow_sql) else ""
@@ -2037,7 +2097,10 @@ def groupconcat_sql(
2037
2097
  if order and order.this:
2038
2098
  this = order.this.pop()
2039
2099
 
2040
- args = self.format_args(this, f"{separator}{on_overflow_sql}")
2100
+ args = self.format_args(
2101
+ this, f"{separator}{on_overflow_sql}" if separator or on_overflow_sql else None
2102
+ )
2103
+
2041
2104
  listagg: exp.Expression = exp.Anonymous(this=func_name, expressions=[args])
2042
2105
 
2043
2106
  modifiers = self.sql(limit)
@@ -2075,3 +2138,19 @@ def build_replace_with_optional_replacement(args: t.List) -> exp.Replace:
2075
2138
  expression=seq_get(args, 1),
2076
2139
  replacement=seq_get(args, 2) or exp.Literal.string(""),
2077
2140
  )
2141
+
2142
+
2143
+ def regexp_replace_global_modifier(expression: exp.RegexpReplace) -> exp.Expression | None:
2144
+ modifiers = expression.args.get("modifiers")
2145
+ single_replace = expression.args.get("single_replace")
2146
+ occurrence = expression.args.get("occurrence")
2147
+
2148
+ if not single_replace and (not occurrence or (occurrence.is_int and occurrence.to_py() == 0)):
2149
+ if not modifiers or modifiers.is_string:
2150
+ # Append 'g' to the modifiers if they are not provided since
2151
+ # the semantics of REGEXP_REPLACE from the input dialect
2152
+ # is to replace all occurrences of the pattern.
2153
+ value = "" if not modifiers else modifiers.name
2154
+ modifiers = exp.Literal.string(value + "g")
2155
+
2156
+ return modifiers