sqlglot 27.10.0__py3-none-any.whl → 27.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlglot/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '27.10.0'
32
- __version_tuple__ = version_tuple = (27, 10, 0)
31
+ __version__ = version = '27.12.0'
32
+ __version_tuple__ = version_tuple = (27, 12, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -4,6 +4,9 @@ import logging
4
4
  import re
5
5
  import typing as t
6
6
 
7
+
8
+ from sqlglot.optimizer.annotate_types import TypeAnnotator
9
+
7
10
  from sqlglot import exp, generator, jsonpath, parser, tokens, transforms
8
11
  from sqlglot._typing import E
9
12
  from sqlglot.dialects.dialect import (
@@ -172,6 +175,18 @@ def _build_to_hex(args: t.List) -> exp.Hex | exp.MD5:
172
175
  return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.LowerHex(this=arg)
173
176
 
174
177
 
178
+ def _build_json_strip_nulls(args: t.List) -> exp.JSONStripNulls:
179
+ expression = exp.JSONStripNulls(this=seq_get(args, 0))
180
+
181
+ for arg in args[1:]:
182
+ if isinstance(arg, exp.Kwarg):
183
+ expression.set(arg.this.name.lower(), arg)
184
+ else:
185
+ expression.set("expression", arg)
186
+
187
+ return expression
188
+
189
+
175
190
  def _array_contains_sql(self: BigQuery.Generator, expression: exp.ArrayContains) -> str:
176
191
  return self.sql(
177
192
  exp.Exists(
@@ -295,6 +310,23 @@ def _annotate_math_functions(self: TypeAnnotator, expression: E) -> E:
295
310
  return expression
296
311
 
297
312
 
313
+ def _annotate_by_args_with_coerce(self: TypeAnnotator, expression: E) -> E:
314
+ """
315
+ +------------+------------+------------+-------------+---------+
316
+ | INPUT | INT64 | NUMERIC | BIGNUMERIC | FLOAT64 |
317
+ +------------+------------+------------+-------------+---------+
318
+ | INT64 | INT64 | NUMERIC | BIGNUMERIC | FLOAT64 |
319
+ | NUMERIC | NUMERIC | NUMERIC | BIGNUMERIC | FLOAT64 |
320
+ | BIGNUMERIC | BIGNUMERIC | BIGNUMERIC | BIGNUMERIC | FLOAT64 |
321
+ | FLOAT64 | FLOAT64 | FLOAT64 | FLOAT64 | FLOAT64 |
322
+ +------------+------------+------------+-------------+---------+
323
+ """
324
+ self._annotate_args(expression)
325
+
326
+ self._set_type(expression, self._maybe_coerce(expression.this.type, expression.expression.type))
327
+ return expression
328
+
329
+
298
330
  def _annotate_by_args_approx_top(self: TypeAnnotator, expression: exp.ApproxTopK) -> exp.ApproxTopK:
299
331
  self._annotate_args(expression)
300
332
 
@@ -453,6 +485,14 @@ class BigQuery(Dialect):
453
485
  # All set operations require either a DISTINCT or ALL specifier
454
486
  SET_OP_DISTINCT_BY_DEFAULT = dict.fromkeys((exp.Except, exp.Intersect, exp.Union), None)
455
487
 
488
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/navigation_functions#percentile_cont
489
+ COERCES_TO = {
490
+ **TypeAnnotator.COERCES_TO,
491
+ exp.DataType.Type.BIGDECIMAL: {exp.DataType.Type.DOUBLE},
492
+ }
493
+ COERCES_TO[exp.DataType.Type.DECIMAL] |= {exp.DataType.Type.BIGDECIMAL}
494
+ COERCES_TO[exp.DataType.Type.BIGINT] |= {exp.DataType.Type.BIGDECIMAL}
495
+
456
496
  # BigQuery maps Type.TIMESTAMP to DATETIME, so we need to amend the inferred types
457
497
  TYPE_TO_EXPRESSIONS = {
458
498
  **Dialect.TYPE_TO_EXPRESSIONS,
@@ -474,23 +514,47 @@ class BigQuery(Dialect):
474
514
  **{
475
515
  expr_type: lambda self, e: self._annotate_by_args(e, "this")
476
516
  for expr_type in (
517
+ exp.Abs,
518
+ exp.ArgMax,
519
+ exp.ArgMin,
520
+ exp.DateTrunc,
521
+ exp.DatetimeTrunc,
522
+ exp.FirstValue,
523
+ exp.GroupConcat,
524
+ exp.IgnoreNulls,
525
+ exp.JSONExtract,
526
+ exp.Lead,
477
527
  exp.Left,
478
- exp.Right,
479
528
  exp.Lower,
480
- exp.Upper,
529
+ exp.NthValue,
481
530
  exp.Pad,
482
- exp.Trim,
531
+ exp.PercentileDisc,
483
532
  exp.RegexpExtract,
484
533
  exp.RegexpReplace,
485
534
  exp.Repeat,
535
+ exp.Replace,
536
+ exp.RespectNulls,
537
+ exp.Reverse,
538
+ exp.Right,
539
+ exp.SafeNegate,
540
+ exp.Sign,
486
541
  exp.Substring,
542
+ exp.TimestampTrunc,
543
+ exp.Translate,
544
+ exp.Trim,
545
+ exp.Upper,
487
546
  )
488
547
  },
548
+ exp.Acos: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
549
+ exp.Acosh: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
550
+ exp.Asin: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
551
+ exp.Asinh: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
552
+ exp.Atan: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
553
+ exp.Atanh: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
554
+ exp.Atan2: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
489
555
  exp.ApproxTopSum: lambda self, e: _annotate_by_args_approx_top(self, e),
490
556
  exp.ApproxTopK: lambda self, e: _annotate_by_args_approx_top(self, e),
491
557
  exp.ApproxQuantiles: lambda self, e: self._annotate_by_args(e, "this", array=True),
492
- exp.ArgMax: lambda self, e: self._annotate_by_args(e, "this"),
493
- exp.ArgMin: lambda self, e: self._annotate_by_args(e, "this"),
494
558
  exp.Array: _annotate_array,
495
559
  exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
496
560
  exp.Ascii: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BIGINT),
@@ -500,6 +564,7 @@ class BigQuery(Dialect):
500
564
  exp.BitwiseCountAgg: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BIGINT),
501
565
  exp.ByteLength: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BIGINT),
502
566
  exp.ByteString: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BINARY),
567
+ exp.Cbrt: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
503
568
  exp.CodePointsToBytes: lambda self, e: self._annotate_with_type(
504
569
  e, exp.DataType.Type.BINARY
505
570
  ),
@@ -509,59 +574,99 @@ class BigQuery(Dialect):
509
574
  exp.Concat: _annotate_concat,
510
575
  exp.Contains: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BOOLEAN),
