snowpark-connect 0.28.1__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (28) hide show
  1. snowflake/snowpark_connect/config.py +11 -2
  2. snowflake/snowpark_connect/expression/map_unresolved_function.py +172 -210
  3. snowflake/snowpark_connect/relation/io_utils.py +21 -1
  4. snowflake/snowpark_connect/relation/map_extension.py +21 -4
  5. snowflake/snowpark_connect/relation/map_map_partitions.py +7 -8
  6. snowflake/snowpark_connect/relation/map_relation.py +1 -3
  7. snowflake/snowpark_connect/relation/read/map_read.py +22 -3
  8. snowflake/snowpark_connect/relation/read/map_read_csv.py +105 -26
  9. snowflake/snowpark_connect/relation/read/map_read_json.py +45 -34
  10. snowflake/snowpark_connect/relation/read/map_read_text.py +6 -1
  11. snowflake/snowpark_connect/relation/stage_locator.py +85 -53
  12. snowflake/snowpark_connect/relation/write/map_write.py +38 -4
  13. snowflake/snowpark_connect/server.py +18 -13
  14. snowflake/snowpark_connect/utils/context.py +0 -14
  15. snowflake/snowpark_connect/utils/io_utils.py +36 -0
  16. snowflake/snowpark_connect/utils/session.py +3 -0
  17. snowflake/snowpark_connect/utils/udf_cache.py +37 -7
  18. snowflake/snowpark_connect/version.py +1 -1
  19. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.29.0.dist-info}/METADATA +3 -2
  20. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.29.0.dist-info}/RECORD +28 -28
  21. {snowpark_connect-0.28.1.data → snowpark_connect-0.29.0.data}/scripts/snowpark-connect +0 -0
  22. {snowpark_connect-0.28.1.data → snowpark_connect-0.29.0.data}/scripts/snowpark-session +0 -0
  23. {snowpark_connect-0.28.1.data → snowpark_connect-0.29.0.data}/scripts/snowpark-submit +0 -0
  24. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.29.0.dist-info}/WHEEL +0 -0
  25. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE-binary +0 -0
  26. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE.txt +0 -0
  27. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/NOTICE-binary +0 -0
  28. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.29.0.dist-info}/top_level.txt +0 -0
@@ -264,16 +264,22 @@ SESSION_CONFIG_KEY_WHITELIST = {
264
264
  "snowpark.connect.udtf.compatibility_mode",
265
265
  "snowpark.connect.views.duplicate_column_names_handling_mode",
266
266
  "enable_snowflake_extension_behavior",
267
+ "spark.hadoop.fs.s3a.server-side-encryption.key",
268
+ "spark.hadoop.fs.s3a.assumed.role.arn",
267
269
  }
