snowpark-connect 0.29.0__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (41) hide show
  1. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  2. snowflake/snowpark_connect/client.py +65 -0
  3. snowflake/snowpark_connect/column_name_handler.py +6 -0
  4. snowflake/snowpark_connect/config.py +25 -3
  5. snowflake/snowpark_connect/execute_plan/map_execution_root.py +21 -19
  6. snowflake/snowpark_connect/expression/map_extension.py +277 -1
  7. snowflake/snowpark_connect/expression/map_sql_expression.py +107 -2
  8. snowflake/snowpark_connect/expression/map_unresolved_function.py +253 -59
  9. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  10. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  11. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
  12. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
  13. snowflake/snowpark_connect/relation/io_utils.py +61 -4
  14. snowflake/snowpark_connect/relation/map_column_ops.py +9 -4
  15. snowflake/snowpark_connect/relation/map_join.py +8 -0
  16. snowflake/snowpark_connect/relation/map_row_ops.py +129 -17
  17. snowflake/snowpark_connect/relation/map_show_string.py +14 -6
  18. snowflake/snowpark_connect/relation/map_sql.py +39 -5
  19. snowflake/snowpark_connect/relation/map_stats.py +21 -6
  20. snowflake/snowpark_connect/relation/read/map_read.py +9 -0
  21. snowflake/snowpark_connect/relation/read/map_read_csv.py +17 -6
  22. snowflake/snowpark_connect/relation/read/map_read_json.py +12 -2
  23. snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -1
  24. snowflake/snowpark_connect/relation/read/metadata_utils.py +159 -0
  25. snowflake/snowpark_connect/relation/utils.py +19 -2
  26. snowflake/snowpark_connect/relation/write/map_write.py +44 -29
  27. snowflake/snowpark_connect/server.py +11 -3
  28. snowflake/snowpark_connect/type_mapping.py +75 -3
  29. snowflake/snowpark_connect/utils/describe_query_cache.py +6 -3
  30. snowflake/snowpark_connect/utils/telemetry.py +105 -23
  31. snowflake/snowpark_connect/version.py +1 -1
  32. {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/METADATA +1 -1
  33. {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/RECORD +41 -37
  34. {snowpark_connect-0.29.0.data → snowpark_connect-0.30.1.data}/scripts/snowpark-connect +0 -0
  35. {snowpark_connect-0.29.0.data → snowpark_connect-0.30.1.data}/scripts/snowpark-session +0 -0
  36. {snowpark_connect-0.29.0.data → snowpark_connect-0.30.1.data}/scripts/snowpark-submit +0 -0
  37. {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/WHEEL +0 -0
  38. {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/licenses/LICENSE-binary +0 -0
  39. {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/licenses/LICENSE.txt +0 -0
  40. {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/licenses/NOTICE-binary +0 -0
  41. {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@ from pyspark.sql.connect import functions as pyspark_functions
15
15
 
16
16
  import snowflake.snowpark_connect.proto.snowflake_expression_ext_pb2 as snowflake_proto
17
17
  from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
18
+ from snowflake.snowpark.types import DayTimeIntervalType, YearMonthIntervalType
18
19
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
19
20
  from snowflake.snowpark_connect.config import global_config
20
21
  from snowflake.snowpark_connect.typed_column import TypedColumn
@@ -33,6 +34,24 @@ from .typer import ExpressionTyper
33
34
 
34
35
  DECIMAL_RE = re.compile(r"decimal\((\d+), *(\d+)\)")
35
36
 
37
+ _INTERVAL_YEARMONTH_PATTERN_RE = re.compile(r"interval (year|month)( to (year|month))?")
38
+ _INTERVAL_DAYTIME_PATTERN_RE = re.compile(
39
+ r"interval (day|hour|minute|second)( to (day|hour|minute|second))?"
40
+ )
41
+
42
+ # Interval field mappings using proper constants
43
+ _YEAR_MONTH_FIELD_MAP = {
44
+ "year": YearMonthIntervalType.YEAR,
45
+ "month": YearMonthIntervalType.MONTH,
46
+ }
47
+
48
+ _DAY_TIME_FIELD_MAP = {
49
+ "day": DayTimeIntervalType.DAY,
50
+ "hour": DayTimeIntervalType.HOUR,
51
+ "minute": DayTimeIntervalType.MINUTE,
52
+ "second": DayTimeIntervalType.SECOND,
53
+ }
54
+
36
55
  _window_specs = ContextVar[dict[str, any]]("_window_specs", default={})
37
56
 
38
57
 
@@ -388,8 +407,94 @@ def map_logical_plan_expression(exp: jpype.JObject) -> expressions_proto.Express
388
407
  type_value = types_proto.DataType()
389
408
  elif type_name == "binary":
390
409
  type_value = bytes(type_value)
391
- elif type_name.startswith("interval "):
392
- type_name = "day_time_interval" # TODO
410
+ elif year_month_match := _INTERVAL_YEARMONTH_PATTERN_RE.match(type_name):
411
+ # Extract start and end fields for year-month intervals
412
+ start_field_name = year_month_match.group(1) # 'year' or 'month'
413
+ end_field_name = (
414
+ year_month_match.group(3)
415
+ if year_month_match.group(3)
416
+ else start_field_name
417
+ )
418
+
419
+ # Validate field names exist in mapping
420
+ start_field = _YEAR_MONTH_FIELD_MAP.get(start_field_name)
421
+ end_field = _YEAR_MONTH_FIELD_MAP.get(end_field_name)
422
+
423
+ if start_field is None:
424
+ raise AnalysisException(
425
+ f"Invalid year-month interval start field: '{start_field_name}'. Expected 'year' or 'month'."
426
+ )
427
+ if end_field is None:
428
+ raise AnalysisException(
429
+ f"Invalid year-month interval end field: '{end_field_name}'. Expected 'year' or 'month'."
430
+ )
431
+
432
+ # Validate field ordering (start_field should be <= end_field)
433
+ if start_field > end_field:
434
+ raise AnalysisException(
435
+ f"Invalid year-month interval: start field '{start_field_name}' must come before or equal to end field '{end_field_name}'."
436
+ )
437
+
438
+ # Use extension for year-month intervals to preserve start/end field info
439
+ literal = expressions_proto.Expression.Literal(
440
+ year_month_interval=type_value
441
+ )
442
+ any_proto = Any()
443
+ any_proto.Pack(
444
+ snowflake_proto.ExpExtension(
445
+ interval_literal=snowflake_proto.IntervalLiteralExpression(
446
+ literal=literal,
447
+ start_field=start_field,
448
+ end_field=end_field,
449
+ )
450
+ )
451
+ )
452
+ return expressions_proto.Expression(extension=any_proto)
453
+ elif day_time_match := _INTERVAL_DAYTIME_PATTERN_RE.match(type_name):
454
+ # Extract start and end fields for day-time intervals
455
+ start_field_name = day_time_match.group(
456
+ 1
457
+ ) # 'day', 'hour', 'minute', 'second'
458
+ end_field_name = (
459
+ day_time_match.group(3)
460
+ if day_time_match.group(3)
461
+ else start_field_name
462
+ )
463
+
464
+ # Validate field names exist in mapping
465
+ start_field = _DAY_TIME_FIELD_MAP.get(start_field_name)
466
+ end_field = _DAY_TIME_FIELD_MAP.get(end_field_name)
467
+
468
+ if start_field is None:
469
+ raise AnalysisException(
470
+ f"Invalid day-time interval start field: '{start_field_name}'. Expected 'day', 'hour', 'minute', or 'second'."
471
+ )
472
+ if end_field is None:
473
+ raise AnalysisException(
474
+ f"Invalid day-time interval end field: '{end_field_name}'. Expected 'day', 'hour', 'minute', or 'second'."
475
+ )
476
+
477
+ # Validate field ordering (start_field should be <= end_field)
478
+ if start_field > end_field:
479
+ raise AnalysisException(
480
+ f"Invalid day-time interval: start field '{start_field_name}' must come before or equal to end field '{end_field_name}'."
481
+ )
482
+
483
+ # Use extension for day-time intervals to preserve start/end field info
484
+ literal = expressions_proto.Expression.Literal(
485
+ day_time_interval=type_value
486
+ )
487
+ any_proto = Any()
488
+ any_proto.Pack(
489
+ snowflake_proto.ExpExtension(
490
+ interval_literal=snowflake_proto.IntervalLiteralExpression(
491
+ literal=literal,
492
+ start_field=start_field,
493
+ end_field=end_field,
494
+ )
495
+ )
496
+ )
497
+ return expressions_proto.Expression(extension=any_proto)
393
498
  elif m := DECIMAL_RE.fullmatch(type_name):
394
499
  type_name = "decimal"
395
500
  type_value = expressions_proto.Expression.Literal.Decimal(
@@ -20,7 +20,7 @@ from contextlib import suppress
20
20
  from decimal import ROUND_HALF_EVEN, ROUND_HALF_UP, Context, Decimal
21
21
  from functools import partial, reduce
22
22
  from pathlib import Path
23
- from typing import List, Optional
23
+ from typing import List, Optional, Union
24
24
  from urllib.parse import quote, unquote
25
25
 
26
26
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
@@ -50,6 +50,7 @@ from snowflake.snowpark.types import (
50
50
  ByteType,
51
51
  DataType,
52
52
  DateType,
53
+ DayTimeIntervalType,
53
54
  DecimalType,
54
55
  DoubleType,
55
56
  FloatType,
@@ -64,6 +65,7 @@ from snowflake.snowpark.types import (
64
65
  TimestampTimeZone,
65
66
  TimestampType,
66
67
  VariantType,
68
+ YearMonthIntervalType,
67
69
  _FractionalType,
68
70
  _IntegralType,
69
71
  _NumericType,
@@ -639,8 +641,17 @@ def map_unresolved_function(
639
641
  if isinstance(t, (IntegerType, ShortType, ByteType)):
640
642
  result_type = DateType()
641
643
  result_exp = snowpark_args[0] + snowpark_args[1]
644
+ elif isinstance(t, (DayTimeIntervalType, YearMonthIntervalType)):
645
+ result_type = TimestampType()
646
+ result_exp = (
647
+ snowpark_args[date_param_index]
648
+ + snowpark_args[t_param_index]
649
+ )
642
650
  elif (
643
- "INTERVAL"
651
+ hasattr(
652
+ snowpark_typed_args[t_param_index].col._expr1, "pretty_name"
653
+ )
654
+ and "INTERVAL"
644
655
  == snowpark_typed_args[t_param_index].col._expr1.pretty_name
645
656
  ):
646
657
  result_type = TimestampType()
@@ -740,6 +751,12 @@ def map_unresolved_function(
740
751
  # TODO SNOW-2034420: resolve return type (it should be INTERVAL DAY)
741
752
  result_type = LongType()
742
753
  result_exp = snowpark_args[0] - snowpark_args[1]
754
+ case (DateType(), DayTimeIntervalType()) | (
755
+ DateType(),
756
+ YearMonthIntervalType(),
757
+ ):
758
+ result_type = TimestampType()
759
+ result_exp = snowpark_args[0] - snowpark_args[1]
743
760
  case (DateType(), StringType()):
744
761
  if (
745
762
  hasattr(snowpark_typed_args[1].col._expr1, "pretty_name")
@@ -944,6 +961,10 @@ def map_unresolved_function(
944
961
  raise AnalysisException(
945
962
  f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} < {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
946
963
  )
964
+ # Check for interval-string comparisons
965
+ _check_interval_string_comparison(
966
+ "<", snowpark_typed_args, snowpark_arg_names
967
+ )
947
968
  left, right = _coerce_for_comparison(
948
969
  snowpark_typed_args[0], snowpark_typed_args[1]
949
970
  )
@@ -958,6 +979,10 @@ def map_unresolved_function(
958
979
  raise AnalysisException(
959
980
  f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} <= {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
960
981
  )
982
+ # Check for interval-string comparisons
983
+ _check_interval_string_comparison(
984
+ "<=", snowpark_typed_args, snowpark_arg_names
985
+ )
961
986
  left, right = _coerce_for_comparison(
962
987
  snowpark_typed_args[0], snowpark_typed_args[1]
963
988
  )
@@ -976,6 +1001,10 @@ def map_unresolved_function(
976
1001
  )
977
1002
  result_exp = TypedColumn(left.eqNullSafe(right), lambda: [BooleanType()])
978
1003
  case "==" | "=":
1004
+ # Check for interval-string comparisons
1005
+ _check_interval_string_comparison(
1006
+ "=", snowpark_typed_args, snowpark_arg_names
1007
+ )
979
1008
  spark_function_name = f"({snowpark_arg_names[0]} = {snowpark_arg_names[1]})"
980
1009
  left, right = _coerce_for_comparison(
981
1010
  snowpark_typed_args[0], snowpark_typed_args[1]
@@ -991,6 +1020,10 @@ def map_unresolved_function(
991
1020
  raise AnalysisException(
992
1021
  f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} > {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
993
1022
  )
1023
+ # Check for interval-string comparisons
1024
+ _check_interval_string_comparison(
1025
+ ">", snowpark_typed_args, snowpark_arg_names
1026
+ )
994
1027
  left, right = _coerce_for_comparison(
995
1028
  snowpark_typed_args[0], snowpark_typed_args[1]
996
1029
  )
@@ -1005,6 +1038,10 @@ def map_unresolved_function(
1005
1038
  raise AnalysisException(
1006
1039
  f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} >= {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
1007
1040
  )
1041
+ # Check for interval-string comparisons
1042
+ _check_interval_string_comparison(
1043
+ ">=", snowpark_typed_args, snowpark_arg_names
1044
+ )
1008
1045
  left, right = _coerce_for_comparison(
1009
1046
  snowpark_typed_args[0], snowpark_typed_args[1]
1010
1047
  )
@@ -1135,9 +1172,7 @@ def map_unresolved_function(
1135
1172
  # SNOW-1955784: Support accuracy parameter
1136
1173
  # Use percentile_disc to return actual values from dataset (matches PySpark behavior)
1137
1174
 
1138
- def _pyspark_approx_percentile(
1139
- column: Column, percentage: float, original_type: DataType
1140
- ) -> Column:
1175
+ def _pyspark_approx_percentile(column: Column, percentage: float) -> Column:
1141
1176
  """
1142
1177
  PySpark-compatible percentile that returns actual values from dataset.
1143
1178
  - PySpark's approx_percentile returns the "smallest value in the ordered col values
@@ -1154,7 +1189,7 @@ def map_unresolved_function(
1154
1189
  result = snowpark_fn.function("percentile_disc")(
1155
1190
  snowpark_fn.lit(percentage)
1156
1191
  ).within_group(column)
1157
- return snowpark_fn.cast(result, original_type)
1192
+ return result
1158
1193
 
1159
1194
  column_type = snowpark_typed_args[0].typ
1160
1195
 
@@ -1165,26 +1200,18 @@ def map_unresolved_function(
1165
1200
  assert array_func.function_name == "array", array_func
1166
1201
 
1167
1202
  percentile_results = [
1168
- _pyspark_approx_percentile(
1169
- snowpark_args[0], unwrap_literal(arg), column_type
1170
- )
1203
+ _pyspark_approx_percentile(snowpark_args[0], unwrap_literal(arg))
1171
1204
  for arg in array_func.arguments
1172
1205
  ]
1173
1206
 
1174
1207
  result_type = ArrayType(element_type=column_type, contains_null=False)
1175
- result_exp = snowpark_fn.cast(
1176
- snowpark_fn.array_construct(*percentile_results),
1177
- result_type,
1178
- )
1208
+ result_exp = snowpark_fn.array_construct(*percentile_results)
1209
+ result_exp = _resolve_aggregate_exp(result_exp, result_type)
1179
1210
  else:
1180
1211
  # Handle single percentile
1181
1212
  percentage = unwrap_literal(exp.unresolved_function.arguments[1])
1182
- result_exp = TypedColumn(
1183
- _pyspark_approx_percentile(
1184
- snowpark_args[0], percentage, column_type
1185
- ),
1186
- lambda: [column_type],
1187
- )
1213
+ result_exp = _pyspark_approx_percentile(snowpark_args[0], percentage)
1214
+ result_exp = _resolve_aggregate_exp(result_exp, column_type)
1188
1215
  case "array":
1189
1216
  if len(snowpark_args) == 0:
1190
1217
  result_exp = snowpark_fn.cast(
@@ -1358,27 +1385,55 @@ def map_unresolved_function(
1358
1385
  result_exp = snowpark_fn.cast(result_exp, array_type)
1359
1386
  result_exp = TypedColumn(result_exp, lambda: snowpark_typed_args[0].types)
1360
1387
  case "array_repeat":
1388
+ elem, count = snowpark_args[0], snowpark_args[1]
1389
+ elem_type = snowpark_typed_args[0].typ
1390
+ result_type = ArrayType(elem_type)
1361
1391
 
1362
- @cached_udf(
1363
- input_types=[VariantType(), LongType()],
1364
- return_type=ArrayType(),
1365
- )
1366
- def _array_repeat(elem, n):
1367
- if n is None:
1368
- return None
1369
- if n < 0:
1370
- return []
1371
- return [elem] * n
1392
+ fallback_to_udf = True
1372
1393
 
1373
- result_exp = snowpark_fn.cast(
1374
- _array_repeat(
1375
- snowpark_fn.cast(snowpark_args[0], VariantType()), snowpark_args[1]
1376
- ),
1377
- ArrayType(snowpark_typed_args[0].typ),
1378
- )
1379
- result_exp = TypedColumn(
1380
- result_exp, lambda: [ArrayType(snowpark_typed_args[0].typ)]
1381
- )
1394
+ if isinstance(count._expression, Literal):
1395
+ count_value = count._expression.value
1396
+ fallback_to_udf = False
1397
+
1398
+ if count_value is None:
1399
+ result_exp = snowpark_fn.lit(None).cast(result_type)
1400
+ elif count_value <= 0:
1401
+ result_exp = snowpark_fn.array_construct().cast(result_type)
1402
+ elif count_value <= 16:
1403
+ # count_value is small enough to initialize the array directly in memory
1404
+ elem_variant = snowpark_fn.cast(elem, VariantType())
1405
+ result_exp = snowpark_fn.array_construct(
1406
+ *([elem_variant] * count_value)
1407
+ ).cast(result_type)
1408
+ else:
1409
+ fallback_to_udf = True
1410
+
1411
+ if fallback_to_udf:
1412
+
1413
+ @cached_udf(
1414
+ input_types=[VariantType(), LongType()],
1415
+ return_type=ArrayType(),
1416
+ )
1417
+ def _array_repeat(elem, n):
1418
+ if n is None:
1419
+ return None
1420
+ if n < 0:
1421
+ return []
1422
+ return [elem] * n
1423
+
1424
+ elem_variant = snowpark_fn.cast(elem, VariantType())
1425
+
1426
+ result_exp = (
1427
+ snowpark_fn.when(
1428
+ count.is_null(), snowpark_fn.lit(None).cast(result_type)
1429
+ )
1430
+ .when(count <= 0, snowpark_fn.array_construct().cast(result_type))
1431
+ .otherwise(
1432
+ snowpark_fn.cast(
1433
+ _array_repeat(elem_variant, count), result_type
1434
+ )
1435
+ )
1436
+ )
1382
1437
  case "array_size":
1383
1438
  array_type = snowpark_typed_args[0].typ
1384
1439
  if not isinstance(array_type, ArrayType):
@@ -1578,7 +1633,7 @@ def map_unresolved_function(
1578
1633
  result_exp = TypedColumn(result_exp, lambda: [LongType()])
1579
1634
  case "bit_get" | "getbit":
1580
1635
  snowflake_compat = get_boolean_session_config_param(
1581
- "enable_snowflake_extension_behavior"
1636
+ "snowpark.connect.enable_snowflake_extension_behavior"
1582
1637
  )
1583
1638
  col, pos = snowpark_args
1584
1639
  if snowflake_compat:
@@ -1885,14 +1940,11 @@ def map_unresolved_function(
1885
1940
  qualifiers = snowpark_args[0].get_qualifiers()
1886
1941
  case "collect_list" | "array_agg":
1887
1942
  # TODO: SNOW-1967177 - Support structured types in array_agg
1888
- result_exp = snowpark_fn.cast(
1889
- snowpark_fn.array_agg(
1890
- snowpark_typed_args[0].column(to_semi_structure=True)
1891
- ),
1892
- ArrayType(snowpark_typed_args[0].typ),
1943
+ result_exp = snowpark_fn.array_agg(
1944
+ snowpark_typed_args[0].column(to_semi_structure=True)
1893
1945
  )
1894
- result_exp = TypedColumn(
1895
- result_exp, lambda: [ArrayType(snowpark_typed_args[0].typ)]
1946
+ result_exp = _resolve_aggregate_exp(
1947
+ result_exp, ArrayType(snowpark_typed_args[0].typ)
1896
1948
  )
1897
1949
  spark_function_name = f"collect_list({snowpark_arg_names[0]})"
1898
1950
  case "collect_set":
@@ -2379,15 +2431,30 @@ def map_unresolved_function(
2379
2431
  # If format is NULL, return NULL for all rows
2380
2432
  result_exp = snowpark_fn.lit(None)
2381
2433
  else:
2434
+ format_lit = snowpark_fn.lit(
2435
+ map_spark_timestamp_format_expression(
2436
+ exp.unresolved_function.arguments[1],
2437
+ snowpark_typed_args[0].typ,
2438
+ )
2439
+ )
2382
2440
  result_exp = snowpark_fn.date_format(
2383
2441
  snowpark_args[0],
2384
- snowpark_fn.lit(
2385
- map_spark_timestamp_format_expression(
2386
- exp.unresolved_function.arguments[1],
2387
- snowpark_typed_args[0].typ,
2388
- )
2389
- ),
2442
+ format_lit,
2390
2443
  )
2444
+
2445
+ if format_literal == "EEEE":
2446
+ # TODO: SNOW-2356874, for weekday, Snowflake only supports abbreviated name, e.g. "Fri". Patch spark "EEEE" until
2447
+ # snowflake supports full weekday name.
2448
+ result_exp = (
2449
+ snowpark_fn.when(result_exp == "Mon", "Monday")
2450
+ .when(result_exp == "Tue", "Tuesday")
2451
+ .when(result_exp == "Wed", "Wednesday")
2452
+ .when(result_exp == "Thu", "Thursday")
2453
+ .when(result_exp == "Fri", "Friday")
2454
+ .when(result_exp == "Sat", "Saturday")
2455
+ .when(result_exp == "Sun", "Sunday")
2456
+ .otherwise(result_exp)
2457
+ )
2391
2458
  result_exp = TypedColumn(result_exp, lambda: [StringType()])
2392
2459
  case "date_from_unix_date":
2393
2460
  result_exp = snowpark_fn.date_add(
@@ -2486,6 +2553,14 @@ def map_unresolved_function(
2486
2553
  )
2487
2554
  result_type = LongType()
2488
2555
  case "date_part" | "datepart" | "extract":
2556
+ # Check for interval types and throw NotImplementedError
2557
+ if isinstance(
2558
+ snowpark_typed_args[1].typ, (YearMonthIntervalType, DayTimeIntervalType)
2559
+ ):
2560
+ raise NotImplementedError(
2561
+ f"{function_name} with interval types is not supported"
2562
+ )
2563
+
2489
2564
  field_lit: str | None = unwrap_literal(exp.unresolved_function.arguments[0])
2490
2565
 
2491
2566
  if field_lit is None:
@@ -3261,7 +3336,7 @@ def map_unresolved_function(
3261
3336
  # TODO: See the spark-compatibility-issues.md explanation, this is quite different from Spark.
3262
3337
  # MapType columns as input should raise an exception as they are not hashable.
3263
3338
  snowflake_compat = get_boolean_session_config_param(
3264
- "enable_snowflake_extension_behavior"
3339
+ "snowpark.connect.enable_snowflake_extension_behavior"
3265
3340
  )
3266
3341
  # Snowflake's hash function does allow MAP types, but Spark does not. Therefore, if we have the expansion flag enabled
3267
3342
  # we want to let it pass through and hash MAP types.
@@ -3768,10 +3843,21 @@ def map_unresolved_function(
3768
3843
  snowpark_fn.lit(is_outer),
3769
3844
  )
3770
3845
  case "input_file_name":
3771
- # TODO: Use METADATA$FILENAME to get an actual file name, but this requires a refactor, because this is not possible straightaway here.
3772
- raise SnowparkConnectNotImplementedError(
3773
- "input_file_name is not yet supported."
3846
+ # Return the filename metadata column for file-based DataFrames
3847
+ # If METADATA$FILENAME doesn't exist (e.g., for DataFrames created from local data),
3848
+ # return empty string to match Spark's behavior
3849
+ from snowflake.snowpark_connect.relation.read.metadata_utils import (
3850
+ METADATA_FILENAME_COLUMN,
3774
3851
  )
3852
+
3853
+ available_columns = column_mapping.get_snowpark_columns()
3854
+ if METADATA_FILENAME_COLUMN in available_columns:
3855
+ result_exp = snowpark_fn.col(METADATA_FILENAME_COLUMN)
3856
+ else:
3857
+ # Return empty when METADATA$FILENAME column doesn't exist, matching Spark behavior
3858
+ result_exp = snowpark_fn.lit("").cast(StringType())
3859
+ result_type = StringType()
3860
+ spark_function_name = "input_file_name()"
3775
3861
  case "instr":
3776
3862
  result_exp = snowpark_fn.charindex(snowpark_args[1], snowpark_args[0])
3777
3863
  result_type = LongType()
@@ -4731,7 +4817,7 @@ def map_unresolved_function(
4731
4817
  )
4732
4818
  case "md5":
4733
4819
  snowflake_compat = get_boolean_session_config_param(
4734
- "enable_snowflake_extension_behavior"
4820
+ "snowpark.connect.enable_snowflake_extension_behavior"
4735
4821
  )
4736
4822
 
4737
4823
  # MD5 in Spark only accepts BinaryType or types that can be implicitly cast to it (StringType)
@@ -5305,9 +5391,14 @@ def map_unresolved_function(
5305
5391
  result_exp = snowpark_fn.function(function_name)(
5306
5392
  _check_percentile_percentage(exp.unresolved_function.arguments[1])
5307
5393
  ).within_group(snowpark_args[0])
5308
- result_exp = TypedColumn(
5309
- snowpark_fn.cast(result_exp, FloatType()), lambda: [DoubleType()]
5394
+ result_exp = (
5395
+ TypedColumn(
5396
+ snowpark_fn.cast(result_exp, FloatType()), lambda: [DoubleType()]
5397
+ )
5398
+ if not is_window_enabled()
5399
+ else TypedColumnWithDeferredCast(result_exp, lambda: [DoubleType()])
5310
5400
  )
5401
+
5311
5402
  spark_function_name = f"{function_name}({unwrap_literal(exp.unresolved_function.arguments[1])}) WITHIN GROUP (ORDER BY {snowpark_arg_names[0]})"
5312
5403
  case "pi":
5313
5404
  spark_function_name = "PI()"
@@ -7526,6 +7617,12 @@ def map_unresolved_function(
7526
7617
  )
7527
7618
  result_type = DateType()
7528
7619
  case "try_add":
7620
+ # Check for interval types and throw NotImplementedError
7621
+ for arg in snowpark_typed_args:
7622
+ if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
7623
+ raise NotImplementedError(
7624
+ "try_add with interval types is not supported"
7625
+ )
7529
7626
  result_exp = _try_arithmetic_helper(snowpark_typed_args, snowpark_args, 0)
7530
7627
  result_exp = _type_with_typer(result_exp)
7531
7628
  case "try_aes_decrypt":
@@ -7579,6 +7676,12 @@ def map_unresolved_function(
7579
7676
  DoubleType(), cleaned, calculating_avg=True
7580
7677
  )
7581
7678
  case "try_divide":
7679
+ # Check for interval types and throw NotImplementedError
7680
+ for arg in snowpark_typed_args:
7681
+ if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
7682
+ raise NotImplementedError(
7683
+ "try_divide with interval types is not supported"
7684
+ )
7582
7685
  match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
7583
7686
  case (NullType(), t) | (t, NullType()):
7584
7687
  result_exp = snowpark_fn.lit(None)
@@ -7687,6 +7790,12 @@ def map_unresolved_function(
7687
7790
  f"Expected either (ArrayType, IntegralType) or (MapType, StringType), got {snowpark_typed_args[0].typ}, {snowpark_typed_args[1].typ}."
7688
7791
  )
7689
7792
  case "try_multiply":
7793
+ # Check for interval types and throw NotImplementedError
7794
+ for arg in snowpark_typed_args:
7795
+ if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
7796
+ raise NotImplementedError(
7797
+ "try_multiply with interval types is not supported"
7798
+ )
7690
7799
  match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
7691
7800
  case (NullType(), t) | (t, NullType()):
7692
7801
  result_exp = snowpark_fn.lit(None)
@@ -7785,6 +7894,12 @@ def map_unresolved_function(
7785
7894
  snowpark_typed_args[0].typ, snowpark_args[0]
7786
7895
  )
7787
7896
  case "try_subtract":
7897
+ # Check for interval types and throw NotImplementedError
7898
+ for arg in snowpark_typed_args:
7899
+ if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
7900
+ raise NotImplementedError(
7901
+ "try_subtract with interval types is not supported"
7902
+ )
7788
7903
  result_exp = _try_arithmetic_helper(snowpark_typed_args, snowpark_args, 1)
7789
7904
  result_exp = _type_with_typer(result_exp)
7790
7905
  case "try_to_number":
@@ -9865,6 +9980,70 @@ def _get_add_sub_result_type(
9865
9980
  return result_type, overflow_possible
9866
9981
 
9867
9982
 
9983
+ def _get_interval_type_name(
9984
+ interval_type: Union[YearMonthIntervalType, DayTimeIntervalType]
9985
+ ) -> str:
9986
+ """Get the formatted interval type name for error messages."""
9987
+ if isinstance(interval_type, YearMonthIntervalType):
9988
+ if interval_type.start_field == 0 and interval_type.end_field == 0:
9989
+ return "INTERVAL YEAR"
9990
+ elif interval_type.start_field == 1 and interval_type.end_field == 1:
9991
+ return "INTERVAL MONTH"
9992
+ else:
9993
+ return "INTERVAL YEAR TO MONTH"
9994
+ else: # DayTimeIntervalType
9995
+ if interval_type.start_field == 0 and interval_type.end_field == 0:
9996
+ return "INTERVAL DAY"
9997
+ elif interval_type.start_field == 1 and interval_type.end_field == 1:
9998
+ return "INTERVAL HOUR"
9999
+ elif interval_type.start_field == 2 and interval_type.end_field == 2:
10000
+ return "INTERVAL MINUTE"
10001
+ elif interval_type.start_field == 3 and interval_type.end_field == 3:
10002
+ return "INTERVAL SECOND"
10003
+ else:
10004
+ return "INTERVAL DAY TO SECOND"
10005
+
10006
+
10007
+ def _check_interval_string_comparison(
10008
+ operator: str, snowpark_typed_args: List[TypedColumn], snowpark_arg_names: List[str]
10009
+ ) -> None:
10010
+ """Check for invalid interval-string comparisons and raise AnalysisException if found."""
10011
+ if (
10012
+ isinstance(
10013
+ snowpark_typed_args[0].typ, (YearMonthIntervalType, DayTimeIntervalType)
10014
+ )
10015
+ and isinstance(snowpark_typed_args[1].typ, StringType)
10016
+ or isinstance(snowpark_typed_args[0].typ, StringType)
10017
+ and isinstance(
10018
+ snowpark_typed_args[1].typ, (YearMonthIntervalType, DayTimeIntervalType)
10019
+ )
10020
+ ):
10021
+ # Format interval type name for error message
10022
+ interval_type = (
10023
+ snowpark_typed_args[0].typ
10024
+ if isinstance(
10025
+ snowpark_typed_args[0].typ, (YearMonthIntervalType, DayTimeIntervalType)
10026
+ )
10027
+ else snowpark_typed_args[1].typ
10028
+ )
10029
+ interval_name = _get_interval_type_name(interval_type)
10030
+
10031
+ left_type = (
10032
+ "STRING"
10033
+ if isinstance(snowpark_typed_args[0].typ, StringType)
10034
+ else interval_name
10035
+ )
10036
+ right_type = (
10037
+ "STRING"
10038
+ if isinstance(snowpark_typed_args[1].typ, StringType)
10039
+ else interval_name
10040
+ )
10041
+
10042
+ raise AnalysisException(
10043
+ f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "({snowpark_arg_names[0]} {operator} {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{left_type}" and "{right_type}").;'
10044
+ )
10045
+
10046
+
9868
10047
  def _get_spark_function_name(
9869
10048
  col1: TypedColumn,
9870
10049
  col2: TypedColumn,
@@ -9906,6 +10085,21 @@ def _get_spark_function_name(
9906
10085
  return f"{date_param_name1} {operation_op} {snowpark_arg_names[1]}"
9907
10086
  else:
9908
10087
  return f"{operation_func}(cast({date_param_name1} as date), cast({snowpark_arg_names[1]} as double))"
10088
+ case (DateType(), DayTimeIntervalType()) | (
10089
+ DateType(),
10090
+ YearMonthIntervalType(),
10091
+ ):
10092
+ date_param_name1 = _get_literal_param_name(exp, 0, snowpark_arg_names[0])
10093
+ return f"{date_param_name1} {operation_op} {snowpark_arg_names[1]}"
10094
+ case (DayTimeIntervalType(), DateType()) | (
10095
+ YearMonthIntervalType(),
10096
+ DateType(),
10097
+ ):
10098
+ date_param_name2 = _get_literal_param_name(exp, 1, snowpark_arg_names[1])
10099
+ if function_name == "+":
10100
+ return f"{date_param_name2} {operation_op} {snowpark_arg_names[0]}"
10101
+ else:
10102
+ return default_spark_function_name
9909
10103
  case (DateType() as dt, _) | (_, DateType() as dt):
9910
10104
  date_param_index = 0 if dt == col1.typ else 1
9911
10105
  date_param_name = _get_literal_param_name(