511
576
  exp.Corr: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
577
+ exp.Cot: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
578
+ exp.CosineDistance: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
579
+ exp.Coth: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
512
580
  exp.CovarPop: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
513
581
  exp.CovarSamp: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
582
+ exp.Csc: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
583
+ exp.Csch: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
584
+ exp.CumeDist: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
514
585
  exp.DateFromUnixDate: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DATE),
515
- exp.DateTrunc: lambda self, e: self._annotate_by_args(e, "this"),
586
+ exp.DenseRank: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BIGINT),
587
+ exp.EuclideanDistance: lambda self, e: self._annotate_with_type(
588
+ e, exp.DataType.Type.DOUBLE
589
+ ),
516
590
  exp.FarmFingerprint: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BIGINT),
517
591
  exp.Unhex: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BINARY),
518
592
  exp.Float64: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
593
+ exp.Format: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.VARCHAR),
519
594
  exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type(
520
595
  e, exp.DataType.build("ARRAY<TIMESTAMP>", dialect="bigquery")
521
596
  ),
522
597
  exp.Grouping: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BIGINT),
598
+ exp.IsInf: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BOOLEAN),
599
+ exp.IsNan: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BOOLEAN),
523
600
  exp.JSONArray: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.JSON),
601
+ exp.JSONArrayAppend: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.JSON),
602
+ exp.JSONArrayInsert: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.JSON),
524
603
  exp.JSONBool: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BOOLEAN),
