sqlglot 28.4.0__py3-none-any.whl → 28.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sqlglot/_version.py +2 -2
  2. sqlglot/dialects/bigquery.py +20 -23
  3. sqlglot/dialects/clickhouse.py +2 -0
  4. sqlglot/dialects/dialect.py +355 -18
  5. sqlglot/dialects/doris.py +38 -90
  6. sqlglot/dialects/druid.py +1 -0
  7. sqlglot/dialects/duckdb.py +1739 -163
  8. sqlglot/dialects/exasol.py +17 -1
  9. sqlglot/dialects/hive.py +27 -2
  10. sqlglot/dialects/mysql.py +103 -11
  11. sqlglot/dialects/oracle.py +38 -1
  12. sqlglot/dialects/postgres.py +142 -33
  13. sqlglot/dialects/presto.py +6 -2
  14. sqlglot/dialects/redshift.py +7 -1
  15. sqlglot/dialects/singlestore.py +13 -3
  16. sqlglot/dialects/snowflake.py +271 -21
  17. sqlglot/dialects/spark.py +25 -0
  18. sqlglot/dialects/spark2.py +4 -3
  19. sqlglot/dialects/starrocks.py +152 -17
  20. sqlglot/dialects/trino.py +1 -0
  21. sqlglot/dialects/tsql.py +5 -0
  22. sqlglot/diff.py +1 -1
  23. sqlglot/expressions.py +239 -47
  24. sqlglot/generator.py +173 -44
  25. sqlglot/optimizer/annotate_types.py +129 -60
  26. sqlglot/optimizer/merge_subqueries.py +13 -2
  27. sqlglot/optimizer/qualify_columns.py +7 -0
  28. sqlglot/optimizer/resolver.py +19 -0
  29. sqlglot/optimizer/scope.py +12 -0
  30. sqlglot/optimizer/unnest_subqueries.py +7 -0
  31. sqlglot/parser.py +251 -58
  32. sqlglot/schema.py +186 -14
  33. sqlglot/tokens.py +36 -6
  34. sqlglot/transforms.py +6 -5
  35. sqlglot/typing/__init__.py +29 -10
  36. sqlglot/typing/bigquery.py +5 -10
  37. sqlglot/typing/duckdb.py +39 -0
  38. sqlglot/typing/hive.py +50 -1
  39. sqlglot/typing/mysql.py +32 -0
  40. sqlglot/typing/presto.py +0 -1
  41. sqlglot/typing/snowflake.py +80 -17
  42. sqlglot/typing/spark.py +29 -0
  43. sqlglot/typing/spark2.py +9 -1
  44. sqlglot/typing/tsql.py +21 -0
  45. {sqlglot-28.4.0.dist-info → sqlglot-28.8.0.dist-info}/METADATA +47 -2
  46. sqlglot-28.8.0.dist-info/RECORD +95 -0
  47. {sqlglot-28.4.0.dist-info → sqlglot-28.8.0.dist-info}/WHEEL +1 -1
  48. sqlglot-28.4.0.dist-info/RECORD +0 -92
  49. {sqlglot-28.4.0.dist-info → sqlglot-28.8.0.dist-info}/licenses/LICENSE +0 -0
  50. {sqlglot-28.4.0.dist-info → sqlglot-28.8.0.dist-info}/top_level.txt +0 -0
sqlglot/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '28.4.0'
32
- __version_tuple__ = version_tuple = (28, 4, 0)
31
+ __version__ = version = '28.8.0'
32
+ __version_tuple__ = version_tuple = (28, 8, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -51,6 +51,8 @@ JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar, exp.JSONExtr
51
51
 
52
52
  DQUOTES_ESCAPING_JSON_FUNCTIONS = ("JSON_QUERY", "JSON_VALUE", "JSON_QUERY_ARRAY")
53
53
 
54
+ MAKE_INTERVAL_KWARGS = ["year", "month", "day", "hour", "minute", "second"]
55
+
54
56
 
55
57
  def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Values) -> str:
56
58
  if not expression.find_ancestor(exp.From, exp.Join):
