snowpark-connect 1.6.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. snowflake/snowpark_connect/client/server.py +37 -0
  2. snowflake/snowpark_connect/config.py +72 -3
  3. snowflake/snowpark_connect/expression/error_utils.py +28 -0
  4. snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
  5. snowflake/snowpark_connect/expression/map_cast.py +108 -17
  6. snowflake/snowpark_connect/expression/map_udf.py +1 -0
  7. snowflake/snowpark_connect/expression/map_unresolved_function.py +229 -96
  8. snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
  9. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.2.0.jar +0 -0
  10. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
  11. snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
  12. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
  13. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.13-3.5.6.jar +0 -0
  14. snowflake/snowpark_connect/includes/jars/spark-sql_2.13-3.5.6.jar +0 -0
  15. snowflake/snowpark_connect/relation/map_aggregate.py +43 -1
  16. snowflake/snowpark_connect/relation/read/map_read_csv.py +73 -4
  17. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +4 -1
  18. snowflake/snowpark_connect/relation/read/map_read_json.py +4 -1
  19. snowflake/snowpark_connect/relation/read/map_read_parquet.py +4 -1
  20. snowflake/snowpark_connect/relation/read/map_read_socket.py +4 -0
  21. snowflake/snowpark_connect/relation/read/map_read_table.py +4 -1
  22. snowflake/snowpark_connect/relation/read/map_read_text.py +4 -1
  23. snowflake/snowpark_connect/relation/read/reader_config.py +6 -0
  24. snowflake/snowpark_connect/resources_initializer.py +90 -29
  25. snowflake/snowpark_connect/server.py +6 -41
  26. snowflake/snowpark_connect/server_common/__init__.py +4 -1
  27. snowflake/snowpark_connect/type_support.py +130 -0
  28. snowflake/snowpark_connect/utils/context.py +8 -0
  29. snowflake/snowpark_connect/utils/java_stored_procedure.py +53 -27
  30. snowflake/snowpark_connect/utils/java_udaf_utils.py +46 -28
  31. snowflake/snowpark_connect/utils/java_udtf_utils.py +1 -1
  32. snowflake/snowpark_connect/utils/jvm_udf_utils.py +48 -15
  33. snowflake/snowpark_connect/utils/scala_udf_utils.py +98 -22
  34. snowflake/snowpark_connect/utils/telemetry.py +33 -22
  35. snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
  36. snowflake/snowpark_connect/version.py +1 -1
  37. {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +12 -2
  38. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +4 -2
  39. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +46 -37
  40. {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
  41. {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
  42. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
  43. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
  44. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  45. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
  46. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
@@ -58,6 +58,7 @@ from snowflake.snowpark_connect.server_common import ( # noqa: F401 - re-export
58
58
  from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
59
59
  from snowflake.snowpark_connect.utils.env_utils import get_int_from_env
60
60
  from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
61
+ from snowflake.snowpark_connect.utils.telemetry import telemetry
61
62
  from spark.connect import envelope_pb2
62
63
 
63
64
 
@@ -194,6 +195,7 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
194
195
  """Execute a Spark plan by forwarding to GS backend."""
195
196
  logger.debug("Received Execute Plan request")
196
197
  query_id = None
198
+ telemetry.initialize_request_summary(request)
197
199
 
198
200
  try:
199
201
  spark_resource = self._get_spark_resource()
@@ -216,12 +218,16 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
216
218
  )
217
219
 
218
220
  except GrpcErrorStatusException as e:
221
+ telemetry.report_request_failure(e)
219
222
  context.abort_with_status(rpc_status.to_status(e.status))
220
223
  except Exception as e:
224
+ telemetry.report_request_failure(e)
221
225
  logger.error(f"Error in ExecutePlan, query id {query_id}", exc_info=True)
222
226
  return _log_and_return_error(
223
227
  "Error in ExecutePlan call", e, grpc.StatusCode.INTERNAL, context
224
228
  )
229
+ finally:
230
+ telemetry.send_request_summary_telemetry()
225
231
 
226
232
  def _call_backend_config(
227
233
  self, request: base_pb2.ConfigRequest
@@ -299,6 +305,7 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
299
305
  self, request: base_pb2.ConfigRequest, context: grpc.ServicerContext
300
306
  ) -> base_pb2.ConfigResponse:
301
307
  logger.debug("Received Config request")
308
+ telemetry.initialize_request_summary(request)
302
309
 
303
310
  try:
304
311
  op = request.operation
@@ -370,18 +377,23 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
370
377
  return self._call_backend_config(request)
371
378
 
372
379
  except GrpcErrorStatusException as e:
380
+ telemetry.report_request_failure(e)
373
381
  context.abort_with_status(rpc_status.to_status(e.status))
374
382
  except Exception as e:
383
+ telemetry.report_request_failure(e)
375
384
  logger.error("Error in Config", exc_info=True)
376
385
  return _log_and_return_error(
377
386
  "Error in Config call", e, grpc.StatusCode.INTERNAL, context
378
387
  )
388
+ finally:
389
+ telemetry.send_request_summary_telemetry()
379
390
 
380
391
  def AnalyzePlan(
381
392
  self, request: base_pb2.AnalyzePlanRequest, context: grpc.ServicerContext
382
393
  ) -> base_pb2.AnalyzePlanResponse:
383
394
  logger.debug("Received Analyze Plan request")
384
395
  query_id = None
396
+ telemetry.initialize_request_summary(request)
385
397
 
386
398
  try:
387
399
  spark_resource = self._get_spark_resource()
@@ -403,12 +415,16 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
403
415
  return resp_envelope.analyze_plan_response
404
416
 
405
417
  except GrpcErrorStatusException as e:
418
+ telemetry.report_request_failure(e)
406
419
  context.abort_with_status(rpc_status.to_status(e.status))
407
420
  except Exception as e:
421
+ telemetry.report_request_failure(e)
408
422
  logger.error(f"Error in AnalyzePlan, query id {query_id}", exc_info=True)
409
423
  return _log_and_return_error(
410
424
  "Error in AnalyzePlan call", e, grpc.StatusCode.INTERNAL, context
411
425
  )
426
+ finally:
427
+ telemetry.send_request_summary_telemetry()
412
428
 
413
429
  def AddArtifacts(
414
430
  self,
@@ -422,6 +438,7 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
422
438
 
423
439
  for request in request_iterator:
424
440
  query_id = None
441
+ telemetry.initialize_request_summary(request)
425
442
  try:
426
443
  response_bytes = spark_resource.add_artifacts(
427
444
  request.SerializeToString()
@@ -444,14 +461,18 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
444
461
  add_artifacts_response = resp_envelope.add_artifacts_response
445
462
 
446
463
  except GrpcErrorStatusException as e:
464
+ telemetry.report_request_failure(e)
447
465
  context.abort_with_status(rpc_status.to_status(e.status))
448
466
  except Exception as e:
467
+ telemetry.report_request_failure(e)
449
468
  logger.error(
450
469
  f"Error in AddArtifacts, query id {query_id}", exc_info=True
451
470
  )
452
471
  return _log_and_return_error(
453
472
  "Error in AddArtifacts call", e, grpc.StatusCode.INTERNAL, context
454
473
  )
474
+ finally:
475
+ telemetry.send_request_summary_telemetry()
455
476
 
456
477
  if add_artifacts_response is None:
457
478
  raise ValueError("AddArtifacts received empty request_iterator")
@@ -464,6 +485,7 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
464
485
  """Check statuses of artifacts in the session and returns them in a [[ArtifactStatusesResponse]]"""
465
486
  logger.debug("Received ArtifactStatus request")
466
487
  query_id = None
488
+ telemetry.initialize_request_summary(request)
467
489
 
468
490
  try:
469
491
  spark_resource = self._get_spark_resource()
@@ -485,12 +507,16 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
485
507
 
486
508
  return resp_envelope.artifact_status_response
487
509
  except GrpcErrorStatusException as e:
510
+ telemetry.report_request_failure(e)
488
511
  context.abort_with_status(rpc_status.to_status(e.status))
489
512
  except Exception as e:
513
+ telemetry.report_request_failure(e)
490
514
  logger.error(f"Error in ArtifactStatus, query id {query_id}", exc_info=True)
491
515
  return _log_and_return_error(
492
516
  "Error in ArtifactStatus call", e, grpc.StatusCode.INTERNAL, context
493
517
  )
518
+ finally:
519
+ telemetry.send_request_summary_telemetry()
494
520
 
495
521
  def Interrupt(
496
522
  self, request: base_pb2.InterruptRequest, context: grpc.ServicerContext
@@ -505,16 +531,20 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
505
531
  ) -> base_pb2.ReleaseExecuteResponse:
506
532
  """Release an execution."""
507
533
  logger.debug("Received Release Execute request")
534
+ telemetry.initialize_request_summary(request)
508
535
  try:
509
536
  return base_pb2.ReleaseExecuteResponse(
510
537
  session_id=request.session_id,
511
538
  operation_id=request.operation_id or str(uuid.uuid4()),
512
539
  )
513
540
  except Exception as e:
541
+ telemetry.report_request_failure(e)
514
542
  logger.error("Error in ReleaseExecute", exc_info=True)
515
543
  return _log_and_return_error(
516
544
  "Error in ReleaseExecute call", e, grpc.StatusCode.INTERNAL, context
517
545
  )
546
+ finally:
547
+ telemetry.send_request_summary_telemetry()
518
548
 
519
549
  def ReattachExecute(
520
550
  self, request: base_pb2.ReattachExecuteRequest, context: grpc.ServicerContext
@@ -542,6 +572,9 @@ def _serve(
542
572
  if session is None:
543
573
  session = get_or_create_snowpark_session()
544
574
 
575
+ # Initialize telemetry with session and thin client source identifier
576
+ telemetry.initialize(session, source="SparkConnectLightWeightClient")
577
+
545
578
  server_options = _get_default_grpc_options()
546
579
  max_workers = get_int_from_env("SPARK_CONNECT_CLIENT_GRPC_MAX_WORKERS", 10)
547
580
 
@@ -560,6 +593,7 @@ def _serve(
560
593
  server.start()
561
594
  server_running.set()
562
595
  logger.info("Snowpark Connect server started!")
596
+ telemetry.send_server_started_telemetry()
563
597
 
564
598
  if stop_event is not None:
565
599
  # start a background thread to listen for stop event and terminate the server
@@ -579,6 +613,9 @@ def _serve(
579
613
  logger.error("Error starting up Snowpark Connect server", exc_info=True)
580
614
  attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
581
615
  raise e
616
+ finally:
617
+ # Flush the telemetry queue if possible
618
+ telemetry.shutdown()
582
619
 
583
620
 
584
621
  def start_session(
@@ -23,6 +23,7 @@ from snowflake.snowpark.exceptions import SnowparkSQLException
23
23
  from snowflake.snowpark.types import TimestampTimeZone, TimestampType
24
24
  from snowflake.snowpark_connect.error.error_codes import ErrorCodes
25
25
  from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
26
+ from snowflake.snowpark_connect.type_support import set_integral_types_conversion
26
27
  from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
27
28
  from snowflake.snowpark_connect.utils.context import (
28
29
  get_jpype_jclass_lock,
@@ -159,7 +160,16 @@ class GlobalConfig:
159
160
  # USE_VECTORIZED_SCANNER will become the default in a future BCR; Snowflake recommends setting it to TRUE for new workloads.
160
161
  # This significantly reduces latency for loading Parquet files by downloading only relevant columnar sections into memory.
161
162
  "snowpark.connect.parquet.useVectorizedScanner": "true",
163
+ # USE_LOGICAL_TYPE enables proper handling of Parquet logical types (TIMESTAMP, DATE, DECIMAL).
164
+ # Without useLogicalType set to "true", Parquet TIMESTAMP (INT64 physical) is incorrectly read as NUMBER(38,0).
165
+ "snowpark.connect.parquet.useLogicalType": "false",
162
166
  "spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue": "false",
167
+ "spark.sql.parquet.outputTimestampType": "TIMESTAMP_MILLIS",
168
+ "snowpark.connect.handleIntegralOverflow": "false",
169
+ "snowpark.connect.scala.version": "2.12",
170
+ # Control whether to convert decimal - to integral types and vice versa: DecimalType(p,0) <-> ByteType/ShortType/IntegerType/LongType
171
+ # Values: "client_default" (behavior based on client type), "enabled", "disabled"
172
+ "snowpark.connect.integralTypesEmulation": "client_default",
163
173
  }
164
174
 
165
175
  boolean_config_list = [
@@ -170,12 +180,14 @@ class GlobalConfig:
170
180
  "spark.sql.caseSensitive",
171
181
  "snowpark.connect.localRelation.optimizeSmallData",
172
182
  "snowpark.connect.parquet.useVectorizedScanner",
183
+ "snowpark.connect.parquet.useLogicalType",
173
184
  "spark.sql.ansi.enabled",
174
185
  "spark.sql.legacy.allowHashOnMapType",
175
186
  "spark.Catalog.databaseFilterInformationSchema",
176
187
  "spark.sql.parser.quotedRegexColumnNames",
177
188
  "snowflake.repartition.for.writes",
178
189
  "spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue",
190
+ "snowpark.connect.handleIntegralOverflow",
179
191
  ]
180
192
 
181
193
  int_config_list = [
@@ -192,8 +204,15 @@ class GlobalConfig:
192
204
  "spark.app.name": lambda session, name: setattr(
193
205
  session, "query_tag", f"Spark-Connect-App-Name={name}"
194
206
  ),
207
+ # TODO SNOW-2896871: Remove with version 1.10.0
195
208
  "snowpark.connect.udf.imports": lambda session, imports: parse_imports(
196
- session, imports
209
+ session, imports, "python"
210
+ ),
211
+ "snowpark.connect.udf.python.imports": lambda session, imports: parse_imports(
212
+ session, imports, "python"
213
+ ),
214
+ "snowpark.connect.udf.java.imports": lambda session, imports: parse_imports(
215
+ session, imports, "java"
197
216
  ),
198
217
  }
199
218
 
@@ -359,6 +378,11 @@ CONFIG_ALLOWED_VALUES: dict[str, tuple] = {
359
378
  "all",
360
379
  "none",
361
380
  ),
381
+ "snowpark.connect.integralTypesEmulation": (
382
+ "client_default",
383
+ "enabled",
384
+ "disabled",
385
+ ),
362
386
  }
363
387
 
364
388
  # Set some default configuration that are necessary for the driver.
@@ -641,6 +665,27 @@ def set_snowflake_parameters(
641
665
  # TODO: SNOW-2367714 Remove this once the fix is automatically enabled in Snowpark
642
666
  snowpark.context._enable_fix_2360274 = str_to_bool(value)
643
667
  logger.info(f"Updated snowpark session structured types fix: {value}")
668
+ case "spark.sql.parquet.outputTimestampType":
669
+ if value == "TIMESTAMP_MICROS":
670
+ snowpark_session.sql(
671
+ "ALTER SESSION SET UNLOAD_PARQUET_TIME_TIMESTAMP_MILLIS = false"
672
+ ).collect()
673
+ else:
674
+ # Default: TIMESTAMP_MILLIS (or any other value)
675
+ snowpark_session.sql(
676
+ "ALTER SESSION SET UNLOAD_PARQUET_TIME_TIMESTAMP_MILLIS = true"
677
+ ).collect()
678
+ logger.info(f"Updated parquet timestamp output type to: {value}")
679
+ case "snowpark.connect.scala.version":
680
+ # force java udf helper recreation
681
+ set_java_udf_creator_initialized_state(False)
682
+ case "snowpark.connect.integralTypesEmulation":
683
+ # "client_default" - don't change, let set_spark_version handle it
684
+ # "enabled" / "disabled" - explicitly set
685
+ if value.lower() == "enabled":
686
+ set_integral_types_conversion(True)
687
+ elif value.lower() == "disabled":
688
+ set_integral_types_conversion(False)
644
689
  case _:
645
690
  pass
646
691
 
@@ -726,15 +771,22 @@ def external_table_location() -> Optional[str]:
726
771
  )
727
772
 
728
773
 
729
- def parse_imports(session: snowpark.Session, imports: str | None) -> None:
774
+ def parse_imports(
775
+ session: snowpark.Session, imports: str | None, language: str
776
+ ) -> None:
730
777
  if not imports:
731
778
  return
732
779
 
733
780
  # UDF needs to be recreated to include new imports
734
781
  clear_external_udxf_cache(session)
782
+ if language == "java":
783
+
784
+ set_java_udf_creator_initialized_state(False)
735
785
 
736
786
  for udf_import in imports.strip("[] ").split(","):
737
- session.add_import(udf_import)
787
+ udf_import = udf_import.strip()
788
+ if udf_import:
789
+ session.add_import(udf_import)
738
790
 
739
791
 
740
792
  def get_timestamp_type():
@@ -827,3 +879,20 @@ def check_table_supports_operation(table_identifier: str, operation: str) -> boo
827
879
  return table_metadata.get("supports_column_rename", True)
828
880
 
829
881
  return True
882
+
883
+
884
+ def get_scala_version() -> str:
885
+ return global_config.get("snowpark.connect.scala.version")
886
+
887
+
888
+ _java_udf_creator_initialized = False
889
+
890
+
891
+ def is_java_udf_creator_initialized() -> bool:
892
+ global _java_udf_creator_initialized
893
+ return _java_udf_creator_initialized
894
+
895
+
896
+ def set_java_udf_creator_initialized_state(value: bool) -> None:
897
+ global _java_udf_creator_initialized
898
+ _java_udf_creator_initialized = value
@@ -0,0 +1,28 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ import snowflake.snowpark.functions as snowpark_fn
6
+ from snowflake.snowpark.column import Column
7
+ from snowflake.snowpark.types import DataType, StringType
8
+
9
+
10
+ def raise_error_helper(return_type: DataType, error_class=None):
11
+ error_class_str = (
12
+ f":{error_class.__name__}"
13
+ if error_class and hasattr(error_class, "__name__")
14
+ else ""
15
+ )
16
+
17
+ def _raise_fn(*msgs: Column) -> Column:
18
+ return snowpark_fn.cast(
19
+ snowpark_fn.abs(
20
+ snowpark_fn.concat(
21
+ snowpark_fn.lit(f"[snowpark-connect-exception{error_class_str}]"),
22
+ *(msg.try_cast(StringType()) for msg in msgs),
23
+ )
24
+ ).cast(StringType()),
25
+ return_type,
26
+ )
27
+
28
+ return _raise_fn
@@ -0,0 +1,219 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ from pyspark.errors.exceptions.base import ArithmeticException
6
+
7
+ import snowflake.snowpark.functions as snowpark_fn
8
+ from snowflake.snowpark.column import Column
9
+ from snowflake.snowpark.types import (
10
+ ByteType,
11
+ DataType,
12
+ IntegerType,
13
+ LongType,
14
+ ShortType,
15
+ StringType,
16
+ )
17
+ from snowflake.snowpark_connect.config import global_config
18
+ from snowflake.snowpark_connect.expression.error_utils import raise_error_helper
19
+
20
+
21
+ def get_integral_type_bounds(typ: DataType) -> tuple[int, int]:
22
+ if isinstance(typ, ByteType):
23
+ return (-128, 127)
24
+ elif isinstance(typ, ShortType):
25
+ return (-32768, 32767)
26
+ elif isinstance(typ, IntegerType):
27
+ return (-2147483648, 2147483647)
28
+ elif isinstance(typ, LongType):
29
+ return (-9223372036854775808, 9223372036854775807)
30
+ else:
31
+ raise ValueError(f"Unsupported integral type: {typ}")
32
+
33
+
34
+ def apply_integral_overflow(col: Column, to_type: DataType) -> Column:
35
+ if not global_config.snowpark_connect_handleIntegralOverflow:
36
+ return col.cast(to_type)
37
+
38
+ min_val, max_val = get_integral_type_bounds(to_type)
39
+ range_size = max_val - min_val + 1
40
+
41
+ offset_value = col - snowpark_fn.lit(min_val)
42
+ wrapped_offset = snowpark_fn.function("MOD")(
43
+ offset_value, snowpark_fn.lit(range_size)
44
+ )
45
+
46
+ wrapped_offset = snowpark_fn.when(
47
+ wrapped_offset < 0, wrapped_offset + snowpark_fn.lit(range_size)
48
+ ).otherwise(wrapped_offset)
49
+
50
+ wrapped_result = wrapped_offset + snowpark_fn.lit(min_val)
51
+
52
+ return snowpark_fn.when(
53
+ (col >= snowpark_fn.lit(min_val)) & (col <= snowpark_fn.lit(max_val)),
54
+ col.cast(to_type),
55
+ ).otherwise(wrapped_result.cast(to_type))
56
+
57
+
58
+ def apply_fractional_to_integral_cast(col: Column, to_type: DataType) -> Column:
59
+ if not global_config.snowpark_connect_handleIntegralOverflow:
60
+ return col.cast(to_type)
61
+
62
+ min_val, max_val = get_integral_type_bounds(to_type)
63
+
64
+ clamped = (
65
+ snowpark_fn.when(col > snowpark_fn.lit(max_val), snowpark_fn.lit(max_val))
66
+ .when(col < snowpark_fn.lit(min_val), snowpark_fn.lit(min_val))
67
+ .otherwise(col)
68
+ )
69
+
70
+ return clamped.cast(to_type)
71
+
72
+
73
+ def apply_integral_overflow_with_ansi_check(
74
+ col: Column, to_type: DataType, ansi_enabled: bool
75
+ ) -> Column:
76
+ if not global_config.snowpark_connect_handleIntegralOverflow:
77
+ return col.cast(to_type)
78
+
79
+ if not ansi_enabled:
80
+ return apply_integral_overflow(col, to_type)
81
+
82
+ min_val, max_val = get_integral_type_bounds(to_type)
83
+ type_name = to_type.typeName().upper()
84
+
85
+ raise_error = raise_error_helper(to_type, ArithmeticException)
86
+
87
+ return snowpark_fn.when(
88
+ (col < snowpark_fn.lit(min_val)) | (col > snowpark_fn.lit(max_val)),
89
+ raise_error(
90
+ snowpark_fn.lit("[CAST_OVERFLOW] The value "),
91
+ col.cast(StringType()),
92
+ snowpark_fn.lit(
93
+ f" of the type BIGINT cannot be cast to {type_name} due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead."
94
+ ),
95
+ ),
96
+ ).otherwise(col.cast(to_type))
97
+
98
+
99
+ def apply_fractional_to_integral_cast_with_ansi_check(
100
+ col: Column, to_type: DataType, ansi_enabled: bool
101
+ ) -> Column:
102
+ if not global_config.snowpark_connect_handleIntegralOverflow:
103
+ return col.cast(to_type)
104
+
105
+ if not ansi_enabled:
106
+ return apply_fractional_to_integral_cast(col, to_type)
107
+
108
+ min_val, max_val = get_integral_type_bounds(to_type)
109
+ type_name = to_type.typeName().upper()
110
+
111
+ raise_error = raise_error_helper(to_type, ArithmeticException)
112
+
113
+ return snowpark_fn.when(
114
+ (col < snowpark_fn.lit(min_val)) | (col > snowpark_fn.lit(max_val)),
115
+ raise_error(
116
+ snowpark_fn.lit("[CAST_OVERFLOW] The value "),
117
+ col.cast(StringType()),
118
+ snowpark_fn.lit(
119
+ f" of the type DOUBLE cannot be cast to {type_name} "
120
+ f"due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead."
121
+ ),
122
+ ),
123
+ ).otherwise(col.cast(to_type))
124
+
125
+
126
+ def apply_arithmetic_overflow_with_ansi_check(
127
+ result_col: Column, result_type: DataType, ansi_enabled: bool, operation_name: str
128
+ ) -> Column:
129
+ if not global_config.snowpark_connect_handleIntegralOverflow:
130
+ return result_col.cast(result_type)
131
+
132
+ if not ansi_enabled:
133
+ return apply_integral_overflow(result_col, result_type)
134
+
135
+ min_val, max_val = get_integral_type_bounds(result_type)
136
+
137
+ raise_error = raise_error_helper(result_type, ArithmeticException)
138
+
139
+ return snowpark_fn.when(
140
+ (result_col < snowpark_fn.lit(min_val))
141
+ | (result_col > snowpark_fn.lit(max_val)),
142
+ raise_error(
143
+ snowpark_fn.lit(
144
+ f"[ARITHMETIC_OVERFLOW] {operation_name} overflow. "
145
+ f"Use 'try_{operation_name.lower()}' to tolerate overflow and return NULL instead. "
146
+ f'If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.'
147
+ ),
148
+ ),
149
+ ).otherwise(result_col.cast(result_type))
150
+
151
+
152
+ def apply_unary_overflow(value_col: Column, result_type: DataType) -> Column:
153
+ if not global_config.snowpark_connect_handleIntegralOverflow:
154
+ return (value_col * snowpark_fn.lit(-1)).cast(result_type)
155
+
156
+ min_val, _ = get_integral_type_bounds(result_type)
157
+ return snowpark_fn.when(
158
+ value_col == snowpark_fn.lit(min_val),
159
+ snowpark_fn.lit(min_val).cast(result_type),
160
+ ).otherwise((value_col * snowpark_fn.lit(-1)).cast(result_type))
161
+
162
+
163
+ def apply_unary_overflow_with_ansi_check(
164
+ value_col: Column, result_type: DataType, ansi_enabled: bool, operation_name: str
165
+ ) -> Column:
166
+ if not global_config.snowpark_connect_handleIntegralOverflow:
167
+ return (value_col * snowpark_fn.lit(-1)).cast(result_type)
168
+
169
+ if not ansi_enabled:
170
+ return apply_unary_overflow(value_col, result_type)
171
+
172
+ min_val, _ = get_integral_type_bounds(result_type)
173
+
174
+ raise_error = raise_error_helper(result_type, ArithmeticException)
175
+
176
+ return snowpark_fn.when(
177
+ value_col == snowpark_fn.lit(min_val),
178
+ raise_error(
179
+ snowpark_fn.lit(
180
+ f"[ARITHMETIC_OVERFLOW] {operation_name} overflow. "
181
+ f'If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.'
182
+ ),
183
+ ),
184
+ ).otherwise((value_col * snowpark_fn.lit(-1)).cast(result_type))
185
+
186
+
187
+ def apply_abs_overflow(value_col: Column, result_type: DataType) -> Column:
188
+ if not global_config.snowpark_connect_handleIntegralOverflow:
189
+ return snowpark_fn.abs(value_col).cast(result_type)
190
+
191
+ min_val, _ = get_integral_type_bounds(result_type)
192
+ return snowpark_fn.when(
193
+ value_col == snowpark_fn.lit(min_val),
194
+ snowpark_fn.lit(min_val).cast(result_type),
195
+ ).otherwise(snowpark_fn.abs(value_col).cast(result_type))
196
+
197
+
198
+ def apply_abs_overflow_with_ansi_check(
199
+ value_col: Column, result_type: DataType, ansi_enabled: bool
200
+ ) -> Column:
201
+ if not global_config.snowpark_connect_handleIntegralOverflow:
202
+ return snowpark_fn.abs(value_col).cast(result_type)
203
+
204
+ if not ansi_enabled:
205
+ return apply_abs_overflow(value_col, result_type)
206
+
207
+ min_val, _ = get_integral_type_bounds(result_type)
208
+
209
+ raise_error = raise_error_helper(result_type, ArithmeticException)
210
+
211
+ return snowpark_fn.when(
212
+ value_col == snowpark_fn.lit(min_val),
213
+ raise_error(
214
+ snowpark_fn.lit(
215
+ "[ARITHMETIC_OVERFLOW] abs overflow. "
216
+ 'If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.'
217
+ ),
218
+ ),
219
+ ).otherwise(snowpark_fn.abs(value_col).cast(result_type))