525
604
  exp.JSONExtractScalar: lambda self, e: self._annotate_with_type(
526
605
  e, exp.DataType.Type.VARCHAR
527
606
  ),
528
- exp.JSONValueArray: lambda self, e: self._annotate_with_type(
529
- e, exp.DataType.build("ARRAY<VARCHAR>")
607
+ exp.JSONExtractArray: lambda self, e: self._annotate_by_args(e, "this", array=True),
608
+ exp.JSONFormat: lambda self, e: self._annotate_with_type(
609
+ e, exp.DataType.Type.JSON if e.args.get("to_json") else exp.DataType.Type.VARCHAR
610
+ ),
611
+ exp.JSONKeysAtDepth: lambda self, e: self._annotate_with_type(
612
+ e, exp.DataType.build("ARRAY<VARCHAR>", dialect="bigquery")
530
613
  ),
614
+ exp.JSONObject: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.JSON),
615
+ exp.JSONRemove: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.JSON),
616
+ exp.JSONSet: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.JSON),
617
+ exp.JSONStripNulls: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.JSON),
531
618
  exp.JSONType: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.VARCHAR),
619
+ exp.JSONValueArray: lambda self, e: self._annotate_with_type(
620
+ e, exp.DataType.build("ARRAY<VARCHAR>", dialect="bigquery")
621
+ ),
532
622
  exp.Lag: lambda self, e: self._annotate_by_args(e, "this", "default"),
533
623
  exp.LowerHex: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.VARCHAR),
624
+ exp.LaxBool: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BOOLEAN),
625
+ exp.LaxFloat64: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
626
+ exp.LaxInt64: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BIGINT),
627
+ exp.LaxString: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.VARCHAR),
534
628
  exp.MD5Digest: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BINARY),
535
629
  exp.Normalize: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.VARCHAR),
630
+ exp.Ntile: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BIGINT),
536
631
  exp.ParseTime: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.TIME),
537
632
  exp.ParseDatetime: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DATETIME),
538
633
  exp.ParseBignumeric: lambda self, e: self._annotate_with_type(
539
634
  e, exp.DataType.Type.BIGDECIMAL
540
635
  ),
541
636
  exp.ParseNumeric: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DECIMAL),
637
+ exp.PercentileCont: lambda self, e: _annotate_by_args_with_coerce(self, e),
638
+ exp.PercentRank: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
639
+ exp.Rank: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BIGINT),
640
+ exp.RangeBucket: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BIGINT),
542
641
  exp.RegexpExtractAll: lambda self, e: self._annotate_by_args(e, "this", array=True),
543
- exp.Replace: lambda self, e: self._annotate_by_args(e, "this"),
544
- exp.Reverse: lambda self, e: self._annotate_by_args(e, "this"),
642
+ exp.RegexpInstr: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BIGINT),
643
+ exp.RowNumber: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BIGINT),
644
+ exp.Rand: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
545
645
  exp.SafeConvertBytesToString: lambda self, e: self._annotate_with_type(
546
646
  e, exp.DataType.Type.VARCHAR
547
647
  ),
648
+ exp.SafeAdd: lambda self, e: _annotate_by_args_with_coerce(self, e),
649
+ exp.SafeMultiply: lambda self, e: _annotate_by_args_with_coerce(self, e),
650
+ exp.SafeSubtract: lambda self, e: _annotate_by_args_with_coerce(self, e),
651
+ exp.Sec: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
652
+ exp.Sech: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
548
653
  exp.Soundex: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.VARCHAR),
549
654
  exp.SHA: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BINARY),
550
655
  exp.SHA2: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BINARY),