@@ -389,7 +391,9 @@ class BigQuery(Dialect):
389
391
  EXCLUDES_PSEUDOCOLUMNS_FROM_STAR = True
390
392
  QUERY_RESULTS_ARE_STRUCTS = True
391
393
  JSON_EXTRACT_SCALAR_SCALAR_ONLY = True
394
+ LEAST_GREATEST_IGNORES_NULLS = False
392
395
  DEFAULT_NULL_TYPE = exp.DataType.Type.BIGINT
396
+ PRIORITIZE_NON_LITERAL_TYPES = True
393
397
 
394
398
  # https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#initcap
395
399
  INITCAP_DEFAULT_DELIMITER_CHARS = ' \t\n\r\f\v\\[\\](){}/|<>!?@"^#$&~_,.:;*%+\\-'
@@ -602,12 +606,6 @@ class BigQuery(Dialect):
602
606
  "EDIT_DISTANCE": _build_levenshtein,
603
607
  "FORMAT_DATE": _build_format_time(exp.TsOrDsToDate),
604
608
  "GENERATE_ARRAY": exp.GenerateSeries.from_arg_list,
605
- "GREATEST": lambda args: exp.Greatest(
606
- this=seq_get(args, 0), expressions=args[1:], null_if_any_null=True
607
- ),
608
- "LEAST": lambda args: exp.Least(
609
- this=seq_get(args, 0), expressions=args[1:], null_if_any_null=True
610
- ),
611
609
  "JSON_EXTRACT_SCALAR": _build_extract_json_with_default_path(exp.JSONExtractScalar),
612
610
  "JSON_EXTRACT_ARRAY": _build_extract_json_with_default_path(exp.JSONExtractArray),
613
611
  "JSON_EXTRACT_STRING_ARRAY": _build_extract_json_with_default_path(exp.JSONValueArray),
@@ -964,7 +962,7 @@ class BigQuery(Dialect):
964
962
  def _parse_make_interval(self) -> exp.MakeInterval:
965
963
  expr = exp.MakeInterval()
966
964
 
967
- for arg_key in expr.arg_types:
965
+ for arg_key in MAKE_INTERVAL_KWARGS:
968
966
  value = self._parse_lambda()
969
967
 
970
968
  if not value:
@@ -1069,20 +1067,23 @@ class BigQuery(Dialect):
1069
1067
  )
1070
1068
 
1071
1069
  def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
1070
+ func_index = self._index + 1
1072
1071
  this = super()._parse_column_ops(this)
1073
1072
 
1074
- if isinstance(this, exp.Dot):
1075
- prefix_name = this.this.name.upper()
1076
- func_name = this.name.upper()
1077
- if prefix_name == "NET":
1078
- if func_name == "HOST":
1079
- this = self.expression(
1080
- exp.NetHost, this=seq_get(this.expression.expressions, 0)
1081
- )
1082
- elif prefix_name == "SAFE":
1083
- if func_name == "TIMESTAMP":
1084
- this = _build_timestamp(this.expression.expressions)
1085
- this.set("safe", True)
1073
+ if isinstance(this, exp.Dot) and isinstance(this.expression, exp.Func):
1074
+ prefix = this.this.name.upper()
1075
+
1076
+ func: t.Optional[t.Type[exp.Func]] = None
1077
+ if prefix == "NET":
1078
+ func = exp.NetFunc
1079
+ elif prefix == "SAFE":
1080
+ func = exp.SafeFunc
1081
+
1082
+ if func:
1083
+ # Retreat to try and parse a known function instead of an anonymous one,
1084
+ # which is parsed by the base column ops parser due to anonymous_func=true
1085
+ self._retreat(func_index)
1086
+ this = func(this=self._parse_function(any_token=True))
1086
1087
 
1087
1088
  return this
1088
1089
 
@@ -1551,7 +1552,3 @@ class BigQuery(Dialect):
1551
1552
  kind = f" {kind}" if kind else ""
1552
1553
 
1553
1554
  return f"{variables}{kind}{default}"
