snowpark-connect 0.30.1__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/column_name_handler.py +200 -102
- snowflake/snowpark_connect/column_qualifier.py +47 -0
- snowflake/snowpark_connect/config.py +51 -16
- snowflake/snowpark_connect/dataframe_container.py +3 -2
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +142 -22
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +9 -3
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +5 -1
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/literal.py +7 -1
- snowflake/snowpark_connect/expression/map_cast.py +17 -5
- snowflake/snowpark_connect/expression/map_expression.py +53 -8
- snowflake/snowpark_connect/expression/map_extension.py +37 -11
- snowflake/snowpark_connect/expression/map_sql_expression.py +102 -32
- snowflake/snowpark_connect/expression/map_udf.py +10 -2
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +38 -14
- snowflake/snowpark_connect/expression/map_unresolved_function.py +1476 -292
- snowflake/snowpark_connect/expression/map_unresolved_star.py +14 -8
- snowflake/snowpark_connect/expression/map_update_fields.py +14 -4
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +38 -13
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +6 -1
- snowflake/snowpark_connect/relation/map_aggregate.py +8 -5
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +92 -59
- snowflake/snowpark_connect/relation/map_extension.py +38 -17
- snowflake/snowpark_connect/relation/map_join.py +26 -12
- snowflake/snowpark_connect/relation/map_local_relation.py +5 -1
- snowflake/snowpark_connect/relation/map_relation.py +33 -7
- snowflake/snowpark_connect/relation/map_row_ops.py +23 -7
- snowflake/snowpark_connect/relation/map_sql.py +124 -25
- snowflake/snowpark_connect/relation/map_stats.py +5 -1
- snowflake/snowpark_connect/relation/map_subquery_alias.py +4 -1
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +49 -13
- snowflake/snowpark_connect/relation/read/map_read.py +15 -3
- snowflake/snowpark_connect/relation/read/map_read_csv.py +11 -3
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
- snowflake/snowpark_connect/relation/read/map_read_json.py +8 -2
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +13 -3
- snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +21 -8
- snowflake/snowpark_connect/relation/read/map_read_text.py +5 -1
- snowflake/snowpark_connect/relation/read/metadata_utils.py +5 -1
- snowflake/snowpark_connect/relation/stage_locator.py +5 -1
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +160 -48
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources_initializer.py +5 -1
- snowflake/snowpark_connect/server.py +73 -21
- snowflake/snowpark_connect/type_mapping.py +90 -20
- snowflake/snowpark_connect/typed_column.py +8 -6
- snowflake/snowpark_connect/utils/context.py +42 -1
- snowflake/snowpark_connect/utils/describe_query_cache.py +3 -0
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/identifiers.py +11 -3
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +11 -3
- snowflake/snowpark_connect/utils/session.py +24 -4
- snowflake/snowpark_connect/utils/telemetry.py +6 -0
- snowflake/snowpark_connect/utils/temporary_view_cache.py +5 -1
- snowflake/snowpark_connect/utils/udf_cache.py +5 -3
- snowflake/snowpark_connect/utils/udf_helper.py +20 -6
- snowflake/snowpark_connect/utils/udf_utils.py +4 -4
- snowflake/snowpark_connect/utils/udtf_helper.py +5 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +34 -26
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +1 -1
- {snowpark_connect-0.30.1.dist-info → snowpark_connect-0.32.0.dist-info}/METADATA +7 -3
- {snowpark_connect-0.30.1.dist-info → snowpark_connect-0.32.0.dist-info}/RECORD +85 -85
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
- {snowpark_connect-0.30.1.data → snowpark_connect-0.32.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.30.1.data → snowpark_connect-0.32.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.30.1.data → snowpark_connect-0.32.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.30.1.dist-info → snowpark_connect-0.32.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.30.1.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.30.1.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.30.1.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.30.1.dist-info → snowpark_connect-0.32.0.dist-info}/top_level.txt +0 -0
|
@@ -56,7 +56,11 @@ from snowflake.snowpark_connect.analyze_plan.map_tree_string import map_tree_str
|
|
|
56
56
|
from snowflake.snowpark_connect.config import route_config_proto
|
|
57
57
|
from snowflake.snowpark_connect.constants import SERVER_SIDE_SESSION_ID
|
|
58
58
|
from snowflake.snowpark_connect.control_server import ControlServicer
|
|
59
|
-
from snowflake.snowpark_connect.error.
|
|
59
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
60
|
+
from snowflake.snowpark_connect.error.error_utils import (
|
|
61
|
+
attach_custom_error_code,
|
|
62
|
+
build_grpc_error_response,
|
|
63
|
+
)
|
|
60
64
|
from snowflake.snowpark_connect.execute_plan.map_execution_command import (
|
|
61
65
|
map_execution_command,
|
|
62
66
|
)
|
|
@@ -96,7 +100,7 @@ from snowflake.snowpark_connect.utils.interrupt import (
|
|
|
96
100
|
interrupt_queries_with_tag,
|
|
97
101
|
interrupt_query,
|
|
98
102
|
)
|
|
99
|
-
from snowflake.snowpark_connect.utils.profiling import profile_method
|
|
103
|
+
from snowflake.snowpark_connect.utils.profiling import PROFILING_ENABLED, profile_method
|
|
100
104
|
from snowflake.snowpark_connect.utils.session import (
|
|
101
105
|
configure_snowpark_session,
|
|
102
106
|
get_or_create_snowpark_session,
|
|
@@ -154,9 +158,9 @@ def _handle_exception(context, e: Exception):
|
|
|
154
158
|
logger.error("Error: %s - %s", type(e).__name__, str(e))
|
|
155
159
|
|
|
156
160
|
telemetry.report_request_failure(e)
|
|
157
|
-
|
|
158
161
|
if tcm.TCM_MODE:
|
|
159
|
-
#
|
|
162
|
+
# spark decoder will catch the error and return it to GS gracefully
|
|
163
|
+
attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
|
|
160
164
|
raise e
|
|
161
165
|
|
|
162
166
|
from grpc_status import rpc_status
|
|
@@ -374,9 +378,13 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
374
378
|
),
|
|
375
379
|
)
|
|
376
380
|
case _:
|
|
377
|
-
|
|
381
|
+
exception = SnowparkConnectNotImplementedError(
|
|
378
382
|
f"ANALYZE PLAN NOT IMPLEMENTED:\n{request}"
|
|
379
383
|
)
|
|
384
|
+
attach_custom_error_code(
|
|
385
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
386
|
+
)
|
|
387
|
+
raise exception
|
|
380
388
|
except Exception as e:
|
|
381
389
|
_handle_exception(context, e)
|
|
382
390
|
finally:
|
|
@@ -527,9 +535,13 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
527
535
|
),
|
|
528
536
|
)
|
|
529
537
|
case _:
|
|
530
|
-
|
|
538
|
+
exception = ValueError(
|
|
531
539
|
f"Unexpected payload type in AddArtifacts: {request.WhichOneof('payload')}"
|
|
532
540
|
)
|
|
541
|
+
attach_custom_error_code(
|
|
542
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
543
|
+
)
|
|
544
|
+
raise exception
|
|
533
545
|
|
|
534
546
|
for name, data in cache_data.items():
|
|
535
547
|
_try_handle_local_relation(name, bytes(data))
|
|
@@ -635,9 +647,13 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
635
647
|
case proto_base.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID:
|
|
636
648
|
interrupted_ids = interrupt_query(request.operation_id)
|
|
637
649
|
case _:
|
|
638
|
-
|
|
650
|
+
exception = SnowparkConnectNotImplementedError(
|
|
639
651
|
f"INTERRUPT NOT IMPLEMENTED:\n{request}"
|
|
640
652
|
)
|
|
653
|
+
attach_custom_error_code(
|
|
654
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
655
|
+
)
|
|
656
|
+
raise exception
|
|
641
657
|
|
|
642
658
|
return proto_base.InterruptResponse(
|
|
643
659
|
session_id=request.session_id,
|
|
@@ -655,9 +671,11 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
655
671
|
continue. If there is a ResultComplete, the client should use ReleaseExecute with
|
|
656
672
|
"""
|
|
657
673
|
logger.info("ReattachExecute")
|
|
658
|
-
|
|
674
|
+
exception = SnowparkConnectNotImplementedError(
|
|
659
675
|
"Spark client has detached, please resubmit request. In a future version, the server will be support the reattach."
|
|
660
676
|
)
|
|
677
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
678
|
+
raise exception
|
|
661
679
|
|
|
662
680
|
def ReleaseExecute(self, request: proto_base.ReleaseExecuteRequest, context):
|
|
663
681
|
"""Release an reattachable execution, or parts thereof.
|
|
@@ -760,8 +778,11 @@ def _serve(
|
|
|
760
778
|
|
|
761
779
|
ChannelBuilder.MAX_MESSAGE_LENGTH = grpc_max_msg_size
|
|
762
780
|
|
|
781
|
+
# cProfile doesn't work correctly with multiple threads
|
|
782
|
+
max_workers = 1 if PROFILING_ENABLED else 10
|
|
783
|
+
|
|
763
784
|
server = grpc.server(
|
|
764
|
-
futures.ThreadPoolExecutor(max_workers=
|
|
785
|
+
futures.ThreadPoolExecutor(max_workers=max_workers), options=server_options
|
|
765
786
|
)
|
|
766
787
|
control_servicer = ControlServicer(session)
|
|
767
788
|
proto_base_grpc.add_SparkConnectServiceServicer_to_server(
|
|
@@ -791,6 +812,7 @@ def _serve(
|
|
|
791
812
|
)
|
|
792
813
|
else:
|
|
793
814
|
logger.error("Error starting up Snowpark Connect server", exc_info=True)
|
|
815
|
+
attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
|
|
794
816
|
raise e
|
|
795
817
|
finally:
|
|
796
818
|
# flush the telemetry queue if possible
|
|
@@ -808,7 +830,9 @@ def _set_remote_url(remote_url: str):
|
|
|
808
830
|
elif parsed_url.scheme == "unix":
|
|
809
831
|
_server_url = remote_url.split("/;")[0]
|
|
810
832
|
else:
|
|
811
|
-
|
|
833
|
+
exception = RuntimeError(f"Invalid Snowpark Connect URL: {remote_url}")
|
|
834
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_SPARK_CONNECT_URL)
|
|
835
|
+
raise exception
|
|
812
836
|
|
|
813
837
|
|
|
814
838
|
def _set_server_tcp_port(server_port: int):
|
|
@@ -822,7 +846,9 @@ def _check_port_is_free(port: int) -> None:
|
|
|
822
846
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
823
847
|
s.settimeout(1)
|
|
824
848
|
if s.connect_ex(("127.0.0.1", port)) == 0:
|
|
825
|
-
|
|
849
|
+
exception = RuntimeError(f"TCP port {port} is already in use")
|
|
850
|
+
attach_custom_error_code(exception, ErrorCodes.TCP_PORT_ALREADY_IN_USE)
|
|
851
|
+
raise exception
|
|
826
852
|
|
|
827
853
|
|
|
828
854
|
def _set_server_unix_domain_socket(path: str):
|
|
@@ -834,14 +860,18 @@ def _set_server_unix_domain_socket(path: str):
|
|
|
834
860
|
def get_server_url() -> str:
|
|
835
861
|
global _server_url
|
|
836
862
|
if not _server_url:
|
|
837
|
-
|
|
863
|
+
exception = RuntimeError("Server URL not set")
|
|
864
|
+
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
865
|
+
raise exception
|
|
838
866
|
return _server_url
|
|
839
867
|
|
|
840
868
|
|
|
841
869
|
def get_client_url() -> str:
|
|
842
870
|
global _client_url
|
|
843
871
|
if not _client_url:
|
|
844
|
-
|
|
872
|
+
exception = RuntimeError("Client URL not set")
|
|
873
|
+
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
874
|
+
raise exception
|
|
845
875
|
return _client_url
|
|
846
876
|
|
|
847
877
|
|
|
@@ -871,12 +901,14 @@ class UnixDomainSocketChannelBuilder(ChannelBuilder):
|
|
|
871
901
|
if url is None:
|
|
872
902
|
url = get_client_url()
|
|
873
903
|
if url[:6] != "unix:/" or len(url) < 7:
|
|
874
|
-
|
|
904
|
+
exception = PySparkValueError(
|
|
875
905
|
error_class="INVALID_CONNECT_URL",
|
|
876
906
|
message_parameters={
|
|
877
907
|
"detail": "The URL must start with 'unix://'. Please update the URL to follow the correct format, e.g., 'unix://unix_domain_socket_path'.",
|
|
878
908
|
},
|
|
879
909
|
)
|
|
910
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_SPARK_CONNECT_URL)
|
|
911
|
+
raise exception
|
|
880
912
|
|
|
881
913
|
# Rewrite the URL to use http as the scheme so that we can leverage
|
|
882
914
|
# Python's built-in parser to parse URL parameters
|
|
@@ -919,7 +951,7 @@ class UnixDomainSocketChannelBuilder(ChannelBuilder):
|
|
|
919
951
|
for p in parts:
|
|
920
952
|
kv = p.split("=")
|
|
921
953
|
if len(kv) != 2:
|
|
922
|
-
|
|
954
|
+
exception = PySparkValueError(
|
|
923
955
|
error_class="INVALID_CONNECT_URL",
|
|
924
956
|
message_parameters={
|
|
925
957
|
"detail": f"Parameter '{p}' should be provided as a "
|
|
@@ -927,6 +959,10 @@ class UnixDomainSocketChannelBuilder(ChannelBuilder):
|
|
|
927
959
|
f"the parameter to follow the correct format, e.g., 'key=value'.",
|
|
928
960
|
},
|
|
929
961
|
)
|
|
962
|
+
attach_custom_error_code(
|
|
963
|
+
exception, ErrorCodes.INVALID_SPARK_CONNECT_URL
|
|
964
|
+
)
|
|
965
|
+
raise exception
|
|
930
966
|
self.params[kv[0]] = urllib.parse.unquote(kv[1])
|
|
931
967
|
|
|
932
968
|
netloc = self.url.netloc.split(":")
|
|
@@ -942,7 +978,7 @@ class UnixDomainSocketChannelBuilder(ChannelBuilder):
|
|
|
942
978
|
self.host = netloc[0]
|
|
943
979
|
self.port = int(netloc[1])
|
|
944
980
|
else:
|
|
945
|
-
|
|
981
|
+
exception = PySparkValueError(
|
|
946
982
|
error_class="INVALID_CONNECT_URL",
|
|
947
983
|
message_parameters={
|
|
948
984
|
"detail": f"Target destination '{self.url.netloc}' should match the "
|
|
@@ -950,6 +986,8 @@ class UnixDomainSocketChannelBuilder(ChannelBuilder):
|
|
|
950
986
|
f"the correct format, e.g., 'hostname:port'.",
|
|
951
987
|
},
|
|
952
988
|
)
|
|
989
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_SPARK_CONNECT_URL)
|
|
990
|
+
raise exception
|
|
953
991
|
|
|
954
992
|
# We override this to enable compatibility with Spark 4.0
|
|
955
993
|
host = None
|
|
@@ -988,9 +1026,11 @@ def start_jvm():
|
|
|
988
1026
|
if tcm.TCM_MODE:
|
|
989
1027
|
# No-op if JVM is already started in TCM mode
|
|
990
1028
|
return
|
|
991
|
-
|
|
1029
|
+
exception = RuntimeError(
|
|
992
1030
|
"JVM must not be running when starting the Spark Connect server"
|
|
993
1031
|
)
|
|
1032
|
+
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
1033
|
+
raise exception
|
|
994
1034
|
|
|
995
1035
|
pyspark_jars = (
|
|
996
1036
|
pathlib.Path(snowflake.snowpark_connect.__file__).parent / "includes/jars"
|
|
@@ -1065,14 +1105,18 @@ def start_session(
|
|
|
1065
1105
|
_SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE = max_grpc_message_size
|
|
1066
1106
|
|
|
1067
1107
|
if os.environ.get("SPARK_ENV_LOADED"):
|
|
1068
|
-
|
|
1108
|
+
exception = RuntimeError(
|
|
1069
1109
|
"Snowpark Connect cannot be run inside of a Spark environment"
|
|
1070
1110
|
)
|
|
1111
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_STARTUP_OPERATION)
|
|
1112
|
+
raise exception
|
|
1071
1113
|
if connection_parameters is not None:
|
|
1072
1114
|
if snowpark_session is not None:
|
|
1073
|
-
|
|
1115
|
+
exception = ValueError(
|
|
1074
1116
|
"Only specify one of snowpark_session and connection_parameters"
|
|
1075
1117
|
)
|
|
1118
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_STARTUP_INPUT)
|
|
1119
|
+
raise exception
|
|
1076
1120
|
snowpark_session = snowpark.Session.builder.configs(
|
|
1077
1121
|
connection_parameters
|
|
1078
1122
|
).create()
|
|
@@ -1084,9 +1128,11 @@ def start_session(
|
|
|
1084
1128
|
return
|
|
1085
1129
|
|
|
1086
1130
|
if len(list(filter(None, [remote_url, tcp_port, unix_domain_socket]))) > 1:
|
|
1087
|
-
|
|
1131
|
+
exception = RuntimeError(
|
|
1088
1132
|
"Can only set at most one of remote_url, tcp_port, and unix_domain_socket"
|
|
1089
1133
|
)
|
|
1134
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_STARTUP_INPUT)
|
|
1135
|
+
raise exception
|
|
1090
1136
|
|
|
1091
1137
|
url_from_env = os.environ.get("SPARK_REMOTE", None)
|
|
1092
1138
|
if remote_url:
|
|
@@ -1124,7 +1170,11 @@ def start_session(
|
|
|
1124
1170
|
server_thread.start()
|
|
1125
1171
|
_server_running.wait()
|
|
1126
1172
|
if _server_error:
|
|
1127
|
-
|
|
1173
|
+
exception = RuntimeError("Snowpark Connect session failed to start")
|
|
1174
|
+
attach_custom_error_code(
|
|
1175
|
+
exception, ErrorCodes.STARTUP_CONNECTION_FAILED
|
|
1176
|
+
)
|
|
1177
|
+
raise exception
|
|
1128
1178
|
return server_thread
|
|
1129
1179
|
else:
|
|
1130
1180
|
# Launch in the foreground.
|
|
@@ -1132,6 +1182,7 @@ def start_session(
|
|
|
1132
1182
|
except Exception as e:
|
|
1133
1183
|
_reset_server_run_state()
|
|
1134
1184
|
logger.error(e, exc_info=True)
|
|
1185
|
+
attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
|
|
1135
1186
|
raise e
|
|
1136
1187
|
|
|
1137
1188
|
|
|
@@ -1165,6 +1216,7 @@ def get_session(url: Optional[str] = None, conf: SparkConf = None) -> SparkSessi
|
|
|
1165
1216
|
except Exception as e:
|
|
1166
1217
|
_reset_server_run_state()
|
|
1167
1218
|
logger.error(e, exc_info=True)
|
|
1219
|
+
attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
|
|
1168
1220
|
raise e
|
|
1169
1221
|
|
|
1170
1222
|
|
|
@@ -29,12 +29,17 @@ from snowflake.snowpark_connect.constants import (
|
|
|
29
29
|
from snowflake.snowpark_connect.date_time_format_mapping import (
|
|
30
30
|
convert_spark_format_to_snowflake,
|
|
31
31
|
)
|
|
32
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
33
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
32
34
|
from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
|
|
33
35
|
from snowflake.snowpark_connect.expression.map_sql_expression import (
|
|
34
36
|
_INTERVAL_DAYTIME_PATTERN_RE,
|
|
35
37
|
_INTERVAL_YEARMONTH_PATTERN_RE,
|
|
36
38
|
)
|
|
37
|
-
from snowflake.snowpark_connect.utils.context import
|
|
39
|
+
from snowflake.snowpark_connect.utils.context import (
|
|
40
|
+
get_is_evaluating_sql,
|
|
41
|
+
get_jpype_jclass_lock,
|
|
42
|
+
)
|
|
38
43
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
39
44
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
40
45
|
SnowparkConnectNotImplementedError,
|
|
@@ -61,12 +66,14 @@ SNOWPARK_TYPE_NAME_TO_PYSPARK_TYPE_NAME = {
|
|
|
61
66
|
|
|
62
67
|
@cache
|
|
63
68
|
def _get_struct_type_class():
|
|
64
|
-
|
|
69
|
+
with get_jpype_jclass_lock():
|
|
70
|
+
return jpype.JClass("org.apache.spark.sql.types.StructType")
|
|
65
71
|
|
|
66
72
|
|
|
67
73
|
@cache
|
|
68
74
|
def get_python_sql_utils_class():
|
|
69
|
-
|
|
75
|
+
with get_jpype_jclass_lock():
|
|
76
|
+
return jpype.JClass("org.apache.spark.sql.api.python.PythonSQLUtils")
|
|
70
77
|
|
|
71
78
|
|
|
72
79
|
def _parse_ddl_with_spark_scala(ddl_string: str) -> pyspark.sql.types.DataType:
|
|
@@ -291,9 +298,11 @@ def snowpark_to_proto_type(
|
|
|
291
298
|
)
|
|
292
299
|
}
|
|
293
300
|
case _:
|
|
294
|
-
|
|
301
|
+
exception = SnowparkConnectNotImplementedError(
|
|
295
302
|
f"Unsupported snowpark data type: {data_type}"
|
|
296
303
|
)
|
|
304
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
305
|
+
raise exception
|
|
297
306
|
|
|
298
307
|
|
|
299
308
|
def cast_to_match_snowpark_type(
|
|
@@ -333,7 +342,9 @@ def cast_to_match_snowpark_type(
|
|
|
333
342
|
with suppress(TypeError):
|
|
334
343
|
date = datetime.strptime(content, format)
|
|
335
344
|
return date
|
|
336
|
-
|
|
345
|
+
exception = ValueError(f"Date casting error for {str(content)}")
|
|
346
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_CAST)
|
|
347
|
+
raise exception
|
|
337
348
|
case snowpark.types.ShortType:
|
|
338
349
|
return int(content)
|
|
339
350
|
case snowpark.types.StringType:
|
|
@@ -363,9 +374,11 @@ def cast_to_match_snowpark_type(
|
|
|
363
374
|
case snowpark.types.DayTimeIntervalType:
|
|
364
375
|
return str(content)
|
|
365
376
|
case _:
|
|
366
|
-
|
|
377
|
+
exception = SnowparkConnectNotImplementedError(
|
|
367
378
|
f"Unsupported snowpark data type in casting: {data_type}"
|
|
368
379
|
)
|
|
380
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
381
|
+
raise exception
|
|
369
382
|
|
|
370
383
|
|
|
371
384
|
def snowpark_to_iceberg_type(data_type: snowpark.types.DataType) -> str:
|
|
@@ -398,9 +411,11 @@ def snowpark_to_iceberg_type(data_type: snowpark.types.DataType) -> str:
|
|
|
398
411
|
case snowpark.types.TimestampType:
|
|
399
412
|
return "timestamp"
|
|
400
413
|
case _:
|
|
401
|
-
|
|
414
|
+
exception = SnowparkConnectNotImplementedError(
|
|
402
415
|
f"Unsupported snowpark data type for iceber: {data_type}"
|
|
403
416
|
)
|
|
417
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
418
|
+
raise exception
|
|
404
419
|
|
|
405
420
|
|
|
406
421
|
def proto_to_snowpark_type(
|
|
@@ -487,9 +502,11 @@ def map_snowpark_types_to_pyarrow_types(
|
|
|
487
502
|
)
|
|
488
503
|
)
|
|
489
504
|
else:
|
|
490
|
-
|
|
505
|
+
exception = AnalysisException(
|
|
491
506
|
f"Unsupported arrow type {pa_type} for snowpark ArrayType."
|
|
492
507
|
)
|
|
508
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
|
|
509
|
+
raise exception
|
|
493
510
|
case snowpark.types.BinaryType:
|
|
494
511
|
return pa.binary()
|
|
495
512
|
case snowpark.types.BooleanType:
|
|
@@ -530,9 +547,11 @@ def map_snowpark_types_to_pyarrow_types(
|
|
|
530
547
|
),
|
|
531
548
|
)
|
|
532
549
|
else:
|
|
533
|
-
|
|
550
|
+
exception = AnalysisException(
|
|
534
551
|
f"Unsupported arrow type {pa_type} for snowpark MapType."
|
|
535
552
|
)
|
|
553
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
|
|
554
|
+
raise exception
|
|
536
555
|
case snowpark.types.NullType:
|
|
537
556
|
return pa.string()
|
|
538
557
|
case snowpark.types.ShortType:
|
|
@@ -557,15 +576,20 @@ def map_snowpark_types_to_pyarrow_types(
|
|
|
557
576
|
]
|
|
558
577
|
)
|
|
559
578
|
else:
|
|
560
|
-
|
|
579
|
+
exception = AnalysisException(
|
|
561
580
|
f"Unsupported arrow type {pa_type} for snowpark StructType."
|
|
562
581
|
)
|
|
582
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
|
|
583
|
+
raise exception
|
|
563
584
|
case snowpark.types.TimestampType:
|
|
564
|
-
|
|
565
|
-
|
|
585
|
+
# Check if pa_type has unit attribute (it should be a timestamp type)
|
|
586
|
+
unit = pa_type.unit if hasattr(pa_type, "unit") else "us"
|
|
587
|
+
tz = pa_type.tz if hasattr(pa_type, "tz") else None
|
|
588
|
+
|
|
589
|
+
# Spark truncates nanosecond precision to microseconds
|
|
566
590
|
if unit == "ns":
|
|
567
|
-
# Spark truncates nanosecond precision to microseconds
|
|
568
591
|
unit = "us"
|
|
592
|
+
|
|
569
593
|
return pa.timestamp(unit, tz=tz)
|
|
570
594
|
case snowpark.types.VariantType:
|
|
571
595
|
return pa.string()
|
|
@@ -576,9 +600,11 @@ def map_snowpark_types_to_pyarrow_types(
|
|
|
576
600
|
# Return string type so formatted intervals are preserved in display
|
|
577
601
|
return pa.string()
|
|
578
602
|
case _:
|
|
579
|
-
|
|
603
|
+
exception = SnowparkConnectNotImplementedError(
|
|
580
604
|
f"Unsupported snowpark data type: {snowpark_type}"
|
|
581
605
|
)
|
|
606
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
607
|
+
raise exception
|
|
582
608
|
|
|
583
609
|
|
|
584
610
|
def map_pyarrow_to_snowpark_types(pa_type: pa.DataType) -> snowpark.types.DataType:
|
|
@@ -647,10 +673,15 @@ def map_pyarrow_to_snowpark_types(pa_type: pa.DataType) -> snowpark.types.DataTy
|
|
|
647
673
|
return snowpark.types.TimestampType()
|
|
648
674
|
elif pa.types.is_null(pa_type):
|
|
649
675
|
return snowpark.types.NullType()
|
|
676
|
+
elif pa.types.is_duration(pa_type):
|
|
677
|
+
# Map PyArrow duration[us] to DayTimeIntervalType
|
|
678
|
+
return snowpark.types.DayTimeIntervalType()
|
|
650
679
|
else:
|
|
651
|
-
|
|
680
|
+
exception = SnowparkConnectNotImplementedError(
|
|
652
681
|
f"Unsupported PyArrow data type: {pa_type}"
|
|
653
682
|
)
|
|
683
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
684
|
+
raise exception
|
|
654
685
|
|
|
655
686
|
|
|
656
687
|
def map_pyspark_types_to_snowpark_types(
|
|
@@ -736,9 +767,11 @@ def map_pyspark_types_to_snowpark_types(
|
|
|
736
767
|
return snowpark.types.DayTimeIntervalType(
|
|
737
768
|
type_to_map.startField, type_to_map.endField
|
|
738
769
|
)
|
|
739
|
-
|
|
770
|
+
exception = SnowparkConnectNotImplementedError(
|
|
740
771
|
f"Unsupported spark data type: {type_to_map}"
|
|
741
772
|
)
|
|
773
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
774
|
+
raise exception
|
|
742
775
|
|
|
743
776
|
|
|
744
777
|
def map_snowpark_to_pyspark_types(
|
|
@@ -811,7 +844,11 @@ def map_snowpark_to_pyspark_types(
|
|
|
811
844
|
return pyspark.sql.types.DayTimeIntervalType(
|
|
812
845
|
type_to_map.start_field, type_to_map.end_field
|
|
813
846
|
)
|
|
814
|
-
|
|
847
|
+
exception = SnowparkConnectNotImplementedError(
|
|
848
|
+
f"Unsupported data type: {type_to_map}"
|
|
849
|
+
)
|
|
850
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
851
|
+
raise exception
|
|
815
852
|
|
|
816
853
|
|
|
817
854
|
def map_simple_types(simple_type: str) -> snowpark.types.DataType:
|
|
@@ -861,14 +898,43 @@ def map_simple_types(simple_type: str) -> snowpark.types.DataType:
|
|
|
861
898
|
return snowpark.types.YearMonthIntervalType()
|
|
862
899
|
case type_name if _INTERVAL_DAYTIME_PATTERN_RE.match(type_name):
|
|
863
900
|
return snowpark.types.DayTimeIntervalType()
|
|
901
|
+
# Year-Month interval cases
|
|
902
|
+
case "interval year":
|
|
903
|
+
return snowpark.types.YearMonthIntervalType(0)
|
|
904
|
+
case "interval month":
|
|
905
|
+
return snowpark.types.YearMonthIntervalType(1)
|
|
906
|
+
case "interval year to month":
|
|
907
|
+
return snowpark.types.YearMonthIntervalType(0, 1)
|
|
908
|
+
case "interval day":
|
|
909
|
+
return snowpark.types.DayTimeIntervalType(0)
|
|
910
|
+
case "interval hour":
|
|
911
|
+
return snowpark.types.DayTimeIntervalType(1)
|
|
912
|
+
case "interval minute":
|
|
913
|
+
return snowpark.types.DayTimeIntervalType(2)
|
|
914
|
+
case "interval second":
|
|
915
|
+
return snowpark.types.DayTimeIntervalType(3)
|
|
916
|
+
case "interval day to hour":
|
|
917
|
+
return snowpark.types.DayTimeIntervalType(0, 1)
|
|
918
|
+
case "interval day to minute":
|
|
919
|
+
return snowpark.types.DayTimeIntervalType(0, 2)
|
|
920
|
+
case "interval day to second":
|
|
921
|
+
return snowpark.types.DayTimeIntervalType(0, 3)
|
|
922
|
+
case "interval hour to minute":
|
|
923
|
+
return snowpark.types.DayTimeIntervalType(1, 2)
|
|
924
|
+
case "interval hour to second":
|
|
925
|
+
return snowpark.types.DayTimeIntervalType(1, 3)
|
|
926
|
+
case "interval minute to second":
|
|
927
|
+
return snowpark.types.DayTimeIntervalType(2, 3)
|
|
864
928
|
case _:
|
|
865
929
|
if simple_type.startswith("decimal"):
|
|
866
930
|
precision = int(simple_type.split("(")[1].split(",")[0])
|
|
867
931
|
scale = int(simple_type.split(",")[1].split(")")[0])
|
|
868
932
|
return snowpark.types.DecimalType(precision, scale)
|
|
869
|
-
|
|
933
|
+
exception = SnowparkConnectNotImplementedError(
|
|
870
934
|
f"Unsupported simple type: {simple_type}"
|
|
871
935
|
)
|
|
936
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
937
|
+
raise exception
|
|
872
938
|
|
|
873
939
|
|
|
874
940
|
def map_json_schema_to_snowpark(
|
|
@@ -1009,9 +1075,11 @@ def map_spark_timestamp_format_expression(
|
|
|
1009
1075
|
lit_value, _ = get_literal_field_and_name(arguments.literal)
|
|
1010
1076
|
return convert_spark_format_to_snowflake(lit_value, timestamp_input_type)
|
|
1011
1077
|
case other:
|
|
1012
|
-
|
|
1078
|
+
exception = SnowparkConnectNotImplementedError(
|
|
1013
1079
|
f"Unsupported expression type {other} in timestamp format argument"
|
|
1014
1080
|
)
|
|
1081
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
1082
|
+
raise exception
|
|
1015
1083
|
|
|
1016
1084
|
|
|
1017
1085
|
def map_spark_number_format_expression(
|
|
@@ -1030,9 +1098,11 @@ def map_spark_number_format_expression(
|
|
|
1030
1098
|
case "literal":
|
|
1031
1099
|
lit_value, _ = get_literal_field_and_name(arguments.literal)
|
|
1032
1100
|
case other:
|
|
1033
|
-
|
|
1101
|
+
exception = SnowparkConnectNotImplementedError(
|
|
1034
1102
|
f"Unsupported expression type {other} in number format argument"
|
|
1035
1103
|
)
|
|
1104
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
1105
|
+
raise exception
|
|
1036
1106
|
|
|
1037
1107
|
return _map_spark_to_snowflake_number_format(lit_value)
|
|
1038
1108
|
|
|
@@ -8,6 +8,7 @@ from functools import cached_property
|
|
|
8
8
|
import snowflake.snowpark.functions as snowpark_fn
|
|
9
9
|
from snowflake import snowpark
|
|
10
10
|
from snowflake.snowpark.column import Column
|
|
11
|
+
from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
|
|
11
12
|
|
|
12
13
|
_EMPTY_COLUMN = Column("")
|
|
13
14
|
|
|
@@ -44,11 +45,11 @@ class TypedColumn:
|
|
|
44
45
|
def alias(self, alias_name: str):
|
|
45
46
|
return TypedColumn(self.col.alias(alias_name), self._type_resolver)
|
|
46
47
|
|
|
47
|
-
def set_qualifiers(self, qualifiers:
|
|
48
|
+
def set_qualifiers(self, qualifiers: set[ColumnQualifier]) -> None:
|
|
48
49
|
self.qualifiers = qualifiers
|
|
49
50
|
|
|
50
|
-
def get_qualifiers(self) ->
|
|
51
|
-
return getattr(self, "qualifiers",
|
|
51
|
+
def get_qualifiers(self) -> set[ColumnQualifier]:
|
|
52
|
+
return getattr(self, "qualifiers", {ColumnQualifier.no_qualifier()})
|
|
52
53
|
|
|
53
54
|
def set_catalog_database_info(self, catalog_database_info: dict[str, str]) -> None:
|
|
54
55
|
self._catalog_database_info = catalog_database_info
|
|
@@ -63,12 +64,13 @@ class TypedColumn:
|
|
|
63
64
|
def get_database(self) -> str | None:
|
|
64
65
|
return self._catalog_database_info.get("database")
|
|
65
66
|
|
|
66
|
-
def set_multi_col_qualifiers(self, qualifiers: list[
|
|
67
|
+
def set_multi_col_qualifiers(self, qualifiers: list[set[ColumnQualifier]]) -> None:
|
|
67
68
|
self.multi_col_qualifiers = qualifiers
|
|
68
69
|
|
|
69
|
-
def get_multi_col_qualifiers(self, num_columns) -> list[
|
|
70
|
+
def get_multi_col_qualifiers(self, num_columns) -> list[set[ColumnQualifier]]:
|
|
70
71
|
if not hasattr(self, "multi_col_qualifiers"):
|
|
71
|
-
|
|
72
|
+
|
|
73
|
+
return [{ColumnQualifier.no_qualifier()} for i in range(num_columns)]
|
|
72
74
|
assert (
|
|
73
75
|
len(self.multi_col_qualifiers) == num_columns
|
|
74
76
|
), f"Expected {num_columns} multi-column qualifiers, got {len(self.multi_col_qualifiers)}"
|
|
@@ -2,10 +2,12 @@
|
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
|
|
5
|
+
import os
|
|
5
6
|
import re
|
|
7
|
+
import threading
|
|
6
8
|
from contextlib import contextmanager
|
|
7
9
|
from contextvars import ContextVar
|
|
8
|
-
from typing import Mapping, Optional
|
|
10
|
+
from typing import Iterator, Mapping, Optional
|
|
9
11
|
|
|
10
12
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
11
13
|
|
|
@@ -25,6 +27,7 @@ _is_evaluating_sql = ContextVar[bool]("_is_evaluating_sql", default=False)
|
|
|
25
27
|
_is_evaluating_join_condition = ContextVar(
|
|
26
28
|
"_is_evaluating_join_condition", default=("default", False, [], [])
|
|
27
29
|
)
|
|
30
|
+
_is_processing_order_by = ContextVar[bool]("_is_processing_order_by", default=False)
|
|
28
31
|
|
|
29
32
|
_sql_aggregate_function_count = ContextVar[int](
|
|
30
33
|
"_contains_aggregate_function", default=0
|
|
@@ -56,6 +59,23 @@ _is_in_pivot = ContextVar[bool]("_is_in_pivot", default=False)
|
|
|
56
59
|
_is_in_udtf_context = ContextVar[bool]("_is_in_udtf_context", default=False)
|
|
57
60
|
_accessing_temp_object = ContextVar[bool]("_accessing_temp_object", default=False)
|
|
58
61
|
|
|
62
|
+
# Thread-safe lock for JPype JClass creation to prevent access violations
|
|
63
|
+
_jpype_jclass_lock = threading.Lock()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@contextmanager
|
|
67
|
+
def get_jpype_jclass_lock() -> Iterator[None]:
|
|
68
|
+
"""
|
|
69
|
+
Context manager that acquires the JPype JClass lock on Windows platforms.
|
|
70
|
+
On non-Windows (os.name != 'nt'), it yields without acquiring the lock.
|
|
71
|
+
"""
|
|
72
|
+
if os.name == "nt":
|
|
73
|
+
with _jpype_jclass_lock:
|
|
74
|
+
yield
|
|
75
|
+
else:
|
|
76
|
+
yield
|
|
77
|
+
|
|
78
|
+
|
|
59
79
|
# Lateral Column Alias helpers
|
|
60
80
|
# We keep a thread-local mapping from alias name -> TypedColumn that is
|
|
61
81
|
# populated incrementally while the projection list is being processed.
|
|
@@ -207,6 +227,27 @@ def push_evaluating_sql_scope():
|
|
|
207
227
|
_is_evaluating_sql.set(prev)
|
|
208
228
|
|
|
209
229
|
|
|
230
|
+
def get_is_processing_order_by() -> bool:
|
|
231
|
+
"""
|
|
232
|
+
Gets the value of _is_processing_order_by for the current context, defaults to False.
|
|
233
|
+
"""
|
|
234
|
+
return _is_processing_order_by.get()
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
@contextmanager
|
|
238
|
+
def push_processing_order_by_scope():
|
|
239
|
+
"""
|
|
240
|
+
Context manager that sets a flag indicating if ORDER BY expressions are being evaluated.
|
|
241
|
+
This enables optimizations like reusing already-computed UDF columns.
|
|
242
|
+
"""
|
|
243
|
+
prev = _is_processing_order_by.get()
|
|
244
|
+
try:
|
|
245
|
+
_is_processing_order_by.set(True)
|
|
246
|
+
yield
|
|
247
|
+
finally:
|
|
248
|
+
_is_processing_order_by.set(prev)
|
|
249
|
+
|
|
250
|
+
|
|
210
251
|
def get_is_evaluating_join_condition() -> tuple[str, bool, list, list]:
|
|
211
252
|
"""
|
|
212
253
|
Gets the value of _is_evaluating_join_condition for the current context, defaults to False.
|
|
@@ -12,6 +12,8 @@ from typing import Any
|
|
|
12
12
|
from snowflake import snowpark
|
|
13
13
|
from snowflake.connector.cursor import ResultMetadataV2
|
|
14
14
|
from snowflake.snowpark._internal.server_connection import ServerConnection
|
|
15
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
16
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
15
17
|
from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
|
|
16
18
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
17
19
|
from snowflake.snowpark_connect.utils.telemetry import telemetry
|
|
@@ -148,6 +150,7 @@ def instrument_session_for_describe_cache(session: snowpark.Session):
|
|
|
148
150
|
telemetry.report_query(result, **kwargs)
|
|
149
151
|
except Exception as e:
|
|
150
152
|
telemetry.report_query(e, **kwargs)
|
|
153
|
+
attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
|
|
151
154
|
raise e
|
|
152
155
|
return result
|
|
153
156
|
|