551
- exp.Sign: lambda self, e: self._annotate_by_args(e, "this"),
656
+ exp.Sin: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
657
+ exp.Sinh: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.DOUBLE),
552
658
  exp.Split: lambda self, e: self._annotate_by_args(e, "this", array=True),
553
659
  exp.TimestampFromParts: lambda self, e: self._annotate_with_type(
554
660
  e, exp.DataType.Type.DATETIME
555
661
  ),
556
- exp.TimestampTrunc: lambda self, e: self._annotate_by_args(e, "this"),
557
662
  exp.TimeFromParts: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.TIME),
558
663
  exp.TimeTrunc: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.TIME),
559
664
  exp.ToCodePoints: lambda self, e: self._annotate_with_type(
560
665
  e, exp.DataType.build("ARRAY<BIGINT>", dialect="bigquery")
561
666
  ),
562
667
  exp.TsOrDsToTime: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.TIME),
563
- exp.Translate: lambda self, e: self._annotate_by_args(e, "this"),
564
668
  exp.Unicode: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.BIGINT),
669
+ exp.Uuid: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.VARCHAR),
565
670
  }
566
671
 
567
672
  def normalize_identifier(self, expression: E) -> E:
@@ -682,8 +787,11 @@ class BigQuery(Dialect):
682
787
  "GENERATE_ARRAY": exp.GenerateSeries.from_arg_list,
683
788
  "JSON_EXTRACT_SCALAR": _build_extract_json_with_default_path(exp.JSONExtractScalar),
684
789
  "JSON_EXTRACT_ARRAY": _build_extract_json_with_default_path(exp.JSONExtractArray),
790
+ "JSON_EXTRACT_STRING_ARRAY": _build_extract_json_with_default_path(exp.JSONValueArray),
791
+ "JSON_KEYS": exp.JSONKeysAtDepth.from_arg_list,
685
792
  "JSON_QUERY": parser.build_extract_json_with_path(exp.JSONExtract),
686
793
  "JSON_QUERY_ARRAY": _build_extract_json_with_default_path(exp.JSONExtractArray),
794
+ "JSON_STRIP_NULLS": _build_json_strip_nulls,
687
795
  "JSON_VALUE": _build_extract_json_with_default_path(exp.JSONExtractScalar),
688
796
  "JSON_VALUE_ARRAY": _build_extract_json_with_default_path(exp.JSONValueArray),
689
797
  "LENGTH": lambda args: exp.Length(this=seq_get(args, 0), binary=True),
@@ -730,6 +838,9 @@ class BigQuery(Dialect):
730
838
  this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
731
839
  ),
732
840
  "TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(this=seq_get(args, 0)),
841
+ "TO_JSON": lambda args: exp.JSONFormat(
842
+ this=seq_get(args, 0), options=seq_get(args, 1), to_json=True
843
+ ),
733
844
  "TO_JSON_STRING": exp.JSONFormat.from_arg_list,
734
845
  "FORMAT_DATETIME": _build_format_time(exp.TsOrDsToDatetime),
735
846
  "FORMAT_TIMESTAMP": _build_format_time(exp.TsOrDsToTimestamp),
@@ -798,9 +909,13 @@ class BigQuery(Dialect):
798
909
  "SAFE_ORDINAL": (1, True),
799
910
  }
800
911
 
801
- def _parse_for_in(self) -> exp.ForIn:
912
+ def _parse_for_in(self) -> t.Union[exp.ForIn, exp.Command]:
913
+ index = self._index
802
914
  this = self._parse_range()
803
915
  self._match_text_seq("DO")
916
+ if self._match(TokenType.COMMAND):
917
+ self._retreat(index)
918
+ return self._parse_as_command(self._prev)
804
919
  return self.expression(exp.ForIn, this=this, expression=self._parse_statement())
805
920
 
806
921
  def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
@@ -1196,7 +1311,13 @@ class BigQuery(Dialect):
1196
1311
  exp.JSONExtract: _json_extract_sql,
