snowpark-connect 0.29.0__py3-none-any.whl → 0.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client.py +65 -0
- snowflake/snowpark_connect/column_name_handler.py +6 -0
- snowflake/snowpark_connect/config.py +25 -3
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +21 -19
- snowflake/snowpark_connect/expression/map_extension.py +277 -1
- snowflake/snowpark_connect/expression/map_sql_expression.py +107 -2
- snowflake/snowpark_connect/expression/map_unresolved_function.py +253 -59
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/relation/io_utils.py +61 -4
- snowflake/snowpark_connect/relation/map_column_ops.py +9 -4
- snowflake/snowpark_connect/relation/map_join.py +8 -0
- snowflake/snowpark_connect/relation/map_row_ops.py +129 -17
- snowflake/snowpark_connect/relation/map_show_string.py +14 -6
- snowflake/snowpark_connect/relation/map_sql.py +39 -5
- snowflake/snowpark_connect/relation/map_stats.py +21 -6
- snowflake/snowpark_connect/relation/read/map_read.py +9 -0
- snowflake/snowpark_connect/relation/read/map_read_csv.py +17 -6
- snowflake/snowpark_connect/relation/read/map_read_json.py +12 -2
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -1
- snowflake/snowpark_connect/relation/read/metadata_utils.py +159 -0
- snowflake/snowpark_connect/relation/utils.py +19 -2
- snowflake/snowpark_connect/relation/write/map_write.py +44 -29
- snowflake/snowpark_connect/server.py +11 -3
- snowflake/snowpark_connect/type_mapping.py +75 -3
- snowflake/snowpark_connect/utils/describe_query_cache.py +6 -3
- snowflake/snowpark_connect/utils/telemetry.py +105 -23
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/METADATA +1 -1
- {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/RECORD +41 -37
- {snowpark_connect-0.29.0.data → snowpark_connect-0.30.1.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.29.0.data → snowpark_connect-0.30.1.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.29.0.data → snowpark_connect-0.30.1.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/top_level.txt +0 -0
|
@@ -15,6 +15,7 @@ from pyspark.sql.connect import functions as pyspark_functions
|
|
|
15
15
|
|
|
16
16
|
import snowflake.snowpark_connect.proto.snowflake_expression_ext_pb2 as snowflake_proto
|
|
17
17
|
from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
|
|
18
|
+
from snowflake.snowpark.types import DayTimeIntervalType, YearMonthIntervalType
|
|
18
19
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
19
20
|
from snowflake.snowpark_connect.config import global_config
|
|
20
21
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
@@ -33,6 +34,24 @@ from .typer import ExpressionTyper
|
|
|
33
34
|
|
|
34
35
|
DECIMAL_RE = re.compile(r"decimal\((\d+), *(\d+)\)")
|
|
35
36
|
|
|
37
|
+
_INTERVAL_YEARMONTH_PATTERN_RE = re.compile(r"interval (year|month)( to (year|month))?")
|
|
38
|
+
_INTERVAL_DAYTIME_PATTERN_RE = re.compile(
|
|
39
|
+
r"interval (day|hour|minute|second)( to (day|hour|minute|second))?"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
# Interval field mappings using proper constants
|
|
43
|
+
_YEAR_MONTH_FIELD_MAP = {
|
|
44
|
+
"year": YearMonthIntervalType.YEAR,
|
|
45
|
+
"month": YearMonthIntervalType.MONTH,
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
_DAY_TIME_FIELD_MAP = {
|
|
49
|
+
"day": DayTimeIntervalType.DAY,
|
|
50
|
+
"hour": DayTimeIntervalType.HOUR,
|
|
51
|
+
"minute": DayTimeIntervalType.MINUTE,
|
|
52
|
+
"second": DayTimeIntervalType.SECOND,
|
|
53
|
+
}
|
|
54
|
+
|
|
36
55
|
_window_specs = ContextVar[dict[str, any]]("_window_specs", default={})
|
|
37
56
|
|
|
38
57
|
|
|
@@ -388,8 +407,94 @@ def map_logical_plan_expression(exp: jpype.JObject) -> expressions_proto.Express
|
|
|
388
407
|
type_value = types_proto.DataType()
|
|
389
408
|
elif type_name == "binary":
|
|
390
409
|
type_value = bytes(type_value)
|
|
391
|
-
elif
|
|
392
|
-
|
|
410
|
+
elif year_month_match := _INTERVAL_YEARMONTH_PATTERN_RE.match(type_name):
|
|
411
|
+
# Extract start and end fields for year-month intervals
|
|
412
|
+
start_field_name = year_month_match.group(1) # 'year' or 'month'
|
|
413
|
+
end_field_name = (
|
|
414
|
+
year_month_match.group(3)
|
|
415
|
+
if year_month_match.group(3)
|
|
416
|
+
else start_field_name
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
# Validate field names exist in mapping
|
|
420
|
+
start_field = _YEAR_MONTH_FIELD_MAP.get(start_field_name)
|
|
421
|
+
end_field = _YEAR_MONTH_FIELD_MAP.get(end_field_name)
|
|
422
|
+
|
|
423
|
+
if start_field is None:
|
|
424
|
+
raise AnalysisException(
|
|
425
|
+
f"Invalid year-month interval start field: '{start_field_name}'. Expected 'year' or 'month'."
|
|
426
|
+
)
|
|
427
|
+
if end_field is None:
|
|
428
|
+
raise AnalysisException(
|
|
429
|
+
f"Invalid year-month interval end field: '{end_field_name}'. Expected 'year' or 'month'."
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
# Validate field ordering (start_field should be <= end_field)
|
|
433
|
+
if start_field > end_field:
|
|
434
|
+
raise AnalysisException(
|
|
435
|
+
f"Invalid year-month interval: start field '{start_field_name}' must come before or equal to end field '{end_field_name}'."
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
# Use extension for year-month intervals to preserve start/end field info
|
|
439
|
+
literal = expressions_proto.Expression.Literal(
|
|
440
|
+
year_month_interval=type_value
|
|
441
|
+
)
|
|
442
|
+
any_proto = Any()
|
|
443
|
+
any_proto.Pack(
|
|
444
|
+
snowflake_proto.ExpExtension(
|
|
445
|
+
interval_literal=snowflake_proto.IntervalLiteralExpression(
|
|
446
|
+
literal=literal,
|
|
447
|
+
start_field=start_field,
|
|
448
|
+
end_field=end_field,
|
|
449
|
+
)
|
|
450
|
+
)
|
|
451
|
+
)
|
|
452
|
+
return expressions_proto.Expression(extension=any_proto)
|
|
453
|
+
elif day_time_match := _INTERVAL_DAYTIME_PATTERN_RE.match(type_name):
|
|
454
|
+
# Extract start and end fields for day-time intervals
|
|
455
|
+
start_field_name = day_time_match.group(
|
|
456
|
+
1
|
|
457
|
+
) # 'day', 'hour', 'minute', 'second'
|
|
458
|
+
end_field_name = (
|
|
459
|
+
day_time_match.group(3)
|
|
460
|
+
if day_time_match.group(3)
|
|
461
|
+
else start_field_name
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# Validate field names exist in mapping
|
|
465
|
+
start_field = _DAY_TIME_FIELD_MAP.get(start_field_name)
|
|
466
|
+
end_field = _DAY_TIME_FIELD_MAP.get(end_field_name)
|
|
467
|
+
|
|
468
|
+
if start_field is None:
|
|
469
|
+
raise AnalysisException(
|
|
470
|
+
f"Invalid day-time interval start field: '{start_field_name}'. Expected 'day', 'hour', 'minute', or 'second'."
|
|
471
|
+
)
|
|
472
|
+
if end_field is None:
|
|
473
|
+
raise AnalysisException(
|
|
474
|
+
f"Invalid day-time interval end field: '{end_field_name}'. Expected 'day', 'hour', 'minute', or 'second'."
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Validate field ordering (start_field should be <= end_field)
|
|
478
|
+
if start_field > end_field:
|
|
479
|
+
raise AnalysisException(
|
|
480
|
+
f"Invalid day-time interval: start field '{start_field_name}' must come before or equal to end field '{end_field_name}'."
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# Use extension for day-time intervals to preserve start/end field info
|
|
484
|
+
literal = expressions_proto.Expression.Literal(
|
|
485
|
+
day_time_interval=type_value
|
|
486
|
+
)
|
|
487
|
+
any_proto = Any()
|
|
488
|
+
any_proto.Pack(
|
|
489
|
+
snowflake_proto.ExpExtension(
|
|
490
|
+
interval_literal=snowflake_proto.IntervalLiteralExpression(
|
|
491
|
+
literal=literal,
|
|
492
|
+
start_field=start_field,
|
|
493
|
+
end_field=end_field,
|
|
494
|
+
)
|
|
495
|
+
)
|
|
496
|
+
)
|
|
497
|
+
return expressions_proto.Expression(extension=any_proto)
|
|
393
498
|
elif m := DECIMAL_RE.fullmatch(type_name):
|
|
394
499
|
type_name = "decimal"
|
|
395
500
|
type_value = expressions_proto.Expression.Literal.Decimal(
|
|
@@ -20,7 +20,7 @@ from contextlib import suppress
|
|
|
20
20
|
from decimal import ROUND_HALF_EVEN, ROUND_HALF_UP, Context, Decimal
|
|
21
21
|
from functools import partial, reduce
|
|
22
22
|
from pathlib import Path
|
|
23
|
-
from typing import List, Optional
|
|
23
|
+
from typing import List, Optional, Union
|
|
24
24
|
from urllib.parse import quote, unquote
|
|
25
25
|
|
|
26
26
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
@@ -50,6 +50,7 @@ from snowflake.snowpark.types import (
|
|
|
50
50
|
ByteType,
|
|
51
51
|
DataType,
|
|
52
52
|
DateType,
|
|
53
|
+
DayTimeIntervalType,
|
|
53
54
|
DecimalType,
|
|
54
55
|
DoubleType,
|
|
55
56
|
FloatType,
|
|
@@ -64,6 +65,7 @@ from snowflake.snowpark.types import (
|
|
|
64
65
|
TimestampTimeZone,
|
|
65
66
|
TimestampType,
|
|
66
67
|
VariantType,
|
|
68
|
+
YearMonthIntervalType,
|
|
67
69
|
_FractionalType,
|
|
68
70
|
_IntegralType,
|
|
69
71
|
_NumericType,
|
|
@@ -639,8 +641,17 @@ def map_unresolved_function(
|
|
|
639
641
|
if isinstance(t, (IntegerType, ShortType, ByteType)):
|
|
640
642
|
result_type = DateType()
|
|
641
643
|
result_exp = snowpark_args[0] + snowpark_args[1]
|
|
644
|
+
elif isinstance(t, (DayTimeIntervalType, YearMonthIntervalType)):
|
|
645
|
+
result_type = TimestampType()
|
|
646
|
+
result_exp = (
|
|
647
|
+
snowpark_args[date_param_index]
|
|
648
|
+
+ snowpark_args[t_param_index]
|
|
649
|
+
)
|
|
642
650
|
elif (
|
|
643
|
-
|
|
651
|
+
hasattr(
|
|
652
|
+
snowpark_typed_args[t_param_index].col._expr1, "pretty_name"
|
|
653
|
+
)
|
|
654
|
+
and "INTERVAL"
|
|
644
655
|
== snowpark_typed_args[t_param_index].col._expr1.pretty_name
|
|
645
656
|
):
|
|
646
657
|
result_type = TimestampType()
|
|
@@ -740,6 +751,12 @@ def map_unresolved_function(
|
|
|
740
751
|
# TODO SNOW-2034420: resolve return type (it should be INTERVAL DAY)
|
|
741
752
|
result_type = LongType()
|
|
742
753
|
result_exp = snowpark_args[0] - snowpark_args[1]
|
|
754
|
+
case (DateType(), DayTimeIntervalType()) | (
|
|
755
|
+
DateType(),
|
|
756
|
+
YearMonthIntervalType(),
|
|
757
|
+
):
|
|
758
|
+
result_type = TimestampType()
|
|
759
|
+
result_exp = snowpark_args[0] - snowpark_args[1]
|
|
743
760
|
case (DateType(), StringType()):
|
|
744
761
|
if (
|
|
745
762
|
hasattr(snowpark_typed_args[1].col._expr1, "pretty_name")
|
|
@@ -944,6 +961,10 @@ def map_unresolved_function(
|
|
|
944
961
|
raise AnalysisException(
|
|
945
962
|
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} < {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
|
|
946
963
|
)
|
|
964
|
+
# Check for interval-string comparisons
|
|
965
|
+
_check_interval_string_comparison(
|
|
966
|
+
"<", snowpark_typed_args, snowpark_arg_names
|
|
967
|
+
)
|
|
947
968
|
left, right = _coerce_for_comparison(
|
|
948
969
|
snowpark_typed_args[0], snowpark_typed_args[1]
|
|
949
970
|
)
|
|
@@ -958,6 +979,10 @@ def map_unresolved_function(
|
|
|
958
979
|
raise AnalysisException(
|
|
959
980
|
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} <= {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
|
|
960
981
|
)
|
|
982
|
+
# Check for interval-string comparisons
|
|
983
|
+
_check_interval_string_comparison(
|
|
984
|
+
"<=", snowpark_typed_args, snowpark_arg_names
|
|
985
|
+
)
|
|
961
986
|
left, right = _coerce_for_comparison(
|
|
962
987
|
snowpark_typed_args[0], snowpark_typed_args[1]
|
|
963
988
|
)
|
|
@@ -976,6 +1001,10 @@ def map_unresolved_function(
|
|
|
976
1001
|
)
|
|
977
1002
|
result_exp = TypedColumn(left.eqNullSafe(right), lambda: [BooleanType()])
|
|
978
1003
|
case "==" | "=":
|
|
1004
|
+
# Check for interval-string comparisons
|
|
1005
|
+
_check_interval_string_comparison(
|
|
1006
|
+
"=", snowpark_typed_args, snowpark_arg_names
|
|
1007
|
+
)
|
|
979
1008
|
spark_function_name = f"({snowpark_arg_names[0]} = {snowpark_arg_names[1]})"
|
|
980
1009
|
left, right = _coerce_for_comparison(
|
|
981
1010
|
snowpark_typed_args[0], snowpark_typed_args[1]
|
|
@@ -991,6 +1020,10 @@ def map_unresolved_function(
|
|
|
991
1020
|
raise AnalysisException(
|
|
992
1021
|
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} > {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
|
|
993
1022
|
)
|
|
1023
|
+
# Check for interval-string comparisons
|
|
1024
|
+
_check_interval_string_comparison(
|
|
1025
|
+
">", snowpark_typed_args, snowpark_arg_names
|
|
1026
|
+
)
|
|
994
1027
|
left, right = _coerce_for_comparison(
|
|
995
1028
|
snowpark_typed_args[0], snowpark_typed_args[1]
|
|
996
1029
|
)
|
|
@@ -1005,6 +1038,10 @@ def map_unresolved_function(
|
|
|
1005
1038
|
raise AnalysisException(
|
|
1006
1039
|
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} >= {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
|
|
1007
1040
|
)
|
|
1041
|
+
# Check for interval-string comparisons
|
|
1042
|
+
_check_interval_string_comparison(
|
|
1043
|
+
">=", snowpark_typed_args, snowpark_arg_names
|
|
1044
|
+
)
|
|
1008
1045
|
left, right = _coerce_for_comparison(
|
|
1009
1046
|
snowpark_typed_args[0], snowpark_typed_args[1]
|
|
1010
1047
|
)
|
|
@@ -1135,9 +1172,7 @@ def map_unresolved_function(
|
|
|
1135
1172
|
# SNOW-1955784: Support accuracy parameter
|
|
1136
1173
|
# Use percentile_disc to return actual values from dataset (matches PySpark behavior)
|
|
1137
1174
|
|
|
1138
|
-
def _pyspark_approx_percentile(
|
|
1139
|
-
column: Column, percentage: float, original_type: DataType
|
|
1140
|
-
) -> Column:
|
|
1175
|
+
def _pyspark_approx_percentile(column: Column, percentage: float) -> Column:
|
|
1141
1176
|
"""
|
|
1142
1177
|
PySpark-compatible percentile that returns actual values from dataset.
|
|
1143
1178
|
- PySpark's approx_percentile returns the "smallest value in the ordered col values
|
|
@@ -1154,7 +1189,7 @@ def map_unresolved_function(
|
|
|
1154
1189
|
result = snowpark_fn.function("percentile_disc")(
|
|
1155
1190
|
snowpark_fn.lit(percentage)
|
|
1156
1191
|
).within_group(column)
|
|
1157
|
-
return
|
|
1192
|
+
return result
|
|
1158
1193
|
|
|
1159
1194
|
column_type = snowpark_typed_args[0].typ
|
|
1160
1195
|
|
|
@@ -1165,26 +1200,18 @@ def map_unresolved_function(
|
|
|
1165
1200
|
assert array_func.function_name == "array", array_func
|
|
1166
1201
|
|
|
1167
1202
|
percentile_results = [
|
|
1168
|
-
_pyspark_approx_percentile(
|
|
1169
|
-
snowpark_args[0], unwrap_literal(arg), column_type
|
|
1170
|
-
)
|
|
1203
|
+
_pyspark_approx_percentile(snowpark_args[0], unwrap_literal(arg))
|
|
1171
1204
|
for arg in array_func.arguments
|
|
1172
1205
|
]
|
|
1173
1206
|
|
|
1174
1207
|
result_type = ArrayType(element_type=column_type, contains_null=False)
|
|
1175
|
-
result_exp = snowpark_fn.
|
|
1176
|
-
|
|
1177
|
-
result_type,
|
|
1178
|
-
)
|
|
1208
|
+
result_exp = snowpark_fn.array_construct(*percentile_results)
|
|
1209
|
+
result_exp = _resolve_aggregate_exp(result_exp, result_type)
|
|
1179
1210
|
else:
|
|
1180
1211
|
# Handle single percentile
|
|
1181
1212
|
percentage = unwrap_literal(exp.unresolved_function.arguments[1])
|
|
1182
|
-
result_exp =
|
|
1183
|
-
|
|
1184
|
-
snowpark_args[0], percentage, column_type
|
|
1185
|
-
),
|
|
1186
|
-
lambda: [column_type],
|
|
1187
|
-
)
|
|
1213
|
+
result_exp = _pyspark_approx_percentile(snowpark_args[0], percentage)
|
|
1214
|
+
result_exp = _resolve_aggregate_exp(result_exp, column_type)
|
|
1188
1215
|
case "array":
|
|
1189
1216
|
if len(snowpark_args) == 0:
|
|
1190
1217
|
result_exp = snowpark_fn.cast(
|
|
@@ -1358,27 +1385,55 @@ def map_unresolved_function(
|
|
|
1358
1385
|
result_exp = snowpark_fn.cast(result_exp, array_type)
|
|
1359
1386
|
result_exp = TypedColumn(result_exp, lambda: snowpark_typed_args[0].types)
|
|
1360
1387
|
case "array_repeat":
|
|
1388
|
+
elem, count = snowpark_args[0], snowpark_args[1]
|
|
1389
|
+
elem_type = snowpark_typed_args[0].typ
|
|
1390
|
+
result_type = ArrayType(elem_type)
|
|
1361
1391
|
|
|
1362
|
-
|
|
1363
|
-
input_types=[VariantType(), LongType()],
|
|
1364
|
-
return_type=ArrayType(),
|
|
1365
|
-
)
|
|
1366
|
-
def _array_repeat(elem, n):
|
|
1367
|
-
if n is None:
|
|
1368
|
-
return None
|
|
1369
|
-
if n < 0:
|
|
1370
|
-
return []
|
|
1371
|
-
return [elem] * n
|
|
1392
|
+
fallback_to_udf = True
|
|
1372
1393
|
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
1394
|
+
if isinstance(count._expression, Literal):
|
|
1395
|
+
count_value = count._expression.value
|
|
1396
|
+
fallback_to_udf = False
|
|
1397
|
+
|
|
1398
|
+
if count_value is None:
|
|
1399
|
+
result_exp = snowpark_fn.lit(None).cast(result_type)
|
|
1400
|
+
elif count_value <= 0:
|
|
1401
|
+
result_exp = snowpark_fn.array_construct().cast(result_type)
|
|
1402
|
+
elif count_value <= 16:
|
|
1403
|
+
# count_value is small enough to initialize the array directly in memory
|
|
1404
|
+
elem_variant = snowpark_fn.cast(elem, VariantType())
|
|
1405
|
+
result_exp = snowpark_fn.array_construct(
|
|
1406
|
+
*([elem_variant] * count_value)
|
|
1407
|
+
).cast(result_type)
|
|
1408
|
+
else:
|
|
1409
|
+
fallback_to_udf = True
|
|
1410
|
+
|
|
1411
|
+
if fallback_to_udf:
|
|
1412
|
+
|
|
1413
|
+
@cached_udf(
|
|
1414
|
+
input_types=[VariantType(), LongType()],
|
|
1415
|
+
return_type=ArrayType(),
|
|
1416
|
+
)
|
|
1417
|
+
def _array_repeat(elem, n):
|
|
1418
|
+
if n is None:
|
|
1419
|
+
return None
|
|
1420
|
+
if n < 0:
|
|
1421
|
+
return []
|
|
1422
|
+
return [elem] * n
|
|
1423
|
+
|
|
1424
|
+
elem_variant = snowpark_fn.cast(elem, VariantType())
|
|
1425
|
+
|
|
1426
|
+
result_exp = (
|
|
1427
|
+
snowpark_fn.when(
|
|
1428
|
+
count.is_null(), snowpark_fn.lit(None).cast(result_type)
|
|
1429
|
+
)
|
|
1430
|
+
.when(count <= 0, snowpark_fn.array_construct().cast(result_type))
|
|
1431
|
+
.otherwise(
|
|
1432
|
+
snowpark_fn.cast(
|
|
1433
|
+
_array_repeat(elem_variant, count), result_type
|
|
1434
|
+
)
|
|
1435
|
+
)
|
|
1436
|
+
)
|
|
1382
1437
|
case "array_size":
|
|
1383
1438
|
array_type = snowpark_typed_args[0].typ
|
|
1384
1439
|
if not isinstance(array_type, ArrayType):
|
|
@@ -1578,7 +1633,7 @@ def map_unresolved_function(
|
|
|
1578
1633
|
result_exp = TypedColumn(result_exp, lambda: [LongType()])
|
|
1579
1634
|
case "bit_get" | "getbit":
|
|
1580
1635
|
snowflake_compat = get_boolean_session_config_param(
|
|
1581
|
-
"enable_snowflake_extension_behavior"
|
|
1636
|
+
"snowpark.connect.enable_snowflake_extension_behavior"
|
|
1582
1637
|
)
|
|
1583
1638
|
col, pos = snowpark_args
|
|
1584
1639
|
if snowflake_compat:
|
|
@@ -1885,14 +1940,11 @@ def map_unresolved_function(
|
|
|
1885
1940
|
qualifiers = snowpark_args[0].get_qualifiers()
|
|
1886
1941
|
case "collect_list" | "array_agg":
|
|
1887
1942
|
# TODO: SNOW-1967177 - Support structured types in array_agg
|
|
1888
|
-
result_exp = snowpark_fn.
|
|
1889
|
-
|
|
1890
|
-
snowpark_typed_args[0].column(to_semi_structure=True)
|
|
1891
|
-
),
|
|
1892
|
-
ArrayType(snowpark_typed_args[0].typ),
|
|
1943
|
+
result_exp = snowpark_fn.array_agg(
|
|
1944
|
+
snowpark_typed_args[0].column(to_semi_structure=True)
|
|
1893
1945
|
)
|
|
1894
|
-
result_exp =
|
|
1895
|
-
result_exp,
|
|
1946
|
+
result_exp = _resolve_aggregate_exp(
|
|
1947
|
+
result_exp, ArrayType(snowpark_typed_args[0].typ)
|
|
1896
1948
|
)
|
|
1897
1949
|
spark_function_name = f"collect_list({snowpark_arg_names[0]})"
|
|
1898
1950
|
case "collect_set":
|
|
@@ -2379,15 +2431,30 @@ def map_unresolved_function(
|
|
|
2379
2431
|
# If format is NULL, return NULL for all rows
|
|
2380
2432
|
result_exp = snowpark_fn.lit(None)
|
|
2381
2433
|
else:
|
|
2434
|
+
format_lit = snowpark_fn.lit(
|
|
2435
|
+
map_spark_timestamp_format_expression(
|
|
2436
|
+
exp.unresolved_function.arguments[1],
|
|
2437
|
+
snowpark_typed_args[0].typ,
|
|
2438
|
+
)
|
|
2439
|
+
)
|
|
2382
2440
|
result_exp = snowpark_fn.date_format(
|
|
2383
2441
|
snowpark_args[0],
|
|
2384
|
-
|
|
2385
|
-
map_spark_timestamp_format_expression(
|
|
2386
|
-
exp.unresolved_function.arguments[1],
|
|
2387
|
-
snowpark_typed_args[0].typ,
|
|
2388
|
-
)
|
|
2389
|
-
),
|
|
2442
|
+
format_lit,
|
|
2390
2443
|
)
|
|
2444
|
+
|
|
2445
|
+
if format_literal == "EEEE":
|
|
2446
|
+
# TODO: SNOW-2356874, for weekday, Snowflake only supports abbreviated name, e.g. "Fri". Patch spark "EEEE" until
|
|
2447
|
+
# snowflake supports full weekday name.
|
|
2448
|
+
result_exp = (
|
|
2449
|
+
snowpark_fn.when(result_exp == "Mon", "Monday")
|
|
2450
|
+
.when(result_exp == "Tue", "Tuesday")
|
|
2451
|
+
.when(result_exp == "Wed", "Wednesday")
|
|
2452
|
+
.when(result_exp == "Thu", "Thursday")
|
|
2453
|
+
.when(result_exp == "Fri", "Friday")
|
|
2454
|
+
.when(result_exp == "Sat", "Saturday")
|
|
2455
|
+
.when(result_exp == "Sun", "Sunday")
|
|
2456
|
+
.otherwise(result_exp)
|
|
2457
|
+
)
|
|
2391
2458
|
result_exp = TypedColumn(result_exp, lambda: [StringType()])
|
|
2392
2459
|
case "date_from_unix_date":
|
|
2393
2460
|
result_exp = snowpark_fn.date_add(
|
|
@@ -2486,6 +2553,14 @@ def map_unresolved_function(
|
|
|
2486
2553
|
)
|
|
2487
2554
|
result_type = LongType()
|
|
2488
2555
|
case "date_part" | "datepart" | "extract":
|
|
2556
|
+
# Check for interval types and throw NotImplementedError
|
|
2557
|
+
if isinstance(
|
|
2558
|
+
snowpark_typed_args[1].typ, (YearMonthIntervalType, DayTimeIntervalType)
|
|
2559
|
+
):
|
|
2560
|
+
raise NotImplementedError(
|
|
2561
|
+
f"{function_name} with interval types is not supported"
|
|
2562
|
+
)
|
|
2563
|
+
|
|
2489
2564
|
field_lit: str | None = unwrap_literal(exp.unresolved_function.arguments[0])
|
|
2490
2565
|
|
|
2491
2566
|
if field_lit is None:
|
|
@@ -3261,7 +3336,7 @@ def map_unresolved_function(
|
|
|
3261
3336
|
# TODO: See the spark-compatibility-issues.md explanation, this is quite different from Spark.
|
|
3262
3337
|
# MapType columns as input should raise an exception as they are not hashable.
|
|
3263
3338
|
snowflake_compat = get_boolean_session_config_param(
|
|
3264
|
-
"enable_snowflake_extension_behavior"
|
|
3339
|
+
"snowpark.connect.enable_snowflake_extension_behavior"
|
|
3265
3340
|
)
|
|
3266
3341
|
# Snowflake's hash function does allow MAP types, but Spark does not. Therefore, if we have the expansion flag enabled
|
|
3267
3342
|
# we want to let it pass through and hash MAP types.
|
|
@@ -3768,10 +3843,21 @@ def map_unresolved_function(
|
|
|
3768
3843
|
snowpark_fn.lit(is_outer),
|
|
3769
3844
|
)
|
|
3770
3845
|
case "input_file_name":
|
|
3771
|
-
#
|
|
3772
|
-
|
|
3773
|
-
|
|
3846
|
+
# Return the filename metadata column for file-based DataFrames
|
|
3847
|
+
# If METADATA$FILENAME doesn't exist (e.g., for DataFrames created from local data),
|
|
3848
|
+
# return empty string to match Spark's behavior
|
|
3849
|
+
from snowflake.snowpark_connect.relation.read.metadata_utils import (
|
|
3850
|
+
METADATA_FILENAME_COLUMN,
|
|
3774
3851
|
)
|
|
3852
|
+
|
|
3853
|
+
available_columns = column_mapping.get_snowpark_columns()
|
|
3854
|
+
if METADATA_FILENAME_COLUMN in available_columns:
|
|
3855
|
+
result_exp = snowpark_fn.col(METADATA_FILENAME_COLUMN)
|
|
3856
|
+
else:
|
|
3857
|
+
# Return empty when METADATA$FILENAME column doesn't exist, matching Spark behavior
|
|
3858
|
+
result_exp = snowpark_fn.lit("").cast(StringType())
|
|
3859
|
+
result_type = StringType()
|
|
3860
|
+
spark_function_name = "input_file_name()"
|
|
3775
3861
|
case "instr":
|
|
3776
3862
|
result_exp = snowpark_fn.charindex(snowpark_args[1], snowpark_args[0])
|
|
3777
3863
|
result_type = LongType()
|
|
@@ -4731,7 +4817,7 @@ def map_unresolved_function(
|
|
|
4731
4817
|
)
|
|
4732
4818
|
case "md5":
|
|
4733
4819
|
snowflake_compat = get_boolean_session_config_param(
|
|
4734
|
-
"enable_snowflake_extension_behavior"
|
|
4820
|
+
"snowpark.connect.enable_snowflake_extension_behavior"
|
|
4735
4821
|
)
|
|
4736
4822
|
|
|
4737
4823
|
# MD5 in Spark only accepts BinaryType or types that can be implicitly cast to it (StringType)
|
|
@@ -5305,9 +5391,14 @@ def map_unresolved_function(
|
|
|
5305
5391
|
result_exp = snowpark_fn.function(function_name)(
|
|
5306
5392
|
_check_percentile_percentage(exp.unresolved_function.arguments[1])
|
|
5307
5393
|
).within_group(snowpark_args[0])
|
|
5308
|
-
result_exp =
|
|
5309
|
-
|
|
5394
|
+
result_exp = (
|
|
5395
|
+
TypedColumn(
|
|
5396
|
+
snowpark_fn.cast(result_exp, FloatType()), lambda: [DoubleType()]
|
|
5397
|
+
)
|
|
5398
|
+
if not is_window_enabled()
|
|
5399
|
+
else TypedColumnWithDeferredCast(result_exp, lambda: [DoubleType()])
|
|
5310
5400
|
)
|
|
5401
|
+
|
|
5311
5402
|
spark_function_name = f"{function_name}({unwrap_literal(exp.unresolved_function.arguments[1])}) WITHIN GROUP (ORDER BY {snowpark_arg_names[0]})"
|
|
5312
5403
|
case "pi":
|
|
5313
5404
|
spark_function_name = "PI()"
|
|
@@ -7526,6 +7617,12 @@ def map_unresolved_function(
|
|
|
7526
7617
|
)
|
|
7527
7618
|
result_type = DateType()
|
|
7528
7619
|
case "try_add":
|
|
7620
|
+
# Check for interval types and throw NotImplementedError
|
|
7621
|
+
for arg in snowpark_typed_args:
|
|
7622
|
+
if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
|
|
7623
|
+
raise NotImplementedError(
|
|
7624
|
+
"try_add with interval types is not supported"
|
|
7625
|
+
)
|
|
7529
7626
|
result_exp = _try_arithmetic_helper(snowpark_typed_args, snowpark_args, 0)
|
|
7530
7627
|
result_exp = _type_with_typer(result_exp)
|
|
7531
7628
|
case "try_aes_decrypt":
|
|
@@ -7579,6 +7676,12 @@ def map_unresolved_function(
|
|
|
7579
7676
|
DoubleType(), cleaned, calculating_avg=True
|
|
7580
7677
|
)
|
|
7581
7678
|
case "try_divide":
|
|
7679
|
+
# Check for interval types and throw NotImplementedError
|
|
7680
|
+
for arg in snowpark_typed_args:
|
|
7681
|
+
if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
|
|
7682
|
+
raise NotImplementedError(
|
|
7683
|
+
"try_divide with interval types is not supported"
|
|
7684
|
+
)
|
|
7582
7685
|
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
7583
7686
|
case (NullType(), t) | (t, NullType()):
|
|
7584
7687
|
result_exp = snowpark_fn.lit(None)
|
|
@@ -7687,6 +7790,12 @@ def map_unresolved_function(
|
|
|
7687
7790
|
f"Expected either (ArrayType, IntegralType) or (MapType, StringType), got {snowpark_typed_args[0].typ}, {snowpark_typed_args[1].typ}."
|
|
7688
7791
|
)
|
|
7689
7792
|
case "try_multiply":
|
|
7793
|
+
# Check for interval types and throw NotImplementedError
|
|
7794
|
+
for arg in snowpark_typed_args:
|
|
7795
|
+
if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
|
|
7796
|
+
raise NotImplementedError(
|
|
7797
|
+
"try_multiply with interval types is not supported"
|
|
7798
|
+
)
|
|
7690
7799
|
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
7691
7800
|
case (NullType(), t) | (t, NullType()):
|
|
7692
7801
|
result_exp = snowpark_fn.lit(None)
|
|
@@ -7785,6 +7894,12 @@ def map_unresolved_function(
|
|
|
7785
7894
|
snowpark_typed_args[0].typ, snowpark_args[0]
|
|
7786
7895
|
)
|
|
7787
7896
|
case "try_subtract":
|
|
7897
|
+
# Check for interval types and throw NotImplementedError
|
|
7898
|
+
for arg in snowpark_typed_args:
|
|
7899
|
+
if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
|
|
7900
|
+
raise NotImplementedError(
|
|
7901
|
+
"try_subtract with interval types is not supported"
|
|
7902
|
+
)
|
|
7788
7903
|
result_exp = _try_arithmetic_helper(snowpark_typed_args, snowpark_args, 1)
|
|
7789
7904
|
result_exp = _type_with_typer(result_exp)
|
|
7790
7905
|
case "try_to_number":
|
|
@@ -9865,6 +9980,70 @@ def _get_add_sub_result_type(
|
|
|
9865
9980
|
return result_type, overflow_possible
|
|
9866
9981
|
|
|
9867
9982
|
|
|
9983
|
+
def _get_interval_type_name(
|
|
9984
|
+
interval_type: Union[YearMonthIntervalType, DayTimeIntervalType]
|
|
9985
|
+
) -> str:
|
|
9986
|
+
"""Get the formatted interval type name for error messages."""
|
|
9987
|
+
if isinstance(interval_type, YearMonthIntervalType):
|
|
9988
|
+
if interval_type.start_field == 0 and interval_type.end_field == 0:
|
|
9989
|
+
return "INTERVAL YEAR"
|
|
9990
|
+
elif interval_type.start_field == 1 and interval_type.end_field == 1:
|
|
9991
|
+
return "INTERVAL MONTH"
|
|
9992
|
+
else:
|
|
9993
|
+
return "INTERVAL YEAR TO MONTH"
|
|
9994
|
+
else: # DayTimeIntervalType
|
|
9995
|
+
if interval_type.start_field == 0 and interval_type.end_field == 0:
|
|
9996
|
+
return "INTERVAL DAY"
|
|
9997
|
+
elif interval_type.start_field == 1 and interval_type.end_field == 1:
|
|
9998
|
+
return "INTERVAL HOUR"
|
|
9999
|
+
elif interval_type.start_field == 2 and interval_type.end_field == 2:
|
|
10000
|
+
return "INTERVAL MINUTE"
|
|
10001
|
+
elif interval_type.start_field == 3 and interval_type.end_field == 3:
|
|
10002
|
+
return "INTERVAL SECOND"
|
|
10003
|
+
else:
|
|
10004
|
+
return "INTERVAL DAY TO SECOND"
|
|
10005
|
+
|
|
10006
|
+
|
|
10007
|
+
def _check_interval_string_comparison(
|
|
10008
|
+
operator: str, snowpark_typed_args: List[TypedColumn], snowpark_arg_names: List[str]
|
|
10009
|
+
) -> None:
|
|
10010
|
+
"""Check for invalid interval-string comparisons and raise AnalysisException if found."""
|
|
10011
|
+
if (
|
|
10012
|
+
isinstance(
|
|
10013
|
+
snowpark_typed_args[0].typ, (YearMonthIntervalType, DayTimeIntervalType)
|
|
10014
|
+
)
|
|
10015
|
+
and isinstance(snowpark_typed_args[1].typ, StringType)
|
|
10016
|
+
or isinstance(snowpark_typed_args[0].typ, StringType)
|
|
10017
|
+
and isinstance(
|
|
10018
|
+
snowpark_typed_args[1].typ, (YearMonthIntervalType, DayTimeIntervalType)
|
|
10019
|
+
)
|
|
10020
|
+
):
|
|
10021
|
+
# Format interval type name for error message
|
|
10022
|
+
interval_type = (
|
|
10023
|
+
snowpark_typed_args[0].typ
|
|
10024
|
+
if isinstance(
|
|
10025
|
+
snowpark_typed_args[0].typ, (YearMonthIntervalType, DayTimeIntervalType)
|
|
10026
|
+
)
|
|
10027
|
+
else snowpark_typed_args[1].typ
|
|
10028
|
+
)
|
|
10029
|
+
interval_name = _get_interval_type_name(interval_type)
|
|
10030
|
+
|
|
10031
|
+
left_type = (
|
|
10032
|
+
"STRING"
|
|
10033
|
+
if isinstance(snowpark_typed_args[0].typ, StringType)
|
|
10034
|
+
else interval_name
|
|
10035
|
+
)
|
|
10036
|
+
right_type = (
|
|
10037
|
+
"STRING"
|
|
10038
|
+
if isinstance(snowpark_typed_args[1].typ, StringType)
|
|
10039
|
+
else interval_name
|
|
10040
|
+
)
|
|
10041
|
+
|
|
10042
|
+
raise AnalysisException(
|
|
10043
|
+
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "({snowpark_arg_names[0]} {operator} {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{left_type}" and "{right_type}").;'
|
|
10044
|
+
)
|
|
10045
|
+
|
|
10046
|
+
|
|
9868
10047
|
def _get_spark_function_name(
|
|
9869
10048
|
col1: TypedColumn,
|
|
9870
10049
|
col2: TypedColumn,
|
|
@@ -9906,6 +10085,21 @@ def _get_spark_function_name(
|
|
|
9906
10085
|
return f"{date_param_name1} {operation_op} {snowpark_arg_names[1]}"
|
|
9907
10086
|
else:
|
|
9908
10087
|
return f"{operation_func}(cast({date_param_name1} as date), cast({snowpark_arg_names[1]} as double))"
|
|
10088
|
+
case (DateType(), DayTimeIntervalType()) | (
|
|
10089
|
+
DateType(),
|
|
10090
|
+
YearMonthIntervalType(),
|
|
10091
|
+
):
|
|
10092
|
+
date_param_name1 = _get_literal_param_name(exp, 0, snowpark_arg_names[0])
|
|
10093
|
+
return f"{date_param_name1} {operation_op} {snowpark_arg_names[1]}"
|
|
10094
|
+
case (DayTimeIntervalType(), DateType()) | (
|
|
10095
|
+
YearMonthIntervalType(),
|
|
10096
|
+
DateType(),
|
|
10097
|
+
):
|
|
10098
|
+
date_param_name2 = _get_literal_param_name(exp, 1, snowpark_arg_names[1])
|
|
10099
|
+
if function_name == "+":
|
|
10100
|
+
return f"{date_param_name2} {operation_op} {snowpark_arg_names[0]}"
|
|
10101
|
+
else:
|
|
10102
|
+
return default_spark_function_name
|
|
9909
10103
|
case (DateType() as dt, _) | (_, DateType() as dt):
|
|
9910
10104
|
date_param_index = 0 if dt == col1.typ else 1
|
|
9911
10105
|
date_param_name = _get_literal_param_name(
|