snowpark-connect 0.28.1__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client.py +65 -0
- snowflake/snowpark_connect/column_name_handler.py +6 -0
- snowflake/snowpark_connect/config.py +33 -5
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +21 -19
- snowflake/snowpark_connect/expression/map_extension.py +277 -1
- snowflake/snowpark_connect/expression/map_sql_expression.py +107 -2
- snowflake/snowpark_connect/expression/map_unresolved_function.py +425 -269
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/relation/io_utils.py +21 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +9 -4
- snowflake/snowpark_connect/relation/map_extension.py +21 -4
- snowflake/snowpark_connect/relation/map_join.py +8 -0
- snowflake/snowpark_connect/relation/map_map_partitions.py +7 -8
- snowflake/snowpark_connect/relation/map_relation.py +1 -3
- snowflake/snowpark_connect/relation/map_row_ops.py +116 -15
- snowflake/snowpark_connect/relation/map_show_string.py +14 -6
- snowflake/snowpark_connect/relation/map_sql.py +39 -5
- snowflake/snowpark_connect/relation/map_stats.py +1 -1
- snowflake/snowpark_connect/relation/read/map_read.py +22 -3
- snowflake/snowpark_connect/relation/read/map_read_csv.py +119 -29
- snowflake/snowpark_connect/relation/read/map_read_json.py +57 -36
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -1
- snowflake/snowpark_connect/relation/read/map_read_text.py +6 -1
- snowflake/snowpark_connect/relation/read/metadata_utils.py +159 -0
- snowflake/snowpark_connect/relation/stage_locator.py +85 -53
- snowflake/snowpark_connect/relation/write/map_write.py +67 -4
- snowflake/snowpark_connect/server.py +29 -16
- snowflake/snowpark_connect/type_mapping.py +75 -3
- snowflake/snowpark_connect/utils/context.py +0 -14
- snowflake/snowpark_connect/utils/describe_query_cache.py +6 -3
- snowflake/snowpark_connect/utils/io_utils.py +36 -0
- snowflake/snowpark_connect/utils/session.py +4 -0
- snowflake/snowpark_connect/utils/telemetry.py +30 -5
- snowflake/snowpark_connect/utils/udf_cache.py +37 -7
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/METADATA +3 -2
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/RECORD +47 -45
- {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/top_level.txt +0 -0
|
@@ -15,11 +15,12 @@ import tempfile
|
|
|
15
15
|
import time
|
|
16
16
|
import uuid
|
|
17
17
|
from collections import defaultdict
|
|
18
|
+
from collections.abc import Callable
|
|
18
19
|
from contextlib import suppress
|
|
19
20
|
from decimal import ROUND_HALF_EVEN, ROUND_HALF_UP, Context, Decimal
|
|
20
21
|
from functools import partial, reduce
|
|
21
22
|
from pathlib import Path
|
|
22
|
-
from typing import List, Optional
|
|
23
|
+
from typing import List, Optional, Union
|
|
23
24
|
from urllib.parse import quote, unquote
|
|
24
25
|
|
|
25
26
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
@@ -49,6 +50,7 @@ from snowflake.snowpark.types import (
|
|
|
49
50
|
ByteType,
|
|
50
51
|
DataType,
|
|
51
52
|
DateType,
|
|
53
|
+
DayTimeIntervalType,
|
|
52
54
|
DecimalType,
|
|
53
55
|
DoubleType,
|
|
54
56
|
FloatType,
|
|
@@ -63,6 +65,7 @@ from snowflake.snowpark.types import (
|
|
|
63
65
|
TimestampTimeZone,
|
|
64
66
|
TimestampType,
|
|
65
67
|
VariantType,
|
|
68
|
+
YearMonthIntervalType,
|
|
66
69
|
_FractionalType,
|
|
67
70
|
_IntegralType,
|
|
68
71
|
_NumericType,
|
|
@@ -199,7 +202,7 @@ def _validate_numeric_args(
|
|
|
199
202
|
case StringType():
|
|
200
203
|
# Cast strings to doubles following Spark
|
|
201
204
|
# https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala#L204
|
|
202
|
-
modified_args[i] =
|
|
205
|
+
modified_args[i] = snowpark_fn.try_cast(snowpark_args[i], DoubleType())
|
|
203
206
|
case _:
|
|
204
207
|
raise TypeError(
|
|
205
208
|
f"Data type mismatch: {function_name} requires numeric types, but got {typed_args[0].typ} and {typed_args[1].typ}."
|
|
@@ -519,7 +522,7 @@ def map_unresolved_function(
|
|
|
519
522
|
DecimalType() as t,
|
|
520
523
|
):
|
|
521
524
|
p1, s1 = _get_type_precision(t)
|
|
522
|
-
result_type = _get_decimal_multiplication_result_type(
|
|
525
|
+
result_type, _ = _get_decimal_multiplication_result_type(
|
|
523
526
|
p1, s1, p1, s1
|
|
524
527
|
)
|
|
525
528
|
result_exp = snowpark_fn.lit(None)
|
|
@@ -528,11 +531,17 @@ def map_unresolved_function(
|
|
|
528
531
|
):
|
|
529
532
|
p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
|
|
530
533
|
p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
534
|
+
(
|
|
535
|
+
result_type,
|
|
536
|
+
overflow_possible,
|
|
537
|
+
) = _get_decimal_multiplication_result_type(p1, s1, p2, s2)
|
|
538
|
+
result_exp = _arithmetic_operation(
|
|
539
|
+
snowpark_typed_args[0],
|
|
540
|
+
snowpark_typed_args[1],
|
|
541
|
+
lambda x, y: x * y,
|
|
542
|
+
overflow_possible,
|
|
543
|
+
global_config.spark_sql_ansi_enabled,
|
|
544
|
+
result_type,
|
|
536
545
|
)
|
|
537
546
|
case (NullType(), NullType()):
|
|
538
547
|
result_type = DoubleType()
|
|
@@ -617,7 +626,7 @@ def map_unresolved_function(
|
|
|
617
626
|
)
|
|
618
627
|
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
619
628
|
case (NullType(), _) | (_, NullType()):
|
|
620
|
-
result_type = _get_add_sub_result_type(
|
|
629
|
+
result_type, _ = _get_add_sub_result_type(
|
|
621
630
|
snowpark_typed_args[0].typ,
|
|
622
631
|
snowpark_typed_args[1].typ,
|
|
623
632
|
spark_function_name,
|
|
@@ -632,8 +641,17 @@ def map_unresolved_function(
|
|
|
632
641
|
if isinstance(t, (IntegerType, ShortType, ByteType)):
|
|
633
642
|
result_type = DateType()
|
|
634
643
|
result_exp = snowpark_args[0] + snowpark_args[1]
|
|
644
|
+
elif isinstance(t, (DayTimeIntervalType, YearMonthIntervalType)):
|
|
645
|
+
result_type = TimestampType()
|
|
646
|
+
result_exp = (
|
|
647
|
+
snowpark_args[date_param_index]
|
|
648
|
+
+ snowpark_args[t_param_index]
|
|
649
|
+
)
|
|
635
650
|
elif (
|
|
636
|
-
|
|
651
|
+
hasattr(
|
|
652
|
+
snowpark_typed_args[t_param_index].col._expr1, "pretty_name"
|
|
653
|
+
)
|
|
654
|
+
and "INTERVAL"
|
|
637
655
|
== snowpark_typed_args[t_param_index].col._expr1.pretty_name
|
|
638
656
|
):
|
|
639
657
|
result_type = TimestampType()
|
|
@@ -693,14 +711,21 @@ def map_unresolved_function(
|
|
|
693
711
|
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{spark_function_name}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
|
|
694
712
|
)
|
|
695
713
|
case _:
|
|
696
|
-
result_type = _get_add_sub_result_type(
|
|
714
|
+
result_type, overflow_possible = _get_add_sub_result_type(
|
|
697
715
|
snowpark_typed_args[0].typ,
|
|
698
716
|
snowpark_typed_args[1].typ,
|
|
699
717
|
spark_function_name,
|
|
700
718
|
)
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
719
|
+
|
|
720
|
+
result_exp = _arithmetic_operation(
|
|
721
|
+
snowpark_typed_args[0],
|
|
722
|
+
snowpark_typed_args[1],
|
|
723
|
+
lambda x, y: x + y,
|
|
724
|
+
overflow_possible,
|
|
725
|
+
global_config.spark_sql_ansi_enabled,
|
|
726
|
+
result_type,
|
|
727
|
+
)
|
|
728
|
+
|
|
704
729
|
case "-":
|
|
705
730
|
spark_function_name = _get_spark_function_name(
|
|
706
731
|
snowpark_typed_args[0],
|
|
@@ -715,7 +740,7 @@ def map_unresolved_function(
|
|
|
715
740
|
result_type = LongType()
|
|
716
741
|
result_exp = snowpark_fn.lit(None).cast(result_type)
|
|
717
742
|
case (NullType(), _) | (_, NullType()):
|
|
718
|
-
result_type = _get_add_sub_result_type(
|
|
743
|
+
result_type, _ = _get_add_sub_result_type(
|
|
719
744
|
snowpark_typed_args[0].typ,
|
|
720
745
|
snowpark_typed_args[1].typ,
|
|
721
746
|
spark_function_name,
|
|
@@ -726,6 +751,12 @@ def map_unresolved_function(
|
|
|
726
751
|
# TODO SNOW-2034420: resolve return type (it should be INTERVAL DAY)
|
|
727
752
|
result_type = LongType()
|
|
728
753
|
result_exp = snowpark_args[0] - snowpark_args[1]
|
|
754
|
+
case (DateType(), DayTimeIntervalType()) | (
|
|
755
|
+
DateType(),
|
|
756
|
+
YearMonthIntervalType(),
|
|
757
|
+
):
|
|
758
|
+
result_type = TimestampType()
|
|
759
|
+
result_exp = snowpark_args[0] - snowpark_args[1]
|
|
729
760
|
case (DateType(), StringType()):
|
|
730
761
|
if (
|
|
731
762
|
hasattr(snowpark_typed_args[1].col._expr1, "pretty_name")
|
|
@@ -806,14 +837,20 @@ def map_unresolved_function(
|
|
|
806
837
|
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{spark_function_name}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
|
|
807
838
|
)
|
|
808
839
|
case _:
|
|
809
|
-
result_type = _get_add_sub_result_type(
|
|
840
|
+
result_type, overflow_possible = _get_add_sub_result_type(
|
|
810
841
|
snowpark_typed_args[0].typ,
|
|
811
842
|
snowpark_typed_args[1].typ,
|
|
812
843
|
spark_function_name,
|
|
813
844
|
)
|
|
814
|
-
result_exp =
|
|
815
|
-
|
|
816
|
-
|
|
845
|
+
result_exp = _arithmetic_operation(
|
|
846
|
+
snowpark_typed_args[0],
|
|
847
|
+
snowpark_typed_args[1],
|
|
848
|
+
lambda x, y: x - y,
|
|
849
|
+
overflow_possible,
|
|
850
|
+
global_config.spark_sql_ansi_enabled,
|
|
851
|
+
result_type,
|
|
852
|
+
)
|
|
853
|
+
|
|
817
854
|
case "/":
|
|
818
855
|
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
819
856
|
case (DecimalType() as t1, NullType()):
|
|
@@ -825,15 +862,17 @@ def map_unresolved_function(
|
|
|
825
862
|
):
|
|
826
863
|
p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
|
|
827
864
|
p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
|
|
828
|
-
result_type,
|
|
865
|
+
result_type, overflow_possible = _get_decimal_division_result_type(
|
|
829
866
|
p1, s1, p2, s2
|
|
830
867
|
)
|
|
831
|
-
|
|
868
|
+
|
|
869
|
+
result_exp = _arithmetic_operation(
|
|
870
|
+
snowpark_typed_args[0],
|
|
871
|
+
snowpark_typed_args[1],
|
|
872
|
+
lambda x, y: _divnull(x, y),
|
|
873
|
+
overflow_possible,
|
|
874
|
+
global_config.spark_sql_ansi_enabled,
|
|
832
875
|
result_type,
|
|
833
|
-
t,
|
|
834
|
-
overflow_detected,
|
|
835
|
-
snowpark_args,
|
|
836
|
-
spark_function_name,
|
|
837
876
|
)
|
|
838
877
|
case (NullType(), NullType()):
|
|
839
878
|
result_type = DoubleType()
|
|
@@ -922,6 +961,10 @@ def map_unresolved_function(
|
|
|
922
961
|
raise AnalysisException(
|
|
923
962
|
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} < {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
|
|
924
963
|
)
|
|
964
|
+
# Check for interval-string comparisons
|
|
965
|
+
_check_interval_string_comparison(
|
|
966
|
+
"<", snowpark_typed_args, snowpark_arg_names
|
|
967
|
+
)
|
|
925
968
|
left, right = _coerce_for_comparison(
|
|
926
969
|
snowpark_typed_args[0], snowpark_typed_args[1]
|
|
927
970
|
)
|
|
@@ -936,6 +979,10 @@ def map_unresolved_function(
|
|
|
936
979
|
raise AnalysisException(
|
|
937
980
|
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} <= {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
|
|
938
981
|
)
|
|
982
|
+
# Check for interval-string comparisons
|
|
983
|
+
_check_interval_string_comparison(
|
|
984
|
+
"<=", snowpark_typed_args, snowpark_arg_names
|
|
985
|
+
)
|
|
939
986
|
left, right = _coerce_for_comparison(
|
|
940
987
|
snowpark_typed_args[0], snowpark_typed_args[1]
|
|
941
988
|
)
|
|
@@ -954,6 +1001,10 @@ def map_unresolved_function(
|
|
|
954
1001
|
)
|
|
955
1002
|
result_exp = TypedColumn(left.eqNullSafe(right), lambda: [BooleanType()])
|
|
956
1003
|
case "==" | "=":
|
|
1004
|
+
# Check for interval-string comparisons
|
|
1005
|
+
_check_interval_string_comparison(
|
|
1006
|
+
"=", snowpark_typed_args, snowpark_arg_names
|
|
1007
|
+
)
|
|
957
1008
|
spark_function_name = f"({snowpark_arg_names[0]} = {snowpark_arg_names[1]})"
|
|
958
1009
|
left, right = _coerce_for_comparison(
|
|
959
1010
|
snowpark_typed_args[0], snowpark_typed_args[1]
|
|
@@ -969,6 +1020,10 @@ def map_unresolved_function(
|
|
|
969
1020
|
raise AnalysisException(
|
|
970
1021
|
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} > {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
|
|
971
1022
|
)
|
|
1023
|
+
# Check for interval-string comparisons
|
|
1024
|
+
_check_interval_string_comparison(
|
|
1025
|
+
">", snowpark_typed_args, snowpark_arg_names
|
|
1026
|
+
)
|
|
972
1027
|
left, right = _coerce_for_comparison(
|
|
973
1028
|
snowpark_typed_args[0], snowpark_typed_args[1]
|
|
974
1029
|
)
|
|
@@ -983,6 +1038,10 @@ def map_unresolved_function(
|
|
|
983
1038
|
raise AnalysisException(
|
|
984
1039
|
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{snowpark_arg_names[0]} >= {snowpark_arg_names[1]}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").;'
|
|
985
1040
|
)
|
|
1041
|
+
# Check for interval-string comparisons
|
|
1042
|
+
_check_interval_string_comparison(
|
|
1043
|
+
">=", snowpark_typed_args, snowpark_arg_names
|
|
1044
|
+
)
|
|
986
1045
|
left, right = _coerce_for_comparison(
|
|
987
1046
|
snowpark_typed_args[0], snowpark_typed_args[1]
|
|
988
1047
|
)
|
|
@@ -1113,9 +1172,7 @@ def map_unresolved_function(
|
|
|
1113
1172
|
# SNOW-1955784: Support accuracy parameter
|
|
1114
1173
|
# Use percentile_disc to return actual values from dataset (matches PySpark behavior)
|
|
1115
1174
|
|
|
1116
|
-
def _pyspark_approx_percentile(
|
|
1117
|
-
column: Column, percentage: float, original_type: DataType
|
|
1118
|
-
) -> Column:
|
|
1175
|
+
def _pyspark_approx_percentile(column: Column, percentage: float) -> Column:
|
|
1119
1176
|
"""
|
|
1120
1177
|
PySpark-compatible percentile that returns actual values from dataset.
|
|
1121
1178
|
- PySpark's approx_percentile returns the "smallest value in the ordered col values
|
|
@@ -1132,7 +1189,7 @@ def map_unresolved_function(
|
|
|
1132
1189
|
result = snowpark_fn.function("percentile_disc")(
|
|
1133
1190
|
snowpark_fn.lit(percentage)
|
|
1134
1191
|
).within_group(column)
|
|
1135
|
-
return
|
|
1192
|
+
return result
|
|
1136
1193
|
|
|
1137
1194
|
column_type = snowpark_typed_args[0].typ
|
|
1138
1195
|
|
|
@@ -1143,26 +1200,18 @@ def map_unresolved_function(
|
|
|
1143
1200
|
assert array_func.function_name == "array", array_func
|
|
1144
1201
|
|
|
1145
1202
|
percentile_results = [
|
|
1146
|
-
_pyspark_approx_percentile(
|
|
1147
|
-
snowpark_args[0], unwrap_literal(arg), column_type
|
|
1148
|
-
)
|
|
1203
|
+
_pyspark_approx_percentile(snowpark_args[0], unwrap_literal(arg))
|
|
1149
1204
|
for arg in array_func.arguments
|
|
1150
1205
|
]
|
|
1151
1206
|
|
|
1152
1207
|
result_type = ArrayType(element_type=column_type, contains_null=False)
|
|
1153
|
-
result_exp = snowpark_fn.
|
|
1154
|
-
|
|
1155
|
-
result_type,
|
|
1156
|
-
)
|
|
1208
|
+
result_exp = snowpark_fn.array_construct(*percentile_results)
|
|
1209
|
+
result_exp = _resolve_aggregate_exp(result_exp, result_type)
|
|
1157
1210
|
else:
|
|
1158
1211
|
# Handle single percentile
|
|
1159
1212
|
percentage = unwrap_literal(exp.unresolved_function.arguments[1])
|
|
1160
|
-
result_exp =
|
|
1161
|
-
|
|
1162
|
-
snowpark_args[0], percentage, column_type
|
|
1163
|
-
),
|
|
1164
|
-
lambda: [column_type],
|
|
1165
|
-
)
|
|
1213
|
+
result_exp = _pyspark_approx_percentile(snowpark_args[0], percentage)
|
|
1214
|
+
result_exp = _resolve_aggregate_exp(result_exp, column_type)
|
|
1166
1215
|
case "array":
|
|
1167
1216
|
if len(snowpark_args) == 0:
|
|
1168
1217
|
result_exp = snowpark_fn.cast(
|
|
@@ -1336,27 +1385,55 @@ def map_unresolved_function(
|
|
|
1336
1385
|
result_exp = snowpark_fn.cast(result_exp, array_type)
|
|
1337
1386
|
result_exp = TypedColumn(result_exp, lambda: snowpark_typed_args[0].types)
|
|
1338
1387
|
case "array_repeat":
|
|
1388
|
+
elem, count = snowpark_args[0], snowpark_args[1]
|
|
1389
|
+
elem_type = snowpark_typed_args[0].typ
|
|
1390
|
+
result_type = ArrayType(elem_type)
|
|
1339
1391
|
|
|
1340
|
-
|
|
1341
|
-
input_types=[VariantType(), LongType()],
|
|
1342
|
-
return_type=ArrayType(),
|
|
1343
|
-
)
|
|
1344
|
-
def _array_repeat(elem, n):
|
|
1345
|
-
if n is None:
|
|
1346
|
-
return None
|
|
1347
|
-
if n < 0:
|
|
1348
|
-
return []
|
|
1349
|
-
return [elem] * n
|
|
1392
|
+
fallback_to_udf = True
|
|
1350
1393
|
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1394
|
+
if isinstance(count._expression, Literal):
|
|
1395
|
+
count_value = count._expression.value
|
|
1396
|
+
fallback_to_udf = False
|
|
1397
|
+
|
|
1398
|
+
if count_value is None:
|
|
1399
|
+
result_exp = snowpark_fn.lit(None).cast(result_type)
|
|
1400
|
+
elif count_value <= 0:
|
|
1401
|
+
result_exp = snowpark_fn.array_construct().cast(result_type)
|
|
1402
|
+
elif count_value <= 16:
|
|
1403
|
+
# count_value is small enough to initialize the array directly in memory
|
|
1404
|
+
elem_variant = snowpark_fn.cast(elem, VariantType())
|
|
1405
|
+
result_exp = snowpark_fn.array_construct(
|
|
1406
|
+
*([elem_variant] * count_value)
|
|
1407
|
+
).cast(result_type)
|
|
1408
|
+
else:
|
|
1409
|
+
fallback_to_udf = True
|
|
1410
|
+
|
|
1411
|
+
if fallback_to_udf:
|
|
1412
|
+
|
|
1413
|
+
@cached_udf(
|
|
1414
|
+
input_types=[VariantType(), LongType()],
|
|
1415
|
+
return_type=ArrayType(),
|
|
1416
|
+
)
|
|
1417
|
+
def _array_repeat(elem, n):
|
|
1418
|
+
if n is None:
|
|
1419
|
+
return None
|
|
1420
|
+
if n < 0:
|
|
1421
|
+
return []
|
|
1422
|
+
return [elem] * n
|
|
1423
|
+
|
|
1424
|
+
elem_variant = snowpark_fn.cast(elem, VariantType())
|
|
1425
|
+
|
|
1426
|
+
result_exp = (
|
|
1427
|
+
snowpark_fn.when(
|
|
1428
|
+
count.is_null(), snowpark_fn.lit(None).cast(result_type)
|
|
1429
|
+
)
|
|
1430
|
+
.when(count <= 0, snowpark_fn.array_construct().cast(result_type))
|
|
1431
|
+
.otherwise(
|
|
1432
|
+
snowpark_fn.cast(
|
|
1433
|
+
_array_repeat(elem_variant, count), result_type
|
|
1434
|
+
)
|
|
1435
|
+
)
|
|
1436
|
+
)
|
|
1360
1437
|
case "array_size":
|
|
1361
1438
|
array_type = snowpark_typed_args[0].typ
|
|
1362
1439
|
if not isinstance(array_type, ArrayType):
|
|
@@ -1556,7 +1633,7 @@ def map_unresolved_function(
|
|
|
1556
1633
|
result_exp = TypedColumn(result_exp, lambda: [LongType()])
|
|
1557
1634
|
case "bit_get" | "getbit":
|
|
1558
1635
|
snowflake_compat = get_boolean_session_config_param(
|
|
1559
|
-
"enable_snowflake_extension_behavior"
|
|
1636
|
+
"snowpark.connect.enable_snowflake_extension_behavior"
|
|
1560
1637
|
)
|
|
1561
1638
|
col, pos = snowpark_args
|
|
1562
1639
|
if snowflake_compat:
|
|
@@ -1863,14 +1940,11 @@ def map_unresolved_function(
|
|
|
1863
1940
|
qualifiers = snowpark_args[0].get_qualifiers()
|
|
1864
1941
|
case "collect_list" | "array_agg":
|
|
1865
1942
|
# TODO: SNOW-1967177 - Support structured types in array_agg
|
|
1866
|
-
result_exp = snowpark_fn.
|
|
1867
|
-
|
|
1868
|
-
snowpark_typed_args[0].column(to_semi_structure=True)
|
|
1869
|
-
),
|
|
1870
|
-
ArrayType(snowpark_typed_args[0].typ),
|
|
1943
|
+
result_exp = snowpark_fn.array_agg(
|
|
1944
|
+
snowpark_typed_args[0].column(to_semi_structure=True)
|
|
1871
1945
|
)
|
|
1872
|
-
result_exp =
|
|
1873
|
-
result_exp,
|
|
1946
|
+
result_exp = _resolve_aggregate_exp(
|
|
1947
|
+
result_exp, ArrayType(snowpark_typed_args[0].typ)
|
|
1874
1948
|
)
|
|
1875
1949
|
spark_function_name = f"collect_list({snowpark_arg_names[0]})"
|
|
1876
1950
|
case "collect_set":
|
|
@@ -2357,15 +2431,30 @@ def map_unresolved_function(
|
|
|
2357
2431
|
# If format is NULL, return NULL for all rows
|
|
2358
2432
|
result_exp = snowpark_fn.lit(None)
|
|
2359
2433
|
else:
|
|
2434
|
+
format_lit = snowpark_fn.lit(
|
|
2435
|
+
map_spark_timestamp_format_expression(
|
|
2436
|
+
exp.unresolved_function.arguments[1],
|
|
2437
|
+
snowpark_typed_args[0].typ,
|
|
2438
|
+
)
|
|
2439
|
+
)
|
|
2360
2440
|
result_exp = snowpark_fn.date_format(
|
|
2361
2441
|
snowpark_args[0],
|
|
2362
|
-
|
|
2363
|
-
map_spark_timestamp_format_expression(
|
|
2364
|
-
exp.unresolved_function.arguments[1],
|
|
2365
|
-
snowpark_typed_args[0].typ,
|
|
2366
|
-
)
|
|
2367
|
-
),
|
|
2442
|
+
format_lit,
|
|
2368
2443
|
)
|
|
2444
|
+
|
|
2445
|
+
if format_literal == "EEEE":
|
|
2446
|
+
# TODO: SNOW-2356874, for weekday, Snowflake only supports abbreviated name, e.g. "Fri". Patch spark "EEEE" until
|
|
2447
|
+
# snowflake supports full weekday name.
|
|
2448
|
+
result_exp = (
|
|
2449
|
+
snowpark_fn.when(result_exp == "Mon", "Monday")
|
|
2450
|
+
.when(result_exp == "Tue", "Tuesday")
|
|
2451
|
+
.when(result_exp == "Wed", "Wednesday")
|
|
2452
|
+
.when(result_exp == "Thu", "Thursday")
|
|
2453
|
+
.when(result_exp == "Fri", "Friday")
|
|
2454
|
+
.when(result_exp == "Sat", "Saturday")
|
|
2455
|
+
.when(result_exp == "Sun", "Sunday")
|
|
2456
|
+
.otherwise(result_exp)
|
|
2457
|
+
)
|
|
2369
2458
|
result_exp = TypedColumn(result_exp, lambda: [StringType()])
|
|
2370
2459
|
case "date_from_unix_date":
|
|
2371
2460
|
result_exp = snowpark_fn.date_add(
|
|
@@ -2464,6 +2553,14 @@ def map_unresolved_function(
|
|
|
2464
2553
|
)
|
|
2465
2554
|
result_type = LongType()
|
|
2466
2555
|
case "date_part" | "datepart" | "extract":
|
|
2556
|
+
# Check for interval types and throw NotImplementedError
|
|
2557
|
+
if isinstance(
|
|
2558
|
+
snowpark_typed_args[1].typ, (YearMonthIntervalType, DayTimeIntervalType)
|
|
2559
|
+
):
|
|
2560
|
+
raise NotImplementedError(
|
|
2561
|
+
f"{function_name} with interval types is not supported"
|
|
2562
|
+
)
|
|
2563
|
+
|
|
2467
2564
|
field_lit: str | None = unwrap_literal(exp.unresolved_function.arguments[0])
|
|
2468
2565
|
|
|
2469
2566
|
if field_lit is None:
|
|
@@ -3239,7 +3336,7 @@ def map_unresolved_function(
|
|
|
3239
3336
|
# TODO: See the spark-compatibility-issues.md explanation, this is quite different from Spark.
|
|
3240
3337
|
# MapType columns as input should raise an exception as they are not hashable.
|
|
3241
3338
|
snowflake_compat = get_boolean_session_config_param(
|
|
3242
|
-
"enable_snowflake_extension_behavior"
|
|
3339
|
+
"snowpark.connect.enable_snowflake_extension_behavior"
|
|
3243
3340
|
)
|
|
3244
3341
|
# Snowflake's hash function does allow MAP types, but Spark does not. Therefore, if we have the expansion flag enabled
|
|
3245
3342
|
# we want to let it pass through and hash MAP types.
|
|
@@ -3746,10 +3843,21 @@ def map_unresolved_function(
|
|
|
3746
3843
|
snowpark_fn.lit(is_outer),
|
|
3747
3844
|
)
|
|
3748
3845
|
case "input_file_name":
|
|
3749
|
-
#
|
|
3750
|
-
|
|
3751
|
-
|
|
3846
|
+
# Return the filename metadata column for file-based DataFrames
|
|
3847
|
+
# If METADATA$FILENAME doesn't exist (e.g., for DataFrames created from local data),
|
|
3848
|
+
# return empty string to match Spark's behavior
|
|
3849
|
+
from snowflake.snowpark_connect.relation.read.metadata_utils import (
|
|
3850
|
+
METADATA_FILENAME_COLUMN,
|
|
3752
3851
|
)
|
|
3852
|
+
|
|
3853
|
+
available_columns = column_mapping.get_snowpark_columns()
|
|
3854
|
+
if METADATA_FILENAME_COLUMN in available_columns:
|
|
3855
|
+
result_exp = snowpark_fn.col(METADATA_FILENAME_COLUMN)
|
|
3856
|
+
else:
|
|
3857
|
+
# Return empty when METADATA$FILENAME column doesn't exist, matching Spark behavior
|
|
3858
|
+
result_exp = snowpark_fn.lit("").cast(StringType())
|
|
3859
|
+
result_type = StringType()
|
|
3860
|
+
spark_function_name = "input_file_name()"
|
|
3753
3861
|
case "instr":
|
|
3754
3862
|
result_exp = snowpark_fn.charindex(snowpark_args[1], snowpark_args[0])
|
|
3755
3863
|
result_type = LongType()
|
|
@@ -4709,7 +4817,7 @@ def map_unresolved_function(
|
|
|
4709
4817
|
)
|
|
4710
4818
|
case "md5":
|
|
4711
4819
|
snowflake_compat = get_boolean_session_config_param(
|
|
4712
|
-
"enable_snowflake_extension_behavior"
|
|
4820
|
+
"snowpark.connect.enable_snowflake_extension_behavior"
|
|
4713
4821
|
)
|
|
4714
4822
|
|
|
4715
4823
|
# MD5 in Spark only accepts BinaryType or types that can be implicitly cast to it (StringType)
|
|
@@ -5283,9 +5391,14 @@ def map_unresolved_function(
|
|
|
5283
5391
|
result_exp = snowpark_fn.function(function_name)(
|
|
5284
5392
|
_check_percentile_percentage(exp.unresolved_function.arguments[1])
|
|
5285
5393
|
).within_group(snowpark_args[0])
|
|
5286
|
-
result_exp =
|
|
5287
|
-
|
|
5394
|
+
result_exp = (
|
|
5395
|
+
TypedColumn(
|
|
5396
|
+
snowpark_fn.cast(result_exp, FloatType()), lambda: [DoubleType()]
|
|
5397
|
+
)
|
|
5398
|
+
if not is_window_enabled()
|
|
5399
|
+
else TypedColumnWithDeferredCast(result_exp, lambda: [DoubleType()])
|
|
5288
5400
|
)
|
|
5401
|
+
|
|
5289
5402
|
spark_function_name = f"{function_name}({unwrap_literal(exp.unresolved_function.arguments[1])}) WITHIN GROUP (ORDER BY {snowpark_arg_names[0]})"
|
|
5290
5403
|
case "pi":
|
|
5291
5404
|
spark_function_name = "PI()"
|
|
@@ -7504,6 +7617,12 @@ def map_unresolved_function(
|
|
|
7504
7617
|
)
|
|
7505
7618
|
result_type = DateType()
|
|
7506
7619
|
case "try_add":
|
|
7620
|
+
# Check for interval types and throw NotImplementedError
|
|
7621
|
+
for arg in snowpark_typed_args:
|
|
7622
|
+
if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
|
|
7623
|
+
raise NotImplementedError(
|
|
7624
|
+
"try_add with interval types is not supported"
|
|
7625
|
+
)
|
|
7507
7626
|
result_exp = _try_arithmetic_helper(snowpark_typed_args, snowpark_args, 0)
|
|
7508
7627
|
result_exp = _type_with_typer(result_exp)
|
|
7509
7628
|
case "try_aes_decrypt":
|
|
@@ -7557,6 +7676,12 @@ def map_unresolved_function(
|
|
|
7557
7676
|
DoubleType(), cleaned, calculating_avg=True
|
|
7558
7677
|
)
|
|
7559
7678
|
case "try_divide":
|
|
7679
|
+
# Check for interval types and throw NotImplementedError
|
|
7680
|
+
for arg in snowpark_typed_args:
|
|
7681
|
+
if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
|
|
7682
|
+
raise NotImplementedError(
|
|
7683
|
+
"try_divide with interval types is not supported"
|
|
7684
|
+
)
|
|
7560
7685
|
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
7561
7686
|
case (NullType(), t) | (t, NullType()):
|
|
7562
7687
|
result_exp = snowpark_fn.lit(None)
|
|
@@ -7580,63 +7705,20 @@ def map_unresolved_function(
|
|
|
7580
7705
|
)
|
|
7581
7706
|
| (DecimalType(), DecimalType())
|
|
7582
7707
|
):
|
|
7583
|
-
|
|
7584
|
-
|
|
7585
|
-
|
|
7586
|
-
|
|
7587
|
-
|
|
7588
|
-
snowpark_typed_args[0].typ.scale,
|
|
7589
|
-
snowpark_typed_args[0].typ.precision,
|
|
7590
|
-
)
|
|
7591
|
-
s2, p2 = (
|
|
7592
|
-
snowpark_typed_args[1].typ.scale,
|
|
7593
|
-
snowpark_typed_args[1].typ.precision,
|
|
7594
|
-
)
|
|
7595
|
-
# The scale and precision formula that Spark follows for DecimalType
|
|
7596
|
-
# arithmetic operations can be found in the following Spark source
|
|
7597
|
-
# code file:
|
|
7598
|
-
# https://github.com/apache/spark/blob/a584cc48ef63fefb2e035349c8684250f8b936c4/docs/sql-ref-ansi-compliance.md
|
|
7599
|
-
new_scale = max(6, s1 + p2 + 1)
|
|
7600
|
-
new_precision = p1 - s1 + s2 + new_scale
|
|
7601
|
-
|
|
7602
|
-
elif isinstance(snowpark_typed_args[0].typ, DecimalType):
|
|
7603
|
-
s1, p1 = (
|
|
7604
|
-
snowpark_typed_args[0].typ.scale,
|
|
7605
|
-
snowpark_typed_args[0].typ.precision,
|
|
7606
|
-
)
|
|
7607
|
-
# INT is treated as Decimal(10, 0)
|
|
7608
|
-
new_scale = max(6, s1 + 11)
|
|
7609
|
-
new_precision = p1 - s1 + new_scale
|
|
7610
|
-
|
|
7611
|
-
else: # right is DecimalType
|
|
7612
|
-
s2, p2 = (
|
|
7613
|
-
snowpark_typed_args[1].typ.scale,
|
|
7614
|
-
snowpark_typed_args[1].typ.precision,
|
|
7615
|
-
)
|
|
7616
|
-
# INT is treated as Decimal(10, 0)
|
|
7617
|
-
new_scale = max(6, 11 + p2)
|
|
7618
|
-
new_precision = (
|
|
7619
|
-
10 - 0 + s2 + new_scale
|
|
7620
|
-
) # INT has precision 10, scale 0
|
|
7621
|
-
|
|
7622
|
-
# apply precision cap
|
|
7623
|
-
if new_precision > 38:
|
|
7624
|
-
new_scale -= new_precision - 38
|
|
7625
|
-
new_precision = 38
|
|
7626
|
-
new_scale = max(new_scale, 6)
|
|
7627
|
-
|
|
7628
|
-
left_double = snowpark_fn.cast(snowpark_args[0], DoubleType())
|
|
7629
|
-
right_double = snowpark_fn.cast(snowpark_args[1], DoubleType())
|
|
7630
|
-
|
|
7631
|
-
quotient = snowpark_fn.when(
|
|
7632
|
-
snowpark_args[1] == 0, snowpark_fn.lit(None)
|
|
7633
|
-
).otherwise(left_double / right_double)
|
|
7634
|
-
quotient = snowpark_fn.cast(quotient, StringType())
|
|
7708
|
+
p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
|
|
7709
|
+
p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
|
|
7710
|
+
result_type, overflow_possible = _get_decimal_division_result_type(
|
|
7711
|
+
p1, s1, p2, s2
|
|
7712
|
+
)
|
|
7635
7713
|
|
|
7636
|
-
result_exp =
|
|
7637
|
-
|
|
7714
|
+
result_exp = _arithmetic_operation(
|
|
7715
|
+
snowpark_typed_args[0],
|
|
7716
|
+
snowpark_typed_args[1],
|
|
7717
|
+
lambda x, y: _divnull(x, y),
|
|
7718
|
+
overflow_possible,
|
|
7719
|
+
False,
|
|
7720
|
+
result_type,
|
|
7638
7721
|
)
|
|
7639
|
-
result_type = DecimalType(new_precision, new_scale)
|
|
7640
7722
|
case (_NumericType(), _NumericType()):
|
|
7641
7723
|
result_exp = snowpark_fn.when(
|
|
7642
7724
|
snowpark_args[1] == 0, snowpark_fn.lit(None)
|
|
@@ -7708,6 +7790,12 @@ def map_unresolved_function(
|
|
|
7708
7790
|
f"Expected either (ArrayType, IntegralType) or (MapType, StringType), got {snowpark_typed_args[0].typ}, {snowpark_typed_args[1].typ}."
|
|
7709
7791
|
)
|
|
7710
7792
|
case "try_multiply":
|
|
7793
|
+
# Check for interval types and throw NotImplementedError
|
|
7794
|
+
for arg in snowpark_typed_args:
|
|
7795
|
+
if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
|
|
7796
|
+
raise NotImplementedError(
|
|
7797
|
+
"try_multiply with interval types is not supported"
|
|
7798
|
+
)
|
|
7711
7799
|
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
7712
7800
|
case (NullType(), t) | (t, NullType()):
|
|
7713
7801
|
result_exp = snowpark_fn.lit(None)
|
|
@@ -7749,42 +7837,21 @@ def map_unresolved_function(
|
|
|
7749
7837
|
)
|
|
7750
7838
|
| (DecimalType(), DecimalType())
|
|
7751
7839
|
):
|
|
7752
|
-
|
|
7753
|
-
|
|
7754
|
-
|
|
7755
|
-
|
|
7756
|
-
|
|
7757
|
-
|
|
7758
|
-
|
|
7759
|
-
|
|
7760
|
-
|
|
7761
|
-
|
|
7762
|
-
|
|
7763
|
-
|
|
7764
|
-
|
|
7765
|
-
|
|
7766
|
-
new_precision = snowpark_typed_args[0].typ.precision + 11
|
|
7767
|
-
new_scale = snowpark_typed_args[0].typ.scale
|
|
7768
|
-
else:
|
|
7769
|
-
new_precision = snowpark_typed_args[1].typ.precision + 11
|
|
7770
|
-
new_scale = snowpark_typed_args[1].typ.scale
|
|
7771
|
-
|
|
7772
|
-
# truncating down appropriately
|
|
7773
|
-
if new_precision > 38:
|
|
7774
|
-
new_precision = 38
|
|
7775
|
-
if new_scale > new_precision:
|
|
7776
|
-
new_scale = new_precision
|
|
7777
|
-
|
|
7778
|
-
left_double = snowpark_fn.cast(snowpark_args[0], DoubleType())
|
|
7779
|
-
right_double = snowpark_fn.cast(snowpark_args[1], DoubleType())
|
|
7780
|
-
|
|
7781
|
-
product = left_double * right_double
|
|
7782
|
-
|
|
7783
|
-
product = snowpark_fn.cast(product, StringType())
|
|
7784
|
-
result_exp = _try_cast_helper(
|
|
7785
|
-
product, DecimalType(new_precision, new_scale)
|
|
7840
|
+
p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
|
|
7841
|
+
p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
|
|
7842
|
+
(
|
|
7843
|
+
result_type,
|
|
7844
|
+
overflow_possible,
|
|
7845
|
+
) = _get_decimal_multiplication_result_type(p1, s1, p2, s2)
|
|
7846
|
+
|
|
7847
|
+
result_exp = _arithmetic_operation(
|
|
7848
|
+
snowpark_typed_args[0],
|
|
7849
|
+
snowpark_typed_args[1],
|
|
7850
|
+
lambda x, y: x * y,
|
|
7851
|
+
overflow_possible,
|
|
7852
|
+
False,
|
|
7853
|
+
result_type,
|
|
7786
7854
|
)
|
|
7787
|
-
result_type = DecimalType(new_precision, new_scale)
|
|
7788
7855
|
case (_NumericType(), _NumericType()):
|
|
7789
7856
|
result_exp = snowpark_args[0] * snowpark_args[1]
|
|
7790
7857
|
result_exp = _type_with_typer(result_exp)
|
|
@@ -7827,6 +7894,12 @@ def map_unresolved_function(
|
|
|
7827
7894
|
snowpark_typed_args[0].typ, snowpark_args[0]
|
|
7828
7895
|
)
|
|
7829
7896
|
case "try_subtract":
|
|
7897
|
+
# Check for interval types and throw NotImplementedError
|
|
7898
|
+
for arg in snowpark_typed_args:
|
|
7899
|
+
if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
|
|
7900
|
+
raise NotImplementedError(
|
|
7901
|
+
"try_subtract with interval types is not supported"
|
|
7902
|
+
)
|
|
7830
7903
|
result_exp = _try_arithmetic_helper(snowpark_typed_args, snowpark_args, 1)
|
|
7831
7904
|
result_exp = _type_with_typer(result_exp)
|
|
7832
7905
|
case "try_to_number":
|
|
@@ -8391,20 +8464,10 @@ def map_unresolved_function(
|
|
|
8391
8464
|
return spark_col_names, typed_col
|
|
8392
8465
|
|
|
8393
8466
|
|
|
8394
|
-
def _cast_helper(column: Column, to: DataType) -> Column:
|
|
8395
|
-
if global_config.spark_sql_ansi_enabled:
|
|
8396
|
-
column_mediator = (
|
|
8397
|
-
snowpark_fn.cast(column, StringType())
|
|
8398
|
-
if isinstance(to, DecimalType)
|
|
8399
|
-
else column
|
|
8400
|
-
)
|
|
8401
|
-
return snowpark_fn.cast(column_mediator, to)
|
|
8402
|
-
else:
|
|
8403
|
-
return _try_cast_helper(column, to)
|
|
8404
|
-
|
|
8405
|
-
|
|
8406
8467
|
def _try_cast_helper(column: Column, to: DataType) -> Column:
|
|
8407
8468
|
"""
|
|
8469
|
+
DEPRECATED because of performance issues
|
|
8470
|
+
|
|
8408
8471
|
Attempts to cast a given column to a specified data type using the same behaviour as Spark.
|
|
8409
8472
|
|
|
8410
8473
|
Args:
|
|
@@ -9600,71 +9663,109 @@ def _decimal_add_sub_result_type_helper(p1, s1, p2, s2):
|
|
|
9600
9663
|
return result_precision, min_scale, return_type_precision, return_type_scale
|
|
9601
9664
|
|
|
9602
9665
|
|
|
9603
|
-
def
|
|
9604
|
-
result_type: DecimalType | DataType,
|
|
9605
|
-
other_type: DataType,
|
|
9606
|
-
snowpark_args: list[Column],
|
|
9607
|
-
) -> Column:
|
|
9608
|
-
if global_config.spark_sql_ansi_enabled:
|
|
9609
|
-
result_exp = snowpark_args[0] * snowpark_args[1]
|
|
9610
|
-
else:
|
|
9611
|
-
if isinstance(other_type, _IntegralType):
|
|
9612
|
-
result_exp = snowpark_args[0].cast(result_type) * snowpark_args[1].cast(
|
|
9613
|
-
result_type
|
|
9614
|
-
)
|
|
9615
|
-
else:
|
|
9616
|
-
result_exp = snowpark_args[0].cast(DoubleType()) * snowpark_args[1].cast(
|
|
9617
|
-
DoubleType()
|
|
9618
|
-
)
|
|
9619
|
-
result_exp = _try_cast_helper(result_exp, result_type)
|
|
9620
|
-
return result_exp
|
|
9621
|
-
|
|
9622
|
-
|
|
9623
|
-
def _get_decimal_multiplication_result_type(p1, s1, p2, s2) -> DecimalType:
|
|
9666
|
+
def _get_decimal_multiplication_result_type(p1, s1, p2, s2) -> tuple[DecimalType, bool]:
|
|
9624
9667
|
result_precision = p1 + p2 + 1
|
|
9625
9668
|
result_scale = s1 + s2
|
|
9669
|
+
overflow_possible = False
|
|
9626
9670
|
if result_precision > 38:
|
|
9671
|
+
overflow_possible = True
|
|
9627
9672
|
if result_scale > 6:
|
|
9628
9673
|
overflow = result_precision - 38
|
|
9629
9674
|
result_scale = max(6, result_scale - overflow)
|
|
9630
9675
|
result_precision = 38
|
|
9631
|
-
return DecimalType(result_precision, result_scale)
|
|
9676
|
+
return DecimalType(result_precision, result_scale), overflow_possible
|
|
9632
9677
|
|
|
9633
9678
|
|
|
9634
|
-
def
|
|
9635
|
-
|
|
9636
|
-
|
|
9637
|
-
|
|
9638
|
-
|
|
9639
|
-
|
|
9679
|
+
def _arithmetic_operation(
|
|
9680
|
+
arg1: TypedColumn,
|
|
9681
|
+
arg2: TypedColumn,
|
|
9682
|
+
op: Callable[[Column, Column], Column],
|
|
9683
|
+
overflow_possible: bool,
|
|
9684
|
+
should_raise_on_overflow: bool,
|
|
9685
|
+
target_type: DecimalType,
|
|
9640
9686
|
) -> Column:
|
|
9641
|
-
|
|
9642
|
-
|
|
9643
|
-
|
|
9644
|
-
|
|
9645
|
-
|
|
9646
|
-
|
|
9647
|
-
|
|
9687
|
+
def _cast_arg(tc: TypedColumn) -> Column:
|
|
9688
|
+
_, s = _get_type_precision(tc.typ)
|
|
9689
|
+
typ = (
|
|
9690
|
+
DoubleType()
|
|
9691
|
+
if s > 0
|
|
9692
|
+
or (
|
|
9693
|
+
isinstance(tc.typ, _FractionalType)
|
|
9694
|
+
and not isinstance(tc.typ, DecimalType)
|
|
9695
|
+
)
|
|
9696
|
+
else LongType()
|
|
9697
|
+
)
|
|
9698
|
+
return tc.col.cast(typ)
|
|
9699
|
+
|
|
9700
|
+
op_for_overflow_check = op(arg1.col.cast(DoubleType()), arg2.col.cast(DoubleType()))
|
|
9701
|
+
safe_op = op(_cast_arg(arg1), _cast_arg(arg2))
|
|
9702
|
+
|
|
9703
|
+
if overflow_possible:
|
|
9704
|
+
return _cast_arithmetic_operation_result(
|
|
9705
|
+
op_for_overflow_check, safe_op, target_type, should_raise_on_overflow
|
|
9648
9706
|
)
|
|
9649
9707
|
else:
|
|
9650
|
-
|
|
9651
|
-
|
|
9652
|
-
|
|
9653
|
-
|
|
9654
|
-
|
|
9708
|
+
return op(arg1.col, arg2.col).cast(target_type)
|
|
9709
|
+
|
|
9710
|
+
|
|
9711
|
+
def _cast_arithmetic_operation_result(
|
|
9712
|
+
overflow_check_expr: Column,
|
|
9713
|
+
result_expr: Column,
|
|
9714
|
+
target_type: DecimalType,
|
|
9715
|
+
should_raise_on_overflow: bool,
|
|
9716
|
+
) -> Column:
|
|
9717
|
+
"""
|
|
9718
|
+
Casts an arithmetic operation result to the target decimal type with overflow detection.
|
|
9719
|
+
This function uses a dual-expression approach for robust overflow handling:
|
|
9720
|
+
Args:
|
|
9721
|
+
overflow_check_expr: Arithmetic expression using DoubleType operands for overflow detection.
|
|
9722
|
+
This expression is used ONLY for boundary checking against the target
|
|
9723
|
+
decimal's min/max values. DoubleType preserves the magnitude of large
|
|
9724
|
+
intermediate results that might overflow in decimal arithmetic.
|
|
9725
|
+
result_expr: Arithmetic expression using safer operand types (LongType for integers,
|
|
9726
|
+
DoubleType for fractionals) for the actual result computation.
|
|
9727
|
+
target_type: Target DecimalType to cast the result to.
|
|
9728
|
+
should_raise_on_overflow: If True raises ArithmeticException on overflow, if False, returns NULL on overflow.
|
|
9729
|
+
"""
|
|
9730
|
+
|
|
9731
|
+
def create_overflow_handler(min_val, max_val, type_name: str):
|
|
9732
|
+
if should_raise_on_overflow:
|
|
9733
|
+
raise_error = _raise_error_helper(target_type, ArithmeticException)
|
|
9734
|
+
return snowpark_fn.when(
|
|
9735
|
+
(overflow_check_expr < snowpark_fn.lit(min_val))
|
|
9736
|
+
| (overflow_check_expr > snowpark_fn.lit(max_val)),
|
|
9737
|
+
raise_error(
|
|
9738
|
+
snowpark_fn.lit(
|
|
9739
|
+
f'[NUMERIC_VALUE_OUT_OF_RANGE] Value cannot be represented as {type_name}. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error, and return NULL instead.'
|
|
9740
|
+
)
|
|
9741
|
+
),
|
|
9742
|
+
).otherwise(result_expr.cast(target_type))
|
|
9743
|
+
else:
|
|
9744
|
+
return snowpark_fn.when(
|
|
9745
|
+
(overflow_check_expr < snowpark_fn.lit(min_val))
|
|
9746
|
+
| (overflow_check_expr > snowpark_fn.lit(max_val)),
|
|
9747
|
+
snowpark_fn.lit(None),
|
|
9748
|
+
).otherwise(result_expr.cast(target_type))
|
|
9749
|
+
|
|
9750
|
+
precision = target_type.precision
|
|
9751
|
+
scale = target_type.scale
|
|
9752
|
+
|
|
9753
|
+
max_val = (10**precision - 1) / (10**scale)
|
|
9754
|
+
min_val = -max_val
|
|
9755
|
+
|
|
9756
|
+
return create_overflow_handler(min_val, max_val, f"DECIMAL({precision},{scale})")
|
|
9655
9757
|
|
|
9656
9758
|
|
|
9657
9759
|
def _get_decimal_division_result_type(p1, s1, p2, s2) -> tuple[DecimalType, bool]:
|
|
9658
|
-
|
|
9760
|
+
overflow_possible = False
|
|
9659
9761
|
result_scale = max(6, s1 + p2 + 1)
|
|
9660
9762
|
result_precision = p1 - s1 + s2 + result_scale
|
|
9661
9763
|
if result_precision > 38:
|
|
9662
|
-
|
|
9663
|
-
overflow_detected = True
|
|
9764
|
+
overflow_possible = True
|
|
9664
9765
|
overflow = result_precision - 38
|
|
9665
9766
|
result_scale = max(6, result_scale - overflow)
|
|
9666
9767
|
result_precision = 38
|
|
9667
|
-
return DecimalType(result_precision, result_scale),
|
|
9768
|
+
return DecimalType(result_precision, result_scale), overflow_possible
|
|
9668
9769
|
|
|
9669
9770
|
|
|
9670
9771
|
def _try_arithmetic_helper(
|
|
@@ -9778,46 +9879,20 @@ def _try_arithmetic_helper(
|
|
|
9778
9879
|
DecimalType(),
|
|
9779
9880
|
DecimalType(),
|
|
9780
9881
|
):
|
|
9781
|
-
|
|
9782
|
-
|
|
9783
|
-
|
|
9784
|
-
|
|
9785
|
-
|
|
9786
|
-
|
|
9787
|
-
|
|
9788
|
-
|
|
9789
|
-
|
|
9790
|
-
|
|
9791
|
-
|
|
9792
|
-
|
|
9793
|
-
|
|
9794
|
-
else:
|
|
9795
|
-
# Both decimal types
|
|
9796
|
-
if operation_type == 1 and s1 == s2: # subtraction with matching scales
|
|
9797
|
-
new_scale = s1
|
|
9798
|
-
max_integral_digits = max(p1 - s1, p2 - s2)
|
|
9799
|
-
new_precision = max_integral_digits + new_scale
|
|
9800
|
-
else:
|
|
9801
|
-
new_scale = max(s1, s2)
|
|
9802
|
-
max_integral_digits = max(p1 - s1, p2 - s2)
|
|
9803
|
-
new_precision = max_integral_digits + new_scale + 1
|
|
9804
|
-
|
|
9805
|
-
# Overflow check
|
|
9806
|
-
if new_precision > 38:
|
|
9807
|
-
if global_config.spark_sql_ansi_enabled:
|
|
9808
|
-
raise ArithmeticException(
|
|
9809
|
-
f'[NUMERIC_VALUE_OUT_OF_RANGE] Precision {new_precision} exceeds maximum allowed precision of 38. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error, and return NULL instead.'
|
|
9810
|
-
)
|
|
9811
|
-
return snowpark_fn.lit(None)
|
|
9812
|
-
|
|
9813
|
-
left_operand, right_operand = snowpark_args[0], snowpark_args[1]
|
|
9814
|
-
|
|
9815
|
-
result = (
|
|
9816
|
-
left_operand + right_operand
|
|
9817
|
-
if operation_type == 0
|
|
9818
|
-
else left_operand - right_operand
|
|
9882
|
+
result_type, overflow_possible = _get_add_sub_result_type(
|
|
9883
|
+
typed_args[0].typ,
|
|
9884
|
+
typed_args[1].typ,
|
|
9885
|
+
"try_add" if operation_type == 0 else "try_subtract",
|
|
9886
|
+
)
|
|
9887
|
+
|
|
9888
|
+
return _arithmetic_operation(
|
|
9889
|
+
typed_args[0],
|
|
9890
|
+
typed_args[1],
|
|
9891
|
+
lambda x, y: x + y if operation_type == 0 else x - y,
|
|
9892
|
+
overflow_possible,
|
|
9893
|
+
False,
|
|
9894
|
+
result_type,
|
|
9819
9895
|
)
|
|
9820
|
-
return snowpark_fn.cast(result, DecimalType(new_precision, new_scale))
|
|
9821
9896
|
|
|
9822
9897
|
# If either of the inputs is floating point, we can just let it go through to Snowflake, where overflow
|
|
9823
9898
|
# matches Spark and goes to inf.
|
|
@@ -9863,7 +9938,8 @@ def _get_add_sub_result_type(
|
|
|
9863
9938
|
type1: DataType,
|
|
9864
9939
|
type2: DataType,
|
|
9865
9940
|
spark_function_name: str,
|
|
9866
|
-
) -> DataType:
|
|
9941
|
+
) -> tuple[DataType, bool]:
|
|
9942
|
+
overflow_possible = False
|
|
9867
9943
|
result_type = _find_common_type([type1, type2])
|
|
9868
9944
|
match result_type:
|
|
9869
9945
|
case DecimalType():
|
|
@@ -9872,6 +9948,7 @@ def _get_add_sub_result_type(
|
|
|
9872
9948
|
result_scale = max(s1, s2)
|
|
9873
9949
|
result_precision = max(p1 - s1, p2 - s2) + result_scale + 1
|
|
9874
9950
|
if result_precision > 38:
|
|
9951
|
+
overflow_possible = True
|
|
9875
9952
|
if result_scale > 6:
|
|
9876
9953
|
overflow = result_precision - 38
|
|
9877
9954
|
result_scale = max(6, result_scale - overflow)
|
|
@@ -9900,7 +9977,71 @@ def _get_add_sub_result_type(
|
|
|
9900
9977
|
raise AnalysisException(
|
|
9901
9978
|
f'[DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: the binary operator requires the input type ("NUMERIC" or "INTERVAL DAY TO SECOND" or "INTERVAL YEAR TO MONTH" or "INTERVAL"), not "BOOLEAN".',
|
|
9902
9979
|
)
|
|
9903
|
-
return result_type
|
|
9980
|
+
return result_type, overflow_possible
|
|
9981
|
+
|
|
9982
|
+
|
|
9983
|
+
def _get_interval_type_name(
|
|
9984
|
+
interval_type: Union[YearMonthIntervalType, DayTimeIntervalType]
|
|
9985
|
+
) -> str:
|
|
9986
|
+
"""Get the formatted interval type name for error messages."""
|
|
9987
|
+
if isinstance(interval_type, YearMonthIntervalType):
|
|
9988
|
+
if interval_type.start_field == 0 and interval_type.end_field == 0:
|
|
9989
|
+
return "INTERVAL YEAR"
|
|
9990
|
+
elif interval_type.start_field == 1 and interval_type.end_field == 1:
|
|
9991
|
+
return "INTERVAL MONTH"
|
|
9992
|
+
else:
|
|
9993
|
+
return "INTERVAL YEAR TO MONTH"
|
|
9994
|
+
else: # DayTimeIntervalType
|
|
9995
|
+
if interval_type.start_field == 0 and interval_type.end_field == 0:
|
|
9996
|
+
return "INTERVAL DAY"
|
|
9997
|
+
elif interval_type.start_field == 1 and interval_type.end_field == 1:
|
|
9998
|
+
return "INTERVAL HOUR"
|
|
9999
|
+
elif interval_type.start_field == 2 and interval_type.end_field == 2:
|
|
10000
|
+
return "INTERVAL MINUTE"
|
|
10001
|
+
elif interval_type.start_field == 3 and interval_type.end_field == 3:
|
|
10002
|
+
return "INTERVAL SECOND"
|
|
10003
|
+
else:
|
|
10004
|
+
return "INTERVAL DAY TO SECOND"
|
|
10005
|
+
|
|
10006
|
+
|
|
10007
|
+
def _check_interval_string_comparison(
|
|
10008
|
+
operator: str, snowpark_typed_args: List[TypedColumn], snowpark_arg_names: List[str]
|
|
10009
|
+
) -> None:
|
|
10010
|
+
"""Check for invalid interval-string comparisons and raise AnalysisException if found."""
|
|
10011
|
+
if (
|
|
10012
|
+
isinstance(
|
|
10013
|
+
snowpark_typed_args[0].typ, (YearMonthIntervalType, DayTimeIntervalType)
|
|
10014
|
+
)
|
|
10015
|
+
and isinstance(snowpark_typed_args[1].typ, StringType)
|
|
10016
|
+
or isinstance(snowpark_typed_args[0].typ, StringType)
|
|
10017
|
+
and isinstance(
|
|
10018
|
+
snowpark_typed_args[1].typ, (YearMonthIntervalType, DayTimeIntervalType)
|
|
10019
|
+
)
|
|
10020
|
+
):
|
|
10021
|
+
# Format interval type name for error message
|
|
10022
|
+
interval_type = (
|
|
10023
|
+
snowpark_typed_args[0].typ
|
|
10024
|
+
if isinstance(
|
|
10025
|
+
snowpark_typed_args[0].typ, (YearMonthIntervalType, DayTimeIntervalType)
|
|
10026
|
+
)
|
|
10027
|
+
else snowpark_typed_args[1].typ
|
|
10028
|
+
)
|
|
10029
|
+
interval_name = _get_interval_type_name(interval_type)
|
|
10030
|
+
|
|
10031
|
+
left_type = (
|
|
10032
|
+
"STRING"
|
|
10033
|
+
if isinstance(snowpark_typed_args[0].typ, StringType)
|
|
10034
|
+
else interval_name
|
|
10035
|
+
)
|
|
10036
|
+
right_type = (
|
|
10037
|
+
"STRING"
|
|
10038
|
+
if isinstance(snowpark_typed_args[1].typ, StringType)
|
|
10039
|
+
else interval_name
|
|
10040
|
+
)
|
|
10041
|
+
|
|
10042
|
+
raise AnalysisException(
|
|
10043
|
+
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "({snowpark_arg_names[0]} {operator} {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{left_type}" and "{right_type}").;'
|
|
10044
|
+
)
|
|
9904
10045
|
|
|
9905
10046
|
|
|
9906
10047
|
def _get_spark_function_name(
|
|
@@ -9944,6 +10085,21 @@ def _get_spark_function_name(
|
|
|
9944
10085
|
return f"{date_param_name1} {operation_op} {snowpark_arg_names[1]}"
|
|
9945
10086
|
else:
|
|
9946
10087
|
return f"{operation_func}(cast({date_param_name1} as date), cast({snowpark_arg_names[1]} as double))"
|
|
10088
|
+
case (DateType(), DayTimeIntervalType()) | (
|
|
10089
|
+
DateType(),
|
|
10090
|
+
YearMonthIntervalType(),
|
|
10091
|
+
):
|
|
10092
|
+
date_param_name1 = _get_literal_param_name(exp, 0, snowpark_arg_names[0])
|
|
10093
|
+
return f"{date_param_name1} {operation_op} {snowpark_arg_names[1]}"
|
|
10094
|
+
case (DayTimeIntervalType(), DateType()) | (
|
|
10095
|
+
YearMonthIntervalType(),
|
|
10096
|
+
DateType(),
|
|
10097
|
+
):
|
|
10098
|
+
date_param_name2 = _get_literal_param_name(exp, 1, snowpark_arg_names[1])
|
|
10099
|
+
if function_name == "+":
|
|
10100
|
+
return f"{date_param_name2} {operation_op} {snowpark_arg_names[0]}"
|
|
10101
|
+
else:
|
|
10102
|
+
return default_spark_function_name
|
|
9947
10103
|
case (DateType() as dt, _) | (_, DateType() as dt):
|
|
9948
10104
|
date_param_index = 0 if dt == col1.typ else 1
|
|
9949
10105
|
date_param_name = _get_literal_param_name(
|