1197
1312
  exp.JSONExtractArray: _json_extract_sql,
1198
1313
  exp.JSONExtractScalar: _json_extract_sql,
1199
- exp.JSONFormat: rename_func("TO_JSON_STRING"),
1314
+ exp.JSONFormat: lambda self, e: self.func(
1315
+ "TO_JSON" if e.args.get("to_json") else "TO_JSON_STRING",
1316
+ e.this,
1317
+ e.args.get("options"),
1318
+ ),
1319
+ exp.JSONKeysAtDepth: rename_func("JSON_KEYS"),
1320
+ exp.JSONValueArray: rename_func("JSON_VALUE_ARRAY"),
1200
1321
  exp.Levenshtein: _levenshtein_sql,
1201
1322
  exp.Max: max_or_greatest,
1202
1323
  exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
@@ -312,6 +312,7 @@ class ClickHouse(Dialect):
312
312
  "ARRAYREVERSE": exp.ArrayReverse.from_arg_list,
313
313
  "ARRAYSLICE": exp.ArraySlice.from_arg_list,
314
314
  "COUNTIF": _build_count_if,
315
+ "COSINEDISTANCE": exp.CosineDistance.from_arg_list,
315
316
  "DATE_ADD": build_date_delta(exp.DateAdd, default_unit=None),
316
317
  "DATEADD": build_date_delta(exp.DateAdd, default_unit=None),
317
318
  "DATE_DIFF": build_date_delta(exp.DateDiff, default_unit=None, supports_timezone=True),
@@ -324,6 +325,7 @@ class ClickHouse(Dialect):
324
325
  exp.JSONExtractScalar, zero_based_indexing=False
325
326
  ),
326
327
  "LENGTH": lambda args: exp.Length(this=seq_get(args, 0), binary=True),
328
+ "L2Distance": exp.EuclideanDistance.from_arg_list,
327
329
  "MAP": parser.build_var_map,
328
330
  "MATCH": exp.RegexpLike.from_arg_list,
329
331
  "PARSEDATETIME": _build_datetime_format(exp.ParseDatetime),
@@ -1094,6 +1096,7 @@ class ClickHouse(Dialect):
1094
1096
  exp.Array: inline_array_sql,
1095
1097
  exp.CastToStrType: rename_func("CAST"),
1096
1098
  exp.CountIf: rename_func("countIf"),
1099
+ exp.CosineDistance: rename_func("cosineDistance"),
1097
1100
  exp.CompressColumnConstraint: lambda self,
1098
1101
  e: f"CODEC({self.expressions(e, key='this', flat=True)})",
1099
1102
  exp.ComputedColumnConstraint: lambda self,
@@ -1123,6 +1126,7 @@ class ClickHouse(Dialect):
1123
1126
  exp.Rand: rename_func("randCanonical"),
1124
1127
  exp.StartsWith: rename_func("startsWith"),
1125
1128
  exp.EndsWith: rename_func("endsWith"),
1129
+ exp.EuclideanDistance: rename_func("L2Distance"),
1126
1130
  exp.StrPosition: lambda self, e: strposition_sql(
1127
1131
  self,
1128
1132
  e,
@@ -106,6 +106,7 @@ class Databricks(Spark):
106
106
  ),
107
107
  }
108
108
 
109
+ TRANSFORMS.pop(exp.RegexpLike)
109
110
  TRANSFORMS.pop(exp.TryCast)
110
111
 
111
112
  TYPE_MAPPING = {
sqlglot/dialects/doris.py CHANGED
@@ -50,6 +50,7 @@ class Doris(MySQL):
50
50
  **MySQL.Parser.FUNCTIONS,
51
51
  "COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list,
52
52
  "DATE_TRUNC": _build_date_trunc,
53
+ "L2_DISTANCE": exp.EuclideanDistance.from_arg_list,
53
54
  "MONTHS_ADD": exp.AddMonths.from_arg_list,
54
55
  "REGEXP": exp.RegexpLike.from_arg_list,
55
56
  "TO_DATE": exp.TsOrDsToDate.from_arg_list,
@@ -210,6 +211,7 @@ class Doris(MySQL):
210
211
  exp.CurrentDate: lambda self, _: self.func("CURRENT_DATE"),
211
212
  exp.CurrentTimestamp: lambda self, _: self.func("NOW"),
212
213
  exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)),
