sqlglot 28.4.1__py3-none-any.whl → 28.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlglot/_version.py +2 -2
- sqlglot/dialects/bigquery.py +20 -23
- sqlglot/dialects/clickhouse.py +2 -0
- sqlglot/dialects/dialect.py +355 -18
- sqlglot/dialects/doris.py +38 -90
- sqlglot/dialects/druid.py +1 -0
- sqlglot/dialects/duckdb.py +1739 -163
- sqlglot/dialects/exasol.py +17 -1
- sqlglot/dialects/hive.py +27 -2
- sqlglot/dialects/mysql.py +103 -11
- sqlglot/dialects/oracle.py +38 -1
- sqlglot/dialects/postgres.py +142 -33
- sqlglot/dialects/presto.py +6 -2
- sqlglot/dialects/redshift.py +7 -1
- sqlglot/dialects/singlestore.py +13 -3
- sqlglot/dialects/snowflake.py +271 -21
- sqlglot/dialects/spark.py +25 -0
- sqlglot/dialects/spark2.py +4 -3
- sqlglot/dialects/starrocks.py +152 -17
- sqlglot/dialects/trino.py +1 -0
- sqlglot/dialects/tsql.py +5 -0
- sqlglot/diff.py +1 -1
- sqlglot/expressions.py +239 -47
- sqlglot/generator.py +173 -44
- sqlglot/optimizer/annotate_types.py +129 -60
- sqlglot/optimizer/merge_subqueries.py +13 -2
- sqlglot/optimizer/qualify_columns.py +7 -0
- sqlglot/optimizer/resolver.py +19 -0
- sqlglot/optimizer/scope.py +12 -0
- sqlglot/optimizer/unnest_subqueries.py +7 -0
- sqlglot/parser.py +251 -58
- sqlglot/schema.py +186 -14
- sqlglot/tokens.py +36 -6
- sqlglot/transforms.py +6 -5
- sqlglot/typing/__init__.py +29 -10
- sqlglot/typing/bigquery.py +5 -10
- sqlglot/typing/duckdb.py +39 -0
- sqlglot/typing/hive.py +50 -1
- sqlglot/typing/mysql.py +32 -0
- sqlglot/typing/presto.py +0 -1
- sqlglot/typing/snowflake.py +80 -17
- sqlglot/typing/spark.py +29 -0
- sqlglot/typing/spark2.py +9 -1
- sqlglot/typing/tsql.py +21 -0
- {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/METADATA +47 -2
- sqlglot-28.8.0.dist-info/RECORD +95 -0
- {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/WHEEL +1 -1
- sqlglot-28.4.1.dist-info/RECORD +0 -92
- {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/licenses/LICENSE +0 -0
- {sqlglot-28.4.1.dist-info → sqlglot-28.8.0.dist-info}/top_level.txt +0 -0
sqlglot/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '28.
|
|
32
|
-
__version_tuple__ = version_tuple = (28,
|
|
31
|
+
__version__ = version = '28.8.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (28, 8, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
sqlglot/dialects/bigquery.py
CHANGED
|
@@ -51,6 +51,8 @@ JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar, exp.JSONExtr
|
|
|
51
51
|
|
|
52
52
|
DQUOTES_ESCAPING_JSON_FUNCTIONS = ("JSON_QUERY", "JSON_VALUE", "JSON_QUERY_ARRAY")
|
|
53
53
|
|
|
54
|
+
MAKE_INTERVAL_KWARGS = ["year", "month", "day", "hour", "minute", "second"]
|
|
55
|
+
|
|
54
56
|
|
|
55
57
|
def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Values) -> str:
|
|
56
58
|
if not expression.find_ancestor(exp.From, exp.Join):
|
|
@@ -389,7 +391,9 @@ class BigQuery(Dialect):
|
|
|
389
391
|
EXCLUDES_PSEUDOCOLUMNS_FROM_STAR = True
|
|
390
392
|
QUERY_RESULTS_ARE_STRUCTS = True
|
|
391
393
|
JSON_EXTRACT_SCALAR_SCALAR_ONLY = True
|
|
394
|
+
LEAST_GREATEST_IGNORES_NULLS = False
|
|
392
395
|
DEFAULT_NULL_TYPE = exp.DataType.Type.BIGINT
|
|
396
|
+
PRIORITIZE_NON_LITERAL_TYPES = True
|
|
393
397
|
|
|
394
398
|
# https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#initcap
|
|
395
399
|
INITCAP_DEFAULT_DELIMITER_CHARS = ' \t\n\r\f\v\\[\\](){}/|<>!?@"^#$&~_,.:;*%+\\-'
|
|
@@ -602,12 +606,6 @@ class BigQuery(Dialect):
|
|
|
602
606
|
"EDIT_DISTANCE": _build_levenshtein,
|
|
603
607
|
"FORMAT_DATE": _build_format_time(exp.TsOrDsToDate),
|
|
604
608
|
"GENERATE_ARRAY": exp.GenerateSeries.from_arg_list,
|
|
605
|
-
"GREATEST": lambda args: exp.Greatest(
|
|
606
|
-
this=seq_get(args, 0), expressions=args[1:], null_if_any_null=True
|
|
607
|
-
),
|
|
608
|
-
"LEAST": lambda args: exp.Least(
|
|
609
|
-
this=seq_get(args, 0), expressions=args[1:], null_if_any_null=True
|
|
610
|
-
),
|
|
611
609
|
"JSON_EXTRACT_SCALAR": _build_extract_json_with_default_path(exp.JSONExtractScalar),
|
|
612
610
|
"JSON_EXTRACT_ARRAY": _build_extract_json_with_default_path(exp.JSONExtractArray),
|
|
613
611
|
"JSON_EXTRACT_STRING_ARRAY": _build_extract_json_with_default_path(exp.JSONValueArray),
|
|
@@ -964,7 +962,7 @@ class BigQuery(Dialect):
|
|
|
964
962
|
def _parse_make_interval(self) -> exp.MakeInterval:
|
|
965
963
|
expr = exp.MakeInterval()
|
|
966
964
|
|
|
967
|
-
for arg_key in
|
|
965
|
+
for arg_key in MAKE_INTERVAL_KWARGS:
|
|
968
966
|
value = self._parse_lambda()
|
|
969
967
|
|
|
970
968
|
if not value:
|
|
@@ -1069,20 +1067,23 @@ class BigQuery(Dialect):
|
|
|
1069
1067
|
)
|
|
1070
1068
|
|
|
1071
1069
|
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
|
1070
|
+
func_index = self._index + 1
|
|
1072
1071
|
this = super()._parse_column_ops(this)
|
|
1073
1072
|
|
|
1074
|
-
if isinstance(this, exp.Dot):
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1073
|
+
if isinstance(this, exp.Dot) and isinstance(this.expression, exp.Func):
|
|
1074
|
+
prefix = this.this.name.upper()
|
|
1075
|
+
|
|
1076
|
+
func: t.Optional[t.Type[exp.Func]] = None
|
|
1077
|
+
if prefix == "NET":
|
|
1078
|
+
func = exp.NetFunc
|
|
1079
|
+
elif prefix == "SAFE":
|
|
1080
|
+
func = exp.SafeFunc
|
|
1081
|
+
|
|
1082
|
+
if func:
|
|
1083
|
+
# Retreat to try and parse a known function instead of an anonymous one,
|
|
1084
|
+
# which is parsed by the base column ops parser due to anonymous_func=true
|
|
1085
|
+
self._retreat(func_index)
|
|
1086
|
+
this = func(this=self._parse_function(any_token=True))
|
|
1086
1087
|
|
|
1087
1088
|
return this
|
|
1088
1089
|
|
|
@@ -1551,7 +1552,3 @@ class BigQuery(Dialect):
|
|
|
1551
1552
|
kind = f" {kind}" if kind else ""
|
|
1552
1553
|
|
|
1553
1554
|
return f"{variables}{kind}{default}"
|
|
1554
|
-
|
|
1555
|
-
def timestamp_sql(self, expression: exp.Timestamp) -> str:
|
|
1556
|
-
prefix = "SAFE." if expression.args.get("safe") else ""
|
|
1557
|
-
return self.func(f"{prefix}TIMESTAMP", expression.this, expression.args.get("zone"))
|
sqlglot/dialects/clickhouse.py
CHANGED
|
@@ -565,6 +565,8 @@ class ClickHouse(Dialect):
|
|
|
565
565
|
"MEDIAN": lambda self: self._parse_quantile(),
|
|
566
566
|
"COLUMNS": lambda self: self._parse_columns(),
|
|
567
567
|
"TUPLE": lambda self: exp.Struct.from_arg_list(self._parse_function_args(alias=True)),
|
|
568
|
+
"AND": lambda self: exp.and_(*self._parse_function_args(alias=False)),
|
|
569
|
+
"OR": lambda self: exp.or_(*self._parse_function_args(alias=False)),
|
|
568
570
|
}
|
|
569
571
|
|
|
570
572
|
FUNCTION_PARSERS.pop("MATCH")
|
sqlglot/dialects/dialect.py
CHANGED
|
@@ -19,6 +19,7 @@ from sqlglot.helper import (
|
|
|
19
19
|
seq_get,
|
|
20
20
|
suggest_closest_match_and_fail,
|
|
21
21
|
to_bool,
|
|
22
|
+
ensure_list,
|
|
22
23
|
)
|
|
23
24
|
from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path
|
|
24
25
|
from sqlglot.parser import Parser
|
|
@@ -27,6 +28,8 @@ from sqlglot.tokens import Token, Tokenizer, TokenType
|
|
|
27
28
|
from sqlglot.trie import new_trie
|
|
28
29
|
from sqlglot.typing import EXPRESSION_METADATA
|
|
29
30
|
|
|
31
|
+
from importlib.metadata import entry_points
|
|
32
|
+
|
|
30
33
|
DATE_ADD_OR_DIFF = t.Union[
|
|
31
34
|
exp.DateAdd,
|
|
32
35
|
exp.DateDiff,
|
|
@@ -66,6 +69,8 @@ UNESCAPED_SEQUENCES = {
|
|
|
66
69
|
"\\\\": "\\",
|
|
67
70
|
}
|
|
68
71
|
|
|
72
|
+
PLUGIN_GROUP_NAME = "sqlglot.dialects"
|
|
73
|
+
|
|
69
74
|
|
|
70
75
|
class Dialects(str, Enum):
|
|
71
76
|
"""Dialects supported by SQLGLot."""
|
|
@@ -153,12 +158,54 @@ class _Dialect(type):
|
|
|
153
158
|
if isinstance(key, Dialects):
|
|
154
159
|
key = key.value
|
|
155
160
|
|
|
156
|
-
#
|
|
157
|
-
# We check that the key is an actual sqlglot module to avoid blindly importing
|
|
158
|
-
# files. Custom user dialects need to be imported at the top-level package, in
|
|
159
|
-
# order for them to be registered as soon as possible.
|
|
161
|
+
# 1. Try standard sqlglot modules first
|
|
160
162
|
if key in DIALECT_MODULE_NAMES:
|
|
163
|
+
module = importlib.import_module(f"sqlglot.dialects.{key}")
|
|
164
|
+
# If module was already imported, the class may not be in _classes
|
|
165
|
+
# Find and register the dialect class from the module
|
|
166
|
+
if key not in cls._classes:
|
|
167
|
+
for attr_name in dir(module):
|
|
168
|
+
attr = getattr(module, attr_name, None)
|
|
169
|
+
if (
|
|
170
|
+
isinstance(attr, type)
|
|
171
|
+
and issubclass(attr, Dialect)
|
|
172
|
+
and attr.__name__.lower() == key
|
|
173
|
+
):
|
|
174
|
+
cls._classes[key] = attr
|
|
175
|
+
break
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
# 2. Try entry points (for plugins)
|
|
179
|
+
try:
|
|
180
|
+
all_eps = entry_points()
|
|
181
|
+
# Python 3.10+ has select() method, older versions use dict-like access
|
|
182
|
+
if hasattr(all_eps, "select"):
|
|
183
|
+
eps = all_eps.select(group=PLUGIN_GROUP_NAME, name=key)
|
|
184
|
+
else:
|
|
185
|
+
# For older Python versions, entry_points() returns a dict-like object
|
|
186
|
+
group_eps = all_eps.get(PLUGIN_GROUP_NAME, []) # type: ignore
|
|
187
|
+
eps = [ep for ep in group_eps if ep.name == key] # type: ignore
|
|
188
|
+
|
|
189
|
+
for entry_point in eps:
|
|
190
|
+
dialect_class = entry_point.load()
|
|
191
|
+
# Verify it's a Dialect subclass
|
|
192
|
+
# issubclass() returns False if not a subclass, TypeError only if not a class at all
|
|
193
|
+
if isinstance(dialect_class, type) and issubclass(dialect_class, Dialect):
|
|
194
|
+
# Register the dialect using the entry point name (key)
|
|
195
|
+
# The metaclass may have registered it by class name, but we need it by entry point name
|
|
196
|
+
if key not in cls._classes:
|
|
197
|
+
cls._classes[key] = dialect_class
|
|
198
|
+
return
|
|
199
|
+
except ImportError:
|
|
200
|
+
# entry_point.load() failed (bad plugin - module/class doesn't exist)
|
|
201
|
+
pass
|
|
202
|
+
|
|
203
|
+
# 3. Try direct import (for backward compatibility)
|
|
204
|
+
# This allows namespace packages or explicit imports to work
|
|
205
|
+
try:
|
|
161
206
|
importlib.import_module(f"sqlglot.dialects.{key}")
|
|
207
|
+
except ImportError:
|
|
208
|
+
pass
|
|
162
209
|
|
|
163
210
|
@classmethod
|
|
164
211
|
def __getitem__(cls, key: str) -> t.Type[Dialect]:
|
|
@@ -235,7 +282,12 @@ class _Dialect(type):
|
|
|
235
282
|
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
|
|
236
283
|
klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
|
|
237
284
|
|
|
238
|
-
|
|
285
|
+
klass.STRINGS_SUPPORT_ESCAPED_SEQUENCES = "\\" in klass.tokenizer_class.STRING_ESCAPES
|
|
286
|
+
klass.BYTE_STRINGS_SUPPORT_ESCAPED_SEQUENCES = (
|
|
287
|
+
"\\" in klass.tokenizer_class.BYTE_STRING_ESCAPES
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
if klass.STRINGS_SUPPORT_ESCAPED_SEQUENCES or klass.BYTE_STRINGS_SUPPORT_ESCAPED_SEQUENCES:
|
|
239
291
|
klass.UNESCAPED_SEQUENCES = {
|
|
240
292
|
**UNESCAPED_SEQUENCES,
|
|
241
293
|
**klass.UNESCAPED_SEQUENCES,
|
|
@@ -650,6 +702,9 @@ class Dialect(metaclass=_Dialect):
|
|
|
650
702
|
ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True
|
|
651
703
|
"""Whether ArrayAgg needs to filter NULL values."""
|
|
652
704
|
|
|
705
|
+
ARRAY_FUNCS_PROPAGATES_NULLS = False
|
|
706
|
+
"""Whether Array update functions return NULL when the input array is NULL."""
|
|
707
|
+
|
|
653
708
|
PROMOTE_TO_INFERRED_DATETIME_TYPE = False
|
|
654
709
|
"""
|
|
655
710
|
This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted
|
|
@@ -741,6 +796,18 @@ class Dialect(metaclass=_Dialect):
|
|
|
741
796
|
For example, in BigQuery the default type of the NULL value is INT64.
|
|
742
797
|
"""
|
|
743
798
|
|
|
799
|
+
LEAST_GREATEST_IGNORES_NULLS = True
|
|
800
|
+
"""
|
|
801
|
+
Whether LEAST/GREATEST functions ignore NULL values, e.g:
|
|
802
|
+
- BigQuery, Snowflake, MySQL, Presto/Trino: LEAST(1, NULL, 2) -> NULL
|
|
803
|
+
- Spark, Postgres, DuckDB, TSQL: LEAST(1, NULL, 2) -> 1
|
|
804
|
+
"""
|
|
805
|
+
|
|
806
|
+
PRIORITIZE_NON_LITERAL_TYPES = False
|
|
807
|
+
"""
|
|
808
|
+
Whether to prioritize non-literal types over literals during type annotation.
|
|
809
|
+
"""
|
|
810
|
+
|
|
744
811
|
# --- Autofilled ---
|
|
745
812
|
|
|
746
813
|
tokenizer_class = Tokenizer
|
|
@@ -935,7 +1002,9 @@ class Dialect(metaclass=_Dialect):
|
|
|
935
1002
|
|
|
936
1003
|
result = cls.get(dialect_name.strip())
|
|
937
1004
|
if not result:
|
|
938
|
-
|
|
1005
|
+
# Include both built-in dialects and any loaded dialects for better error messages
|
|
1006
|
+
all_dialects = set(DIALECT_MODULE_NAMES) | set(cls._classes.keys())
|
|
1007
|
+
suggest_closest_match_and_fail("dialect", dialect_name, all_dialects)
|
|
939
1008
|
|
|
940
1009
|
assert result is not None
|
|
941
1010
|
return result(**kwargs)
|
|
@@ -1282,6 +1351,138 @@ def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
|
|
|
1282
1351
|
)
|
|
1283
1352
|
|
|
1284
1353
|
|
|
1354
|
+
def array_append_sql(
|
|
1355
|
+
name: str, swap_params: bool = False
|
|
1356
|
+
) -> t.Callable[[Generator, exp.ArrayAppend | exp.ArrayPrepend], str]:
|
|
1357
|
+
"""
|
|
1358
|
+
Transpile ARRAY_APPEND/ARRAY_PREPEND between dialects with different NULL propagation semantics.
|
|
1359
|
+
|
|
1360
|
+
Some dialects (Databricks, Spark, Snowflake) return NULL when the input array is NULL.
|
|
1361
|
+
Others (DuckDB, Postgres) create a new single-element array instead.
|
|
1362
|
+
|
|
1363
|
+
Args:
|
|
1364
|
+
name: Target dialect's function name (e.g., "ARRAY_APPEND", "ARRAY_PREPEND")
|
|
1365
|
+
swap_params: If True, generate (element, array) order instead of (array, element).
|
|
1366
|
+
DuckDB LIST_PREPEND and Postgres ARRAY_PREPEND use (element, array).
|
|
1367
|
+
|
|
1368
|
+
Returns:
|
|
1369
|
+
A callable that generates SQL with appropriate NULL handling for the target dialect.
|
|
1370
|
+
Dialects that propagate NULLs need to set `ARRAY_FUNCS_PROPAGATES_NULLS` to True.
|
|
1371
|
+
"""
|
|
1372
|
+
|
|
1373
|
+
def _array_append_sql(self: Generator, expression: exp.ArrayAppend | exp.ArrayPrepend) -> str:
|
|
1374
|
+
this = expression.this
|
|
1375
|
+
element = expression.expression
|
|
1376
|
+
args = [element, this] if swap_params else [this, element]
|
|
1377
|
+
func_sql = self.func(name, *args)
|
|
1378
|
+
|
|
1379
|
+
source_null_propagation = bool(expression.args.get("null_propagation"))
|
|
1380
|
+
target_null_propagation = self.dialect.ARRAY_FUNCS_PROPAGATES_NULLS
|
|
1381
|
+
|
|
1382
|
+
# No transpilation needed when source and target have matching NULL semantics
|
|
1383
|
+
if source_null_propagation == target_null_propagation:
|
|
1384
|
+
return func_sql
|
|
1385
|
+
|
|
1386
|
+
# Source propagates NULLs, target doesn't: wrap in conditional to return NULL explicitly
|
|
1387
|
+
if source_null_propagation:
|
|
1388
|
+
return self.sql(
|
|
1389
|
+
exp.If(
|
|
1390
|
+
this=exp.Is(this=this, expression=exp.Null()),
|
|
1391
|
+
true=exp.Null(),
|
|
1392
|
+
false=func_sql,
|
|
1393
|
+
)
|
|
1394
|
+
)
|
|
1395
|
+
|
|
1396
|
+
# Source doesn't propagate NULLs, target does: use COALESCE to convert NULL to empty array
|
|
1397
|
+
this = exp.Coalesce(expressions=[this, exp.Array(expressions=[])])
|
|
1398
|
+
args = [element, this] if swap_params else [this, element]
|
|
1399
|
+
return self.func(name, *args)
|
|
1400
|
+
|
|
1401
|
+
return _array_append_sql
|
|
1402
|
+
|
|
1403
|
+
|
|
1404
|
+
def array_concat_sql(
|
|
1405
|
+
name: str,
|
|
1406
|
+
) -> t.Callable[[Generator, exp.ArrayConcat], str]:
|
|
1407
|
+
"""
|
|
1408
|
+
Transpile ARRAY_CONCAT/ARRAY_CAT between dialects with different NULL propagation semantics.
|
|
1409
|
+
|
|
1410
|
+
Some dialects (Redshift, Snowflake, Spark) return NULL when ANY input array is NULL.
|
|
1411
|
+
Others (DuckDB, PostgreSQL) skip NULL arrays and continue concatenation.
|
|
1412
|
+
|
|
1413
|
+
Args:
|
|
1414
|
+
name: Target dialect's function name (e.g., "ARRAY_CAT", "ARRAY_CONCAT", "LIST_CONCAT")
|
|
1415
|
+
|
|
1416
|
+
Returns:
|
|
1417
|
+
A callable that generates SQL with appropriate NULL handling for the target dialect.
|
|
1418
|
+
Dialects that propagate NULLs need to set `ARRAY_FUNCS_PROPAGATES_NULLS` to True.
|
|
1419
|
+
"""
|
|
1420
|
+
|
|
1421
|
+
def _build_func_call(self: Generator, func_name: str, args: t.Sequence[exp.Expression]) -> str:
|
|
1422
|
+
"""Build ARRAY_CONCAT call from a list of arguments, handling variadic vs binary nesting."""
|
|
1423
|
+
if self.ARRAY_CONCAT_IS_VAR_LEN:
|
|
1424
|
+
return self.func(func_name, *args)
|
|
1425
|
+
elif len(args) == 1:
|
|
1426
|
+
# Single arg gets empty array to preserve semantics
|
|
1427
|
+
return self.func(func_name, args[0], exp.Array(expressions=[]))
|
|
1428
|
+
else:
|
|
1429
|
+
# Snowflake/PostgreSQL/Redshift require binary nesting: ARRAY_CAT(a, ARRAY_CAT(b, c))
|
|
1430
|
+
# Build right-deep tree recursively to avoid creating new ArrayConcat expressions
|
|
1431
|
+
result = self.func(func_name, args[-2], args[-1])
|
|
1432
|
+
for arg in reversed(args[:-2]):
|
|
1433
|
+
result = f"{func_name}({self.sql(arg)}, {result})"
|
|
1434
|
+
return result
|
|
1435
|
+
|
|
1436
|
+
def _array_concat_sql(self: Generator, expression: exp.ArrayConcat) -> str:
|
|
1437
|
+
this = expression.this
|
|
1438
|
+
exprs = expression.expressions
|
|
1439
|
+
all_args = [this] + exprs
|
|
1440
|
+
|
|
1441
|
+
source_null_propagation = bool(expression.args.get("null_propagation"))
|
|
1442
|
+
target_null_propagation = self.dialect.ARRAY_FUNCS_PROPAGATES_NULLS
|
|
1443
|
+
|
|
1444
|
+
# Skip wrapper when source and target have matching NULL semantics,
|
|
1445
|
+
# or when the first argument is an array literal (which can never be NULL),
|
|
1446
|
+
# or when it's a single-argument call (empty array is added, preserving NULL semantics)
|
|
1447
|
+
if (
|
|
1448
|
+
source_null_propagation == target_null_propagation
|
|
1449
|
+
or isinstance(this, exp.Array)
|
|
1450
|
+
or len(exprs) == 0
|
|
1451
|
+
):
|
|
1452
|
+
return _build_func_call(self, name, all_args)
|
|
1453
|
+
|
|
1454
|
+
# Case 1: Source propagates NULLs, target doesn't (Snowflake → DuckDB)
|
|
1455
|
+
# Check if ANY argument is NULL and return NULL explicitly
|
|
1456
|
+
if source_null_propagation:
|
|
1457
|
+
# Build OR-chain: a IS NULL OR b IS NULL OR c IS NULL
|
|
1458
|
+
null_checks: t.List[exp.Expression] = [
|
|
1459
|
+
exp.Is(this=arg.copy(), expression=exp.Null()) for arg in all_args
|
|
1460
|
+
]
|
|
1461
|
+
combined_check: exp.Expression = reduce(
|
|
1462
|
+
lambda a, b: exp.Or(this=a, expression=b), null_checks
|
|
1463
|
+
)
|
|
1464
|
+
|
|
1465
|
+
func_sql = _build_func_call(self, name, all_args)
|
|
1466
|
+
|
|
1467
|
+
return self.sql(
|
|
1468
|
+
exp.If(
|
|
1469
|
+
this=combined_check,
|
|
1470
|
+
true=exp.Null(),
|
|
1471
|
+
false=func_sql,
|
|
1472
|
+
)
|
|
1473
|
+
)
|
|
1474
|
+
|
|
1475
|
+
# Case 2: Source doesn't propagate NULLs, target does (DuckDB → Snowflake)
|
|
1476
|
+
# Wrap ALL arguments in COALESCE to convert NULL → empty array
|
|
1477
|
+
wrapped_args = [
|
|
1478
|
+
exp.Coalesce(expressions=[arg.copy(), exp.Array(expressions=[])]) for arg in all_args
|
|
1479
|
+
]
|
|
1480
|
+
|
|
1481
|
+
return _build_func_call(self, name, wrapped_args)
|
|
1482
|
+
|
|
1483
|
+
return _array_concat_sql
|
|
1484
|
+
|
|
1485
|
+
|
|
1285
1486
|
def var_map_sql(
|
|
1286
1487
|
self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
|
|
1287
1488
|
) -> str:
|
|
@@ -1300,6 +1501,59 @@ def var_map_sql(
|
|
|
1300
1501
|
return self.func(map_func_name, *args)
|
|
1301
1502
|
|
|
1302
1503
|
|
|
1504
|
+
def months_between_sql(self: Generator, expression: exp.MonthsBetween) -> str:
|
|
1505
|
+
"""
|
|
1506
|
+
Transpile MONTHS_BETWEEN to dialects that don't have native support.
|
|
1507
|
+
|
|
1508
|
+
Snowflake's MONTHS_BETWEEN returns whole months + fractional part where:
|
|
1509
|
+
- Fractional part = (DAY(date1) - DAY(date2)) / 31
|
|
1510
|
+
- Special case: If both dates are last day of month, fractional part = 0
|
|
1511
|
+
|
|
1512
|
+
Formula: DATEDIFF('month', date2, date1) + (DAY(date1) - DAY(date2)) / 31.0
|
|
1513
|
+
"""
|
|
1514
|
+
date1 = expression.this
|
|
1515
|
+
date2 = expression.expression
|
|
1516
|
+
|
|
1517
|
+
# Cast to DATE to ensure consistent behavior
|
|
1518
|
+
date1_cast = exp.cast(date1, exp.DataType.Type.DATE, copy=False)
|
|
1519
|
+
date2_cast = exp.cast(date2, exp.DataType.Type.DATE, copy=False)
|
|
1520
|
+
|
|
1521
|
+
# Whole months: DATEDIFF('month', date2, date1)
|
|
1522
|
+
whole_months = exp.DateDiff(this=date1_cast, expression=date2_cast, unit=exp.var("month"))
|
|
1523
|
+
|
|
1524
|
+
# Day components
|
|
1525
|
+
day1 = exp.Day(this=date1_cast.copy())
|
|
1526
|
+
day2 = exp.Day(this=date2_cast.copy())
|
|
1527
|
+
|
|
1528
|
+
# Last day of month components
|
|
1529
|
+
last_day_of_month1 = exp.LastDay(this=date1_cast.copy())
|
|
1530
|
+
last_day_of_month2 = exp.LastDay(this=date2_cast.copy())
|
|
1531
|
+
|
|
1532
|
+
day_of_last_day1 = exp.Day(this=last_day_of_month1)
|
|
1533
|
+
day_of_last_day2 = exp.Day(this=last_day_of_month2)
|
|
1534
|
+
|
|
1535
|
+
# Check if both are last day of month
|
|
1536
|
+
last_day1 = exp.EQ(this=day1.copy(), expression=day_of_last_day1)
|
|
1537
|
+
last_day2 = exp.EQ(this=day2.copy(), expression=day_of_last_day2)
|
|
1538
|
+
both_last_day = exp.And(this=last_day1, expression=last_day2)
|
|
1539
|
+
|
|
1540
|
+
# Fractional part: (DAY(date1) - DAY(date2)) / 31.0
|
|
1541
|
+
fractional = exp.Div(
|
|
1542
|
+
this=exp.Paren(this=exp.Sub(this=day1.copy(), expression=day2.copy())),
|
|
1543
|
+
expression=exp.Literal.number("31.0"),
|
|
1544
|
+
)
|
|
1545
|
+
|
|
1546
|
+
# If both are last day of month, fractional = 0, else calculate fractional
|
|
1547
|
+
fractional_with_check = exp.If(
|
|
1548
|
+
this=both_last_day, true=exp.Literal.number("0"), false=fractional
|
|
1549
|
+
)
|
|
1550
|
+
|
|
1551
|
+
# Final result: whole_months + fractional
|
|
1552
|
+
result = exp.Add(this=whole_months, expression=fractional_with_check)
|
|
1553
|
+
|
|
1554
|
+
return self.sql(result)
|
|
1555
|
+
|
|
1556
|
+
|
|
1303
1557
|
def build_formatted_time(
|
|
1304
1558
|
exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
|
|
1305
1559
|
) -> t.Callable[[t.List], E]:
|
|
@@ -1899,15 +2153,54 @@ def filter_array_using_unnest(
|
|
|
1899
2153
|
return self.sql(exp.Array(expressions=[filtered]))
|
|
1900
2154
|
|
|
1901
2155
|
|
|
2156
|
+
def array_compact_sql(self: Generator, expression: exp.ArrayCompact) -> str:
|
|
2157
|
+
lambda_id = exp.to_identifier("_u")
|
|
2158
|
+
cond = exp.Is(this=lambda_id, expression=exp.null()).not_()
|
|
2159
|
+
return self.sql(
|
|
2160
|
+
exp.ArrayFilter(
|
|
2161
|
+
this=expression.this,
|
|
2162
|
+
expression=exp.Lambda(this=cond, expressions=[lambda_id]),
|
|
2163
|
+
)
|
|
2164
|
+
)
|
|
2165
|
+
|
|
2166
|
+
|
|
1902
2167
|
def remove_from_array_using_filter(self: Generator, expression: exp.ArrayRemove) -> str:
|
|
1903
2168
|
lambda_id = exp.to_identifier("_u")
|
|
1904
2169
|
cond = exp.NEQ(this=lambda_id, expression=expression.expression)
|
|
1905
|
-
|
|
2170
|
+
|
|
2171
|
+
filter_sql = self.sql(
|
|
1906
2172
|
exp.ArrayFilter(
|
|
1907
|
-
this=expression.this,
|
|
2173
|
+
this=expression.this,
|
|
2174
|
+
expression=exp.Lambda(this=cond, expressions=[lambda_id]),
|
|
1908
2175
|
)
|
|
1909
2176
|
)
|
|
1910
2177
|
|
|
2178
|
+
# Handle NULL propagation for ArrayRemove
|
|
2179
|
+
source_null_propagation = bool(expression.args.get("null_propagation"))
|
|
2180
|
+
target_null_propagation = self.dialect.ARRAY_FUNCS_PROPAGATES_NULLS
|
|
2181
|
+
|
|
2182
|
+
# Source propagates NULLs (Snowflake), target doesn't (DuckDB):
|
|
2183
|
+
# When removal value is NULL, return NULL instead of applying filter
|
|
2184
|
+
if source_null_propagation and not target_null_propagation:
|
|
2185
|
+
removal_value = expression.expression
|
|
2186
|
+
|
|
2187
|
+
# Optimization: skip wrapper if removal value is a non-NULL literal
|
|
2188
|
+
# (e.g., 5, 'a', TRUE) or an array literal (e.g., [1, 2])
|
|
2189
|
+
if (
|
|
2190
|
+
isinstance(removal_value, exp.Literal) and not isinstance(removal_value, exp.Null)
|
|
2191
|
+
) or isinstance(removal_value, exp.Array):
|
|
2192
|
+
return filter_sql
|
|
2193
|
+
|
|
2194
|
+
return self.sql(
|
|
2195
|
+
exp.If(
|
|
2196
|
+
this=exp.Is(this=removal_value, expression=exp.Null()),
|
|
2197
|
+
true=exp.Null(),
|
|
2198
|
+
false=filter_sql,
|
|
2199
|
+
)
|
|
2200
|
+
)
|
|
2201
|
+
|
|
2202
|
+
return filter_sql
|
|
2203
|
+
|
|
1911
2204
|
|
|
1912
2205
|
def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
|
|
1913
2206
|
return self.func(
|
|
@@ -2036,17 +2329,40 @@ def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect],
|
|
|
2036
2329
|
|
|
2037
2330
|
|
|
2038
2331
|
def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str:
|
|
2039
|
-
|
|
2040
|
-
|
|
2041
|
-
|
|
2042
|
-
|
|
2043
|
-
|
|
2044
|
-
|
|
2045
|
-
|
|
2046
|
-
|
|
2047
|
-
|
|
2048
|
-
|
|
2332
|
+
this = expression.this
|
|
2333
|
+
alias = expression.args.get("alias")
|
|
2334
|
+
|
|
2335
|
+
cross_join_expr: t.Optional[exp.Expression] = None
|
|
2336
|
+
if isinstance(this, exp.Posexplode) and alias:
|
|
2337
|
+
# Spark's `FROM x LATERAL VIEW POSEXPLODE(y) t AS pos, col` has the following semantics:
|
|
2338
|
+
# - The first column is the position and the rest (1 for array, 2 for maps) are the exploded values
|
|
2339
|
+
# - The position is 0-based whereas WITH ORDINALITY is 1-based
|
|
2340
|
+
# For that matter, we must (1) subtract 1 from the ORDINALITY position and (2) rearrange the columns accordingly, returning:
|
|
2341
|
+
# `FROM x CROSS JOIN LATERAL (SELECT pos - 1 AS pos, col FROM UNNEST(y) WITH ORDINALITY AS t(col, pos))
|
|
2342
|
+
pos, cols = alias.columns[0], alias.columns[1:]
|
|
2343
|
+
|
|
2344
|
+
cols = ensure_list(cols)
|
|
2345
|
+
lateral_subquery = exp.select(
|
|
2346
|
+
exp.alias_(pos - 1, pos),
|
|
2347
|
+
*cols,
|
|
2348
|
+
).from_(
|
|
2349
|
+
exp.Unnest(
|
|
2350
|
+
expressions=[this.this],
|
|
2351
|
+
offset=True,
|
|
2352
|
+
alias=exp.TableAlias(this=alias.this, columns=[*cols, pos]),
|
|
2353
|
+
),
|
|
2354
|
+
)
|
|
2355
|
+
|
|
2356
|
+
cross_join_expr = exp.Lateral(this=lateral_subquery.subquery())
|
|
2357
|
+
elif isinstance(this, exp.Explode):
|
|
2358
|
+
cross_join_expr = exp.Unnest(
|
|
2359
|
+
expressions=[this.this],
|
|
2360
|
+
alias=alias,
|
|
2049
2361
|
)
|
|
2362
|
+
|
|
2363
|
+
if cross_join_expr:
|
|
2364
|
+
return self.sql(exp.Join(this=cross_join_expr, kind="cross"))
|
|
2365
|
+
|
|
2050
2366
|
return self.lateral_sql(expression)
|
|
2051
2367
|
|
|
2052
2368
|
|
|
@@ -2154,3 +2470,24 @@ def regexp_replace_global_modifier(expression: exp.RegexpReplace) -> exp.Express
|
|
|
2154
2470
|
modifiers = exp.Literal.string(value + "g")
|
|
2155
2471
|
|
|
2156
2472
|
return modifiers
|
|
2473
|
+
|
|
2474
|
+
|
|
2475
|
+
def getbit_sql(self: Generator, expression: exp.Getbit) -> str:
|
|
2476
|
+
"""
|
|
2477
|
+
Generates SQL for Getbit according to DuckDB and Postgres, transpiling it if either:
|
|
2478
|
+
|
|
2479
|
+
1. The zero index corresponds to the least-significant bit
|
|
2480
|
+
2. The input type is an integer value
|
|
2481
|
+
"""
|
|
2482
|
+
value = expression.this
|
|
2483
|
+
position = expression.expression
|
|
2484
|
+
|
|
2485
|
+
if not expression.args.get("zero_is_msb") and expression.is_type(
|
|
2486
|
+
*exp.DataType.SIGNED_INTEGER_TYPES, *exp.DataType.UNSIGNED_INTEGER_TYPES
|
|
2487
|
+
):
|
|
2488
|
+
# Use bitwise operations: (value >> position) & 1
|
|
2489
|
+
shifted = exp.BitwiseRightShift(this=value, expression=position)
|
|
2490
|
+
masked = exp.BitwiseAnd(this=shifted, expression=exp.Literal.number(1))
|
|
2491
|
+
return self.sql(masked)
|
|
2492
|
+
|
|
2493
|
+
return self.func("GET_BIT", value, position)
|