snowpark-connect 0.28.1__py3-none-any.whl → 0.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +11 -2
- snowflake/snowpark_connect/expression/map_unresolved_function.py +172 -210
- snowflake/snowpark_connect/relation/io_utils.py +21 -1
- snowflake/snowpark_connect/relation/map_extension.py +21 -4
- snowflake/snowpark_connect/relation/map_map_partitions.py +7 -8
- snowflake/snowpark_connect/relation/map_relation.py +1 -3
- snowflake/snowpark_connect/relation/read/map_read.py +22 -3
- snowflake/snowpark_connect/relation/read/map_read_csv.py +105 -26
- snowflake/snowpark_connect/relation/read/map_read_json.py +45 -34
- snowflake/snowpark_connect/relation/read/map_read_text.py +6 -1
- snowflake/snowpark_connect/relation/stage_locator.py +85 -53
- snowflake/snowpark_connect/relation/write/map_write.py +38 -4
- snowflake/snowpark_connect/server.py +18 -13
- snowflake/snowpark_connect/utils/context.py +0 -14
- snowflake/snowpark_connect/utils/io_utils.py +36 -0
- snowflake/snowpark_connect/utils/session.py +3 -0
- snowflake/snowpark_connect/utils/udf_cache.py +37 -7
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.29.0.dist-info}/METADATA +3 -2
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.29.0.dist-info}/RECORD +28 -28
- {snowpark_connect-0.28.1.data → snowpark_connect-0.29.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.28.1.data → snowpark_connect-0.29.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.28.1.data → snowpark_connect-0.29.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.29.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.29.0.dist-info}/top_level.txt +0 -0
|
@@ -264,16 +264,22 @@ SESSION_CONFIG_KEY_WHITELIST = {
|
|
|
264
264
|
"snowpark.connect.udtf.compatibility_mode",
|
|
265
265
|
"snowpark.connect.views.duplicate_column_names_handling_mode",
|
|
266
266
|
"enable_snowflake_extension_behavior",
|
|
267
|
+
"spark.hadoop.fs.s3a.server-side-encryption.key",
|
|
268
|
+
"spark.hadoop.fs.s3a.assumed.role.arn",
|
|
267
269
|
}
|
|
268
|
-
|
|
270
|
+
AZURE_ACCOUNT_KEY = re.compile(
|
|
269
271
|
r"^fs\.azure\.sas\.[^\.]+\.[^\.]+\.blob\.core\.windows\.net$"
|
|
270
272
|
)
|
|
273
|
+
AZURE_SAS_KEY = re.compile(
|
|
274
|
+
r"^fs\.azure\.sas\.fixed\.token\.[^\.]+\.dfs\.core\.windows\.net$"
|
|
275
|
+
)
|
|
271
276
|
|
|
272
277
|
|
|
273
278
|
def valid_session_config_key(key: str):
|
|
274
279
|
return (
|
|
275
280
|
key in SESSION_CONFIG_KEY_WHITELIST # AWS session keys
|
|
276
281
|
or AZURE_SAS_KEY.match(key) # Azure session keys
|
|
282
|
+
or AZURE_ACCOUNT_KEY.match(key) # Azure account keys
|
|
277
283
|
)
|
|
278
284
|
|
|
279
285
|
|
|
@@ -578,7 +584,10 @@ def set_snowflake_parameters(
|
|
|
578
584
|
cte_enabled = str_to_bool(value)
|
|
579
585
|
snowpark_session.cte_optimization_enabled = cte_enabled
|
|
580
586
|
logger.info(f"Updated snowpark session CTE optimization: {cte_enabled}")
|
|
581
|
-
|
|
587
|
+
case "snowpark.connect.structured_types.fix":
|
|
588
|
+
# TODO: SNOW-2367714 Remove this once the fix is automatically enabled in Snowpark
|
|
589
|
+
snowpark.context._enable_fix_2360274 = str_to_bool(value)
|
|
590
|
+
logger.info(f"Updated snowpark session structured types fix: {value}")
|
|
582
591
|
case _:
|
|
583
592
|
pass
|
|
584
593
|
|
|
@@ -15,6 +15,7 @@ import tempfile
|
|
|
15
15
|
import time
|
|
16
16
|
import uuid
|
|
17
17
|
from collections import defaultdict
|
|
18
|
+
from collections.abc import Callable
|
|
18
19
|
from contextlib import suppress
|
|
19
20
|
from decimal import ROUND_HALF_EVEN, ROUND_HALF_UP, Context, Decimal
|
|
20
21
|
from functools import partial, reduce
|
|
@@ -199,7 +200,7 @@ def _validate_numeric_args(
|
|
|
199
200
|
case StringType():
|
|
200
201
|
# Cast strings to doubles following Spark
|
|
201
202
|
# https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala#L204
|
|
202
|
-
modified_args[i] =
|
|
203
|
+
modified_args[i] = snowpark_fn.try_cast(snowpark_args[i], DoubleType())
|
|
203
204
|
case _:
|
|
204
205
|
raise TypeError(
|
|
205
206
|
f"Data type mismatch: {function_name} requires numeric types, but got {typed_args[0].typ} and {typed_args[1].typ}."
|
|
@@ -519,7 +520,7 @@ def map_unresolved_function(
|
|
|
519
520
|
DecimalType() as t,
|
|
520
521
|
):
|
|
521
522
|
p1, s1 = _get_type_precision(t)
|
|
522
|
-
result_type = _get_decimal_multiplication_result_type(
|
|
523
|
+
result_type, _ = _get_decimal_multiplication_result_type(
|
|
523
524
|
p1, s1, p1, s1
|
|
524
525
|
)
|
|
525
526
|
result_exp = snowpark_fn.lit(None)
|
|
@@ -528,11 +529,17 @@ def map_unresolved_function(
|
|
|
528
529
|
):
|
|
529
530
|
p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
|
|
530
531
|
p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
532
|
+
(
|
|
533
|
+
result_type,
|
|
534
|
+
overflow_possible,
|
|
535
|
+
) = _get_decimal_multiplication_result_type(p1, s1, p2, s2)
|
|
536
|
+
result_exp = _arithmetic_operation(
|
|
537
|
+
snowpark_typed_args[0],
|
|
538
|
+
snowpark_typed_args[1],
|
|
539
|
+
lambda x, y: x * y,
|
|
540
|
+
overflow_possible,
|
|
541
|
+
global_config.spark_sql_ansi_enabled,
|
|
542
|
+
result_type,
|
|
536
543
|
)
|
|
537
544
|
case (NullType(), NullType()):
|
|
538
545
|
result_type = DoubleType()
|
|
@@ -617,7 +624,7 @@ def map_unresolved_function(
|
|
|
617
624
|
)
|
|
618
625
|
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
619
626
|
case (NullType(), _) | (_, NullType()):
|
|
620
|
-
result_type = _get_add_sub_result_type(
|
|
627
|
+
result_type, _ = _get_add_sub_result_type(
|
|
621
628
|
snowpark_typed_args[0].typ,
|
|
622
629
|
snowpark_typed_args[1].typ,
|
|
623
630
|
spark_function_name,
|
|
@@ -693,14 +700,21 @@ def map_unresolved_function(
|
|
|
693
700
|
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{spark_function_name}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
|
|
694
701
|
)
|
|
695
702
|
case _:
|
|
696
|
-
result_type = _get_add_sub_result_type(
|
|
703
|
+
result_type, overflow_possible = _get_add_sub_result_type(
|
|
697
704
|
snowpark_typed_args[0].typ,
|
|
698
705
|
snowpark_typed_args[1].typ,
|
|
699
706
|
spark_function_name,
|
|
700
707
|
)
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
708
|
+
|
|
709
|
+
result_exp = _arithmetic_operation(
|
|
710
|
+
snowpark_typed_args[0],
|
|
711
|
+
snowpark_typed_args[1],
|
|
712
|
+
lambda x, y: x + y,
|
|
713
|
+
overflow_possible,
|
|
714
|
+
global_config.spark_sql_ansi_enabled,
|
|
715
|
+
result_type,
|
|
716
|
+
)
|
|
717
|
+
|
|
704
718
|
case "-":
|
|
705
719
|
spark_function_name = _get_spark_function_name(
|
|
706
720
|
snowpark_typed_args[0],
|
|
@@ -715,7 +729,7 @@ def map_unresolved_function(
|
|
|
715
729
|
result_type = LongType()
|
|
716
730
|
result_exp = snowpark_fn.lit(None).cast(result_type)
|
|
717
731
|
case (NullType(), _) | (_, NullType()):
|
|
718
|
-
result_type = _get_add_sub_result_type(
|
|
732
|
+
result_type, _ = _get_add_sub_result_type(
|
|
719
733
|
snowpark_typed_args[0].typ,
|
|
720
734
|
snowpark_typed_args[1].typ,
|
|
721
735
|
spark_function_name,
|
|
@@ -806,14 +820,20 @@ def map_unresolved_function(
|
|
|
806
820
|
f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{spark_function_name}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
|
|
807
821
|
)
|
|
808
822
|
case _:
|
|
809
|
-
result_type = _get_add_sub_result_type(
|
|
823
|
+
result_type, overflow_possible = _get_add_sub_result_type(
|
|
810
824
|
snowpark_typed_args[0].typ,
|
|
811
825
|
snowpark_typed_args[1].typ,
|
|
812
826
|
spark_function_name,
|
|
813
827
|
)
|
|
814
|
-
result_exp =
|
|
815
|
-
|
|
816
|
-
|
|
828
|
+
result_exp = _arithmetic_operation(
|
|
829
|
+
snowpark_typed_args[0],
|
|
830
|
+
snowpark_typed_args[1],
|
|
831
|
+
lambda x, y: x - y,
|
|
832
|
+
overflow_possible,
|
|
833
|
+
global_config.spark_sql_ansi_enabled,
|
|
834
|
+
result_type,
|
|
835
|
+
)
|
|
836
|
+
|
|
817
837
|
case "/":
|
|
818
838
|
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
819
839
|
case (DecimalType() as t1, NullType()):
|
|
@@ -825,15 +845,17 @@ def map_unresolved_function(
|
|
|
825
845
|
):
|
|
826
846
|
p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
|
|
827
847
|
p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
|
|
828
|
-
result_type,
|
|
848
|
+
result_type, overflow_possible = _get_decimal_division_result_type(
|
|
829
849
|
p1, s1, p2, s2
|
|
830
850
|
)
|
|
831
|
-
|
|
851
|
+
|
|
852
|
+
result_exp = _arithmetic_operation(
|
|
853
|
+
snowpark_typed_args[0],
|
|
854
|
+
snowpark_typed_args[1],
|
|
855
|
+
lambda x, y: _divnull(x, y),
|
|
856
|
+
overflow_possible,
|
|
857
|
+
global_config.spark_sql_ansi_enabled,
|
|
832
858
|
result_type,
|
|
833
|
-
t,
|
|
834
|
-
overflow_detected,
|
|
835
|
-
snowpark_args,
|
|
836
|
-
spark_function_name,
|
|
837
859
|
)
|
|
838
860
|
case (NullType(), NullType()):
|
|
839
861
|
result_type = DoubleType()
|
|
@@ -7580,63 +7602,20 @@ def map_unresolved_function(
|
|
|
7580
7602
|
)
|
|
7581
7603
|
| (DecimalType(), DecimalType())
|
|
7582
7604
|
):
|
|
7583
|
-
|
|
7584
|
-
|
|
7585
|
-
|
|
7586
|
-
|
|
7587
|
-
|
|
7588
|
-
snowpark_typed_args[0].typ.scale,
|
|
7589
|
-
snowpark_typed_args[0].typ.precision,
|
|
7590
|
-
)
|
|
7591
|
-
s2, p2 = (
|
|
7592
|
-
snowpark_typed_args[1].typ.scale,
|
|
7593
|
-
snowpark_typed_args[1].typ.precision,
|
|
7594
|
-
)
|
|
7595
|
-
# The scale and precision formula that Spark follows for DecimalType
|
|
7596
|
-
# arithmetic operations can be found in the following Spark source
|
|
7597
|
-
# code file:
|
|
7598
|
-
# https://github.com/apache/spark/blob/a584cc48ef63fefb2e035349c8684250f8b936c4/docs/sql-ref-ansi-compliance.md
|
|
7599
|
-
new_scale = max(6, s1 + p2 + 1)
|
|
7600
|
-
new_precision = p1 - s1 + s2 + new_scale
|
|
7601
|
-
|
|
7602
|
-
elif isinstance(snowpark_typed_args[0].typ, DecimalType):
|
|
7603
|
-
s1, p1 = (
|
|
7604
|
-
snowpark_typed_args[0].typ.scale,
|
|
7605
|
-
snowpark_typed_args[0].typ.precision,
|
|
7606
|
-
)
|
|
7607
|
-
# INT is treated as Decimal(10, 0)
|
|
7608
|
-
new_scale = max(6, s1 + 11)
|
|
7609
|
-
new_precision = p1 - s1 + new_scale
|
|
7610
|
-
|
|
7611
|
-
else: # right is DecimalType
|
|
7612
|
-
s2, p2 = (
|
|
7613
|
-
snowpark_typed_args[1].typ.scale,
|
|
7614
|
-
snowpark_typed_args[1].typ.precision,
|
|
7615
|
-
)
|
|
7616
|
-
# INT is treated as Decimal(10, 0)
|
|
7617
|
-
new_scale = max(6, 11 + p2)
|
|
7618
|
-
new_precision = (
|
|
7619
|
-
10 - 0 + s2 + new_scale
|
|
7620
|
-
) # INT has precision 10, scale 0
|
|
7621
|
-
|
|
7622
|
-
# apply precision cap
|
|
7623
|
-
if new_precision > 38:
|
|
7624
|
-
new_scale -= new_precision - 38
|
|
7625
|
-
new_precision = 38
|
|
7626
|
-
new_scale = max(new_scale, 6)
|
|
7627
|
-
|
|
7628
|
-
left_double = snowpark_fn.cast(snowpark_args[0], DoubleType())
|
|
7629
|
-
right_double = snowpark_fn.cast(snowpark_args[1], DoubleType())
|
|
7630
|
-
|
|
7631
|
-
quotient = snowpark_fn.when(
|
|
7632
|
-
snowpark_args[1] == 0, snowpark_fn.lit(None)
|
|
7633
|
-
).otherwise(left_double / right_double)
|
|
7634
|
-
quotient = snowpark_fn.cast(quotient, StringType())
|
|
7605
|
+
p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
|
|
7606
|
+
p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
|
|
7607
|
+
result_type, overflow_possible = _get_decimal_division_result_type(
|
|
7608
|
+
p1, s1, p2, s2
|
|
7609
|
+
)
|
|
7635
7610
|
|
|
7636
|
-
result_exp =
|
|
7637
|
-
|
|
7611
|
+
result_exp = _arithmetic_operation(
|
|
7612
|
+
snowpark_typed_args[0],
|
|
7613
|
+
snowpark_typed_args[1],
|
|
7614
|
+
lambda x, y: _divnull(x, y),
|
|
7615
|
+
overflow_possible,
|
|
7616
|
+
False,
|
|
7617
|
+
result_type,
|
|
7638
7618
|
)
|
|
7639
|
-
result_type = DecimalType(new_precision, new_scale)
|
|
7640
7619
|
case (_NumericType(), _NumericType()):
|
|
7641
7620
|
result_exp = snowpark_fn.when(
|
|
7642
7621
|
snowpark_args[1] == 0, snowpark_fn.lit(None)
|
|
@@ -7749,42 +7728,21 @@ def map_unresolved_function(
|
|
|
7749
7728
|
)
|
|
7750
7729
|
| (DecimalType(), DecimalType())
|
|
7751
7730
|
):
|
|
7752
|
-
|
|
7753
|
-
|
|
7754
|
-
|
|
7755
|
-
|
|
7756
|
-
|
|
7757
|
-
|
|
7758
|
-
|
|
7759
|
-
|
|
7760
|
-
|
|
7761
|
-
|
|
7762
|
-
|
|
7763
|
-
|
|
7764
|
-
|
|
7765
|
-
|
|
7766
|
-
new_precision = snowpark_typed_args[0].typ.precision + 11
|
|
7767
|
-
new_scale = snowpark_typed_args[0].typ.scale
|
|
7768
|
-
else:
|
|
7769
|
-
new_precision = snowpark_typed_args[1].typ.precision + 11
|
|
7770
|
-
new_scale = snowpark_typed_args[1].typ.scale
|
|
7771
|
-
|
|
7772
|
-
# truncating down appropriately
|
|
7773
|
-
if new_precision > 38:
|
|
7774
|
-
new_precision = 38
|
|
7775
|
-
if new_scale > new_precision:
|
|
7776
|
-
new_scale = new_precision
|
|
7777
|
-
|
|
7778
|
-
left_double = snowpark_fn.cast(snowpark_args[0], DoubleType())
|
|
7779
|
-
right_double = snowpark_fn.cast(snowpark_args[1], DoubleType())
|
|
7780
|
-
|
|
7781
|
-
product = left_double * right_double
|
|
7782
|
-
|
|
7783
|
-
product = snowpark_fn.cast(product, StringType())
|
|
7784
|
-
result_exp = _try_cast_helper(
|
|
7785
|
-
product, DecimalType(new_precision, new_scale)
|
|
7731
|
+
p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
|
|
7732
|
+
p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
|
|
7733
|
+
(
|
|
7734
|
+
result_type,
|
|
7735
|
+
overflow_possible,
|
|
7736
|
+
) = _get_decimal_multiplication_result_type(p1, s1, p2, s2)
|
|
7737
|
+
|
|
7738
|
+
result_exp = _arithmetic_operation(
|
|
7739
|
+
snowpark_typed_args[0],
|
|
7740
|
+
snowpark_typed_args[1],
|
|
7741
|
+
lambda x, y: x * y,
|
|
7742
|
+
overflow_possible,
|
|
7743
|
+
False,
|
|
7744
|
+
result_type,
|
|
7786
7745
|
)
|
|
7787
|
-
result_type = DecimalType(new_precision, new_scale)
|
|
7788
7746
|
case (_NumericType(), _NumericType()):
|
|
7789
7747
|
result_exp = snowpark_args[0] * snowpark_args[1]
|
|
7790
7748
|
result_exp = _type_with_typer(result_exp)
|
|
@@ -8391,20 +8349,10 @@ def map_unresolved_function(
|
|
|
8391
8349
|
return spark_col_names, typed_col
|
|
8392
8350
|
|
|
8393
8351
|
|
|
8394
|
-
def _cast_helper(column: Column, to: DataType) -> Column:
|
|
8395
|
-
if global_config.spark_sql_ansi_enabled:
|
|
8396
|
-
column_mediator = (
|
|
8397
|
-
snowpark_fn.cast(column, StringType())
|
|
8398
|
-
if isinstance(to, DecimalType)
|
|
8399
|
-
else column
|
|
8400
|
-
)
|
|
8401
|
-
return snowpark_fn.cast(column_mediator, to)
|
|
8402
|
-
else:
|
|
8403
|
-
return _try_cast_helper(column, to)
|
|
8404
|
-
|
|
8405
|
-
|
|
8406
8352
|
def _try_cast_helper(column: Column, to: DataType) -> Column:
|
|
8407
8353
|
"""
|
|
8354
|
+
DEPRECATED because of performance issues
|
|
8355
|
+
|
|
8408
8356
|
Attempts to cast a given column to a specified data type using the same behaviour as Spark.
|
|
8409
8357
|
|
|
8410
8358
|
Args:
|
|
@@ -9600,71 +9548,109 @@ def _decimal_add_sub_result_type_helper(p1, s1, p2, s2):
|
|
|
9600
9548
|
return result_precision, min_scale, return_type_precision, return_type_scale
|
|
9601
9549
|
|
|
9602
9550
|
|
|
9603
|
-
def
|
|
9604
|
-
result_type: DecimalType | DataType,
|
|
9605
|
-
other_type: DataType,
|
|
9606
|
-
snowpark_args: list[Column],
|
|
9607
|
-
) -> Column:
|
|
9608
|
-
if global_config.spark_sql_ansi_enabled:
|
|
9609
|
-
result_exp = snowpark_args[0] * snowpark_args[1]
|
|
9610
|
-
else:
|
|
9611
|
-
if isinstance(other_type, _IntegralType):
|
|
9612
|
-
result_exp = snowpark_args[0].cast(result_type) * snowpark_args[1].cast(
|
|
9613
|
-
result_type
|
|
9614
|
-
)
|
|
9615
|
-
else:
|
|
9616
|
-
result_exp = snowpark_args[0].cast(DoubleType()) * snowpark_args[1].cast(
|
|
9617
|
-
DoubleType()
|
|
9618
|
-
)
|
|
9619
|
-
result_exp = _try_cast_helper(result_exp, result_type)
|
|
9620
|
-
return result_exp
|
|
9621
|
-
|
|
9622
|
-
|
|
9623
|
-
def _get_decimal_multiplication_result_type(p1, s1, p2, s2) -> DecimalType:
|
|
9551
|
+
def _get_decimal_multiplication_result_type(p1, s1, p2, s2) -> tuple[DecimalType, bool]:
|
|
9624
9552
|
result_precision = p1 + p2 + 1
|
|
9625
9553
|
result_scale = s1 + s2
|
|
9554
|
+
overflow_possible = False
|
|
9626
9555
|
if result_precision > 38:
|
|
9556
|
+
overflow_possible = True
|
|
9627
9557
|
if result_scale > 6:
|
|
9628
9558
|
overflow = result_precision - 38
|
|
9629
9559
|
result_scale = max(6, result_scale - overflow)
|
|
9630
9560
|
result_precision = 38
|
|
9631
|
-
return DecimalType(result_precision, result_scale)
|
|
9561
|
+
return DecimalType(result_precision, result_scale), overflow_possible
|
|
9632
9562
|
|
|
9633
9563
|
|
|
9634
|
-
def
|
|
9635
|
-
|
|
9636
|
-
|
|
9637
|
-
|
|
9638
|
-
|
|
9639
|
-
|
|
9564
|
+
def _arithmetic_operation(
|
|
9565
|
+
arg1: TypedColumn,
|
|
9566
|
+
arg2: TypedColumn,
|
|
9567
|
+
op: Callable[[Column, Column], Column],
|
|
9568
|
+
overflow_possible: bool,
|
|
9569
|
+
should_raise_on_overflow: bool,
|
|
9570
|
+
target_type: DecimalType,
|
|
9640
9571
|
) -> Column:
|
|
9641
|
-
|
|
9642
|
-
|
|
9643
|
-
|
|
9644
|
-
|
|
9645
|
-
|
|
9646
|
-
|
|
9647
|
-
|
|
9572
|
+
def _cast_arg(tc: TypedColumn) -> Column:
|
|
9573
|
+
_, s = _get_type_precision(tc.typ)
|
|
9574
|
+
typ = (
|
|
9575
|
+
DoubleType()
|
|
9576
|
+
if s > 0
|
|
9577
|
+
or (
|
|
9578
|
+
isinstance(tc.typ, _FractionalType)
|
|
9579
|
+
and not isinstance(tc.typ, DecimalType)
|
|
9580
|
+
)
|
|
9581
|
+
else LongType()
|
|
9582
|
+
)
|
|
9583
|
+
return tc.col.cast(typ)
|
|
9584
|
+
|
|
9585
|
+
op_for_overflow_check = op(arg1.col.cast(DoubleType()), arg2.col.cast(DoubleType()))
|
|
9586
|
+
safe_op = op(_cast_arg(arg1), _cast_arg(arg2))
|
|
9587
|
+
|
|
9588
|
+
if overflow_possible:
|
|
9589
|
+
return _cast_arithmetic_operation_result(
|
|
9590
|
+
op_for_overflow_check, safe_op, target_type, should_raise_on_overflow
|
|
9648
9591
|
)
|
|
9649
9592
|
else:
|
|
9650
|
-
|
|
9651
|
-
|
|
9652
|
-
|
|
9653
|
-
|
|
9654
|
-
|
|
9593
|
+
return op(arg1.col, arg2.col).cast(target_type)
|
|
9594
|
+
|
|
9595
|
+
|
|
9596
|
+
def _cast_arithmetic_operation_result(
|
|
9597
|
+
overflow_check_expr: Column,
|
|
9598
|
+
result_expr: Column,
|
|
9599
|
+
target_type: DecimalType,
|
|
9600
|
+
should_raise_on_overflow: bool,
|
|
9601
|
+
) -> Column:
|
|
9602
|
+
"""
|
|
9603
|
+
Casts an arithmetic operation result to the target decimal type with overflow detection.
|
|
9604
|
+
This function uses a dual-expression approach for robust overflow handling:
|
|
9605
|
+
Args:
|
|
9606
|
+
overflow_check_expr: Arithmetic expression using DoubleType operands for overflow detection.
|
|
9607
|
+
This expression is used ONLY for boundary checking against the target
|
|
9608
|
+
decimal's min/max values. DoubleType preserves the magnitude of large
|
|
9609
|
+
intermediate results that might overflow in decimal arithmetic.
|
|
9610
|
+
result_expr: Arithmetic expression using safer operand types (LongType for integers,
|
|
9611
|
+
DoubleType for fractionals) for the actual result computation.
|
|
9612
|
+
target_type: Target DecimalType to cast the result to.
|
|
9613
|
+
should_raise_on_overflow: If True raises ArithmeticException on overflow, if False, returns NULL on overflow.
|
|
9614
|
+
"""
|
|
9615
|
+
|
|
9616
|
+
def create_overflow_handler(min_val, max_val, type_name: str):
|
|
9617
|
+
if should_raise_on_overflow:
|
|
9618
|
+
raise_error = _raise_error_helper(target_type, ArithmeticException)
|
|
9619
|
+
return snowpark_fn.when(
|
|
9620
|
+
(overflow_check_expr < snowpark_fn.lit(min_val))
|
|
9621
|
+
| (overflow_check_expr > snowpark_fn.lit(max_val)),
|
|
9622
|
+
raise_error(
|
|
9623
|
+
snowpark_fn.lit(
|
|
9624
|
+
f'[NUMERIC_VALUE_OUT_OF_RANGE] Value cannot be represented as {type_name}. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error, and return NULL instead.'
|
|
9625
|
+
)
|
|
9626
|
+
),
|
|
9627
|
+
).otherwise(result_expr.cast(target_type))
|
|
9628
|
+
else:
|
|
9629
|
+
return snowpark_fn.when(
|
|
9630
|
+
(overflow_check_expr < snowpark_fn.lit(min_val))
|
|
9631
|
+
| (overflow_check_expr > snowpark_fn.lit(max_val)),
|
|
9632
|
+
snowpark_fn.lit(None),
|
|
9633
|
+
).otherwise(result_expr.cast(target_type))
|
|
9634
|
+
|
|
9635
|
+
precision = target_type.precision
|
|
9636
|
+
scale = target_type.scale
|
|
9637
|
+
|
|
9638
|
+
max_val = (10**precision - 1) / (10**scale)
|
|
9639
|
+
min_val = -max_val
|
|
9640
|
+
|
|
9641
|
+
return create_overflow_handler(min_val, max_val, f"DECIMAL({precision},{scale})")
|
|
9655
9642
|
|
|
9656
9643
|
|
|
9657
9644
|
def _get_decimal_division_result_type(p1, s1, p2, s2) -> tuple[DecimalType, bool]:
|
|
9658
|
-
|
|
9645
|
+
overflow_possible = False
|
|
9659
9646
|
result_scale = max(6, s1 + p2 + 1)
|
|
9660
9647
|
result_precision = p1 - s1 + s2 + result_scale
|
|
9661
9648
|
if result_precision > 38:
|
|
9662
|
-
|
|
9663
|
-
overflow_detected = True
|
|
9649
|
+
overflow_possible = True
|
|
9664
9650
|
overflow = result_precision - 38
|
|
9665
9651
|
result_scale = max(6, result_scale - overflow)
|
|
9666
9652
|
result_precision = 38
|
|
9667
|
-
return DecimalType(result_precision, result_scale),
|
|
9653
|
+
return DecimalType(result_precision, result_scale), overflow_possible
|
|
9668
9654
|
|
|
9669
9655
|
|
|
9670
9656
|
def _try_arithmetic_helper(
|
|
@@ -9778,46 +9764,20 @@ def _try_arithmetic_helper(
|
|
|
9778
9764
|
DecimalType(),
|
|
9779
9765
|
DecimalType(),
|
|
9780
9766
|
):
|
|
9781
|
-
|
|
9782
|
-
|
|
9783
|
-
|
|
9784
|
-
|
|
9785
|
-
|
|
9786
|
-
|
|
9787
|
-
|
|
9788
|
-
|
|
9789
|
-
|
|
9790
|
-
|
|
9791
|
-
|
|
9792
|
-
|
|
9793
|
-
|
|
9794
|
-
else:
|
|
9795
|
-
# Both decimal types
|
|
9796
|
-
if operation_type == 1 and s1 == s2: # subtraction with matching scales
|
|
9797
|
-
new_scale = s1
|
|
9798
|
-
max_integral_digits = max(p1 - s1, p2 - s2)
|
|
9799
|
-
new_precision = max_integral_digits + new_scale
|
|
9800
|
-
else:
|
|
9801
|
-
new_scale = max(s1, s2)
|
|
9802
|
-
max_integral_digits = max(p1 - s1, p2 - s2)
|
|
9803
|
-
new_precision = max_integral_digits + new_scale + 1
|
|
9804
|
-
|
|
9805
|
-
# Overflow check
|
|
9806
|
-
if new_precision > 38:
|
|
9807
|
-
if global_config.spark_sql_ansi_enabled:
|
|
9808
|
-
raise ArithmeticException(
|
|
9809
|
-
f'[NUMERIC_VALUE_OUT_OF_RANGE] Precision {new_precision} exceeds maximum allowed precision of 38. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error, and return NULL instead.'
|
|
9810
|
-
)
|
|
9811
|
-
return snowpark_fn.lit(None)
|
|
9812
|
-
|
|
9813
|
-
left_operand, right_operand = snowpark_args[0], snowpark_args[1]
|
|
9814
|
-
|
|
9815
|
-
result = (
|
|
9816
|
-
left_operand + right_operand
|
|
9817
|
-
if operation_type == 0
|
|
9818
|
-
else left_operand - right_operand
|
|
9767
|
+
result_type, overflow_possible = _get_add_sub_result_type(
|
|
9768
|
+
typed_args[0].typ,
|
|
9769
|
+
typed_args[1].typ,
|
|
9770
|
+
"try_add" if operation_type == 0 else "try_subtract",
|
|
9771
|
+
)
|
|
9772
|
+
|
|
9773
|
+
return _arithmetic_operation(
|
|
9774
|
+
typed_args[0],
|
|
9775
|
+
typed_args[1],
|
|
9776
|
+
lambda x, y: x + y if operation_type == 0 else x - y,
|
|
9777
|
+
overflow_possible,
|
|
9778
|
+
False,
|
|
9779
|
+
result_type,
|
|
9819
9780
|
)
|
|
9820
|
-
return snowpark_fn.cast(result, DecimalType(new_precision, new_scale))
|
|
9821
9781
|
|
|
9822
9782
|
# If either of the inputs is floating point, we can just let it go through to Snowflake, where overflow
|
|
9823
9783
|
# matches Spark and goes to inf.
|
|
@@ -9863,7 +9823,8 @@ def _get_add_sub_result_type(
|
|
|
9863
9823
|
type1: DataType,
|
|
9864
9824
|
type2: DataType,
|
|
9865
9825
|
spark_function_name: str,
|
|
9866
|
-
) -> DataType:
|
|
9826
|
+
) -> tuple[DataType, bool]:
|
|
9827
|
+
overflow_possible = False
|
|
9867
9828
|
result_type = _find_common_type([type1, type2])
|
|
9868
9829
|
match result_type:
|
|
9869
9830
|
case DecimalType():
|
|
@@ -9872,6 +9833,7 @@ def _get_add_sub_result_type(
|
|
|
9872
9833
|
result_scale = max(s1, s2)
|
|
9873
9834
|
result_precision = max(p1 - s1, p2 - s2) + result_scale + 1
|
|
9874
9835
|
if result_precision > 38:
|
|
9836
|
+
overflow_possible = True
|
|
9875
9837
|
if result_scale > 6:
|
|
9876
9838
|
overflow = result_precision - 38
|
|
9877
9839
|
result_scale = max(6, result_scale - overflow)
|
|
@@ -9900,7 +9862,7 @@ def _get_add_sub_result_type(
|
|
|
9900
9862
|
raise AnalysisException(
|
|
9901
9863
|
f'[DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: the binary operator requires the input type ("NUMERIC" or "INTERVAL DAY TO SECOND" or "INTERVAL YEAR TO MONTH" or "INTERVAL"), not "BOOLEAN".',
|
|
9902
9864
|
)
|
|
9903
|
-
return result_type
|
|
9865
|
+
return result_type, overflow_possible
|
|
9904
9866
|
|
|
9905
9867
|
|
|
9906
9868
|
def _get_spark_function_name(
|
|
@@ -7,8 +7,27 @@ from urllib.parse import urlparse
|
|
|
7
7
|
CLOUD_PREFIX_TO_CLOUD = {
|
|
8
8
|
"abfss": "azure",
|
|
9
9
|
"wasbs": "azure",
|
|
10
|
+
"gcs": "gcp",
|
|
11
|
+
"gs": "gcp",
|
|
10
12
|
}
|
|
11
13
|
|
|
14
|
+
SUPPORTED_COMPRESSION_PER_FORMAT = {
|
|
15
|
+
"csv": {"AUTO", "GZIP", "BZ2", "BROTLI", "ZSTD", "DEFLATE", "RAW_DEFLATE", "NONE"},
|
|
16
|
+
"json": {"AUTO", "GZIP", "BZ2", "BROTLI", "ZSTD", "DEFLATE", "RAW_DEFLATE", "NONE"},
|
|
17
|
+
"parquet": {"AUTO", "LZO", "SNAPPY", "NONE"},
|
|
18
|
+
"text": {"NONE"},
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def supported_compressions_for_format(format: str) -> set[str]:
|
|
23
|
+
return SUPPORTED_COMPRESSION_PER_FORMAT.get(format, set())
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def is_supported_compression(format: str, compression: str | None) -> bool:
|
|
27
|
+
if compression is None:
|
|
28
|
+
return True
|
|
29
|
+
return compression in supported_compressions_for_format(format)
|
|
30
|
+
|
|
12
31
|
|
|
13
32
|
def get_cloud_from_url(
|
|
14
33
|
url: str,
|
|
@@ -66,7 +85,8 @@ def is_cloud_path(path: str) -> bool:
|
|
|
66
85
|
or path.startswith("azure://")
|
|
67
86
|
or path.startswith("abfss://")
|
|
68
87
|
or path.startswith("wasbs://") # Azure
|
|
69
|
-
or path.startswith("gcs://")
|
|
88
|
+
or path.startswith("gcs://")
|
|
89
|
+
or path.startswith("gs://") # GCP
|
|
70
90
|
)
|
|
71
91
|
|
|
72
92
|
|