214
+ exp.EuclideanDistance: rename_func("L2_DISTANCE"),
213
215
  exp.GroupConcat: lambda self, e: self.func(
214
216
  "GROUP_CONCAT", e.this, e.args.get("separator") or exp.Literal.string(",")
215
217
  ),
@@ -74,6 +74,27 @@ def build_date_delta_with_cast_interval(
74
74
  return _builder
75
75
 
76
76
 
77
+ def datetype_handler(args: t.List[exp.Expression], dialect: DialectType) -> exp.Expression:
78
+ year, month, day = args
79
+
80
+ if all(isinstance(arg, exp.Literal) and arg.is_int for arg in (year, month, day)):
81
+ date_str = f"{int(year.this):04d}-{int(month.this):02d}-{int(day.this):02d}"
82
+ return exp.Date(this=exp.Literal.string(date_str))
83
+
84
+ return exp.Cast(
85
+ this=exp.Concat(
86
+ expressions=[
87
+ year,
88
+ exp.Literal.string("-"),
89
+ month,
90
+ exp.Literal.string("-"),
91
+ day,
92
+ ]
93
+ ),
94
+ to=exp.DataType.build("DATE"),
95
+ )
96
+
97
+
77
98
  class Dremio(Dialect):
78
99
  SUPPORTS_USER_DEFINED_TYPES = False
79
100
  CONCAT_COALESCE = True
@@ -145,12 +166,16 @@ class Dremio(Dialect):
145
166
 
146
167
  FUNCTIONS = {
147
168
  **parser.Parser.FUNCTIONS,
148
- "TO_CHAR": to_char_is_numeric_handler,
149
- "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "dremio"),
150
- "TO_DATE": build_formatted_time(exp.TsOrDsToDate, "dremio"),
169
+ "ARRAY_GENERATE_RANGE": exp.GenerateSeries.from_arg_list,
151
170
  "DATE_ADD": build_date_delta_with_cast_interval(exp.DateAdd),
171
+ "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "dremio"),
152
172
  "DATE_SUB": build_date_delta_with_cast_interval(exp.DateSub),
153
- "ARRAY_GENERATE_RANGE": exp.GenerateSeries.from_arg_list,
173
+ "REGEXP_MATCHES": exp.RegexpLike.from_arg_list,
174
+ "REPEATSTR": exp.Repeat.from_arg_list,
175
+ "TO_CHAR": to_char_is_numeric_handler,
176
+ "TO_DATE": build_formatted_time(exp.TsOrDsToDate, "dremio"),
177
+ "DATE_PART": exp.Extract.from_arg_list,
178
+ "DATETYPE": datetype_handler,
154
179
  }
155
180
 
156
181
  def _parse_current_date_utc(self) -> exp.Cast:
@@ -304,7 +304,6 @@ class DuckDB(Dialect):
304
304
  "CHAR": TokenType.TEXT,
305
305
  "DATETIME": TokenType.TIMESTAMPNTZ,
306
306
  "DETACH": TokenType.DETACH,
307
- "EXCLUDE": TokenType.EXCEPT,
308
307
  "LOGICAL": TokenType.BOOLEAN,
309
308
  "ONLY": TokenType.ONLY,
310
309
  "PIVOT_WIDER": TokenType.PIVOT,
@@ -386,6 +385,8 @@ class DuckDB(Dialect):
386
385
  "JSON_EXTRACT_PATH": parser.build_extract_json_with_path(exp.JSONExtract),
387
386
  "JSON_EXTRACT_STRING": parser.build_extract_json_with_path(exp.JSONExtractScalar),
388
387
  "LIST_CONTAINS": exp.ArrayContains.from_arg_list,
388
+ "LIST_COSINE_DISTANCE": exp.CosineDistance.from_arg_list,
389
+ "LIST_DISTANCE": exp.EuclideanDistance.from_arg_list,
389
390
  "LIST_FILTER": exp.ArrayFilter.from_arg_list,
390
391
  "LIST_HAS": exp.ArrayContains.from_arg_list,
391
392
  "LIST_HAS_ANY": exp.ArrayOverlaps.from_arg_list,
@@ -650,6 +651,7 @@ class DuckDB(Dialect):
650
651
  ),