1554
-
1555
- def timestamp_sql(self, expression: exp.Timestamp) -> str:
1556
- prefix = "SAFE." if expression.args.get("safe") else ""
1557
- return self.func(f"{prefix}TIMESTAMP", expression.this, expression.args.get("zone"))
@@ -565,6 +565,8 @@ class ClickHouse(Dialect):
565
565
  "MEDIAN": lambda self: self._parse_quantile(),
566
566
  "COLUMNS": lambda self: self._parse_columns(),
567
567
  "TUPLE": lambda self: exp.Struct.from_arg_list(self._parse_function_args(alias=True)),
568
+ "AND": lambda self: exp.and_(*self._parse_function_args(alias=False)),
569
+ "OR": lambda self: exp.or_(*self._parse_function_args(alias=False)),
568
570
  }
569
571
 
570
572
  FUNCTION_PARSERS.pop("MATCH")
@@ -19,6 +19,7 @@ from sqlglot.helper import (
19
19
  seq_get,
20
20
  suggest_closest_match_and_fail,
21
21
  to_bool,
22
+ ensure_list,
22
23
  )
23
24
  from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path
24
25
  from sqlglot.parser import Parser
@@ -27,6 +28,8 @@ from sqlglot.tokens import Token, Tokenizer, TokenType
27
28
  from sqlglot.trie import new_trie
28
29
  from sqlglot.typing import EXPRESSION_METADATA
29
30
 
