snowpark-connect 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/column_name_handler.py +73 -100
- snowflake/snowpark_connect/column_qualifier.py +47 -0
- snowflake/snowpark_connect/dataframe_container.py +3 -2
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +4 -2
- snowflake/snowpark_connect/expression/map_expression.py +5 -4
- snowflake/snowpark_connect/expression/map_extension.py +12 -6
- snowflake/snowpark_connect/expression/map_sql_expression.py +38 -3
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +5 -5
- snowflake/snowpark_connect/expression/map_unresolved_function.py +869 -107
- snowflake/snowpark_connect/expression/map_unresolved_star.py +9 -7
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +4 -1
- snowflake/snowpark_connect/relation/map_aggregate.py +8 -5
- snowflake/snowpark_connect/relation/map_column_ops.py +4 -3
- snowflake/snowpark_connect/relation/map_extension.py +10 -9
- snowflake/snowpark_connect/relation/map_join.py +5 -2
- snowflake/snowpark_connect/relation/map_sql.py +33 -1
- snowflake/snowpark_connect/relation/map_subquery_alias.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_table.py +6 -3
- snowflake/snowpark_connect/relation/write/map_write.py +29 -14
- snowflake/snowpark_connect/server.py +1 -2
- snowflake/snowpark_connect/type_mapping.py +36 -3
- snowflake/snowpark_connect/typed_column.py +8 -6
- snowflake/snowpark_connect/utils/session.py +19 -3
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +1 -1
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/METADATA +5 -2
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/RECORD +36 -37
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
- {snowpark_connect-0.31.0.data → snowpark_connect-0.32.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.31.0.data → snowpark_connect-0.32.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.31.0.data → snowpark_connect-0.32.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/top_level.txt +0 -0
|
@@ -20,7 +20,7 @@ from contextlib import suppress
|
|
|
20
20
|
from decimal import ROUND_HALF_EVEN, ROUND_HALF_UP, Context, Decimal
|
|
21
21
|
from functools import partial, reduce
|
|
22
22
|
from pathlib import Path
|
|
23
|
-
from typing import List, Optional
|
|
23
|
+
from typing import List, Optional
|
|
24
24
|
from urllib.parse import quote, unquote
|
|
25
25
|
|
|
26
26
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
@@ -66,6 +66,7 @@ from snowflake.snowpark.types import (
|
|
|
66
66
|
TimestampType,
|
|
67
67
|
VariantType,
|
|
68
68
|
YearMonthIntervalType,
|
|
69
|
+
_AnsiIntervalType,
|
|
69
70
|
_FractionalType,
|
|
70
71
|
_IntegralType,
|
|
71
72
|
_NumericType,
|
|
@@ -74,6 +75,7 @@ from snowflake.snowpark_connect.column_name_handler import (
|
|
|
74
75
|
ColumnNameMap,
|
|
75
76
|
set_schema_getter,
|
|
76
77
|
)
|
|
78
|
+
from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
|
|
77
79
|
from snowflake.snowpark_connect.config import (
|
|
78
80
|
get_boolean_session_config_param,
|
|
79
81
|
get_timestamp_type,
|
|
@@ -148,7 +150,11 @@ from snowflake.snowpark_connect.utils.xxhash64 import (
|
|
|
148
150
|
MAX_UINT64 = 2**64 - 1
|
|
149
151
|
MAX_INT64 = 2**63 - 1
|
|
150
152
|
MIN_INT64 = -(2**63)
|
|
151
|
-
|
|
153
|
+
MAX_32BIT_SIGNED_INT = 2_147_483_647
|
|
154
|
+
|
|
155
|
+
# Interval arithmetic precision limits
|
|
156
|
+
MAX_DAY_TIME_DAYS = 106751991 # Maximum days for day-time intervals
|
|
157
|
+
MAX_10_DIGIT_LIMIT = 1000000000 # 10-digit limit (1 billion) for interval operands
|
|
152
158
|
|
|
153
159
|
NAN, INFINITY = float("nan"), float("inf")
|
|
154
160
|
|
|
@@ -272,6 +278,40 @@ def _coerce_for_comparison(
|
|
|
272
278
|
return left_col, right_col
|
|
273
279
|
|
|
274
280
|
|
|
281
|
+
def _preprocess_not_equals_expression(exp: expressions_proto.Expression) -> str:
|
|
282
|
+
"""
|
|
283
|
+
Transform NOT(col1 = col2) expressions to col1 != col2 for Snowflake compatibility.
|
|
284
|
+
|
|
285
|
+
Snowflake has issues with NOT (col1 = col2) in subqueries, so we rewrite
|
|
286
|
+
not(==(a, b)) to a != b by modifying the protobuf expression early.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
The (potentially modified) function name as a lowercase string.
|
|
290
|
+
"""
|
|
291
|
+
function_name = exp.unresolved_function.function_name.lower()
|
|
292
|
+
|
|
293
|
+
# Snowflake has issues with NOT (col1 = col2) in subqueries.
|
|
294
|
+
# Transform not(==(a, b)) to a!=b by modifying the protobuf early.
|
|
295
|
+
if (
|
|
296
|
+
function_name in ("not", "!")
|
|
297
|
+
and len(exp.unresolved_function.arguments) == 1
|
|
298
|
+
and exp.unresolved_function.arguments[0].WhichOneof("expr_type")
|
|
299
|
+
== "unresolved_function"
|
|
300
|
+
and exp.unresolved_function.arguments[0].unresolved_function.function_name
|
|
301
|
+
== "=="
|
|
302
|
+
):
|
|
303
|
+
inner_eq_func = exp.unresolved_function.arguments[0].unresolved_function
|
|
304
|
+
inner_args = list(inner_eq_func.arguments)
|
|
305
|
+
|
|
306
|
+
exp.unresolved_function.function_name = "!="
|
|
307
|
+
exp.unresolved_function.ClearField("arguments")
|
|
308
|
+
exp.unresolved_function.arguments.extend(inner_args)
|
|
309
|
+
|
|
310
|
+
function_name = "!="
|
|
311
|
+
|
|
312
|
+
return function_name
|
|
313
|
+
|
|
314
|
+
|
|
275
315
|
def map_unresolved_function(
|
|
276
316
|
exp: expressions_proto.Expression,
|
|
277
317
|
column_mapping: ColumnNameMap,
|
|
@@ -300,6 +340,9 @@ def map_unresolved_function(
|
|
|
300
340
|
# Inject default parameters for functions that need them (especially for Scala clients)
|
|
301
341
|
inject_function_defaults(exp.unresolved_function)
|
|
302
342
|
|
|
343
|
+
# Transform NOT(col = col) to col != col for Snowflake compatibility
|
|
344
|
+
function_name = _preprocess_not_equals_expression(exp)
|
|
345
|
+
|
|
303
346
|
def _resolve_args_expressions(exp: expressions_proto.Expression):
|
|
304
347
|
def _resolve_fn_arg(exp):
|
|
305
348
|
with resolving_fun_args():
|
|
@@ -355,7 +398,7 @@ def map_unresolved_function(
|
|
|
355
398
|
function_name = exp.unresolved_function.function_name.lower()
|
|
356
399
|
telemetry.report_function_usage(function_name)
|
|
357
400
|
result_type: Optional[DataType | List[DateType]] = None
|
|
358
|
-
|
|
401
|
+
qualifier_parts: List[str] = []
|
|
359
402
|
|
|
360
403
|
pyspark_func = getattr(pyspark_functions, function_name, None)
|
|
361
404
|
if pyspark_func and pyspark_func.__doc__.lstrip().startswith("Aggregate function:"):
|
|
@@ -513,9 +556,17 @@ def map_unresolved_function(
|
|
|
513
556
|
)
|
|
514
557
|
result_type = [f.datatype for f in udtf.output_schema]
|
|
515
558
|
case "!=":
|
|
516
|
-
|
|
517
|
-
|
|
559
|
+
_check_interval_string_comparison(
|
|
560
|
+
"!=", snowpark_typed_args, snowpark_arg_names
|
|
561
|
+
)
|
|
562
|
+
# Make the function name same as spark connect. a != b translate's to not(a=b)
|
|
563
|
+
spark_function_name = (
|
|
564
|
+
f"(NOT ({snowpark_arg_names[0]} = {snowpark_arg_names[1]}))"
|
|
565
|
+
)
|
|
566
|
+
left, right = _coerce_for_comparison(
|
|
567
|
+
snowpark_typed_args[0], snowpark_typed_args[1]
|
|
518
568
|
)
|
|
569
|
+
result_exp = TypedColumn(left != right, lambda: [BooleanType()])
|
|
519
570
|
case "%" | "mod":
|
|
520
571
|
if spark_sql_ansi_enabled:
|
|
521
572
|
result_exp = snowpark_args[0] % snowpark_args[1]
|
|
@@ -616,12 +667,87 @@ def map_unresolved_function(
|
|
|
616
667
|
result_exp = snowpark_args[0] * snowpark_args[1].try_cast(
|
|
617
668
|
result_type
|
|
618
669
|
)
|
|
619
|
-
case (
|
|
620
|
-
|
|
621
|
-
|
|
670
|
+
case (StringType(), t) | (t, StringType()) if isinstance(
|
|
671
|
+
t, _AnsiIntervalType
|
|
672
|
+
):
|
|
673
|
+
if isinstance(snowpark_typed_args[0].typ, StringType):
|
|
674
|
+
result_type = type(
|
|
675
|
+
t
|
|
676
|
+
)() # YearMonthIntervalType() or DayTimeIntervalType()
|
|
677
|
+
result_exp = snowpark_args[1] * snowpark_args[0].try_cast(
|
|
678
|
+
LongType()
|
|
679
|
+
)
|
|
680
|
+
spark_function_name = (
|
|
681
|
+
f"({snowpark_arg_names[1]} * {snowpark_arg_names[0]})"
|
|
682
|
+
)
|
|
683
|
+
else:
|
|
684
|
+
result_type = type(
|
|
685
|
+
t
|
|
686
|
+
)() # YearMonthIntervalType() or DayTimeIntervalType()
|
|
687
|
+
result_exp = snowpark_args[0] * snowpark_args[1].try_cast(
|
|
688
|
+
LongType()
|
|
689
|
+
)
|
|
690
|
+
spark_function_name = (
|
|
691
|
+
f"({snowpark_arg_names[0]} * {snowpark_arg_names[1]})"
|
|
692
|
+
)
|
|
693
|
+
case (
|
|
694
|
+
(_NumericType() as t, NullType())
|
|
695
|
+
| (NullType(), _NumericType() as t)
|
|
622
696
|
):
|
|
623
697
|
result_type = t
|
|
624
698
|
result_exp = snowpark_fn.lit(None)
|
|
699
|
+
case (NullType(), t) | (t, NullType()) if isinstance(
|
|
700
|
+
t, _AnsiIntervalType
|
|
701
|
+
):
|
|
702
|
+
result_type = (
|
|
703
|
+
YearMonthIntervalType()
|
|
704
|
+
if isinstance(t, YearMonthIntervalType)
|
|
705
|
+
else DayTimeIntervalType()
|
|
706
|
+
)
|
|
707
|
+
result_exp = snowpark_fn.lit(None)
|
|
708
|
+
if isinstance(snowpark_typed_args[0].typ, NullType):
|
|
709
|
+
spark_function_name = (
|
|
710
|
+
f"({snowpark_arg_names[1]} * {snowpark_arg_names[0]})"
|
|
711
|
+
)
|
|
712
|
+
else:
|
|
713
|
+
spark_function_name = (
|
|
714
|
+
f"({snowpark_arg_names[0]} * {snowpark_arg_names[1]})"
|
|
715
|
+
)
|
|
716
|
+
case (DecimalType(), t) | (t, DecimalType()) if isinstance(
|
|
717
|
+
t, _AnsiIntervalType
|
|
718
|
+
):
|
|
719
|
+
result_type = (
|
|
720
|
+
YearMonthIntervalType()
|
|
721
|
+
if isinstance(t, YearMonthIntervalType)
|
|
722
|
+
else DayTimeIntervalType()
|
|
723
|
+
)
|
|
724
|
+
if isinstance(snowpark_typed_args[0].typ, DecimalType):
|
|
725
|
+
result_exp = snowpark_args[1] * snowpark_args[0]
|
|
726
|
+
spark_function_name = (
|
|
727
|
+
f"({snowpark_arg_names[1]} * {snowpark_arg_names[0]})"
|
|
728
|
+
)
|
|
729
|
+
else:
|
|
730
|
+
result_exp = snowpark_args[0] * snowpark_args[1]
|
|
731
|
+
spark_function_name = (
|
|
732
|
+
f"({snowpark_arg_names[0]} * {snowpark_arg_names[1]})"
|
|
733
|
+
)
|
|
734
|
+
case (t, _NumericType()) if isinstance(t, _AnsiIntervalType):
|
|
735
|
+
result_type = (
|
|
736
|
+
YearMonthIntervalType()
|
|
737
|
+
if isinstance(t, YearMonthIntervalType)
|
|
738
|
+
else DayTimeIntervalType()
|
|
739
|
+
)
|
|
740
|
+
result_exp = snowpark_args[0] * snowpark_args[1]
|
|
741
|
+
case (_NumericType(), t) if isinstance(t, _AnsiIntervalType):
|
|
742
|
+
result_type = (
|
|
743
|
+
YearMonthIntervalType()
|
|
744
|
+
if isinstance(t, YearMonthIntervalType)
|
|
745
|
+
else DayTimeIntervalType()
|
|
746
|
+
)
|
|
747
|
+
result_exp = snowpark_args[1] * snowpark_args[0]
|
|
748
|
+
spark_function_name = (
|
|
749
|
+
f"({snowpark_arg_names[1]} * {snowpark_arg_names[0]})"
|
|
750
|
+
)
|
|
625
751
|
case (_NumericType(), _NumericType()):
|
|
626
752
|
result_type = _find_common_type(
|
|
627
753
|
[arg.typ for arg in snowpark_typed_args]
|
|
@@ -662,7 +788,14 @@ def map_unresolved_function(
|
|
|
662
788
|
result_type = DateType()
|
|
663
789
|
result_exp = snowpark_args[0] + snowpark_args[1]
|
|
664
790
|
elif isinstance(t, (DayTimeIntervalType, YearMonthIntervalType)):
|
|
665
|
-
result_type =
|
|
791
|
+
result_type = (
|
|
792
|
+
TimestampType()
|
|
793
|
+
if isinstance(
|
|
794
|
+
snowpark_typed_args[t_param_index].typ,
|
|
795
|
+
DayTimeIntervalType,
|
|
796
|
+
)
|
|
797
|
+
else DateType()
|
|
798
|
+
)
|
|
666
799
|
result_exp = (
|
|
667
800
|
snowpark_args[date_param_index]
|
|
668
801
|
+ snowpark_args[t_param_index]
|
|
@@ -685,6 +818,35 @@ def map_unresolved_function(
|
|
|
685
818
|
)
|
|
686
819
|
attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
|
|
687
820
|
raise exception
|
|
821
|
+
case (TimestampType(), t) | (t, TimestampType()):
|
|
822
|
+
timestamp_param_index = (
|
|
823
|
+
0
|
|
824
|
+
if isinstance(snowpark_typed_args[0].typ, TimestampType)
|
|
825
|
+
else 1
|
|
826
|
+
)
|
|
827
|
+
t_param_index = 1 - timestamp_param_index
|
|
828
|
+
if isinstance(t, (DayTimeIntervalType, YearMonthIntervalType)):
|
|
829
|
+
result_type = TimestampType()
|
|
830
|
+
result_exp = (
|
|
831
|
+
snowpark_args[timestamp_param_index]
|
|
832
|
+
+ snowpark_args[t_param_index]
|
|
833
|
+
)
|
|
834
|
+
elif (
|
|
835
|
+
hasattr(
|
|
836
|
+
snowpark_typed_args[t_param_index].col._expr1, "pretty_name"
|
|
837
|
+
)
|
|
838
|
+
and "INTERVAL"
|
|
839
|
+
== snowpark_typed_args[t_param_index].col._expr1.pretty_name
|
|
840
|
+
):
|
|
841
|
+
result_type = TimestampType()
|
|
842
|
+
result_exp = (
|
|
843
|
+
snowpark_args[timestamp_param_index]
|
|
844
|
+
+ snowpark_args[t_param_index]
|
|
845
|
+
)
|
|
846
|
+
else:
|
|
847
|
+
raise AnalysisException(
|
|
848
|
+
f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: Parameter 2 requires the ("INTERVAL") type for timestamp operations, however "{snowpark_arg_names[t_param_index]}" has the type "{t}".',
|
|
849
|
+
)
|
|
688
850
|
case (StringType(), StringType()):
|
|
689
851
|
if spark_sql_ansi_enabled:
|
|
690
852
|
exception = AnalysisException(
|
|
@@ -736,6 +898,86 @@ def map_unresolved_function(
|
|
|
736
898
|
)
|
|
737
899
|
attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
|
|
738
900
|
raise exception
|
|
901
|
+
case (t1, t2) | (t2, t1) if isinstance(
|
|
902
|
+
t1, _AnsiIntervalType
|
|
903
|
+
) and isinstance(t2, _AnsiIntervalType) and type(t1) == type(t2):
|
|
904
|
+
# Both operands are the same interval type
|
|
905
|
+
result_type = type(t1)(
|
|
906
|
+
min(t1.start_field, t2.start_field),
|
|
907
|
+
max(t1.end_field, t2.end_field),
|
|
908
|
+
)
|
|
909
|
+
result_exp = snowpark_args[0] + snowpark_args[1]
|
|
910
|
+
case (StringType(), t) | (t, StringType()) if isinstance(
|
|
911
|
+
t, YearMonthIntervalType
|
|
912
|
+
):
|
|
913
|
+
# String + YearMonthInterval: Spark tries to cast string to double first, throws error if it fails
|
|
914
|
+
result_type = StringType()
|
|
915
|
+
if isinstance(snowpark_typed_args[0].typ, StringType):
|
|
916
|
+
result_exp = (
|
|
917
|
+
snowpark_fn.cast(snowpark_args[0], "double")
|
|
918
|
+
+ snowpark_args[1]
|
|
919
|
+
)
|
|
920
|
+
else:
|
|
921
|
+
result_exp = snowpark_args[0] + snowpark_fn.cast(
|
|
922
|
+
snowpark_args[1], "double"
|
|
923
|
+
)
|
|
924
|
+
case (StringType(), t) | (t, StringType()) if isinstance(
|
|
925
|
+
t, DayTimeIntervalType
|
|
926
|
+
):
|
|
927
|
+
# String + DayTimeInterval: try to parse string as timestamp, return NULL if it fails
|
|
928
|
+
# For time-only strings (like '10:00:00'), prepend current date to make it a full timestamp
|
|
929
|
+
result_type = StringType()
|
|
930
|
+
if isinstance(snowpark_typed_args[0].typ, StringType):
|
|
931
|
+
# Check if string looks like time-only (HH:MM:SS or HH:MM pattern)
|
|
932
|
+
# If so, prepend current date; otherwise use as-is
|
|
933
|
+
time_only_pattern = snowpark_fn.function("regexp_like")(
|
|
934
|
+
snowpark_args[0], r"^\d{1,2}:\d{2}(:\d{2})?$"
|
|
935
|
+
)
|
|
936
|
+
timestamp_expr = snowpark_fn.when(
|
|
937
|
+
time_only_pattern,
|
|
938
|
+
snowpark_fn.function("try_to_timestamp_ntz")(
|
|
939
|
+
snowpark_fn.function("concat")(
|
|
940
|
+
snowpark_fn.function("to_char")(
|
|
941
|
+
snowpark_fn.function("current_date")(),
|
|
942
|
+
"YYYY-MM-DD",
|
|
943
|
+
),
|
|
944
|
+
snowpark_fn.lit(" "),
|
|
945
|
+
snowpark_args[0],
|
|
946
|
+
)
|
|
947
|
+
),
|
|
948
|
+
).otherwise(
|
|
949
|
+
snowpark_fn.function("try_to_timestamp_ntz")(
|
|
950
|
+
snowpark_args[0]
|
|
951
|
+
)
|
|
952
|
+
)
|
|
953
|
+
result_exp = timestamp_expr + snowpark_args[1]
|
|
954
|
+
else:
|
|
955
|
+
# interval + string case
|
|
956
|
+
time_only_pattern = snowpark_fn.function("regexp_like")(
|
|
957
|
+
snowpark_args[1], r"^\d{1,2}:\d{2}(:\d{2})?$"
|
|
958
|
+
)
|
|
959
|
+
timestamp_expr = snowpark_fn.when(
|
|
960
|
+
time_only_pattern,
|
|
961
|
+
snowpark_fn.function("try_to_timestamp_ntz")(
|
|
962
|
+
snowpark_fn.function("concat")(
|
|
963
|
+
snowpark_fn.function("to_char")(
|
|
964
|
+
snowpark_fn.function("current_date")(),
|
|
965
|
+
"'YYYY-MM-DD'",
|
|
966
|
+
),
|
|
967
|
+
snowpark_fn.lit(" "),
|
|
968
|
+
snowpark_args[1],
|
|
969
|
+
)
|
|
970
|
+
),
|
|
971
|
+
).otherwise(
|
|
972
|
+
snowpark_fn.function("try_to_timestamp_ntz")(
|
|
973
|
+
snowpark_args[1]
|
|
974
|
+
)
|
|
975
|
+
)
|
|
976
|
+
result_exp = snowpark_args[0] + timestamp_expr
|
|
977
|
+
spark_function_name = (
|
|
978
|
+
f"{snowpark_arg_names[0]} + {snowpark_arg_names[1]}"
|
|
979
|
+
)
|
|
980
|
+
|
|
739
981
|
case _:
|
|
740
982
|
result_type, overflow_possible = _get_add_sub_result_type(
|
|
741
983
|
snowpark_typed_args[0].typ,
|
|
@@ -781,7 +1023,11 @@ def map_unresolved_function(
|
|
|
781
1023
|
DateType(),
|
|
782
1024
|
YearMonthIntervalType(),
|
|
783
1025
|
):
|
|
784
|
-
result_type =
|
|
1026
|
+
result_type = (
|
|
1027
|
+
TimestampType()
|
|
1028
|
+
if isinstance(snowpark_typed_args[1].typ, DayTimeIntervalType)
|
|
1029
|
+
else DateType()
|
|
1030
|
+
)
|
|
785
1031
|
result_exp = snowpark_args[0] - snowpark_args[1]
|
|
786
1032
|
case (DateType(), StringType()):
|
|
787
1033
|
if (
|
|
@@ -799,6 +1045,23 @@ def map_unresolved_function(
|
|
|
799
1045
|
result_exp = snowpark_args[0] - snowpark_args[1].cast(
|
|
800
1046
|
input_type
|
|
801
1047
|
)
|
|
1048
|
+
case (TimestampType(), DayTimeIntervalType()) | (
|
|
1049
|
+
TimestampType(),
|
|
1050
|
+
YearMonthIntervalType(),
|
|
1051
|
+
):
|
|
1052
|
+
result_type = TimestampType()
|
|
1053
|
+
result_exp = snowpark_args[0] - snowpark_args[1]
|
|
1054
|
+
case (TimestampType(), StringType()):
|
|
1055
|
+
if (
|
|
1056
|
+
hasattr(snowpark_typed_args[1].col._expr1, "pretty_name")
|
|
1057
|
+
and "INTERVAL" == snowpark_typed_args[1].col._expr1.pretty_name
|
|
1058
|
+
):
|
|
1059
|
+
result_type = TimestampType()
|
|
1060
|
+
result_exp = snowpark_args[0] - snowpark_args[1]
|
|
1061
|
+
else:
|
|
1062
|
+
raise AnalysisException(
|
|
1063
|
+
f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: Parameter 2 requires the ("INTERVAL") type for timestamp operations, however "{snowpark_arg_names[1]}" has the type "{snowpark_typed_args[1].typ}".',
|
|
1064
|
+
)
|
|
802
1065
|
case (StringType(), DateType()):
|
|
803
1066
|
# TODO SNOW-2034420: resolve return type (it should be INTERVAL DAY)
|
|
804
1067
|
result_type = LongType()
|
|
@@ -870,6 +1133,16 @@ def map_unresolved_function(
|
|
|
870
1133
|
)
|
|
871
1134
|
attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
|
|
872
1135
|
raise exception
|
|
1136
|
+
case (StringType(), t) if isinstance(t, _AnsiIntervalType):
|
|
1137
|
+
# String - Interval: try to parse string as timestamp, return NULL if it fails
|
|
1138
|
+
result_type = StringType()
|
|
1139
|
+
result_exp = (
|
|
1140
|
+
snowpark_fn.function("try_to_timestamp")(snowpark_args[0])
|
|
1141
|
+
- snowpark_args[1]
|
|
1142
|
+
)
|
|
1143
|
+
spark_function_name = (
|
|
1144
|
+
f"{snowpark_arg_names[0]} - {snowpark_arg_names[1]}"
|
|
1145
|
+
)
|
|
873
1146
|
case _:
|
|
874
1147
|
result_type, overflow_possible = _get_add_sub_result_type(
|
|
875
1148
|
snowpark_typed_args[0].typ,
|
|
@@ -968,9 +1241,57 @@ def map_unresolved_function(
|
|
|
968
1241
|
result_exp = _divnull(
|
|
969
1242
|
snowpark_args[0], snowpark_args[1].try_cast(result_type)
|
|
970
1243
|
)
|
|
1244
|
+
case (t, StringType()) if isinstance(t, _AnsiIntervalType):
|
|
1245
|
+
result_type = (
|
|
1246
|
+
YearMonthIntervalType()
|
|
1247
|
+
if isinstance(t, YearMonthIntervalType)
|
|
1248
|
+
else DayTimeIntervalType()
|
|
1249
|
+
)
|
|
1250
|
+
result_exp = snowpark_args[0] / snowpark_args[1].try_cast(
|
|
1251
|
+
LongType()
|
|
1252
|
+
)
|
|
1253
|
+
spark_function_name = (
|
|
1254
|
+
f"({snowpark_arg_names[0]} / {snowpark_arg_names[1]})"
|
|
1255
|
+
)
|
|
971
1256
|
case (_NumericType(), NullType()) | (NullType(), _NumericType()):
|
|
972
1257
|
result_type = DoubleType()
|
|
973
1258
|
result_exp = snowpark_fn.lit(None)
|
|
1259
|
+
case (t, NullType()) if isinstance(t, _AnsiIntervalType):
|
|
1260
|
+
# Only allow interval / null, not null / interval
|
|
1261
|
+
result_type = (
|
|
1262
|
+
YearMonthIntervalType()
|
|
1263
|
+
if isinstance(t, YearMonthIntervalType)
|
|
1264
|
+
else DayTimeIntervalType()
|
|
1265
|
+
)
|
|
1266
|
+
result_exp = snowpark_fn.lit(None)
|
|
1267
|
+
spark_function_name = (
|
|
1268
|
+
f"({snowpark_arg_names[0]} / {snowpark_arg_names[1]})"
|
|
1269
|
+
)
|
|
1270
|
+
case (DecimalType(), t) | (t, DecimalType()) if isinstance(
|
|
1271
|
+
t, _AnsiIntervalType
|
|
1272
|
+
):
|
|
1273
|
+
result_type = (
|
|
1274
|
+
YearMonthIntervalType()
|
|
1275
|
+
if isinstance(t, YearMonthIntervalType)
|
|
1276
|
+
else DayTimeIntervalType()
|
|
1277
|
+
)
|
|
1278
|
+
if isinstance(snowpark_typed_args[0].typ, DecimalType):
|
|
1279
|
+
result_exp = snowpark_args[1] / snowpark_args[0]
|
|
1280
|
+
spark_function_name = (
|
|
1281
|
+
f"({snowpark_arg_names[1]} / {snowpark_arg_names[0]})"
|
|
1282
|
+
)
|
|
1283
|
+
else:
|
|
1284
|
+
result_exp = snowpark_args[0] / snowpark_args[1]
|
|
1285
|
+
spark_function_name = (
|
|
1286
|
+
f"({snowpark_arg_names[0]} / {snowpark_arg_names[1]})"
|
|
1287
|
+
)
|
|
1288
|
+
case (t, _NumericType()) if isinstance(t, _AnsiIntervalType):
|
|
1289
|
+
result_type = (
|
|
1290
|
+
YearMonthIntervalType()
|
|
1291
|
+
if isinstance(t, YearMonthIntervalType)
|
|
1292
|
+
else DayTimeIntervalType()
|
|
1293
|
+
)
|
|
1294
|
+
result_exp = snowpark_args[0] / snowpark_args[1]
|
|
974
1295
|
case (_NumericType(), _NumericType()):
|
|
975
1296
|
result_type = DoubleType()
|
|
976
1297
|
result_exp = _divnull(
|
|
@@ -2027,11 +2348,6 @@ def map_unresolved_function(
|
|
|
2027
2348
|
result_exp = snowpark_fn.coalesce(
|
|
2028
2349
|
*[col.cast(result_type) for col in snowpark_args]
|
|
2029
2350
|
)
|
|
2030
|
-
case "col":
|
|
2031
|
-
# TODO: assign type
|
|
2032
|
-
result_exp = snowpark_fn.col(*snowpark_args)
|
|
2033
|
-
result_exp = _type_with_typer(result_exp)
|
|
2034
|
-
qualifiers = snowpark_args[0].get_qualifiers()
|
|
2035
2351
|
case "collect_list" | "array_agg":
|
|
2036
2352
|
# TODO: SNOW-1967177 - Support structured types in array_agg
|
|
2037
2353
|
result_exp = snowpark_fn.array_agg(
|
|
@@ -2049,11 +2365,6 @@ def map_unresolved_function(
|
|
|
2049
2365
|
result_exp = _resolve_aggregate_exp(
|
|
2050
2366
|
result_exp, ArrayType(snowpark_typed_args[0].typ)
|
|
2051
2367
|
)
|
|
2052
|
-
case "column":
|
|
2053
|
-
# TODO: assign type
|
|
2054
|
-
result_exp = snowpark_fn.column(*snowpark_args)
|
|
2055
|
-
result_exp = _type_with_typer(result_exp)
|
|
2056
|
-
qualifiers = snowpark_args[0].get_qualifiers()
|
|
2057
2368
|
case "concat":
|
|
2058
2369
|
if len(snowpark_args) == 0:
|
|
2059
2370
|
result_exp = TypedColumn(snowpark_fn.lit(""), lambda: [StringType()])
|
|
@@ -2232,7 +2543,7 @@ def map_unresolved_function(
|
|
|
2232
2543
|
snowpark_fn.col("*", _is_qualified_name=True)
|
|
2233
2544
|
)
|
|
2234
2545
|
else:
|
|
2235
|
-
result_exp = snowpark_fn.
|
|
2546
|
+
result_exp = snowpark_fn.call_function("COUNT", *snowpark_args)
|
|
2236
2547
|
result_exp = TypedColumn(result_exp, lambda: [LongType()])
|
|
2237
2548
|
case "count_if":
|
|
2238
2549
|
result_exp = snowpark_fn.call_function("COUNT_IF", snowpark_args[0])
|
|
@@ -2670,16 +2981,6 @@ def map_unresolved_function(
|
|
|
2670
2981
|
)
|
|
2671
2982
|
result_type = LongType()
|
|
2672
2983
|
case "date_part" | "datepart" | "extract":
|
|
2673
|
-
# Check for interval types and throw NotImplementedError
|
|
2674
|
-
if isinstance(
|
|
2675
|
-
snowpark_typed_args[1].typ, (YearMonthIntervalType, DayTimeIntervalType)
|
|
2676
|
-
):
|
|
2677
|
-
exception = NotImplementedError(
|
|
2678
|
-
f"{function_name} with interval types is not supported"
|
|
2679
|
-
)
|
|
2680
|
-
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
2681
|
-
raise exception
|
|
2682
|
-
|
|
2683
2984
|
field_lit: str | None = unwrap_literal(exp.unresolved_function.arguments[0])
|
|
2684
2985
|
|
|
2685
2986
|
if field_lit is None:
|
|
@@ -2724,16 +3025,51 @@ def map_unresolved_function(
|
|
|
2724
3025
|
case "div":
|
|
2725
3026
|
# Only called from SQL, either as `a div b` or `div(a, b)`
|
|
2726
3027
|
# Convert it into `(a - a % b) / b`.
|
|
2727
|
-
|
|
2728
|
-
(
|
|
2729
|
-
|
|
2730
|
-
|
|
2731
|
-
|
|
2732
|
-
|
|
2733
|
-
|
|
2734
|
-
|
|
2735
|
-
|
|
2736
|
-
|
|
3028
|
+
if isinstance(snowpark_typed_args[0].typ, YearMonthIntervalType):
|
|
3029
|
+
if isinstance(snowpark_typed_args[1].typ, YearMonthIntervalType):
|
|
3030
|
+
dividend_total = _calculate_total_months(snowpark_args[0])
|
|
3031
|
+
divisor_total = _calculate_total_months(snowpark_args[1])
|
|
3032
|
+
|
|
3033
|
+
# Handle division by zero interval
|
|
3034
|
+
if not spark_sql_ansi_enabled:
|
|
3035
|
+
result_exp = snowpark_fn.when(
|
|
3036
|
+
divisor_total == 0, snowpark_fn.lit(None)
|
|
3037
|
+
).otherwise(snowpark_fn.trunc(dividend_total / divisor_total))
|
|
3038
|
+
else:
|
|
3039
|
+
result_exp = snowpark_fn.trunc(dividend_total / divisor_total)
|
|
3040
|
+
result_type = LongType()
|
|
3041
|
+
else:
|
|
3042
|
+
raise AnalysisException(
|
|
3043
|
+
f"""[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "({snowpark_arg_names[0]} div {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ({snowpark_typed_args[0].typ} and {snowpark_typed_args[1].typ}).;"""
|
|
3044
|
+
)
|
|
3045
|
+
elif isinstance(snowpark_typed_args[0].typ, DayTimeIntervalType):
|
|
3046
|
+
if isinstance(snowpark_typed_args[1].typ, DayTimeIntervalType):
|
|
3047
|
+
dividend_total = _calculate_total_seconds(snowpark_args[0])
|
|
3048
|
+
divisor_total = _calculate_total_seconds(snowpark_args[1])
|
|
3049
|
+
|
|
3050
|
+
# Handle division by zero interval
|
|
3051
|
+
if not spark_sql_ansi_enabled:
|
|
3052
|
+
result_exp = snowpark_fn.when(
|
|
3053
|
+
divisor_total == 0, snowpark_fn.lit(None)
|
|
3054
|
+
).otherwise(snowpark_fn.trunc(dividend_total / divisor_total))
|
|
3055
|
+
else:
|
|
3056
|
+
result_exp = snowpark_fn.trunc(dividend_total / divisor_total)
|
|
3057
|
+
result_type = LongType()
|
|
3058
|
+
else:
|
|
3059
|
+
raise AnalysisException(
|
|
3060
|
+
f"""[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "({snowpark_arg_names[0]} div {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ({snowpark_typed_args[0].typ} and {snowpark_typed_args[1].typ}).;"""
|
|
3061
|
+
)
|
|
3062
|
+
else:
|
|
3063
|
+
result_exp = snowpark_fn.cast(
|
|
3064
|
+
(snowpark_args[0] - snowpark_args[0] % snowpark_args[1])
|
|
3065
|
+
/ snowpark_args[1],
|
|
3066
|
+
LongType(),
|
|
3067
|
+
)
|
|
3068
|
+
if not spark_sql_ansi_enabled:
|
|
3069
|
+
result_exp = snowpark_fn.when(
|
|
3070
|
+
snowpark_args[1] == 0, snowpark_fn.lit(None)
|
|
3071
|
+
).otherwise(result_exp)
|
|
3072
|
+
result_type = LongType()
|
|
2737
3073
|
case "e":
|
|
2738
3074
|
spark_function_name = "E()"
|
|
2739
3075
|
result_exp = snowpark_fn.lit(math.e)
|
|
@@ -3565,8 +3901,6 @@ def map_unresolved_function(
|
|
|
3565
3901
|
result should be either way good enough.
|
|
3566
3902
|
"""
|
|
3567
3903
|
|
|
3568
|
-
from datetime import date, datetime, time, timedelta
|
|
3569
|
-
|
|
3570
3904
|
def __init__(self) -> None:
|
|
3571
3905
|
|
|
3572
3906
|
# init the RNG for breaking ties in histogram merging. A fixed seed is specified here
|
|
@@ -3710,7 +4044,8 @@ def map_unresolved_function(
|
|
|
3710
4044
|
# just increment 'bin'. This is not done now because we don't want to make any
|
|
3711
4045
|
# assumptions about the range of numeric data being analyzed.
|
|
3712
4046
|
if bin < self.n_used_bins and self.bins[bin][0] == v:
|
|
3713
|
-
self.bins[bin]
|
|
4047
|
+
bin_x, bin_y = self.bins[bin]
|
|
4048
|
+
self.bins[bin] = (bin_x, bin_y + 1)
|
|
3714
4049
|
else:
|
|
3715
4050
|
self.bins.insert(bin + 1, (v, 1.0))
|
|
3716
4051
|
self.n_used_bins += 1
|
|
@@ -4504,6 +4839,17 @@ def map_unresolved_function(
|
|
|
4504
4839
|
date_str_exp = snowpark_fn.concat(y, dash, m, dash, d)
|
|
4505
4840
|
result_exp = snowpark_fn.builtin(snowpark_function)(date_str_exp)
|
|
4506
4841
|
result_type = DateType()
|
|
4842
|
+
case "make_dt_interval":
|
|
4843
|
+
# Pad argument names for display purposes
|
|
4844
|
+
padded_arg_names = snowpark_arg_names.copy()
|
|
4845
|
+
while len(padded_arg_names) < 3: # days, hours, minutes are integers
|
|
4846
|
+
padded_arg_names.append("0")
|
|
4847
|
+
if len(padded_arg_names) < 4: # seconds can be decimal
|
|
4848
|
+
padded_arg_names.append("0.000000")
|
|
4849
|
+
|
|
4850
|
+
spark_function_name = f"make_dt_interval({', '.join(padded_arg_names)})"
|
|
4851
|
+
result_exp = snowpark_fn.interval_day_time_from_parts(*snowpark_args)
|
|
4852
|
+
result_type = DayTimeIntervalType()
|
|
4507
4853
|
case "make_timestamp" | "make_timestamp_ltz" | "make_timestamp_ntz":
|
|
4508
4854
|
y, m, d, h, mins = map(lambda col: col.cast(LongType()), snowpark_args[:5])
|
|
4509
4855
|
y_abs = snowpark_fn.abs(y)
|
|
@@ -4557,6 +4903,15 @@ def map_unresolved_function(
|
|
|
4557
4903
|
result_exp = snowpark_fn.when(
|
|
4558
4904
|
snowpark_fn.is_null(parsed_str_exp), snowpark_fn.lit(None)
|
|
4559
4905
|
).otherwise(make_timestamp_res)
|
|
4906
|
+
case "make_ym_interval":
|
|
4907
|
+
# Pad argument names for display purposes
|
|
4908
|
+
padded_arg_names = snowpark_arg_names.copy()
|
|
4909
|
+
while len(padded_arg_names) < 2: # years, months
|
|
4910
|
+
padded_arg_names.append("0")
|
|
4911
|
+
|
|
4912
|
+
spark_function_name = f"make_ym_interval({', '.join(padded_arg_names)})"
|
|
4913
|
+
result_exp = snowpark_fn.interval_year_month_from_parts(*snowpark_args)
|
|
4914
|
+
result_type = YearMonthIntervalType()
|
|
4560
4915
|
case "map":
|
|
4561
4916
|
allow_duplicate_keys = (
|
|
4562
4917
|
global_config.spark_sql_mapKeyDedupPolicy == "LAST_WIN"
|
|
@@ -5211,7 +5566,11 @@ def map_unresolved_function(
|
|
|
5211
5566
|
spark_function_name = f"(- {snowpark_arg_names[0]})"
|
|
5212
5567
|
else:
|
|
5213
5568
|
spark_function_name = f"negative({snowpark_arg_names[0]})"
|
|
5214
|
-
if
|
|
5569
|
+
if (
|
|
5570
|
+
isinstance(arg_type, _NumericType)
|
|
5571
|
+
or isinstance(arg_type, YearMonthIntervalType)
|
|
5572
|
+
or isinstance(arg_type, DayTimeIntervalType)
|
|
5573
|
+
):
|
|
5215
5574
|
# Instead of using snowpark_fn.negate which can generate invalid SQL for nested minus operations,
|
|
5216
5575
|
# use a direct multiplication by -1 which generates cleaner SQL
|
|
5217
5576
|
result_exp = snowpark_args[0] * snowpark_fn.lit(-1)
|
|
@@ -5236,6 +5595,8 @@ def map_unresolved_function(
|
|
|
5236
5595
|
result_type = (
|
|
5237
5596
|
snowpark_typed_args[0].types
|
|
5238
5597
|
if isinstance(arg_type, _NumericType)
|
|
5598
|
+
or isinstance(arg_type, YearMonthIntervalType)
|
|
5599
|
+
or isinstance(arg_type, DayTimeIntervalType)
|
|
5239
5600
|
else DoubleType()
|
|
5240
5601
|
)
|
|
5241
5602
|
case "next_day":
|
|
@@ -5616,9 +5977,33 @@ def map_unresolved_function(
|
|
|
5616
5977
|
case "percentile_cont" | "percentiledisc":
|
|
5617
5978
|
if function_name == "percentiledisc":
|
|
5618
5979
|
function_name = "percentile_disc"
|
|
5980
|
+
order_by_col = snowpark_args[0]
|
|
5981
|
+
args = exp.unresolved_function.arguments
|
|
5982
|
+
if len(args) != 3:
|
|
5983
|
+
exception = AssertionError(
|
|
5984
|
+
f"{function_name} expected 3 args but got {len(args)}"
|
|
5985
|
+
)
|
|
5986
|
+
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
5987
|
+
raise exception
|
|
5988
|
+
# literal value 0.0 - 1.0
|
|
5989
|
+
percentage_arg = args[1]
|
|
5990
|
+
sort_direction = args[2].sort_order.direction
|
|
5991
|
+
direction_str = "" # defaultValue
|
|
5992
|
+
if (
|
|
5993
|
+
sort_direction
|
|
5994
|
+
== expressions_proto.Expression.SortOrder.SORT_DIRECTION_DESCENDING
|
|
5995
|
+
):
|
|
5996
|
+
direction_str = "DESC"
|
|
5997
|
+
|
|
5998
|
+
# Apply sort direction to the order_by column
|
|
5999
|
+
if direction_str == "DESC":
|
|
6000
|
+
order_by_col_with_direction = order_by_col.desc()
|
|
6001
|
+
else:
|
|
6002
|
+
order_by_col_with_direction = order_by_col.asc()
|
|
6003
|
+
|
|
5619
6004
|
result_exp = snowpark_fn.function(function_name)(
|
|
5620
|
-
_check_percentile_percentage(
|
|
5621
|
-
).within_group(
|
|
6005
|
+
_check_percentile_percentage(percentage_arg)
|
|
6006
|
+
).within_group(order_by_col_with_direction)
|
|
5622
6007
|
result_exp = (
|
|
5623
6008
|
TypedColumn(
|
|
5624
6009
|
snowpark_fn.cast(result_exp, FloatType()), lambda: [DoubleType()]
|
|
@@ -5627,7 +6012,8 @@ def map_unresolved_function(
|
|
|
5627
6012
|
else TypedColumnWithDeferredCast(result_exp, lambda: [DoubleType()])
|
|
5628
6013
|
)
|
|
5629
6014
|
|
|
5630
|
-
|
|
6015
|
+
direction_part = f" {direction_str}" if direction_str else ""
|
|
6016
|
+
spark_function_name = f"{function_name}({unwrap_literal(percentage_arg)}) WITHIN GROUP (ORDER BY {snowpark_arg_names[0]}{direction_part})"
|
|
5631
6017
|
case "pi":
|
|
5632
6018
|
spark_function_name = "PI()"
|
|
5633
6019
|
result_exp = snowpark_fn.lit(math.pi)
|
|
@@ -5767,7 +6153,11 @@ def map_unresolved_function(
|
|
|
5767
6153
|
case "positive":
|
|
5768
6154
|
arg_type = snowpark_typed_args[0].typ
|
|
5769
6155
|
spark_function_name = f"(+ {snowpark_arg_names[0]})"
|
|
5770
|
-
if
|
|
6156
|
+
if (
|
|
6157
|
+
isinstance(arg_type, _NumericType)
|
|
6158
|
+
or isinstance(arg_type, YearMonthIntervalType)
|
|
6159
|
+
or isinstance(arg_type, DayTimeIntervalType)
|
|
6160
|
+
):
|
|
5771
6161
|
result_exp = snowpark_args[0]
|
|
5772
6162
|
elif isinstance(arg_type, StringType):
|
|
5773
6163
|
if spark_sql_ansi_enabled:
|
|
@@ -5790,6 +6180,8 @@ def map_unresolved_function(
|
|
|
5790
6180
|
result_type = (
|
|
5791
6181
|
snowpark_typed_args[0].types
|
|
5792
6182
|
if isinstance(arg_type, _NumericType)
|
|
6183
|
+
or isinstance(arg_type, YearMonthIntervalType)
|
|
6184
|
+
or isinstance(arg_type, DayTimeIntervalType)
|
|
5793
6185
|
else DoubleType()
|
|
5794
6186
|
)
|
|
5795
6187
|
|
|
@@ -6660,11 +7052,43 @@ def map_unresolved_function(
|
|
|
6660
7052
|
fn_name = "sign"
|
|
6661
7053
|
|
|
6662
7054
|
spark_function_name = f"{fn_name}({snowpark_arg_names[0]})"
|
|
6663
|
-
|
|
6664
|
-
|
|
6665
|
-
|
|
6666
|
-
|
|
6667
|
-
|
|
7055
|
+
|
|
7056
|
+
if isinstance(snowpark_typed_args[0].typ, YearMonthIntervalType):
|
|
7057
|
+
# Use SQL expression for zero year-month interval comparison
|
|
7058
|
+
result_exp = (
|
|
7059
|
+
snowpark_fn.when(
|
|
7060
|
+
snowpark_args[0]
|
|
7061
|
+
> snowpark_fn.sql_expr("INTERVAL '0-0' YEAR TO MONTH"),
|
|
7062
|
+
snowpark_fn.lit(1.0),
|
|
7063
|
+
)
|
|
7064
|
+
.when(
|
|
7065
|
+
snowpark_args[0]
|
|
7066
|
+
< snowpark_fn.sql_expr("INTERVAL '0-0' YEAR TO MONTH"),
|
|
7067
|
+
snowpark_fn.lit(-1.0),
|
|
7068
|
+
)
|
|
7069
|
+
.otherwise(snowpark_fn.lit(0.0))
|
|
7070
|
+
)
|
|
7071
|
+
elif isinstance(snowpark_typed_args[0].typ, DayTimeIntervalType):
|
|
7072
|
+
# Use SQL expression for zero day-time interval comparison
|
|
7073
|
+
result_exp = (
|
|
7074
|
+
snowpark_fn.when(
|
|
7075
|
+
snowpark_args[0]
|
|
7076
|
+
> snowpark_fn.sql_expr("INTERVAL '0 0:0:0' DAY TO SECOND"),
|
|
7077
|
+
snowpark_fn.lit(1.0),
|
|
7078
|
+
)
|
|
7079
|
+
.when(
|
|
7080
|
+
snowpark_args[0]
|
|
7081
|
+
< snowpark_fn.sql_expr("INTERVAL '0 0:0:0' DAY TO SECOND"),
|
|
7082
|
+
snowpark_fn.lit(-1.0),
|
|
7083
|
+
)
|
|
7084
|
+
.otherwise(snowpark_fn.lit(0.0))
|
|
7085
|
+
)
|
|
7086
|
+
else:
|
|
7087
|
+
result_exp = snowpark_fn.when(
|
|
7088
|
+
snowpark_args[0] == NAN, snowpark_fn.lit(NAN)
|
|
7089
|
+
).otherwise(
|
|
7090
|
+
snowpark_fn.cast(snowpark_fn.sign(snowpark_args[0]), DoubleType())
|
|
7091
|
+
)
|
|
6668
7092
|
result_type = DoubleType()
|
|
6669
7093
|
case "sin":
|
|
6670
7094
|
spark_function_name = f"SIN({snowpark_arg_names[0]})"
|
|
@@ -6909,7 +7333,16 @@ def map_unresolved_function(
|
|
|
6909
7333
|
)
|
|
6910
7334
|
raise exception
|
|
6911
7335
|
case "split_part":
|
|
6912
|
-
|
|
7336
|
+
# Check for index 0 and throw error to match PySpark behavior
|
|
7337
|
+
raise_error = _raise_error_helper(StringType(), SparkRuntimeException)
|
|
7338
|
+
result_exp = snowpark_fn.when(
|
|
7339
|
+
snowpark_args[2] == 0,
|
|
7340
|
+
raise_error(
|
|
7341
|
+
snowpark_fn.lit(
|
|
7342
|
+
"[INVALID_INDEX_OF_ZERO] The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1)."
|
|
7343
|
+
)
|
|
7344
|
+
),
|
|
7345
|
+
).otherwise(snowpark_fn.call_function("split_part", *snowpark_args))
|
|
6913
7346
|
result_type = StringType()
|
|
6914
7347
|
case "sqrt":
|
|
6915
7348
|
spark_function_name = f"SQRT({snowpark_arg_names[0]})"
|
|
@@ -7939,18 +8372,123 @@ def map_unresolved_function(
|
|
|
7939
8372
|
)
|
|
7940
8373
|
result_type = DateType()
|
|
7941
8374
|
case "try_add":
|
|
7942
|
-
#
|
|
7943
|
-
|
|
7944
|
-
|
|
7945
|
-
|
|
7946
|
-
|
|
8375
|
+
# Handle interval arithmetic with overflow detection
|
|
8376
|
+
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
8377
|
+
case (DateType(), t) | (t, DateType()) if isinstance(
|
|
8378
|
+
t, YearMonthIntervalType
|
|
8379
|
+
):
|
|
8380
|
+
result_type = DateType()
|
|
8381
|
+
result_exp = snowpark_args[0] + snowpark_args[1]
|
|
8382
|
+
case (DateType(), t) | (t, DateType()) if isinstance(
|
|
8383
|
+
t, DayTimeIntervalType
|
|
8384
|
+
):
|
|
8385
|
+
result_type = TimestampType()
|
|
8386
|
+
result_exp = snowpark_args[0] + snowpark_args[1]
|
|
8387
|
+
case (TimestampType(), t) | (t, TimestampType()) if isinstance(
|
|
8388
|
+
t, (DayTimeIntervalType, YearMonthIntervalType)
|
|
8389
|
+
):
|
|
8390
|
+
result_type = (
|
|
8391
|
+
snowpark_typed_args[0].typ
|
|
8392
|
+
if isinstance(snowpark_typed_args[0].typ, TimestampType)
|
|
8393
|
+
else snowpark_typed_args[1].typ
|
|
7947
8394
|
)
|
|
7948
|
-
|
|
7949
|
-
|
|
8395
|
+
result_exp = snowpark_args[0] + snowpark_args[1]
|
|
8396
|
+
case (t1, t2) if (
|
|
8397
|
+
isinstance(t1, YearMonthIntervalType)
|
|
8398
|
+
and isinstance(t2, (_NumericType, StringType))
|
|
8399
|
+
) or (
|
|
8400
|
+
isinstance(t2, YearMonthIntervalType)
|
|
8401
|
+
and isinstance(t1, (_NumericType, StringType))
|
|
8402
|
+
):
|
|
8403
|
+
# YearMonthInterval + numeric/string or numeric/string + YearMonthInterval should throw error
|
|
8404
|
+
exception = AnalysisException(
|
|
8405
|
+
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "try_add({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
|
|
7950
8406
|
)
|
|
8407
|
+
attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
|
|
7951
8408
|
raise exception
|
|
7952
|
-
|
|
7953
|
-
|
|
8409
|
+
case (t1, t2) if isinstance(t1, YearMonthIntervalType) and isinstance(
|
|
8410
|
+
t2, YearMonthIntervalType
|
|
8411
|
+
):
|
|
8412
|
+
result_type = YearMonthIntervalType(
|
|
8413
|
+
min(t1.start_field, t2.start_field),
|
|
8414
|
+
max(t1.end_field, t2.end_field),
|
|
8415
|
+
)
|
|
8416
|
+
|
|
8417
|
+
# For year-month intervals, throw ArithmeticException if operands reach 10+ digits OR result exceeds 9 digits
|
|
8418
|
+
total1 = _calculate_total_months(snowpark_args[0])
|
|
8419
|
+
total2 = _calculate_total_months(snowpark_args[1])
|
|
8420
|
+
ten_digit_limit = snowpark_fn.lit(MAX_10_DIGIT_LIMIT)
|
|
8421
|
+
|
|
8422
|
+
precision_violation = (
|
|
8423
|
+
# Check if either operand already reaches 10 digits (parsing limit)
|
|
8424
|
+
(snowpark_fn.abs(total1) >= ten_digit_limit)
|
|
8425
|
+
| (snowpark_fn.abs(total2) >= ten_digit_limit)
|
|
8426
|
+
| (
|
|
8427
|
+
(total1 > 0)
|
|
8428
|
+
& (total2 > 0)
|
|
8429
|
+
& (total1 >= ten_digit_limit - total2)
|
|
8430
|
+
)
|
|
8431
|
+
| (
|
|
8432
|
+
(total1 < 0)
|
|
8433
|
+
& (total2 < 0)
|
|
8434
|
+
& (total1 <= -ten_digit_limit - total2)
|
|
8435
|
+
)
|
|
8436
|
+
)
|
|
8437
|
+
|
|
8438
|
+
raise_error = _raise_error_helper(result_type, ArithmeticException)
|
|
8439
|
+
result_exp = snowpark_fn.when(
|
|
8440
|
+
precision_violation,
|
|
8441
|
+
raise_error(
|
|
8442
|
+
snowpark_fn.lit(
|
|
8443
|
+
"Year-Month Interval result exceeds Snowflake interval precision limit"
|
|
8444
|
+
)
|
|
8445
|
+
),
|
|
8446
|
+
).otherwise(snowpark_args[0] + snowpark_args[1])
|
|
8447
|
+
case (t1, t2) if isinstance(t1, DayTimeIntervalType) and isinstance(
|
|
8448
|
+
t2, DayTimeIntervalType
|
|
8449
|
+
):
|
|
8450
|
+
result_type = DayTimeIntervalType(
|
|
8451
|
+
min(t1.start_field, t2.start_field),
|
|
8452
|
+
max(t1.end_field, t2.end_field),
|
|
8453
|
+
)
|
|
8454
|
+
# Check for Snowflake's day limit (106751991 days is the cutoff)
|
|
8455
|
+
days1 = snowpark_fn.date_part("day", snowpark_args[0])
|
|
8456
|
+
days2 = snowpark_fn.date_part("day", snowpark_args[1])
|
|
8457
|
+
max_days = snowpark_fn.lit(
|
|
8458
|
+
MAX_DAY_TIME_DAYS
|
|
8459
|
+
) # Snowflake's actual limit
|
|
8460
|
+
min_days = snowpark_fn.lit(-MAX_DAY_TIME_DAYS)
|
|
8461
|
+
|
|
8462
|
+
# Check if either operand exceeds the day limit - throw error like Spark does
|
|
8463
|
+
operand_limit_violation = (snowpark_fn.abs(days1) > max_days) | (
|
|
8464
|
+
snowpark_fn.abs(days2) > max_days
|
|
8465
|
+
)
|
|
8466
|
+
|
|
8467
|
+
# Check if result would exceed day limit (but operands are valid) - return NULL
|
|
8468
|
+
result_overflow = (
|
|
8469
|
+
# Check if result would exceed day limit (positive overflow)
|
|
8470
|
+
((days1 > 0) & (days2 > 0) & (days1 > max_days - days2))
|
|
8471
|
+
| ((days1 < 0) & (days2 < 0) & (days1 < min_days - days2))
|
|
8472
|
+
)
|
|
8473
|
+
|
|
8474
|
+
raise_error = _raise_error_helper(result_type, ArithmeticException)
|
|
8475
|
+
result_exp = (
|
|
8476
|
+
snowpark_fn.when(
|
|
8477
|
+
operand_limit_violation,
|
|
8478
|
+
raise_error(
|
|
8479
|
+
snowpark_fn.lit(
|
|
8480
|
+
"Day-Time Interval operand exceeds Snowflake interval precision limit"
|
|
8481
|
+
)
|
|
8482
|
+
),
|
|
8483
|
+
)
|
|
8484
|
+
.when(result_overflow, snowpark_fn.lit(None))
|
|
8485
|
+
.otherwise(snowpark_args[0] + snowpark_args[1])
|
|
8486
|
+
)
|
|
8487
|
+
case _:
|
|
8488
|
+
result_exp = _try_arithmetic_helper(
|
|
8489
|
+
snowpark_typed_args, snowpark_args, 0
|
|
8490
|
+
)
|
|
8491
|
+
result_exp = _type_with_typer(result_exp)
|
|
7954
8492
|
case "try_aes_decrypt":
|
|
7955
8493
|
result_exp = _aes_helper(
|
|
7956
8494
|
"TRY_DECRYPT",
|
|
@@ -8002,17 +8540,49 @@ def map_unresolved_function(
|
|
|
8002
8540
|
DoubleType(), cleaned, calculating_avg=True
|
|
8003
8541
|
)
|
|
8004
8542
|
case "try_divide":
|
|
8005
|
-
#
|
|
8006
|
-
for arg in snowpark_typed_args:
|
|
8007
|
-
if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
|
|
8008
|
-
exception = NotImplementedError(
|
|
8009
|
-
"try_divide with interval types is not supported"
|
|
8010
|
-
)
|
|
8011
|
-
attach_custom_error_code(
|
|
8012
|
-
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
8013
|
-
)
|
|
8014
|
-
raise exception
|
|
8543
|
+
# Handle interval division with overflow detection
|
|
8015
8544
|
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
8545
|
+
case (t1, t2) if isinstance(t1, _AnsiIntervalType) and isinstance(
|
|
8546
|
+
t2, (_NumericType, StringType)
|
|
8547
|
+
):
|
|
8548
|
+
# Interval / numeric/string
|
|
8549
|
+
result_type = t1
|
|
8550
|
+
interval_arg = snowpark_args[0]
|
|
8551
|
+
divisor = (
|
|
8552
|
+
snowpark_args[1]
|
|
8553
|
+
if isinstance(t2, _NumericType)
|
|
8554
|
+
else snowpark_fn.cast(snowpark_args[1], "double")
|
|
8555
|
+
)
|
|
8556
|
+
|
|
8557
|
+
# Check for division by zero first
|
|
8558
|
+
zero_check = divisor == 0
|
|
8559
|
+
|
|
8560
|
+
if isinstance(result_type, YearMonthIntervalType):
|
|
8561
|
+
# For year-month intervals, check if result exceeds 32-bit signed integer limit
|
|
8562
|
+
result_type = YearMonthIntervalType()
|
|
8563
|
+
total_months = _calculate_total_months(interval_arg)
|
|
8564
|
+
max_months = snowpark_fn.lit(MAX_32BIT_SIGNED_INT)
|
|
8565
|
+
overflow_check = (
|
|
8566
|
+
snowpark_fn.abs(total_months / divisor) > max_months
|
|
8567
|
+
)
|
|
8568
|
+
result_exp = (
|
|
8569
|
+
snowpark_fn.when(zero_check, snowpark_fn.lit(None))
|
|
8570
|
+
.when(overflow_check, snowpark_fn.lit(None))
|
|
8571
|
+
.otherwise(interval_arg / divisor)
|
|
8572
|
+
)
|
|
8573
|
+
else: # DayTimeIntervalType
|
|
8574
|
+
# For day-time intervals, check if result exceeds day limit
|
|
8575
|
+
result_type = DayTimeIntervalType()
|
|
8576
|
+
total_days = _calculate_total_days(interval_arg)
|
|
8577
|
+
max_days = snowpark_fn.lit(MAX_DAY_TIME_DAYS)
|
|
8578
|
+
overflow_check = (
|
|
8579
|
+
snowpark_fn.abs(total_days / divisor) > max_days
|
|
8580
|
+
)
|
|
8581
|
+
result_exp = (
|
|
8582
|
+
snowpark_fn.when(zero_check, snowpark_fn.lit(None))
|
|
8583
|
+
.when(overflow_check, snowpark_fn.lit(None))
|
|
8584
|
+
.otherwise(interval_arg / divisor)
|
|
8585
|
+
)
|
|
8016
8586
|
case (NullType(), t) | (t, NullType()):
|
|
8017
8587
|
result_exp = snowpark_fn.lit(None)
|
|
8018
8588
|
result_type = FloatType()
|
|
@@ -8124,17 +8694,76 @@ def map_unresolved_function(
|
|
|
8124
8694
|
attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
|
|
8125
8695
|
raise exception
|
|
8126
8696
|
case "try_multiply":
|
|
8127
|
-
# Check for interval types and throw NotImplementedError
|
|
8128
|
-
for arg in snowpark_typed_args:
|
|
8129
|
-
if isinstance(arg.typ, (YearMonthIntervalType, DayTimeIntervalType)):
|
|
8130
|
-
exception = NotImplementedError(
|
|
8131
|
-
"try_multiply with interval types is not supported"
|
|
8132
|
-
)
|
|
8133
|
-
attach_custom_error_code(
|
|
8134
|
-
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
8135
|
-
)
|
|
8136
|
-
raise exception
|
|
8137
8697
|
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
8698
|
+
case (t1, t2) if isinstance(t1, _AnsiIntervalType) and isinstance(
|
|
8699
|
+
t2, (_NumericType, StringType)
|
|
8700
|
+
):
|
|
8701
|
+
# Interval * numeric/string
|
|
8702
|
+
result_type = t1
|
|
8703
|
+
interval_arg = snowpark_args[0]
|
|
8704
|
+
multiplier = (
|
|
8705
|
+
snowpark_args[1]
|
|
8706
|
+
if isinstance(t2, _NumericType)
|
|
8707
|
+
else snowpark_fn.cast(snowpark_args[1], "double")
|
|
8708
|
+
)
|
|
8709
|
+
|
|
8710
|
+
if isinstance(result_type, YearMonthIntervalType):
|
|
8711
|
+
# For year-month intervals, check if result exceeds 32-bit signed integer limit
|
|
8712
|
+
result_type = YearMonthIntervalType()
|
|
8713
|
+
total_months = _calculate_total_months(interval_arg)
|
|
8714
|
+
max_months = snowpark_fn.lit(MAX_32BIT_SIGNED_INT)
|
|
8715
|
+
overflow_check = (
|
|
8716
|
+
snowpark_fn.abs(total_months * multiplier) > max_months
|
|
8717
|
+
)
|
|
8718
|
+
result_exp = snowpark_fn.when(
|
|
8719
|
+
overflow_check, snowpark_fn.lit(None)
|
|
8720
|
+
).otherwise(interval_arg * multiplier)
|
|
8721
|
+
else: # DayTimeIntervalType
|
|
8722
|
+
# For day-time intervals, check if result exceeds day limit
|
|
8723
|
+
result_type = DayTimeIntervalType()
|
|
8724
|
+
total_days = _calculate_total_days(interval_arg)
|
|
8725
|
+
max_days = snowpark_fn.lit(MAX_DAY_TIME_DAYS)
|
|
8726
|
+
overflow_check = (
|
|
8727
|
+
snowpark_fn.abs(total_days * multiplier) > max_days
|
|
8728
|
+
)
|
|
8729
|
+
result_exp = snowpark_fn.when(
|
|
8730
|
+
overflow_check, snowpark_fn.lit(None)
|
|
8731
|
+
).otherwise(interval_arg * multiplier)
|
|
8732
|
+
|
|
8733
|
+
case (t1, t2) if isinstance(t2, _AnsiIntervalType) and isinstance(
|
|
8734
|
+
t1, (_NumericType, StringType)
|
|
8735
|
+
):
|
|
8736
|
+
# numeric/string * Interval
|
|
8737
|
+
result_type = t2
|
|
8738
|
+
interval_arg = snowpark_args[1]
|
|
8739
|
+
multiplier = (
|
|
8740
|
+
snowpark_args[0]
|
|
8741
|
+
if isinstance(t1, _NumericType)
|
|
8742
|
+
else snowpark_fn.cast(snowpark_args[0], "double")
|
|
8743
|
+
)
|
|
8744
|
+
|
|
8745
|
+
if isinstance(result_type, YearMonthIntervalType):
|
|
8746
|
+
# For year-month intervals, check if result exceeds 32-bit signed integer limit
|
|
8747
|
+
result_type = YearMonthIntervalType()
|
|
8748
|
+
total_months = _calculate_total_months(interval_arg)
|
|
8749
|
+
max_months = snowpark_fn.lit(MAX_32BIT_SIGNED_INT)
|
|
8750
|
+
overflow_check = (
|
|
8751
|
+
snowpark_fn.abs(total_months * multiplier) > max_months
|
|
8752
|
+
)
|
|
8753
|
+
result_exp = snowpark_fn.when(
|
|
8754
|
+
overflow_check, snowpark_fn.lit(None)
|
|
8755
|
+
).otherwise(interval_arg * multiplier)
|
|
8756
|
+
else: # DayTimeIntervalType
|
|
8757
|
+
# For day-time intervals, check if result exceeds day limit
|
|
8758
|
+
result_type = DayTimeIntervalType()
|
|
8759
|
+
total_days = _calculate_total_days(interval_arg)
|
|
8760
|
+
max_days = snowpark_fn.lit(MAX_DAY_TIME_DAYS)
|
|
8761
|
+
overflow_check = (
|
|
8762
|
+
snowpark_fn.abs(total_days * multiplier) > max_days
|
|
8763
|
+
)
|
|
8764
|
+
result_exp = snowpark_fn.when(
|
|
8765
|
+
overflow_check, snowpark_fn.lit(None)
|
|
8766
|
+
).otherwise(interval_arg * multiplier)
|
|
8138
8767
|
case (NullType(), t) | (t, NullType()):
|
|
8139
8768
|
result_exp = snowpark_fn.lit(None)
|
|
8140
8769
|
match t:
|
|
@@ -8234,18 +8863,112 @@ def map_unresolved_function(
|
|
|
8234
8863
|
snowpark_typed_args[0].typ, snowpark_args[0]
|
|
8235
8864
|
)
|
|
8236
8865
|
case "try_subtract":
|
|
8237
|
-
#
|
|
8238
|
-
|
|
8239
|
-
if isinstance(
|
|
8240
|
-
|
|
8241
|
-
|
|
8242
|
-
|
|
8243
|
-
|
|
8244
|
-
|
|
8866
|
+
# Handle interval arithmetic with overflow detection
|
|
8867
|
+
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
8868
|
+
case (DateType(), t) if isinstance(t, YearMonthIntervalType):
|
|
8869
|
+
result_type = DateType()
|
|
8870
|
+
result_exp = snowpark_args[0] - snowpark_args[1]
|
|
8871
|
+
case (DateType(), t) if isinstance(t, DayTimeIntervalType):
|
|
8872
|
+
result_type = TimestampType()
|
|
8873
|
+
result_exp = snowpark_args[0] - snowpark_args[1]
|
|
8874
|
+
case (TimestampType(), t) if isinstance(
|
|
8875
|
+
t, (DayTimeIntervalType, YearMonthIntervalType)
|
|
8876
|
+
):
|
|
8877
|
+
result_type = snowpark_typed_args[0].typ
|
|
8878
|
+
result_exp = snowpark_args[0] - snowpark_args[1]
|
|
8879
|
+
case (t1, t2) if (
|
|
8880
|
+
isinstance(t1, YearMonthIntervalType)
|
|
8881
|
+
and isinstance(t2, (_NumericType, StringType))
|
|
8882
|
+
) or (
|
|
8883
|
+
isinstance(t2, YearMonthIntervalType)
|
|
8884
|
+
and isinstance(t1, (_NumericType, StringType))
|
|
8885
|
+
):
|
|
8886
|
+
# YearMonthInterval - numeric/string or numeric/string - YearMonthInterval should throw error
|
|
8887
|
+
exception = AnalysisException(
|
|
8888
|
+
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "try_subtract({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
|
|
8245
8889
|
)
|
|
8890
|
+
attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
|
|
8246
8891
|
raise exception
|
|
8247
|
-
|
|
8248
|
-
|
|
8892
|
+
case (t1, t2) if isinstance(t1, YearMonthIntervalType) and isinstance(
|
|
8893
|
+
t2, YearMonthIntervalType
|
|
8894
|
+
):
|
|
8895
|
+
result_type = YearMonthIntervalType(
|
|
8896
|
+
min(t1.start_field, t2.start_field),
|
|
8897
|
+
max(t1.end_field, t2.end_field),
|
|
8898
|
+
)
|
|
8899
|
+
# Check for Snowflake's precision limits: 10+ digits for operands, 9+ digits for results
|
|
8900
|
+
total1 = _calculate_total_months(snowpark_args[0])
|
|
8901
|
+
total2 = _calculate_total_months(snowpark_args[1])
|
|
8902
|
+
ten_digit_limit = snowpark_fn.lit(MAX_10_DIGIT_LIMIT)
|
|
8903
|
+
|
|
8904
|
+
precision_violation = (
|
|
8905
|
+
# Check if either operand already reaches 10 digits (parsing limit)
|
|
8906
|
+
(snowpark_fn.abs(total1) >= ten_digit_limit)
|
|
8907
|
+
| (snowpark_fn.abs(total2) >= ten_digit_limit)
|
|
8908
|
+
| (
|
|
8909
|
+
(total1 > 0)
|
|
8910
|
+
& (total2 < 0)
|
|
8911
|
+
& (total1 >= ten_digit_limit + total2)
|
|
8912
|
+
)
|
|
8913
|
+
| (
|
|
8914
|
+
(total1 < 0)
|
|
8915
|
+
& (total2 > 0)
|
|
8916
|
+
& (total1 <= -ten_digit_limit + total2)
|
|
8917
|
+
)
|
|
8918
|
+
)
|
|
8919
|
+
|
|
8920
|
+
raise_error = _raise_error_helper(result_type, ArithmeticException)
|
|
8921
|
+
result_exp = snowpark_fn.when(
|
|
8922
|
+
precision_violation,
|
|
8923
|
+
raise_error(
|
|
8924
|
+
snowpark_fn.lit(
|
|
8925
|
+
"Year-Month Interval result exceeds Snowflake interval precision limit"
|
|
8926
|
+
)
|
|
8927
|
+
),
|
|
8928
|
+
).otherwise(snowpark_args[0] - snowpark_args[1])
|
|
8929
|
+
case (t1, t2) if isinstance(t1, DayTimeIntervalType) and isinstance(
|
|
8930
|
+
t2, DayTimeIntervalType
|
|
8931
|
+
):
|
|
8932
|
+
result_type = DayTimeIntervalType(
|
|
8933
|
+
min(t1.start_field, t2.start_field),
|
|
8934
|
+
max(t1.end_field, t2.end_field),
|
|
8935
|
+
)
|
|
8936
|
+
# Check for Snowflake's day limit (106751991 days is the cutoff)
|
|
8937
|
+
days1 = snowpark_fn.date_part("day", snowpark_args[0])
|
|
8938
|
+
days2 = snowpark_fn.date_part("day", snowpark_args[1])
|
|
8939
|
+
max_days = snowpark_fn.lit(
|
|
8940
|
+
MAX_DAY_TIME_DAYS
|
|
8941
|
+
) # Snowflake's actual limit
|
|
8942
|
+
min_days = snowpark_fn.lit(-MAX_DAY_TIME_DAYS)
|
|
8943
|
+
|
|
8944
|
+
# Check if either operand exceeds the day limit - throw error like Spark does
|
|
8945
|
+
operand_limit_violation = (snowpark_fn.abs(days1) > max_days) | (
|
|
8946
|
+
snowpark_fn.abs(days2) > max_days
|
|
8947
|
+
)
|
|
8948
|
+
|
|
8949
|
+
# Check if result would exceed day limit (but operands are valid) - return NULL
|
|
8950
|
+
result_overflow = (
|
|
8951
|
+
(days1 > 0) & (days2 < 0) & (days1 > max_days + days2)
|
|
8952
|
+
) | ((days1 < 0) & (days2 > 0) & (days1 < min_days + days2))
|
|
8953
|
+
|
|
8954
|
+
raise_error = _raise_error_helper(result_type, ArithmeticException)
|
|
8955
|
+
result_exp = (
|
|
8956
|
+
snowpark_fn.when(
|
|
8957
|
+
operand_limit_violation,
|
|
8958
|
+
raise_error(
|
|
8959
|
+
snowpark_fn.lit(
|
|
8960
|
+
"Day-Time Interval operand exceeds day limit"
|
|
8961
|
+
)
|
|
8962
|
+
),
|
|
8963
|
+
)
|
|
8964
|
+
.when(result_overflow, snowpark_fn.lit(None))
|
|
8965
|
+
.otherwise(snowpark_args[0] - snowpark_args[1])
|
|
8966
|
+
)
|
|
8967
|
+
case _:
|
|
8968
|
+
result_exp = _try_arithmetic_helper(
|
|
8969
|
+
snowpark_typed_args, snowpark_args, 1
|
|
8970
|
+
)
|
|
8971
|
+
result_exp = _type_with_typer(result_exp)
|
|
8249
8972
|
case "try_to_number":
|
|
8250
8973
|
try_to_number = snowpark_fn.function("try_to_number")
|
|
8251
8974
|
precision, scale = resolve_to_number_precision_and_scale(exp)
|
|
@@ -8828,7 +9551,7 @@ def map_unresolved_function(
|
|
|
8828
9551
|
spark_col_names if len(spark_col_names) > 0 else [spark_function_name]
|
|
8829
9552
|
)
|
|
8830
9553
|
typed_col = _to_typed_column(result_exp, result_type, function_name)
|
|
8831
|
-
typed_col.set_qualifiers(
|
|
9554
|
+
typed_col.set_qualifiers({ColumnQualifier(tuple(qualifier_parts))})
|
|
8832
9555
|
return spark_col_names, typed_col
|
|
8833
9556
|
|
|
8834
9557
|
|
|
@@ -9025,6 +9748,20 @@ def _find_common_type(
|
|
|
9025
9748
|
key_type = _common(type1.key_type, type2.key_type)
|
|
9026
9749
|
value_type = _common(type1.value_type, type2.value_type)
|
|
9027
9750
|
return MapType(key_type, value_type)
|
|
9751
|
+
case (_, _) if isinstance(type1, YearMonthIntervalType) and isinstance(
|
|
9752
|
+
type2, YearMonthIntervalType
|
|
9753
|
+
):
|
|
9754
|
+
return YearMonthIntervalType(
|
|
9755
|
+
min(type1.start_field, type2.start_field),
|
|
9756
|
+
max(type1.end_field, type2.end_field),
|
|
9757
|
+
)
|
|
9758
|
+
case (_, _) if isinstance(type1, DayTimeIntervalType) and isinstance(
|
|
9759
|
+
type2, DayTimeIntervalType
|
|
9760
|
+
):
|
|
9761
|
+
return DayTimeIntervalType(
|
|
9762
|
+
min(type1.start_field, type2.start_field),
|
|
9763
|
+
max(type1.end_field, type2.end_field),
|
|
9764
|
+
)
|
|
9028
9765
|
case _:
|
|
9029
9766
|
exception = AnalysisException(exception_base_message)
|
|
9030
9767
|
attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
|
|
@@ -10384,9 +11121,7 @@ def _get_add_sub_result_type(
|
|
|
10384
11121
|
return result_type, overflow_possible
|
|
10385
11122
|
|
|
10386
11123
|
|
|
10387
|
-
def _get_interval_type_name(
|
|
10388
|
-
interval_type: Union[YearMonthIntervalType, DayTimeIntervalType]
|
|
10389
|
-
) -> str:
|
|
11124
|
+
def _get_interval_type_name(interval_type: _AnsiIntervalType) -> str:
|
|
10390
11125
|
"""Get the formatted interval type name for error messages."""
|
|
10391
11126
|
if isinstance(interval_type, YearMonthIntervalType):
|
|
10392
11127
|
if interval_type.start_field == 0 and interval_type.end_field == 0:
|
|
@@ -10413,21 +11148,15 @@ def _check_interval_string_comparison(
|
|
|
10413
11148
|
) -> None:
|
|
10414
11149
|
"""Check for invalid interval-string comparisons and raise AnalysisException if found."""
|
|
10415
11150
|
if (
|
|
10416
|
-
isinstance(
|
|
10417
|
-
snowpark_typed_args[0].typ, (YearMonthIntervalType, DayTimeIntervalType)
|
|
10418
|
-
)
|
|
11151
|
+
isinstance(snowpark_typed_args[0].typ, _AnsiIntervalType)
|
|
10419
11152
|
and isinstance(snowpark_typed_args[1].typ, StringType)
|
|
10420
11153
|
or isinstance(snowpark_typed_args[0].typ, StringType)
|
|
10421
|
-
and isinstance(
|
|
10422
|
-
snowpark_typed_args[1].typ, (YearMonthIntervalType, DayTimeIntervalType)
|
|
10423
|
-
)
|
|
11154
|
+
and isinstance(snowpark_typed_args[1].typ, _AnsiIntervalType)
|
|
10424
11155
|
):
|
|
10425
11156
|
# Format interval type name for error message
|
|
10426
11157
|
interval_type = (
|
|
10427
11158
|
snowpark_typed_args[0].typ
|
|
10428
|
-
if isinstance(
|
|
10429
|
-
snowpark_typed_args[0].typ, (YearMonthIntervalType, DayTimeIntervalType)
|
|
10430
|
-
)
|
|
11159
|
+
if isinstance(snowpark_typed_args[0].typ, _AnsiIntervalType)
|
|
10431
11160
|
else snowpark_typed_args[1].typ
|
|
10432
11161
|
)
|
|
10433
11162
|
interval_name = _get_interval_type_name(interval_type)
|
|
@@ -10494,12 +11223,18 @@ def _get_spark_function_name(
|
|
|
10494
11223
|
case (DateType(), DayTimeIntervalType()) | (
|
|
10495
11224
|
DateType(),
|
|
10496
11225
|
YearMonthIntervalType(),
|
|
11226
|
+
) | (TimestampType(), DayTimeIntervalType()) | (
|
|
11227
|
+
TimestampType(),
|
|
11228
|
+
YearMonthIntervalType(),
|
|
10497
11229
|
):
|
|
10498
11230
|
date_param_name1 = _get_literal_param_name(exp, 0, snowpark_arg_names[0])
|
|
10499
11231
|
return f"{date_param_name1} {operation_op} {snowpark_arg_names[1]}"
|
|
10500
11232
|
case (DayTimeIntervalType(), DateType()) | (
|
|
10501
11233
|
YearMonthIntervalType(),
|
|
10502
11234
|
DateType(),
|
|
11235
|
+
) | (DayTimeIntervalType(), TimestampType()) | (
|
|
11236
|
+
YearMonthIntervalType(),
|
|
11237
|
+
TimestampType(),
|
|
10503
11238
|
):
|
|
10504
11239
|
date_param_name2 = _get_literal_param_name(exp, 1, snowpark_arg_names[1])
|
|
10505
11240
|
if function_name == "+":
|
|
@@ -10887,3 +11622,30 @@ def _map_from_spark_tz(value: Column) -> Column:
|
|
|
10887
11622
|
.when(value == "VST", snowpark_fn.lit("Asia/Ho_Chi_Minh"))
|
|
10888
11623
|
.otherwise(value) # Return original timezone if no mapping found
|
|
10889
11624
|
)
|
|
11625
|
+
|
|
11626
|
+
|
|
11627
|
+
def _calculate_total_months(interval_arg):
|
|
11628
|
+
"""Calculate total months from a year-month interval."""
|
|
11629
|
+
years = snowpark_fn.date_part("year", interval_arg)
|
|
11630
|
+
months = snowpark_fn.date_part("month", interval_arg)
|
|
11631
|
+
return years * 12 + months
|
|
11632
|
+
|
|
11633
|
+
|
|
11634
|
+
def _calculate_total_days(interval_arg):
|
|
11635
|
+
"""Calculate total days from a day-time interval."""
|
|
11636
|
+
days = snowpark_fn.date_part("day", interval_arg)
|
|
11637
|
+
hours = snowpark_fn.date_part("hour", interval_arg)
|
|
11638
|
+
minutes = snowpark_fn.date_part("minute", interval_arg)
|
|
11639
|
+
seconds = snowpark_fn.date_part("second", interval_arg)
|
|
11640
|
+
# Convert hours, minutes, seconds to fractional days
|
|
11641
|
+
fractional_days = (hours * 3600 + minutes * 60 + seconds) / 86400
|
|
11642
|
+
return days + fractional_days
|
|
11643
|
+
|
|
11644
|
+
|
|
11645
|
+
def _calculate_total_seconds(interval_arg):
|
|
11646
|
+
"""Calculate total seconds from a day-time interval."""
|
|
11647
|
+
days = snowpark_fn.date_part("day", interval_arg)
|
|
11648
|
+
hours = snowpark_fn.date_part("hour", interval_arg)
|
|
11649
|
+
minutes = snowpark_fn.date_part("minute", interval_arg)
|
|
11650
|
+
seconds = snowpark_fn.date_part("second", interval_arg)
|
|
11651
|
+
return days * 86400 + hours * 3600 + minutes * 60 + seconds
|