651
652
  exp.BitwiseXor: rename_func("XOR"),
652
653
  exp.CommentColumnConstraint: no_comment_column_constraint_sql,
654
+ exp.CosineDistance: rename_func("LIST_COSINE_DISTANCE"),
653
655
  exp.CurrentDate: lambda *_: "CURRENT_DATE",
654
656
  exp.CurrentTime: lambda *_: "CURRENT_TIME",
655
657
  exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
@@ -673,6 +675,7 @@ class DuckDB(Dialect):
673
675
  exp.DiToDate: lambda self,
674
676
  e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)",
675
677
  exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False),
678
+ exp.EuclideanDistance: rename_func("LIST_DISTANCE"),
676
679
  exp.GenerateDateArray: _generate_datetime_array_sql,
677
680
  exp.GenerateTimestampArray: _generate_datetime_array_sql,
678
681
  exp.GroupConcat: lambda self, e: groupconcat_sql(self, e, within_group=False),
sqlglot/dialects/hive.py CHANGED
@@ -194,6 +194,16 @@ def _build_to_date(args: t.List) -> exp.TsOrDsToDate:
194
194
  return expr
195
195
 
196
196
 
197
+ def _build_date_add(args: t.List) -> exp.TsOrDsAdd:
198
+ expression = seq_get(args, 1)
199
+ if expression:
200
+ expression = expression * -1
201
+
202
+ return exp.TsOrDsAdd(
203
+ this=seq_get(args, 0), expression=expression, unit=exp.Literal.string("DAY")
204
+ )
205
+
206
+
197
207
  class Hive(Dialect):
198
208
  ALIAS_POST_TABLESAMPLE = True
199
209
  IDENTIFIERS_CAN_START_WITH_DIGIT = True
@@ -314,11 +324,7 @@ class Hive(Dialect):
314
324
  seq_get(args, 1),
315
325
  ]
316
326
  ),
317
- "DATE_SUB": lambda args: exp.TsOrDsAdd(
318
- this=seq_get(args, 0),
319
- expression=exp.Mul(this=seq_get(args, 1), expression=exp.Literal.number(-1)),
320
- unit=exp.Literal.string("DAY"),
321
- ),
327
+ "DATE_SUB": _build_date_add,
322
328
  "DATEDIFF": lambda args: exp.DateDiff(
323
329
  this=exp.TsOrDsToDate(this=seq_get(args, 0)),
324
330
  expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
@@ -107,6 +107,7 @@ class Oracle(Dialect):
107
107
  FUNCTIONS = {
108
108
  **parser.Parser.FUNCTIONS,
109
109
  "CONVERT": exp.ConvertToCharset.from_arg_list,
110
+ "L2_DISTANCE": exp.EuclideanDistance.from_arg_list,
110
111
  "NVL": lambda args: build_coalesce(args, is_nvl=True),
111
112
  "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
112
113
  "TO_CHAR": build_timetostr_or_tochar,
@@ -305,6 +306,7 @@ class Oracle(Dialect):
305
306
  "TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
306
307
  ),
307
308
  exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.unit),
309
+ exp.EuclideanDistance: rename_func("L2_DISTANCE"),
308
310
  exp.Group: transforms.preprocess([transforms.unalias_group]),
309
311
  exp.ILike: no_ilike_sql,
310
312
  exp.LogicalOr: rename_func("MAX"),