31
+ from importlib.metadata import entry_points
32
+
30
33
  DATE_ADD_OR_DIFF = t.Union[
31
34
  exp.DateAdd,
32
35
  exp.DateDiff,
@@ -66,6 +69,8 @@ UNESCAPED_SEQUENCES = {
66
69
  "\\\\": "\\",
67
70
  }
68
71
 
72
+ PLUGIN_GROUP_NAME = "sqlglot.dialects"
73
+
69
74
 
70
75
  class Dialects(str, Enum):
71
76
  """Dialects supported by SQLGLot."""
@@ -153,12 +158,54 @@ class _Dialect(type):
153
158
  if isinstance(key, Dialects):
154
159
  key = key.value
155
160
 
156
- # This import will lead to a new dialect being loaded, and hence, registered.
157
- # We check that the key is an actual sqlglot module to avoid blindly importing
158
- # files. Custom user dialects need to be imported at the top-level package, in
159
- # order for them to be registered as soon as possible.
161
+ # 1. Try standard sqlglot modules first
160
162
  if key in DIALECT_MODULE_NAMES:
163
+ module = importlib.import_module(f"sqlglot.dialects.{key}")
164
+ # If module was already imported, the class may not be in _classes
165
+ # Find and register the dialect class from the module
166
+ if key not in cls._classes:
167
+ for attr_name in dir(module):
168
+ attr = getattr(module, attr_name, None)
169
+ if (
170
+ isinstance(attr, type)
171
+ and issubclass(attr, Dialect)
172
+ and attr.__name__.lower() == key
173
+ ):
174
+ cls._classes[key] = attr
175
+ break
176
+ return
177
+
178
+ # 2. Try entry points (for plugins)
179
+ try:
180
+ all_eps = entry_points()
181
+ # Python 3.10+ has select() method, older versions use dict-like access
182
+ if hasattr(all_eps, "select"):
183
+ eps = all_eps.select(group=PLUGIN_GROUP_NAME, name=key)
184
+ else:
185
+ # For older Python versions, entry_points() returns a dict-like object
186
+ group_eps = all_eps.get(PLUGIN_GROUP_NAME, []) # type: ignore
187
+ eps = [ep for ep in group_eps if ep.name == key] # type: ignore
188
+
189
+ for entry_point in eps:
190
+ dialect_class = entry_point.load()
191
+ # Verify it's a Dialect subclass
192
+ # issubclass() returns False if not a subclass, TypeError only if not a class at all
193
+ if isinstance(dialect_class, type) and issubclass(dialect_class, Dialect):
194
+ # Register the dialect using the entry point name (key)
195
+ # The metaclass may have registered it by class name, but we need it by entry point name
196
+ if key not in cls._classes:
197
+ cls._classes[key] = dialect_class
198
+ return
199
+ except ImportError:
200
+ # entry_point.load() failed (bad plugin - module/class doesn't exist)
201
+ pass
202
+
203
+ # 3. Try direct import (for backward compatibility)
204
+ # This allows namespace packages or explicit imports to work
205
+ try:
161
206
  importlib.import_module(f"sqlglot.dialects.{key}")
207
+ except ImportError:
208
+ pass
162
209
 
163
210
  @classmethod
164
211
  def __getitem__(cls, key: str) -> t.Type[Dialect]:
@@ -235,7 +282,12 @@ class _Dialect(type):
235
282
  klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
236
283
  klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
237
284
 
238
- if "\\" in klass.tokenizer_class.STRING_ESCAPES:
285
+ klass.STRINGS_SUPPORT_ESCAPED_SEQUENCES = "\\" in klass.tokenizer_class.STRING_ESCAPES
286
+ klass.BYTE_STRINGS_SUPPORT_ESCAPED_SEQUENCES = (
287
+ "\\" in klass.tokenizer_class.BYTE_STRING_ESCAPES
288
+ )
289
+
290
+ if klass.STRINGS_SUPPORT_ESCAPED_SEQUENCES or klass.BYTE_STRINGS_SUPPORT_ESCAPED_SEQUENCES:
239
291
  klass.UNESCAPED_SEQUENCES = {
240
292
  **UNESCAPED_SEQUENCES,
241
293
  **klass.UNESCAPED_SEQUENCES,
@@ -650,6 +702,9 @@ class Dialect(metaclass=_Dialect):
650
702
  ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True
651
703
  """Whether ArrayAgg needs to filter NULL values."""
652
704
 
705
+ ARRAY_FUNCS_PROPAGATES_NULLS = False
706
+ """Whether Array update functions return NULL when the input array is NULL."""
707
+
653
708
  PROMOTE_TO_INFERRED_DATETIME_TYPE = False
654
709
  """
655
710
  This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted
@@ -741,6 +796,18 @@ class Dialect(metaclass=_Dialect):
741
796
  For example, in BigQuery the default type of the NULL value is INT64.
742
797
  """
743
798
 
799
+ LEAST_GREATEST_IGNORES_NULLS = True
800
+ """
801
+ Whether LEAST/GREATEST functions ignore NULL values, e.g:
802
+ - BigQuery, Snowflake, MySQL, Presto/Trino: LEAST(1, NULL, 2) -> NULL
803
+ - Spark, Postgres, DuckDB, TSQL: LEAST(1, NULL, 2) -> 1
804
+ """
805
+
806
+ PRIORITIZE_NON_LITERAL_TYPES = False
807
+ """
808
+ Whether to prioritize non-literal types over literals during type annotation.
809
+ """
810
+
744
811
  # --- Autofilled ---
745
812
 
746
813
  tokenizer_class = Tokenizer
@@ -935,7 +1002,9 @@ class Dialect(metaclass=_Dialect):
935
1002
 
936
1003
  result = cls.get(dialect_name.strip())
937
1004
  if not result:
938
- suggest_closest_match_and_fail("dialect", dialect_name, list(DIALECT_MODULE_NAMES))
1005
+ # Include both built-in dialects and any loaded dialects for better error messages
1006
+ all_dialects = set(DIALECT_MODULE_NAMES) | set(cls._classes.keys())
1007
+ suggest_closest_match_and_fail("dialect", dialect_name, all_dialects)
939
1008
 
940
1009
  assert result is not None
941
1010
  return result(**kwargs)
@@ -1282,6 +1351,138 @@ def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
1282
1351
  )
1283
1352
 
1284
1353
 
1354
+ def array_append_sql(
1355
+ name: str, swap_params: bool = False
1356
+ ) -> t.Callable[[Generator, exp.ArrayAppend | exp.ArrayPrepend], str]:
1357
+ """
1358
+ Transpile ARRAY_APPEND/ARRAY_PREPEND between dialects with different NULL propagation semantics.
1359
+
1360
+ Some dialects (Databricks, Spark, Snowflake) return NULL when the input array is NULL.
1361
+ Others (DuckDB, Postgres) create a new single-element array instead.
1362
+
1363
+ Args:
1364
+ name: Target dialect's function name (e.g., "ARRAY_APPEND", "ARRAY_PREPEND")
1365
+ swap_params: If True, generate (element, array) order instead of (array, element).
1366
+ DuckDB LIST_PREPEND and Postgres ARRAY_PREPEND use (element, array).
1367
+
1368
+ Returns:
1369
+ A callable that generates SQL with appropriate NULL handling for the target dialect.
1370
+ Dialects that propagate NULLs need to set `ARRAY_FUNCS_PROPAGATES_NULLS` to True.
1371
+ """
1372
+
1373
+ def _array_append_sql(self: Generator, expression: exp.ArrayAppend | exp.ArrayPrepend) -> str:
1374
+ this = expression.this
1375
+ element = expression.expression
1376
+ args = [element, this] if swap_params else [this, element]
1377
+ func_sql = self.func(name, *args)
1378
+
1379
+ source_null_propagation = bool(expression.args.get("null_propagation"))
1380
+ target_null_propagation = self.dialect.ARRAY_FUNCS_PROPAGATES_NULLS
1381
+
1382
+ # No transpilation needed when source and target have matching NULL semantics
1383
+ if source_null_propagation == target_null_propagation:
1384
+ return func_sql
1385
+
1386
+ # Source propagates NULLs, target doesn't: wrap in conditional to return NULL explicitly
1387
+ if source_null_propagation:
1388
+ return self.sql(
1389
+ exp.If(
1390
+ this=exp.Is(this=this, expression=exp.Null()),
1391
+ true=exp.Null(),
1392
+ false=func_sql,
1393
+ )
1394
+ )
1395
+
1396
+ # Source doesn't propagate NULLs, target does: use COALESCE to convert NULL to empty array
1397
+ this = exp.Coalesce(expressions=[this, exp.Array(expressions=[])])
1398
+ args = [element, this] if swap_params else [this, element]
1399
+ return self.func(name, *args)
1400
+
1401
+ return _array_append_sql
1402
+
1403
+
1404
+ def array_concat_sql(
1405
+ name: str,
1406
+ ) -> t.Callable[[Generator, exp.ArrayConcat], str]:
1407
+ """
1408
+ Transpile ARRAY_CONCAT/ARRAY_CAT between dialects with different NULL propagation semantics.
1409
+
1410
+ Some dialects (Redshift, Snowflake, Spark) return NULL when ANY input array is NULL.
1411
+ Others (DuckDB, PostgreSQL) skip NULL arrays and continue concatenation.
1412
+
1413
+ Args:
1414
+ name: Target dialect's function name (e.g., "ARRAY_CAT", "ARRAY_CONCAT", "LIST_CONCAT")
1415
+
1416
+ Returns:
1417
+ A callable that generates SQL with appropriate NULL handling for the target dialect.
1418
+ Dialects that propagate NULLs need to set `ARRAY_FUNCS_PROPAGATES_NULLS` to True.
1419
+ """
1420
+
1421
+ def _build_func_call(self: Generator, func_name: str, args: t.Sequence[exp.Expression]) -> str:
1422
+ """Build ARRAY_CONCAT call from a list of arguments, handling variadic vs binary nesting."""
1423
+ if self.ARRAY_CONCAT_IS_VAR_LEN:
1424
+ return self.func(func_name, *args)
1425
+ elif len(args) == 1:
1426
+ # Single arg gets empty array to preserve semantics
1427
+ return self.func(func_name, args[0], exp.Array(expressions=[]))
1428
+ else:
1429
+ # Snowflake/PostgreSQL/Redshift require binary nesting: ARRAY_CAT(a, ARRAY_CAT(b, c))
1430
+ # Build right-deep tree recursively to avoid creating new ArrayConcat expressions
1431
+ result = self.func(func_name, args[-2], args[-1])
1432
+ for arg in reversed(args[:-2]):
1433
+ result = f"{func_name}({self.sql(arg)}, {result})"
1434
+ return result
1435
+
1436
+ def _array_concat_sql(self: Generator, expression: exp.ArrayConcat) -> str:
1437
+ this = expression.this
1438
+ exprs = expression.expressions
1439
+ all_args = [this] + exprs
1440
+
1441
+ source_null_propagation = bool(expression.args.get("null_propagation"))
1442
+ target_null_propagation = self.dialect.ARRAY_FUNCS_PROPAGATES_NULLS
1443
+
1444
+ # Skip wrapper when source and target have matching NULL semantics,
1445
+ # or when the first argument is an array literal (which can never be NULL),
1446
+ # or when it's a single-argument call (empty array is added, preserving NULL semantics)
1447
+ if (
1448
+ source_null_propagation == target_null_propagation
1449
+ or isinstance(this, exp.Array)
1450
+ or len(exprs) == 0
1451
+ ):
1452
+ return _build_func_call(self, name, all_args)
1453
+
1454
+ # Case 1: Source propagates NULLs, target doesn't (Snowflake → DuckDB)
1455
+ # Check if ANY argument is NULL and return NULL explicitly
1456
+ if source_null_propagation:
1457
+ # Build OR-chain: a IS NULL OR b IS NULL OR c IS NULL
1458
+ null_checks: t.List[exp.Expression] = [
1459
+ exp.Is(this=arg.copy(), expression=exp.Null()) for arg in all_args
1460
+ ]
1461
+ combined_check: exp.Expression = reduce(
1462
+ lambda a, b: exp.Or(this=a, expression=b), null_checks
1463
+ )
1464
+
1465
+ func_sql = _build_func_call(self, name, all_args)
1466
+
1467
+ return self.sql(
1468
+ exp.If(
1469
+ this=combined_check,
1470
+ true=exp.Null(),
1471
+ false=func_sql,
1472
+ )
1473
+ )
1474
+
1475
+ # Case 2: Source doesn't propagate NULLs, target does (DuckDB → Snowflake)
1476
+ # Wrap ALL arguments in COALESCE to convert NULL → empty array
1477
+ wrapped_args = [
1478
+ exp.Coalesce(expressions=[arg.copy(), exp.Array(expressions=[])]) for arg in all_args
1479
+ ]
1480
+
1481
+ return _build_func_call(self, name, wrapped_args)
1482
+
1483
+ return _array_concat_sql
1484
+
1485
+
1285
1486
  def var_map_sql(
1286
1487
  self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
1287
1488
  ) -> str:
@@ -1300,6 +1501,59 @@ def var_map_sql(
1300
1501
  return self.func(map_func_name, *args)
1301
1502
 
1302
1503
 
1504
+ def months_between_sql(self: Generator, expression: exp.MonthsBetween) -> str:
1505
+ """
1506
+ Transpile MONTHS_BETWEEN to dialects that don't have native support.
1507
+
1508
+ Snowflake's MONTHS_BETWEEN returns whole months + fractional part where:
1509
+ - Fractional part = (DAY(date1) - DAY(date2)) / 31
1510
+ - Special case: If both dates are last day of month, fractional part = 0
1511
+
1512
+ Formula: DATEDIFF('month', date2, date1) + (DAY(date1) - DAY(date2)) / 31.0
1513
+ """
1514
+ date1 = expression.this
1515
+ date2 = expression.expression
1516
+
1517
+ # Cast to DATE to ensure consistent behavior
1518
+ date1_cast = exp.cast(date1, exp.DataType.Type.DATE, copy=False)
1519
+ date2_cast = exp.cast(date2, exp.DataType.Type.DATE, copy=False)
1520
+
1521
+ # Whole months: DATEDIFF('month', date2, date1)
1522
+ whole_months = exp.DateDiff(this=date1_cast, expression=date2_cast, unit=exp.var("month"))
1523
+
1524
+ # Day components
1525
+ day1 = exp.Day(this=date1_cast.copy())
1526
+ day2 = exp.Day(this=date2_cast.copy())
1527
+
1528
+ # Last day of month components
1529
+ last_day_of_month1 = exp.LastDay(this=date1_cast.copy())
1530
+ last_day_of_month2 = exp.LastDay(this=date2_cast.copy())
1531
+
1532
+ day_of_last_day1 = exp.Day(this=last_day_of_month1)
1533
+ day_of_last_day2 = exp.Day(this=last_day_of_month2)
1534
+
1535
+ # Check if both are last day of month
1536
+ last_day1 = exp.EQ(this=day1.copy(), expression=day_of_last_day1)
1537
+ last_day2 = exp.EQ(this=day2.copy(), expression=day_of_last_day2)
1538
+ both_last_day = exp.And(this=last_day1, expression=last_day2)
1539
+
1540
+ # Fractional part: (DAY(date1) - DAY(date2)) / 31.0
1541
+ fractional = exp.Div(
1542
+ this=exp.Paren(this=exp.Sub(this=day1.copy(), expression=day2.copy())),
1543
+ expression=exp.Literal.number("31.0"),
1544
+ )
1545
+
1546
+ # If both are last day of month, fractional = 0, else calculate fractional
1547
+ fractional_with_check = exp.If(
1548
+ this=both_last_day, true=exp.Literal.number("0"), false=fractional
1549
+ )
1550
+
1551
+ # Final result: whole_months + fractional
1552
+ result = exp.Add(this=whole_months, expression=fractional_with_check)
1553
+
1554
+ return self.sql(result)
1555
+
1556
+
1303
1557
  def build_formatted_time(
1304
1558
  exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
1305
1559
  ) -> t.Callable[[t.List], E]:
@@ -1899,15 +2153,54 @@ def filter_array_using_unnest(
1899
2153
  return self.sql(exp.Array(expressions=[filtered]))
1900
2154
 
1901
2155
 
2156
+ def array_compact_sql(self: Generator, expression: exp.ArrayCompact) -> str:
2157
+ lambda_id = exp.to_identifier("_u")
2158
+ cond = exp.Is(this=lambda_id, expression=exp.null()).not_()
2159
+ return self.sql(
2160
+ exp.ArrayFilter(
2161
+ this=expression.this,
2162
+ expression=exp.Lambda(this=cond, expressions=[lambda_id]),
2163
+ )
2164
+ )
2165
+
2166
+
1902
2167
  def remove_from_array_using_filter(self: Generator, expression: exp.ArrayRemove) -> str:
1903
2168
  lambda_id = exp.to_identifier("_u")
1904
2169
  cond = exp.NEQ(this=lambda_id, expression=expression.expression)
1905
- return self.sql(
2170
+
2171
+ filter_sql = self.sql(
1906
2172
  exp.ArrayFilter(
1907
- this=expression.this, expression=exp.Lambda(this=cond, expressions=[lambda_id])
2173
+ this=expression.this,
2174
+ expression=exp.Lambda(this=cond, expressions=[lambda_id]),
1908
2175
  )
1909
2176
  )
1910
2177
 
2178
+ # Handle NULL propagation for ArrayRemove
2179
+ source_null_propagation = bool(expression.args.get("null_propagation"))
2180
+ target_null_propagation = self.dialect.ARRAY_FUNCS_PROPAGATES_NULLS
2181
+
2182
+ # Source propagates NULLs (Snowflake), target doesn't (DuckDB):
2183
+ # When removal value is NULL, return NULL instead of applying filter
2184
+ if source_null_propagation and not target_null_propagation:
2185
+ removal_value = expression.expression
2186
+
2187
+ # Optimization: skip wrapper if removal value is a non-NULL literal
2188
+ # (e.g., 5, 'a', TRUE) or an array literal (e.g., [1, 2])
2189
+ if (
2190
+ isinstance(removal_value, exp.Literal) and not isinstance(removal_value, exp.Null)
2191
+ ) or isinstance(removal_value, exp.Array):
2192
+ return filter_sql
2193
+
2194
+ return self.sql(
2195
+ exp.If(
2196
+ this=exp.Is(this=removal_value, expression=exp.Null()),
2197
+ true=exp.Null(),
2198
+ false=filter_sql,
2199
+ )
2200
+ )
2201
+
2202
+ return filter_sql
2203
+
1911
2204
 
1912
2205
  def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
1913
2206
  return self.func(
@@ -2036,17 +2329,40 @@ def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect],
2036
2329
 
2037
2330
 
2038
2331
  def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str:
2039
- if isinstance(expression.this, exp.Explode):
2040
- return self.sql(
2041
- exp.Join(
2042
- this=exp.Unnest(
2043
- expressions=[expression.this.this],
2044
- alias=expression.args.get("alias"),
2045
- offset=isinstance(expression.this, exp.Posexplode),
2046
- ),
2047
- kind="cross",
2048
- )
2332
+ this = expression.this
2333
+ alias = expression.args.get("alias")
2334
+
2335
+ cross_join_expr: t.Optional[exp.Expression] = None
2336
+ if isinstance(this, exp.Posexplode) and alias:
2337
+ # Spark's `FROM x LATERAL VIEW POSEXPLODE(y) t AS pos, col` has the following semantics:
2338
+ # - The first column is the position and the rest (1 for array, 2 for maps) are the exploded values
2339
+ # - The position is 0-based whereas WITH ORDINALITY is 1-based
2340
+ # For that matter, we must (1) subtract 1 from the ORDINALITY position and (2) rearrange the columns accordingly, returning:
2341
+ # `FROM x CROSS JOIN LATERAL (SELECT pos - 1 AS pos, col FROM UNNEST(y) WITH ORDINALITY AS t(col, pos))
2342
+ pos, cols = alias.columns[0], alias.columns[1:]
2343
+
2344
+ cols = ensure_list(cols)
2345
+ lateral_subquery = exp.select(
2346
+ exp.alias_(pos - 1, pos),
2347
+ *cols,
2348
+ ).from_(
2349
+ exp.Unnest(
2350
+ expressions=[this.this],
2351
+ offset=True,
2352
+ alias=exp.TableAlias(this=alias.this, columns=[*cols, pos]),
2353
+ ),
2354
+ )
2355
+
2356
+ cross_join_expr = exp.Lateral(this=lateral_subquery.subquery())
2357
+ elif isinstance(this, exp.Explode):
2358
+ cross_join_expr = exp.Unnest(
2359
+ expressions=[this.this],
2360
+ alias=alias,
2049
2361
  )
2362
+
2363
+ if cross_join_expr:
2364
+ return self.sql(exp.Join(this=cross_join_expr, kind="cross"))
2365
+
2050
2366
  return self.lateral_sql(expression)
2051
2367
 
2052
2368
 
@@ -2154,3 +2470,24 @@ def regexp_replace_global_modifier(expression: exp.RegexpReplace) -> exp.Express
2154
2470
  modifiers = exp.Literal.string(value + "g")
2155
2471
 
2156
2472
  return modifiers
2473
+
2474
+
2475
+ def getbit_sql(self: Generator, expression: exp.Getbit) -> str:
2476
+ """
2477
+ Generates SQL for Getbit according to DuckDB and Postgres, transpiling it if either:
2478
+
2479
+ 1. The zero index corresponds to the least-significant bit
2480
+ 2. The input type is an integer value
2481
+ """
2482
+ value = expression.this
2483
+ position = expression.expression
2484
+
2485
+ if not expression.args.get("zero_is_msb") and expression.is_type(
2486
+ *exp.DataType.SIGNED_INTEGER_TYPES, *exp.DataType.UNSIGNED_INTEGER_TYPES
2487
+ ):
2488
+ # Use bitwise operations: (value >> position) & 1
2489
+ shifted = exp.BitwiseRightShift(this=value, expression=position)
2490
+ masked = exp.BitwiseAnd(this=shifted, expression=exp.Literal.number(1))
2491
+ return self.sql(masked)
2492
+
2493
+ return self.func("GET_BIT", value, position)