268
- AZURE_SAS_KEY = re.compile(
270
+ AZURE_ACCOUNT_KEY = re.compile(
269
271
  r"^fs\.azure\.sas\.[^\.]+\.[^\.]+\.blob\.core\.windows\.net$"
270
272
  )
273
+ AZURE_SAS_KEY = re.compile(
274
+ r"^fs\.azure\.sas\.fixed\.token\.[^\.]+\.dfs\.core\.windows\.net$"
275
+ )
271
276
 
272
277
 
273
278
  def valid_session_config_key(key: str):
274
279
  return (
275
280
  key in SESSION_CONFIG_KEY_WHITELIST # AWS session keys
276
281
  or AZURE_SAS_KEY.match(key) # Azure session keys
282
+ or AZURE_ACCOUNT_KEY.match(key) # Azure account keys
277
283
  )
278
284
 
279
285
 
@@ -578,7 +584,10 @@ def set_snowflake_parameters(
578
584
  cte_enabled = str_to_bool(value)
579
585
  snowpark_session.cte_optimization_enabled = cte_enabled
580
586
  logger.info(f"Updated snowpark session CTE optimization: {cte_enabled}")
581
-
587
+ case "snowpark.connect.structured_types.fix":
588
+ # TODO: SNOW-2367714 Remove this once the fix is automatically enabled in Snowpark
589
+ snowpark.context._enable_fix_2360274 = str_to_bool(value)
590
+ logger.info(f"Updated snowpark session structured types fix: {value}")
582
591
  case _:
583
592
  pass
584
593
 
@@ -15,6 +15,7 @@ import tempfile
15
15
  import time
16
16
  import uuid
17
17
  from collections import defaultdict
18
+ from collections.abc import Callable
18
19
  from contextlib import suppress
19
20
  from decimal import ROUND_HALF_EVEN, ROUND_HALF_UP, Context, Decimal
20
21
  from functools import partial, reduce
@@ -199,7 +200,7 @@ def _validate_numeric_args(
199
200
  case StringType():
200
201
  # Cast strings to doubles following Spark
201
202
  # https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala#L204
202
- modified_args[i] = _try_cast_helper(snowpark_args[i], DoubleType())
203
+ modified_args[i] = snowpark_fn.try_cast(snowpark_args[i], DoubleType())
203
204
  case _:
204
205
  raise TypeError(
205
206
  f"Data type mismatch: {function_name} requires numeric types, but got {typed_args[0].typ} and {typed_args[1].typ}."
@@ -519,7 +520,7 @@ def map_unresolved_function(
519
520
  DecimalType() as t,
520
521
  ):
521
522
  p1, s1 = _get_type_precision(t)
522
- result_type = _get_decimal_multiplication_result_type(
523
+ result_type, _ = _get_decimal_multiplication_result_type(
523
524
  p1, s1, p1, s1
524
525
  )
525
526
  result_exp = snowpark_fn.lit(None)
@@ -528,11 +529,17 @@ def map_unresolved_function(
528
529
  ):
529
530
  p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
530
531
  p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
531
- result_type = _get_decimal_multiplication_result_type(
532
- p1, s1, p2, s2
533
- )
534
- result_exp = _get_decimal_multiplication_result_exp(
535
- result_type, t, snowpark_args
532
+ (
533
+ result_type,
534
+ overflow_possible,
535
+ ) = _get_decimal_multiplication_result_type(p1, s1, p2, s2)
536
+ result_exp = _arithmetic_operation(
537
+ snowpark_typed_args[0],
538
+ snowpark_typed_args[1],
539
+ lambda x, y: x * y,
540
+ overflow_possible,
541
+ global_config.spark_sql_ansi_enabled,
542
+ result_type,
536
543
  )
537
544
  case (NullType(), NullType()):
538
545
  result_type = DoubleType()
@@ -617,7 +624,7 @@ def map_unresolved_function(
617
624
  )
618
625
  match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
619
626
  case (NullType(), _) | (_, NullType()):
620
- result_type = _get_add_sub_result_type(
627
+ result_type, _ = _get_add_sub_result_type(
621
628
  snowpark_typed_args[0].typ,
622
629
  snowpark_typed_args[1].typ,
623
630
  spark_function_name,
@@ -693,14 +700,21 @@ def map_unresolved_function(
693
700
  f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{spark_function_name}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
694
701
  )
695
702
  case _:
696
- result_type = _get_add_sub_result_type(
703
+ result_type, overflow_possible = _get_add_sub_result_type(
697
704
  snowpark_typed_args[0].typ,
698
705
  snowpark_typed_args[1].typ,
699
706
  spark_function_name,
700
707
  )
701
- result_exp = snowpark_args[0] + snowpark_args[1]
702
- if isinstance(result_type, DecimalType):
703
- result_exp = _cast_helper(result_exp, result_type)
708
+
709
+ result_exp = _arithmetic_operation(
710
+ snowpark_typed_args[0],
711
+ snowpark_typed_args[1],
712
+ lambda x, y: x + y,
713
+ overflow_possible,
714
+ global_config.spark_sql_ansi_enabled,
715
+ result_type,
716
+ )
717
+
704
718
  case "-":
705
719
  spark_function_name = _get_spark_function_name(
706
720
  snowpark_typed_args[0],
@@ -715,7 +729,7 @@ def map_unresolved_function(
715
729
  result_type = LongType()
716
730
  result_exp = snowpark_fn.lit(None).cast(result_type)
717
731
  case (NullType(), _) | (_, NullType()):
718
- result_type = _get_add_sub_result_type(
732
+ result_type, _ = _get_add_sub_result_type(
719
733
  snowpark_typed_args[0].typ,
720
734
  snowpark_typed_args[1].typ,
721
735
  spark_function_name,
@@ -806,14 +820,20 @@ def map_unresolved_function(
806
820
  f'[DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES] Cannot resolve "{spark_function_name}" due to data type mismatch: the left and right operands of the binary operator have incompatible types ("{snowpark_typed_args[0].typ}" and "{snowpark_typed_args[1].typ}").'
807
821
  )
808
822
  case _:
809
- result_type = _get_add_sub_result_type(
823
+ result_type, overflow_possible = _get_add_sub_result_type(
810
824
  snowpark_typed_args[0].typ,
811
825
  snowpark_typed_args[1].typ,
812
826
  spark_function_name,
813
827
  )
814
- result_exp = snowpark_args[0] - snowpark_args[1]
815
- if isinstance(result_type, DecimalType):
816
- result_exp = _cast_helper(result_exp, result_type)
828
+ result_exp = _arithmetic_operation(
829
+ snowpark_typed_args[0],
830
+ snowpark_typed_args[1],
831
+ lambda x, y: x - y,
832
+ overflow_possible,
833
+ global_config.spark_sql_ansi_enabled,
834
+ result_type,
835
+ )
836
+
817
837
  case "/":
818
838
  match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
819
839
  case (DecimalType() as t1, NullType()):
@@ -825,15 +845,17 @@ def map_unresolved_function(
825
845
  ):
826
846
  p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
827
847
  p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
828
- result_type, overflow_detected = _get_decimal_division_result_type(
848
+ result_type, overflow_possible = _get_decimal_division_result_type(
829
849
  p1, s1, p2, s2
830
850
  )
831
- result_exp = _get_decimal_division_result_exp(
851
+
852
+ result_exp = _arithmetic_operation(
853
+ snowpark_typed_args[0],
854
+ snowpark_typed_args[1],
855
+ lambda x, y: _divnull(x, y),
856
+ overflow_possible,
857
+ global_config.spark_sql_ansi_enabled,
832
858
  result_type,
833
- t,
834
- overflow_detected,
835
- snowpark_args,
836
- spark_function_name,
837
859
  )
838
860
  case (NullType(), NullType()):
839
861
  result_type = DoubleType()
@@ -7580,63 +7602,20 @@ def map_unresolved_function(
7580
7602
  )
7581
7603
  | (DecimalType(), DecimalType())
7582
7604
  ):
7583
- # compute new precision and scale using correct decimal division rules
7584
- if isinstance(
7585
- snowpark_typed_args[0].typ, DecimalType
7586
- ) and isinstance(snowpark_typed_args[1].typ, DecimalType):
7587
- s1, p1 = (
7588
- snowpark_typed_args[0].typ.scale,
7589
- snowpark_typed_args[0].typ.precision,
7590
- )
7591
- s2, p2 = (
7592
- snowpark_typed_args[1].typ.scale,
7593
- snowpark_typed_args[1].typ.precision,
7594
- )
7595
- # The scale and precision formula that Spark follows for DecimalType
7596
- # arithmetic operations can be found in the following Spark source
7597
- # code file:
7598
- # https://github.com/apache/spark/blob/a584cc48ef63fefb2e035349c8684250f8b936c4/docs/sql-ref-ansi-compliance.md
7599
- new_scale = max(6, s1 + p2 + 1)
7600
- new_precision = p1 - s1 + s2 + new_scale
7601
-
7602
- elif isinstance(snowpark_typed_args[0].typ, DecimalType):
7603
- s1, p1 = (
7604
- snowpark_typed_args[0].typ.scale,
7605
- snowpark_typed_args[0].typ.precision,
7606
- )
7607
- # INT is treated as Decimal(10, 0)
7608
- new_scale = max(6, s1 + 11)
7609
- new_precision = p1 - s1 + new_scale
7610
-
7611
- else: # right is DecimalType
7612
- s2, p2 = (
7613
- snowpark_typed_args[1].typ.scale,
7614
- snowpark_typed_args[1].typ.precision,
7615
- )
7616
- # INT is treated as Decimal(10, 0)
7617
- new_scale = max(6, 11 + p2)
7618
- new_precision = (
7619
- 10 - 0 + s2 + new_scale
7620
- ) # INT has precision 10, scale 0
7621
-
7622
- # apply precision cap
7623
- if new_precision > 38:
7624
- new_scale -= new_precision - 38
7625
- new_precision = 38
7626
- new_scale = max(new_scale, 6)
7627
-
7628
- left_double = snowpark_fn.cast(snowpark_args[0], DoubleType())
7629
- right_double = snowpark_fn.cast(snowpark_args[1], DoubleType())
7630
-
7631
- quotient = snowpark_fn.when(
7632
- snowpark_args[1] == 0, snowpark_fn.lit(None)
7633
- ).otherwise(left_double / right_double)
7634
- quotient = snowpark_fn.cast(quotient, StringType())
7605
+ p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
7606
+ p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
7607
+ result_type, overflow_possible = _get_decimal_division_result_type(
7608
+ p1, s1, p2, s2
7609
+ )
7635
7610
 
7636
- result_exp = _try_cast_helper(
7637
- quotient, DecimalType(new_precision, new_scale)
7611
+ result_exp = _arithmetic_operation(
7612
+ snowpark_typed_args[0],
7613
+ snowpark_typed_args[1],
7614
+ lambda x, y: _divnull(x, y),
7615
+ overflow_possible,
7616
+ False,
7617
+ result_type,
7638
7618
  )
7639
- result_type = DecimalType(new_precision, new_scale)
7640
7619
  case (_NumericType(), _NumericType()):
7641
7620
  result_exp = snowpark_fn.when(
7642
7621
  snowpark_args[1] == 0, snowpark_fn.lit(None)
@@ -7749,42 +7728,21 @@ def map_unresolved_function(
7749
7728
  )
7750
7729
  | (DecimalType(), DecimalType())
7751
7730
  ):
7752
- # figure out what precision to use as the overflow amount
7753
- if isinstance(
7754
- snowpark_typed_args[0].typ, DecimalType
7755
- ) and isinstance(snowpark_typed_args[1].typ, DecimalType):
7756
- new_precision = (
7757
- snowpark_typed_args[0].typ.precision
7758
- + snowpark_typed_args[1].typ.precision
7759
- + 1
7760
- )
7761
- new_scale = (
7762
- snowpark_typed_args[0].typ.scale
7763
- + snowpark_typed_args[1].typ.scale
7764
- )
7765
- elif isinstance(snowpark_typed_args[0].typ, DecimalType):
7766
- new_precision = snowpark_typed_args[0].typ.precision + 11
7767
- new_scale = snowpark_typed_args[0].typ.scale
7768
- else:
7769
- new_precision = snowpark_typed_args[1].typ.precision + 11
7770
- new_scale = snowpark_typed_args[1].typ.scale
7771
-
7772
- # truncating down appropriately
7773
- if new_precision > 38:
7774
- new_precision = 38
7775
- if new_scale > new_precision:
7776
- new_scale = new_precision
7777
-
7778
- left_double = snowpark_fn.cast(snowpark_args[0], DoubleType())
7779
- right_double = snowpark_fn.cast(snowpark_args[1], DoubleType())
7780
-
7781
- product = left_double * right_double
7782
-
7783
- product = snowpark_fn.cast(product, StringType())
7784
- result_exp = _try_cast_helper(
7785
- product, DecimalType(new_precision, new_scale)
7731
+ p1, s1 = _get_type_precision(snowpark_typed_args[0].typ)
7732
+ p2, s2 = _get_type_precision(snowpark_typed_args[1].typ)
7733
+ (
7734
+ result_type,
7735
+ overflow_possible,
7736
+ ) = _get_decimal_multiplication_result_type(p1, s1, p2, s2)
7737
+
7738
+ result_exp = _arithmetic_operation(
7739
+ snowpark_typed_args[0],
7740
+ snowpark_typed_args[1],
7741
+ lambda x, y: x * y,
7742
+ overflow_possible,
7743
+ False,
7744
+ result_type,
7786
7745
  )
7787
- result_type = DecimalType(new_precision, new_scale)
7788
7746
  case (_NumericType(), _NumericType()):
7789
7747
  result_exp = snowpark_args[0] * snowpark_args[1]
7790
7748
  result_exp = _type_with_typer(result_exp)
@@ -8391,20 +8349,10 @@ def map_unresolved_function(
8391
8349
  return spark_col_names, typed_col
8392
8350
 
8393
8351
 
8394
- def _cast_helper(column: Column, to: DataType) -> Column:
8395
- if global_config.spark_sql_ansi_enabled:
8396
- column_mediator = (
8397
- snowpark_fn.cast(column, StringType())
8398
- if isinstance(to, DecimalType)
8399
- else column
8400
- )
8401
- return snowpark_fn.cast(column_mediator, to)
8402
- else:
8403
- return _try_cast_helper(column, to)
8404
-
8405
-
8406
8352
  def _try_cast_helper(column: Column, to: DataType) -> Column:
8407
8353
  """
8354
+ DEPRECATED because of performance issues
8355
+
8408
8356
  Attempts to cast a given column to a specified data type using the same behaviour as Spark.
8409
8357
 
8410
8358
  Args:
@@ -9600,71 +9548,109 @@ def _decimal_add_sub_result_type_helper(p1, s1, p2, s2):
9600
9548
  return result_precision, min_scale, return_type_precision, return_type_scale
9601
9549
 
9602
9550
 
9603
- def _get_decimal_multiplication_result_exp(
9604
- result_type: DecimalType | DataType,
9605
- other_type: DataType,
9606
- snowpark_args: list[Column],
9607
- ) -> Column:
9608
- if global_config.spark_sql_ansi_enabled:
9609
- result_exp = snowpark_args[0] * snowpark_args[1]
9610
- else:
9611
- if isinstance(other_type, _IntegralType):
9612
- result_exp = snowpark_args[0].cast(result_type) * snowpark_args[1].cast(
9613
- result_type
9614
- )
9615
- else:
9616
- result_exp = snowpark_args[0].cast(DoubleType()) * snowpark_args[1].cast(
9617
- DoubleType()
9618
- )
9619
- result_exp = _try_cast_helper(result_exp, result_type)
9620
- return result_exp
9621
-
9622
-
9623
- def _get_decimal_multiplication_result_type(p1, s1, p2, s2) -> DecimalType:
9551
+ def _get_decimal_multiplication_result_type(p1, s1, p2, s2) -> tuple[DecimalType, bool]:
9624
9552
  result_precision = p1 + p2 + 1
9625
9553
  result_scale = s1 + s2
9554
+ overflow_possible = False
9626
9555
  if result_precision > 38:
9556
+ overflow_possible = True
9627
9557
  if result_scale > 6:
9628
9558
  overflow = result_precision - 38
9629
9559
  result_scale = max(6, result_scale - overflow)
9630
9560
  result_precision = 38
9631
- return DecimalType(result_precision, result_scale)
9561
+ return DecimalType(result_precision, result_scale), overflow_possible
9632
9562
 
9633
9563
 
9634
- def _get_decimal_division_result_exp(
9635
- result_type: DecimalType | DataType,
9636
- other_type: DataType,
9637
- overflow_detected: bool,
9638
- snowpark_args: list[Column],
9639
- spark_function_name: str,
9564
+ def _arithmetic_operation(
9565
+ arg1: TypedColumn,
9566
+ arg2: TypedColumn,
9567
+ op: Callable[[Column, Column], Column],
9568
+ overflow_possible: bool,
9569
+ should_raise_on_overflow: bool,
9570
+ target_type: DecimalType,
9640
9571
  ) -> Column:
9641
- if (
9642
- isinstance(other_type, DecimalType)
9643
- and overflow_detected
9644
- and global_config.spark_sql_ansi_enabled
9645
- ):
9646
- raise ArithmeticException(
9647
- f'[NUMERIC_VALUE_OUT_OF_RANGE] {spark_function_name} cannot be represented as Decimal({result_type.precision}, {result_type.scale}). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error, and return NULL instead.'
9572
+ def _cast_arg(tc: TypedColumn) -> Column:
9573
+ _, s = _get_type_precision(tc.typ)
9574
+ typ = (
9575
+ DoubleType()
9576
+ if s > 0
9577
+ or (
9578
+ isinstance(tc.typ, _FractionalType)
9579
+ and not isinstance(tc.typ, DecimalType)
9580
+ )
9581
+ else LongType()
9582
+ )
9583
+ return tc.col.cast(typ)
9584
+
9585
+ op_for_overflow_check = op(arg1.col.cast(DoubleType()), arg2.col.cast(DoubleType()))
9586
+ safe_op = op(_cast_arg(arg1), _cast_arg(arg2))
9587
+
9588
+ if overflow_possible:
9589
+ return _cast_arithmetic_operation_result(
9590
+ op_for_overflow_check, safe_op, target_type, should_raise_on_overflow
9648
9591
  )
9649
9592
  else:
9650
- dividend = snowpark_args[0].cast(DoubleType())
9651
- divisor = snowpark_args[1]
9652
- result_exp = _divnull(dividend, divisor)
9653
- result_exp = _cast_helper(result_exp, result_type)
9654
- return result_exp
9593
+ return op(arg1.col, arg2.col).cast(target_type)
9594
+
9595
+
9596
+ def _cast_arithmetic_operation_result(
9597
+ overflow_check_expr: Column,
9598
+ result_expr: Column,
9599
+ target_type: DecimalType,
9600
+ should_raise_on_overflow: bool,
9601
+ ) -> Column:
9602
+ """
9603
+ Casts an arithmetic operation result to the target decimal type with overflow detection.
9604
+ This function uses a dual-expression approach for robust overflow handling:
9605
+ Args:
9606
+ overflow_check_expr: Arithmetic expression using DoubleType operands for overflow detection.
9607
+ This expression is used ONLY for boundary checking against the target
9608
+ decimal's min/max values. DoubleType preserves the magnitude of large
9609
+ intermediate results that might overflow in decimal arithmetic.
9610
+ result_expr: Arithmetic expression using safer operand types (LongType for integers,
9611
+ DoubleType for fractionals) for the actual result computation.
9612
+ target_type: Target DecimalType to cast the result to.
9613
+ should_raise_on_overflow: If True raises ArithmeticException on overflow, if False, returns NULL on overflow.
9614
+ """
9615
+
9616
+ def create_overflow_handler(min_val, max_val, type_name: str):
9617
+ if should_raise_on_overflow:
9618
+ raise_error = _raise_error_helper(target_type, ArithmeticException)
9619
+ return snowpark_fn.when(
9620
+ (overflow_check_expr < snowpark_fn.lit(min_val))
9621
+ | (overflow_check_expr > snowpark_fn.lit(max_val)),
9622
+ raise_error(
9623
+ snowpark_fn.lit(
9624
+ f'[NUMERIC_VALUE_OUT_OF_RANGE] Value cannot be represented as {type_name}. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error, and return NULL instead.'
9625
+ )
9626
+ ),
9627
+ ).otherwise(result_expr.cast(target_type))
9628
+ else:
9629
+ return snowpark_fn.when(
9630
+ (overflow_check_expr < snowpark_fn.lit(min_val))
9631
+ | (overflow_check_expr > snowpark_fn.lit(max_val)),
9632
+ snowpark_fn.lit(None),
9633
+ ).otherwise(result_expr.cast(target_type))
9634
+
9635
+ precision = target_type.precision
9636
+ scale = target_type.scale
9637
+
9638
+ max_val = (10**precision - 1) / (10**scale)
9639
+ min_val = -max_val
9640
+
9641
+ return create_overflow_handler(min_val, max_val, f"DECIMAL({precision},{scale})")
9655
9642
 
9656
9643
 
9657
9644
  def _get_decimal_division_result_type(p1, s1, p2, s2) -> tuple[DecimalType, bool]:
9658
- overflow_detected = False
9645
+ overflow_possible = False
9659
9646
  result_scale = max(6, s1 + p2 + 1)
9660
9647
  result_precision = p1 - s1 + s2 + result_scale
9661
9648
  if result_precision > 38:
9662
- if result_precision > 40:
9663
- overflow_detected = True
9649
+ overflow_possible = True
9664
9650
  overflow = result_precision - 38
9665
9651
  result_scale = max(6, result_scale - overflow)
9666
9652
  result_precision = 38
9667
- return DecimalType(result_precision, result_scale), overflow_detected
9653
+ return DecimalType(result_precision, result_scale), overflow_possible
9668
9654
 
9669
9655
 
9670
9656
  def _try_arithmetic_helper(
@@ -9778,46 +9764,20 @@ def _try_arithmetic_helper(
9778
9764
  DecimalType(),
9779
9765
  DecimalType(),
9780
9766
  ):
9781
- p1, s1 = _get_type_precision(typed_args[0].typ)
9782
- p2, s2 = _get_type_precision(typed_args[1].typ)
9783
-
9784
- if isinstance(typed_args[0].typ, _IntegralType) and isinstance(
9785
- typed_args[1].typ, DecimalType
9786
- ):
9787
- new_scale = s2
9788
- new_precision = max(p2, p1 + s2)
9789
- elif isinstance(typed_args[0].typ, DecimalType) and isinstance(
9790
- typed_args[1].typ, _IntegralType
9791
- ):
9792
- new_scale = s1
9793
- new_precision = max(p1, p2 + s1)
9794
- else:
9795
- # Both decimal types
9796
- if operation_type == 1 and s1 == s2: # subtraction with matching scales
9797
- new_scale = s1
9798
- max_integral_digits = max(p1 - s1, p2 - s2)
9799
- new_precision = max_integral_digits + new_scale
9800
- else:
9801
- new_scale = max(s1, s2)
9802
- max_integral_digits = max(p1 - s1, p2 - s2)
9803
- new_precision = max_integral_digits + new_scale + 1
9804
-
9805
- # Overflow check
9806
- if new_precision > 38:
9807
- if global_config.spark_sql_ansi_enabled:
9808
- raise ArithmeticException(
9809
- f'[NUMERIC_VALUE_OUT_OF_RANGE] Precision {new_precision} exceeds maximum allowed precision of 38. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error, and return NULL instead.'
9810
- )
9811
- return snowpark_fn.lit(None)
9812
-
9813
- left_operand, right_operand = snowpark_args[0], snowpark_args[1]
9814
-
9815
- result = (
9816
- left_operand + right_operand
9817
- if operation_type == 0
9818
- else left_operand - right_operand
9767
+ result_type, overflow_possible = _get_add_sub_result_type(
9768
+ typed_args[0].typ,
9769
+ typed_args[1].typ,
9770
+ "try_add" if operation_type == 0 else "try_subtract",
9771
+ )
9772
+
9773
+ return _arithmetic_operation(
9774
+ typed_args[0],
9775
+ typed_args[1],
9776
+ lambda x, y: x + y if operation_type == 0 else x - y,
9777
+ overflow_possible,
9778
+ False,
9779
+ result_type,
9819
9780
  )
9820
- return snowpark_fn.cast(result, DecimalType(new_precision, new_scale))
9821
9781
 
9822
9782
  # If either of the inputs is floating point, we can just let it go through to Snowflake, where overflow
9823
9783
  # matches Spark and goes to inf.
@@ -9863,7 +9823,8 @@ def _get_add_sub_result_type(
9863
9823
  type1: DataType,
9864
9824
  type2: DataType,
9865
9825
  spark_function_name: str,
9866
- ) -> DataType:
9826
+ ) -> tuple[DataType, bool]:
9827
+ overflow_possible = False
9867
9828
  result_type = _find_common_type([type1, type2])
9868
9829
  match result_type:
9869
9830
  case DecimalType():
@@ -9872,6 +9833,7 @@ def _get_add_sub_result_type(
9872
9833
  result_scale = max(s1, s2)
9873
9834
  result_precision = max(p1 - s1, p2 - s2) + result_scale + 1
9874
9835
  if result_precision > 38:
9836
+ overflow_possible = True
9875
9837
  if result_scale > 6:
9876
9838
  overflow = result_precision - 38
9877
9839
  result_scale = max(6, result_scale - overflow)
@@ -9900,7 +9862,7 @@ def _get_add_sub_result_type(
9900
9862
  raise AnalysisException(
9901
9863
  f'[DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: the binary operator requires the input type ("NUMERIC" or "INTERVAL DAY TO SECOND" or "INTERVAL YEAR TO MONTH" or "INTERVAL"), not "BOOLEAN".',
9902
9864
  )
9903
- return result_type
9865
+ return result_type, overflow_possible
9904
9866
 
9905
9867
 
9906
9868
  def _get_spark_function_name(
@@ -7,8 +7,27 @@ from urllib.parse import urlparse
7
7
  CLOUD_PREFIX_TO_CLOUD = {
8
8
  "abfss": "azure",
9
9
  "wasbs": "azure",
10
+ "gcs": "gcp",
11
+ "gs": "gcp",
10
12
  }
11
13
 
14
+ SUPPORTED_COMPRESSION_PER_FORMAT = {
15
+ "csv": {"AUTO", "GZIP", "BZ2", "BROTLI", "ZSTD", "DEFLATE", "RAW_DEFLATE", "NONE"},
16
+ "json": {"AUTO", "GZIP", "BZ2", "BROTLI", "ZSTD", "DEFLATE", "RAW_DEFLATE", "NONE"},
17
+ "parquet": {"AUTO", "LZO", "SNAPPY", "NONE"},
18
+ "text": {"NONE"},
19
+ }
20
+
21
+
22
+ def supported_compressions_for_format(format: str) -> set[str]:
23
+ return SUPPORTED_COMPRESSION_PER_FORMAT.get(format, set())
24
+
25
+
26
+ def is_supported_compression(format: str, compression: str | None) -> bool:
27
+ if compression is None:
28
+ return True
29
+ return compression in supported_compressions_for_format(format)
30
+
12
31
 
13
32
  def get_cloud_from_url(
14
33
  url: str,
@@ -66,7 +85,8 @@ def is_cloud_path(path: str) -> bool:
66
85
  or path.startswith("azure://")
67
86
  or path.startswith("abfss://")
68
87
  or path.startswith("wasbs://") # Azure
69
- or path.startswith("gcs://") # GCP
88
+ or path.startswith("gcs://")
89
+ or path.startswith("gs://") # GCP
70
90
  )
71
91
 
72
92