snowpark-connect 1.6.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/client/server.py +37 -0
- snowflake/snowpark_connect/config.py +72 -3
- snowflake/snowpark_connect/expression/error_utils.py +28 -0
- snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
- snowflake/snowpark_connect/expression/map_cast.py +108 -17
- snowflake/snowpark_connect/expression/map_udf.py +1 -0
- snowflake/snowpark_connect/expression/map_unresolved_function.py +229 -96
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/relation/map_aggregate.py +43 -1
- snowflake/snowpark_connect/relation/read/map_read_csv.py +73 -4
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_json.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_socket.py +4 -0
- snowflake/snowpark_connect/relation/read/map_read_table.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_text.py +4 -1
- snowflake/snowpark_connect/relation/read/reader_config.py +6 -0
- snowflake/snowpark_connect/resources_initializer.py +90 -29
- snowflake/snowpark_connect/server.py +6 -41
- snowflake/snowpark_connect/server_common/__init__.py +4 -1
- snowflake/snowpark_connect/type_support.py +130 -0
- snowflake/snowpark_connect/utils/context.py +8 -0
- snowflake/snowpark_connect/utils/java_stored_procedure.py +53 -27
- snowflake/snowpark_connect/utils/java_udaf_utils.py +46 -28
- snowflake/snowpark_connect/utils/java_udtf_utils.py +1 -1
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +48 -15
- snowflake/snowpark_connect/utils/scala_udf_utils.py +98 -22
- snowflake/snowpark_connect/utils/telemetry.py +33 -22
- snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +12 -2
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +4 -2
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +46 -37
- {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -58,6 +58,7 @@ from snowflake.snowpark_connect.server_common import ( # noqa: F401 - re-export
|
|
|
58
58
|
from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
|
|
59
59
|
from snowflake.snowpark_connect.utils.env_utils import get_int_from_env
|
|
60
60
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
61
|
+
from snowflake.snowpark_connect.utils.telemetry import telemetry
|
|
61
62
|
from spark.connect import envelope_pb2
|
|
62
63
|
|
|
63
64
|
|
|
@@ -194,6 +195,7 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
|
|
|
194
195
|
"""Execute a Spark plan by forwarding to GS backend."""
|
|
195
196
|
logger.debug("Received Execute Plan request")
|
|
196
197
|
query_id = None
|
|
198
|
+
telemetry.initialize_request_summary(request)
|
|
197
199
|
|
|
198
200
|
try:
|
|
199
201
|
spark_resource = self._get_spark_resource()
|
|
@@ -216,12 +218,16 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
|
|
|
216
218
|
)
|
|
217
219
|
|
|
218
220
|
except GrpcErrorStatusException as e:
|
|
221
|
+
telemetry.report_request_failure(e)
|
|
219
222
|
context.abort_with_status(rpc_status.to_status(e.status))
|
|
220
223
|
except Exception as e:
|
|
224
|
+
telemetry.report_request_failure(e)
|
|
221
225
|
logger.error(f"Error in ExecutePlan, query id {query_id}", exc_info=True)
|
|
222
226
|
return _log_and_return_error(
|
|
223
227
|
"Error in ExecutePlan call", e, grpc.StatusCode.INTERNAL, context
|
|
224
228
|
)
|
|
229
|
+
finally:
|
|
230
|
+
telemetry.send_request_summary_telemetry()
|
|
225
231
|
|
|
226
232
|
def _call_backend_config(
|
|
227
233
|
self, request: base_pb2.ConfigRequest
|
|
@@ -299,6 +305,7 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
|
|
|
299
305
|
self, request: base_pb2.ConfigRequest, context: grpc.ServicerContext
|
|
300
306
|
) -> base_pb2.ConfigResponse:
|
|
301
307
|
logger.debug("Received Config request")
|
|
308
|
+
telemetry.initialize_request_summary(request)
|
|
302
309
|
|
|
303
310
|
try:
|
|
304
311
|
op = request.operation
|
|
@@ -370,18 +377,23 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
|
|
|
370
377
|
return self._call_backend_config(request)
|
|
371
378
|
|
|
372
379
|
except GrpcErrorStatusException as e:
|
|
380
|
+
telemetry.report_request_failure(e)
|
|
373
381
|
context.abort_with_status(rpc_status.to_status(e.status))
|
|
374
382
|
except Exception as e:
|
|
383
|
+
telemetry.report_request_failure(e)
|
|
375
384
|
logger.error("Error in Config", exc_info=True)
|
|
376
385
|
return _log_and_return_error(
|
|
377
386
|
"Error in Config call", e, grpc.StatusCode.INTERNAL, context
|
|
378
387
|
)
|
|
388
|
+
finally:
|
|
389
|
+
telemetry.send_request_summary_telemetry()
|
|
379
390
|
|
|
380
391
|
def AnalyzePlan(
|
|
381
392
|
self, request: base_pb2.AnalyzePlanRequest, context: grpc.ServicerContext
|
|
382
393
|
) -> base_pb2.AnalyzePlanResponse:
|
|
383
394
|
logger.debug("Received Analyze Plan request")
|
|
384
395
|
query_id = None
|
|
396
|
+
telemetry.initialize_request_summary(request)
|
|
385
397
|
|
|
386
398
|
try:
|
|
387
399
|
spark_resource = self._get_spark_resource()
|
|
@@ -403,12 +415,16 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
|
|
|
403
415
|
return resp_envelope.analyze_plan_response
|
|
404
416
|
|
|
405
417
|
except GrpcErrorStatusException as e:
|
|
418
|
+
telemetry.report_request_failure(e)
|
|
406
419
|
context.abort_with_status(rpc_status.to_status(e.status))
|
|
407
420
|
except Exception as e:
|
|
421
|
+
telemetry.report_request_failure(e)
|
|
408
422
|
logger.error(f"Error in AnalyzePlan, query id {query_id}", exc_info=True)
|
|
409
423
|
return _log_and_return_error(
|
|
410
424
|
"Error in AnalyzePlan call", e, grpc.StatusCode.INTERNAL, context
|
|
411
425
|
)
|
|
426
|
+
finally:
|
|
427
|
+
telemetry.send_request_summary_telemetry()
|
|
412
428
|
|
|
413
429
|
def AddArtifacts(
|
|
414
430
|
self,
|
|
@@ -422,6 +438,7 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
|
|
|
422
438
|
|
|
423
439
|
for request in request_iterator:
|
|
424
440
|
query_id = None
|
|
441
|
+
telemetry.initialize_request_summary(request)
|
|
425
442
|
try:
|
|
426
443
|
response_bytes = spark_resource.add_artifacts(
|
|
427
444
|
request.SerializeToString()
|
|
@@ -444,14 +461,18 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
|
|
|
444
461
|
add_artifacts_response = resp_envelope.add_artifacts_response
|
|
445
462
|
|
|
446
463
|
except GrpcErrorStatusException as e:
|
|
464
|
+
telemetry.report_request_failure(e)
|
|
447
465
|
context.abort_with_status(rpc_status.to_status(e.status))
|
|
448
466
|
except Exception as e:
|
|
467
|
+
telemetry.report_request_failure(e)
|
|
449
468
|
logger.error(
|
|
450
469
|
f"Error in AddArtifacts, query id {query_id}", exc_info=True
|
|
451
470
|
)
|
|
452
471
|
return _log_and_return_error(
|
|
453
472
|
"Error in AddArtifacts call", e, grpc.StatusCode.INTERNAL, context
|
|
454
473
|
)
|
|
474
|
+
finally:
|
|
475
|
+
telemetry.send_request_summary_telemetry()
|
|
455
476
|
|
|
456
477
|
if add_artifacts_response is None:
|
|
457
478
|
raise ValueError("AddArtifacts received empty request_iterator")
|
|
@@ -464,6 +485,7 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
|
|
|
464
485
|
"""Check statuses of artifacts in the session and returns them in a [[ArtifactStatusesResponse]]"""
|
|
465
486
|
logger.debug("Received ArtifactStatus request")
|
|
466
487
|
query_id = None
|
|
488
|
+
telemetry.initialize_request_summary(request)
|
|
467
489
|
|
|
468
490
|
try:
|
|
469
491
|
spark_resource = self._get_spark_resource()
|
|
@@ -485,12 +507,16 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
|
|
|
485
507
|
|
|
486
508
|
return resp_envelope.artifact_status_response
|
|
487
509
|
except GrpcErrorStatusException as e:
|
|
510
|
+
telemetry.report_request_failure(e)
|
|
488
511
|
context.abort_with_status(rpc_status.to_status(e.status))
|
|
489
512
|
except Exception as e:
|
|
513
|
+
telemetry.report_request_failure(e)
|
|
490
514
|
logger.error(f"Error in ArtifactStatus, query id {query_id}", exc_info=True)
|
|
491
515
|
return _log_and_return_error(
|
|
492
516
|
"Error in ArtifactStatus call", e, grpc.StatusCode.INTERNAL, context
|
|
493
517
|
)
|
|
518
|
+
finally:
|
|
519
|
+
telemetry.send_request_summary_telemetry()
|
|
494
520
|
|
|
495
521
|
def Interrupt(
|
|
496
522
|
self, request: base_pb2.InterruptRequest, context: grpc.ServicerContext
|
|
@@ -505,16 +531,20 @@ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
|
|
|
505
531
|
) -> base_pb2.ReleaseExecuteResponse:
|
|
506
532
|
"""Release an execution."""
|
|
507
533
|
logger.debug("Received Release Execute request")
|
|
534
|
+
telemetry.initialize_request_summary(request)
|
|
508
535
|
try:
|
|
509
536
|
return base_pb2.ReleaseExecuteResponse(
|
|
510
537
|
session_id=request.session_id,
|
|
511
538
|
operation_id=request.operation_id or str(uuid.uuid4()),
|
|
512
539
|
)
|
|
513
540
|
except Exception as e:
|
|
541
|
+
telemetry.report_request_failure(e)
|
|
514
542
|
logger.error("Error in ReleaseExecute", exc_info=True)
|
|
515
543
|
return _log_and_return_error(
|
|
516
544
|
"Error in ReleaseExecute call", e, grpc.StatusCode.INTERNAL, context
|
|
517
545
|
)
|
|
546
|
+
finally:
|
|
547
|
+
telemetry.send_request_summary_telemetry()
|
|
518
548
|
|
|
519
549
|
def ReattachExecute(
|
|
520
550
|
self, request: base_pb2.ReattachExecuteRequest, context: grpc.ServicerContext
|
|
@@ -542,6 +572,9 @@ def _serve(
|
|
|
542
572
|
if session is None:
|
|
543
573
|
session = get_or_create_snowpark_session()
|
|
544
574
|
|
|
575
|
+
# Initialize telemetry with session and thin client source identifier
|
|
576
|
+
telemetry.initialize(session, source="SparkConnectLightWeightClient")
|
|
577
|
+
|
|
545
578
|
server_options = _get_default_grpc_options()
|
|
546
579
|
max_workers = get_int_from_env("SPARK_CONNECT_CLIENT_GRPC_MAX_WORKERS", 10)
|
|
547
580
|
|
|
@@ -560,6 +593,7 @@ def _serve(
|
|
|
560
593
|
server.start()
|
|
561
594
|
server_running.set()
|
|
562
595
|
logger.info("Snowpark Connect server started!")
|
|
596
|
+
telemetry.send_server_started_telemetry()
|
|
563
597
|
|
|
564
598
|
if stop_event is not None:
|
|
565
599
|
# start a background thread to listen for stop event and terminate the server
|
|
@@ -579,6 +613,9 @@ def _serve(
|
|
|
579
613
|
logger.error("Error starting up Snowpark Connect server", exc_info=True)
|
|
580
614
|
attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
|
|
581
615
|
raise e
|
|
616
|
+
finally:
|
|
617
|
+
# Flush the telemetry queue if possible
|
|
618
|
+
telemetry.shutdown()
|
|
582
619
|
|
|
583
620
|
|
|
584
621
|
def start_session(
|
|
@@ -23,6 +23,7 @@ from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
|
23
23
|
from snowflake.snowpark.types import TimestampTimeZone, TimestampType
|
|
24
24
|
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
25
25
|
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
26
|
+
from snowflake.snowpark_connect.type_support import set_integral_types_conversion
|
|
26
27
|
from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
|
|
27
28
|
from snowflake.snowpark_connect.utils.context import (
|
|
28
29
|
get_jpype_jclass_lock,
|
|
@@ -159,7 +160,16 @@ class GlobalConfig:
|
|
|
159
160
|
# USE_VECTORIZED_SCANNER will become the default in a future BCR; Snowflake recommends setting it to TRUE for new workloads.
|
|
160
161
|
# This significantly reduces latency for loading Parquet files by downloading only relevant columnar sections into memory.
|
|
161
162
|
"snowpark.connect.parquet.useVectorizedScanner": "true",
|
|
163
|
+
# USE_LOGICAL_TYPE enables proper handling of Parquet logical types (TIMESTAMP, DATE, DECIMAL).
|
|
164
|
+
# Without useLogicalType set to "true", Parquet TIMESTAMP (INT64 physical) is incorrectly read as NUMBER(38,0).
|
|
165
|
+
"snowpark.connect.parquet.useLogicalType": "false",
|
|
162
166
|
"spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue": "false",
|
|
167
|
+
"spark.sql.parquet.outputTimestampType": "TIMESTAMP_MILLIS",
|
|
168
|
+
"snowpark.connect.handleIntegralOverflow": "false",
|
|
169
|
+
"snowpark.connect.scala.version": "2.12",
|
|
170
|
+
# Control whether to convert decimal - to integral types and vice versa: DecimalType(p,0) <-> ByteType/ShortType/IntegerType/LongType
|
|
171
|
+
# Values: "client_default" (behavior based on client type), "enabled", "disabled"
|
|
172
|
+
"snowpark.connect.integralTypesEmulation": "client_default",
|
|
163
173
|
}
|
|
164
174
|
|
|
165
175
|
boolean_config_list = [
|
|
@@ -170,12 +180,14 @@ class GlobalConfig:
|
|
|
170
180
|
"spark.sql.caseSensitive",
|
|
171
181
|
"snowpark.connect.localRelation.optimizeSmallData",
|
|
172
182
|
"snowpark.connect.parquet.useVectorizedScanner",
|
|
183
|
+
"snowpark.connect.parquet.useLogicalType",
|
|
173
184
|
"spark.sql.ansi.enabled",
|
|
174
185
|
"spark.sql.legacy.allowHashOnMapType",
|
|
175
186
|
"spark.Catalog.databaseFilterInformationSchema",
|
|
176
187
|
"spark.sql.parser.quotedRegexColumnNames",
|
|
177
188
|
"snowflake.repartition.for.writes",
|
|
178
189
|
"spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue",
|
|
190
|
+
"snowpark.connect.handleIntegralOverflow",
|
|
179
191
|
]
|
|
180
192
|
|
|
181
193
|
int_config_list = [
|
|
@@ -192,8 +204,15 @@ class GlobalConfig:
|
|
|
192
204
|
"spark.app.name": lambda session, name: setattr(
|
|
193
205
|
session, "query_tag", f"Spark-Connect-App-Name={name}"
|
|
194
206
|
),
|
|
207
|
+
# TODO SNOW-2896871: Remove with version 1.10.0
|
|
195
208
|
"snowpark.connect.udf.imports": lambda session, imports: parse_imports(
|
|
196
|
-
session, imports
|
|
209
|
+
session, imports, "python"
|
|
210
|
+
),
|
|
211
|
+
"snowpark.connect.udf.python.imports": lambda session, imports: parse_imports(
|
|
212
|
+
session, imports, "python"
|
|
213
|
+
),
|
|
214
|
+
"snowpark.connect.udf.java.imports": lambda session, imports: parse_imports(
|
|
215
|
+
session, imports, "java"
|
|
197
216
|
),
|
|
198
217
|
}
|
|
199
218
|
|
|
@@ -359,6 +378,11 @@ CONFIG_ALLOWED_VALUES: dict[str, tuple] = {
|
|
|
359
378
|
"all",
|
|
360
379
|
"none",
|
|
361
380
|
),
|
|
381
|
+
"snowpark.connect.integralTypesEmulation": (
|
|
382
|
+
"client_default",
|
|
383
|
+
"enabled",
|
|
384
|
+
"disabled",
|
|
385
|
+
),
|
|
362
386
|
}
|
|
363
387
|
|
|
364
388
|
# Set some default configuration that are necessary for the driver.
|
|
@@ -641,6 +665,27 @@ def set_snowflake_parameters(
|
|
|
641
665
|
# TODO: SNOW-2367714 Remove this once the fix is automatically enabled in Snowpark
|
|
642
666
|
snowpark.context._enable_fix_2360274 = str_to_bool(value)
|
|
643
667
|
logger.info(f"Updated snowpark session structured types fix: {value}")
|
|
668
|
+
case "spark.sql.parquet.outputTimestampType":
|
|
669
|
+
if value == "TIMESTAMP_MICROS":
|
|
670
|
+
snowpark_session.sql(
|
|
671
|
+
"ALTER SESSION SET UNLOAD_PARQUET_TIME_TIMESTAMP_MILLIS = false"
|
|
672
|
+
).collect()
|
|
673
|
+
else:
|
|
674
|
+
# Default: TIMESTAMP_MILLIS (or any other value)
|
|
675
|
+
snowpark_session.sql(
|
|
676
|
+
"ALTER SESSION SET UNLOAD_PARQUET_TIME_TIMESTAMP_MILLIS = true"
|
|
677
|
+
).collect()
|
|
678
|
+
logger.info(f"Updated parquet timestamp output type to: {value}")
|
|
679
|
+
case "snowpark.connect.scala.version":
|
|
680
|
+
# force java udf helper recreation
|
|
681
|
+
set_java_udf_creator_initialized_state(False)
|
|
682
|
+
case "snowpark.connect.integralTypesEmulation":
|
|
683
|
+
# "client_default" - don't change, let set_spark_version handle it
|
|
684
|
+
# "enabled" / "disabled" - explicitly set
|
|
685
|
+
if value.lower() == "enabled":
|
|
686
|
+
set_integral_types_conversion(True)
|
|
687
|
+
elif value.lower() == "disabled":
|
|
688
|
+
set_integral_types_conversion(False)
|
|
644
689
|
case _:
|
|
645
690
|
pass
|
|
646
691
|
|
|
@@ -726,15 +771,22 @@ def external_table_location() -> Optional[str]:
|
|
|
726
771
|
)
|
|
727
772
|
|
|
728
773
|
|
|
729
|
-
def parse_imports(
|
|
774
|
+
def parse_imports(
|
|
775
|
+
session: snowpark.Session, imports: str | None, language: str
|
|
776
|
+
) -> None:
|
|
730
777
|
if not imports:
|
|
731
778
|
return
|
|
732
779
|
|
|
733
780
|
# UDF needs to be recreated to include new imports
|
|
734
781
|
clear_external_udxf_cache(session)
|
|
782
|
+
if language == "java":
|
|
783
|
+
|
|
784
|
+
set_java_udf_creator_initialized_state(False)
|
|
735
785
|
|
|
736
786
|
for udf_import in imports.strip("[] ").split(","):
|
|
737
|
-
|
|
787
|
+
udf_import = udf_import.strip()
|
|
788
|
+
if udf_import:
|
|
789
|
+
session.add_import(udf_import)
|
|
738
790
|
|
|
739
791
|
|
|
740
792
|
def get_timestamp_type():
|
|
@@ -827,3 +879,20 @@ def check_table_supports_operation(table_identifier: str, operation: str) -> boo
|
|
|
827
879
|
return table_metadata.get("supports_column_rename", True)
|
|
828
880
|
|
|
829
881
|
return True
|
|
882
|
+
|
|
883
|
+
|
|
884
|
+
def get_scala_version() -> str:
|
|
885
|
+
return global_config.get("snowpark.connect.scala.version")
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
_java_udf_creator_initialized = False
|
|
889
|
+
|
|
890
|
+
|
|
891
|
+
def is_java_udf_creator_initialized() -> bool:
|
|
892
|
+
global _java_udf_creator_initialized
|
|
893
|
+
return _java_udf_creator_initialized
|
|
894
|
+
|
|
895
|
+
|
|
896
|
+
def set_java_udf_creator_initialized_state(value: bool) -> None:
|
|
897
|
+
global _java_udf_creator_initialized
|
|
898
|
+
_java_udf_creator_initialized = value
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
import snowflake.snowpark.functions as snowpark_fn
|
|
6
|
+
from snowflake.snowpark.column import Column
|
|
7
|
+
from snowflake.snowpark.types import DataType, StringType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def raise_error_helper(return_type: DataType, error_class=None):
|
|
11
|
+
error_class_str = (
|
|
12
|
+
f":{error_class.__name__}"
|
|
13
|
+
if error_class and hasattr(error_class, "__name__")
|
|
14
|
+
else ""
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
def _raise_fn(*msgs: Column) -> Column:
|
|
18
|
+
return snowpark_fn.cast(
|
|
19
|
+
snowpark_fn.abs(
|
|
20
|
+
snowpark_fn.concat(
|
|
21
|
+
snowpark_fn.lit(f"[snowpark-connect-exception{error_class_str}]"),
|
|
22
|
+
*(msg.try_cast(StringType()) for msg in msgs),
|
|
23
|
+
)
|
|
24
|
+
).cast(StringType()),
|
|
25
|
+
return_type,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
return _raise_fn
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
from pyspark.errors.exceptions.base import ArithmeticException
|
|
6
|
+
|
|
7
|
+
import snowflake.snowpark.functions as snowpark_fn
|
|
8
|
+
from snowflake.snowpark.column import Column
|
|
9
|
+
from snowflake.snowpark.types import (
|
|
10
|
+
ByteType,
|
|
11
|
+
DataType,
|
|
12
|
+
IntegerType,
|
|
13
|
+
LongType,
|
|
14
|
+
ShortType,
|
|
15
|
+
StringType,
|
|
16
|
+
)
|
|
17
|
+
from snowflake.snowpark_connect.config import global_config
|
|
18
|
+
from snowflake.snowpark_connect.expression.error_utils import raise_error_helper
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_integral_type_bounds(typ: DataType) -> tuple[int, int]:
|
|
22
|
+
if isinstance(typ, ByteType):
|
|
23
|
+
return (-128, 127)
|
|
24
|
+
elif isinstance(typ, ShortType):
|
|
25
|
+
return (-32768, 32767)
|
|
26
|
+
elif isinstance(typ, IntegerType):
|
|
27
|
+
return (-2147483648, 2147483647)
|
|
28
|
+
elif isinstance(typ, LongType):
|
|
29
|
+
return (-9223372036854775808, 9223372036854775807)
|
|
30
|
+
else:
|
|
31
|
+
raise ValueError(f"Unsupported integral type: {typ}")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def apply_integral_overflow(col: Column, to_type: DataType) -> Column:
|
|
35
|
+
if not global_config.snowpark_connect_handleIntegralOverflow:
|
|
36
|
+
return col.cast(to_type)
|
|
37
|
+
|
|
38
|
+
min_val, max_val = get_integral_type_bounds(to_type)
|
|
39
|
+
range_size = max_val - min_val + 1
|
|
40
|
+
|
|
41
|
+
offset_value = col - snowpark_fn.lit(min_val)
|
|
42
|
+
wrapped_offset = snowpark_fn.function("MOD")(
|
|
43
|
+
offset_value, snowpark_fn.lit(range_size)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
wrapped_offset = snowpark_fn.when(
|
|
47
|
+
wrapped_offset < 0, wrapped_offset + snowpark_fn.lit(range_size)
|
|
48
|
+
).otherwise(wrapped_offset)
|
|
49
|
+
|
|
50
|
+
wrapped_result = wrapped_offset + snowpark_fn.lit(min_val)
|
|
51
|
+
|
|
52
|
+
return snowpark_fn.when(
|
|
53
|
+
(col >= snowpark_fn.lit(min_val)) & (col <= snowpark_fn.lit(max_val)),
|
|
54
|
+
col.cast(to_type),
|
|
55
|
+
).otherwise(wrapped_result.cast(to_type))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def apply_fractional_to_integral_cast(col: Column, to_type: DataType) -> Column:
|
|
59
|
+
if not global_config.snowpark_connect_handleIntegralOverflow:
|
|
60
|
+
return col.cast(to_type)
|
|
61
|
+
|
|
62
|
+
min_val, max_val = get_integral_type_bounds(to_type)
|
|
63
|
+
|
|
64
|
+
clamped = (
|
|
65
|
+
snowpark_fn.when(col > snowpark_fn.lit(max_val), snowpark_fn.lit(max_val))
|
|
66
|
+
.when(col < snowpark_fn.lit(min_val), snowpark_fn.lit(min_val))
|
|
67
|
+
.otherwise(col)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return clamped.cast(to_type)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def apply_integral_overflow_with_ansi_check(
|
|
74
|
+
col: Column, to_type: DataType, ansi_enabled: bool
|
|
75
|
+
) -> Column:
|
|
76
|
+
if not global_config.snowpark_connect_handleIntegralOverflow:
|
|
77
|
+
return col.cast(to_type)
|
|
78
|
+
|
|
79
|
+
if not ansi_enabled:
|
|
80
|
+
return apply_integral_overflow(col, to_type)
|
|
81
|
+
|
|
82
|
+
min_val, max_val = get_integral_type_bounds(to_type)
|
|
83
|
+
type_name = to_type.typeName().upper()
|
|
84
|
+
|
|
85
|
+
raise_error = raise_error_helper(to_type, ArithmeticException)
|
|
86
|
+
|
|
87
|
+
return snowpark_fn.when(
|
|
88
|
+
(col < snowpark_fn.lit(min_val)) | (col > snowpark_fn.lit(max_val)),
|
|
89
|
+
raise_error(
|
|
90
|
+
snowpark_fn.lit("[CAST_OVERFLOW] The value "),
|
|
91
|
+
col.cast(StringType()),
|
|
92
|
+
snowpark_fn.lit(
|
|
93
|
+
f" of the type BIGINT cannot be cast to {type_name} due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead."
|
|
94
|
+
),
|
|
95
|
+
),
|
|
96
|
+
).otherwise(col.cast(to_type))
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def apply_fractional_to_integral_cast_with_ansi_check(
|
|
100
|
+
col: Column, to_type: DataType, ansi_enabled: bool
|
|
101
|
+
) -> Column:
|
|
102
|
+
if not global_config.snowpark_connect_handleIntegralOverflow:
|
|
103
|
+
return col.cast(to_type)
|
|
104
|
+
|
|
105
|
+
if not ansi_enabled:
|
|
106
|
+
return apply_fractional_to_integral_cast(col, to_type)
|
|
107
|
+
|
|
108
|
+
min_val, max_val = get_integral_type_bounds(to_type)
|
|
109
|
+
type_name = to_type.typeName().upper()
|
|
110
|
+
|
|
111
|
+
raise_error = raise_error_helper(to_type, ArithmeticException)
|
|
112
|
+
|
|
113
|
+
return snowpark_fn.when(
|
|
114
|
+
(col < snowpark_fn.lit(min_val)) | (col > snowpark_fn.lit(max_val)),
|
|
115
|
+
raise_error(
|
|
116
|
+
snowpark_fn.lit("[CAST_OVERFLOW] The value "),
|
|
117
|
+
col.cast(StringType()),
|
|
118
|
+
snowpark_fn.lit(
|
|
119
|
+
f" of the type DOUBLE cannot be cast to {type_name} "
|
|
120
|
+
f"due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead."
|
|
121
|
+
),
|
|
122
|
+
),
|
|
123
|
+
).otherwise(col.cast(to_type))
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def apply_arithmetic_overflow_with_ansi_check(
|
|
127
|
+
result_col: Column, result_type: DataType, ansi_enabled: bool, operation_name: str
|
|
128
|
+
) -> Column:
|
|
129
|
+
if not global_config.snowpark_connect_handleIntegralOverflow:
|
|
130
|
+
return result_col.cast(result_type)
|
|
131
|
+
|
|
132
|
+
if not ansi_enabled:
|
|
133
|
+
return apply_integral_overflow(result_col, result_type)
|
|
134
|
+
|
|
135
|
+
min_val, max_val = get_integral_type_bounds(result_type)
|
|
136
|
+
|
|
137
|
+
raise_error = raise_error_helper(result_type, ArithmeticException)
|
|
138
|
+
|
|
139
|
+
return snowpark_fn.when(
|
|
140
|
+
(result_col < snowpark_fn.lit(min_val))
|
|
141
|
+
| (result_col > snowpark_fn.lit(max_val)),
|
|
142
|
+
raise_error(
|
|
143
|
+
snowpark_fn.lit(
|
|
144
|
+
f"[ARITHMETIC_OVERFLOW] {operation_name} overflow. "
|
|
145
|
+
f"Use 'try_{operation_name.lower()}' to tolerate overflow and return NULL instead. "
|
|
146
|
+
f'If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.'
|
|
147
|
+
),
|
|
148
|
+
),
|
|
149
|
+
).otherwise(result_col.cast(result_type))
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def apply_unary_overflow(value_col: Column, result_type: DataType) -> Column:
|
|
153
|
+
if not global_config.snowpark_connect_handleIntegralOverflow:
|
|
154
|
+
return (value_col * snowpark_fn.lit(-1)).cast(result_type)
|
|
155
|
+
|
|
156
|
+
min_val, _ = get_integral_type_bounds(result_type)
|
|
157
|
+
return snowpark_fn.when(
|
|
158
|
+
value_col == snowpark_fn.lit(min_val),
|
|
159
|
+
snowpark_fn.lit(min_val).cast(result_type),
|
|
160
|
+
).otherwise((value_col * snowpark_fn.lit(-1)).cast(result_type))
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def apply_unary_overflow_with_ansi_check(
|
|
164
|
+
value_col: Column, result_type: DataType, ansi_enabled: bool, operation_name: str
|
|
165
|
+
) -> Column:
|
|
166
|
+
if not global_config.snowpark_connect_handleIntegralOverflow:
|
|
167
|
+
return (value_col * snowpark_fn.lit(-1)).cast(result_type)
|
|
168
|
+
|
|
169
|
+
if not ansi_enabled:
|
|
170
|
+
return apply_unary_overflow(value_col, result_type)
|
|
171
|
+
|
|
172
|
+
min_val, _ = get_integral_type_bounds(result_type)
|
|
173
|
+
|
|
174
|
+
raise_error = raise_error_helper(result_type, ArithmeticException)
|
|
175
|
+
|
|
176
|
+
return snowpark_fn.when(
|
|
177
|
+
value_col == snowpark_fn.lit(min_val),
|
|
178
|
+
raise_error(
|
|
179
|
+
snowpark_fn.lit(
|
|
180
|
+
f"[ARITHMETIC_OVERFLOW] {operation_name} overflow. "
|
|
181
|
+
f'If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.'
|
|
182
|
+
),
|
|
183
|
+
),
|
|
184
|
+
).otherwise((value_col * snowpark_fn.lit(-1)).cast(result_type))
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def apply_abs_overflow(value_col: Column, result_type: DataType) -> Column:
|
|
188
|
+
if not global_config.snowpark_connect_handleIntegralOverflow:
|
|
189
|
+
return snowpark_fn.abs(value_col).cast(result_type)
|
|
190
|
+
|
|
191
|
+
min_val, _ = get_integral_type_bounds(result_type)
|
|
192
|
+
return snowpark_fn.when(
|
|
193
|
+
value_col == snowpark_fn.lit(min_val),
|
|
194
|
+
snowpark_fn.lit(min_val).cast(result_type),
|
|
195
|
+
).otherwise(snowpark_fn.abs(value_col).cast(result_type))
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def apply_abs_overflow_with_ansi_check(
|
|
199
|
+
value_col: Column, result_type: DataType, ansi_enabled: bool
|
|
200
|
+
) -> Column:
|
|
201
|
+
if not global_config.snowpark_connect_handleIntegralOverflow:
|
|
202
|
+
return snowpark_fn.abs(value_col).cast(result_type)
|
|
203
|
+
|
|
204
|
+
if not ansi_enabled:
|
|
205
|
+
return apply_abs_overflow(value_col, result_type)
|
|
206
|
+
|
|
207
|
+
min_val, _ = get_integral_type_bounds(result_type)
|
|
208
|
+
|
|
209
|
+
raise_error = raise_error_helper(result_type, ArithmeticException)
|
|
210
|
+
|
|
211
|
+
return snowpark_fn.when(
|
|
212
|
+
value_col == snowpark_fn.lit(min_val),
|
|
213
|
+
raise_error(
|
|
214
|
+
snowpark_fn.lit(
|
|
215
|
+
"[ARITHMETIC_OVERFLOW] abs overflow. "
|
|
216
|
+
'If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.'
|
|
217
|
+
),
|
|
218
|
+
),
|
|
219
|
+
).otherwise(snowpark_fn.abs(value_col).cast(result_type))
|