snowpark-connect 0.27.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +680 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +237 -23
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +123 -5
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +85 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2748 -746
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +196 -255
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +255 -45
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
- snowflake/snowpark_connect/relation/read/map_read_json.py +320 -85
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +142 -27
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +82 -5
- snowflake/snowpark_connect/relation/read/map_read_text.py +18 -3
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +36 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +110 -48
- snowflake/snowpark_connect/server.py +546 -456
- snowflake/snowpark_connect/server_common/__init__.py +500 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +187 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +125 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +303 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +248 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +101 -332
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +163 -22
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-submit +2 -2
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/METADATA +14 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/RECORD +129 -167
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,680 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Remote Spark Connect Server for Snowpark Connect.
|
|
7
|
+
|
|
8
|
+
Lightweight servicer that forwards Spark Connect requests to Snowflake backend
|
|
9
|
+
via REST API using SparkConnectResource SDK.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import threading
|
|
13
|
+
import uuid
|
|
14
|
+
from concurrent import futures
|
|
15
|
+
from typing import Dict, Iterator, Optional
|
|
16
|
+
|
|
17
|
+
import grpc
|
|
18
|
+
import pyarrow as pa
|
|
19
|
+
from google.rpc import code_pb2
|
|
20
|
+
from grpc_status import rpc_status
|
|
21
|
+
from pyspark.conf import SparkConf
|
|
22
|
+
from pyspark.sql.connect.proto import base_pb2, base_pb2_grpc, types_pb2
|
|
23
|
+
from pyspark.sql.connect.session import SparkSession
|
|
24
|
+
from snowflake.core.spark_connect._spark_connect import SparkConnectResource
|
|
25
|
+
|
|
26
|
+
from snowflake import snowpark
|
|
27
|
+
from snowflake.snowpark import Session
|
|
28
|
+
from snowflake.snowpark_connect.client.error_utils import attach_custom_error_code
|
|
29
|
+
from snowflake.snowpark_connect.client.exceptions import (
|
|
30
|
+
GrpcErrorStatusException,
|
|
31
|
+
UnexpectedResponseException,
|
|
32
|
+
)
|
|
33
|
+
from snowflake.snowpark_connect.client.query_results import (
|
|
34
|
+
fetch_query_result_as_arrow_batches,
|
|
35
|
+
fetch_query_result_as_protobuf,
|
|
36
|
+
)
|
|
37
|
+
from snowflake.snowpark_connect.client.utils.session import (
|
|
38
|
+
get_or_create_snowpark_session,
|
|
39
|
+
)
|
|
40
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
41
|
+
from snowflake.snowpark_connect.server_common import ( # noqa: F401 - re-exported for public API
|
|
42
|
+
_disable_protobuf_recursion_limit,
|
|
43
|
+
_get_default_grpc_options,
|
|
44
|
+
_reset_server_run_state,
|
|
45
|
+
_setup_spark_environment,
|
|
46
|
+
_stop_server,
|
|
47
|
+
configure_server_url,
|
|
48
|
+
get_client_url,
|
|
49
|
+
get_server_error,
|
|
50
|
+
get_server_running,
|
|
51
|
+
get_server_url,
|
|
52
|
+
get_session,
|
|
53
|
+
set_grpc_max_message_size,
|
|
54
|
+
set_server_error,
|
|
55
|
+
setup_signal_handlers,
|
|
56
|
+
validate_startup_parameters,
|
|
57
|
+
)
|
|
58
|
+
from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
|
|
59
|
+
from snowflake.snowpark_connect.utils.env_utils import get_int_from_env
|
|
60
|
+
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
61
|
+
from spark.connect import envelope_pb2
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _log_and_return_error(
|
|
65
|
+
err_mesg: str,
|
|
66
|
+
error: Exception,
|
|
67
|
+
status_code: grpc.StatusCode,
|
|
68
|
+
context: grpc.ServicerContext,
|
|
69
|
+
) -> None:
|
|
70
|
+
"""Log error and set gRPC context."""
|
|
71
|
+
context.set_details(str(error))
|
|
72
|
+
context.set_code(status_code)
|
|
73
|
+
logger.error(f"{err_mesg} status code: {status_code}", exc_info=True)
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _validate_response_type(
|
|
78
|
+
resp_envelope: envelope_pb2.ResponseEnvelope, expected_field: str
|
|
79
|
+
) -> None:
|
|
80
|
+
"""Validate that response envelope has expected type."""
|
|
81
|
+
field_name = resp_envelope.WhichOneof("response_type")
|
|
82
|
+
if field_name != expected_field:
|
|
83
|
+
raise UnexpectedResponseException(
|
|
84
|
+
f"Expected response type {expected_field}, got {field_name}"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _build_result_complete_response(
|
|
89
|
+
request: base_pb2.ExecutePlanRequest,
|
|
90
|
+
) -> base_pb2.ExecutePlanResponse:
|
|
91
|
+
return base_pb2.ExecutePlanResponse(
|
|
92
|
+
session_id=request.session_id,
|
|
93
|
+
operation_id=request.operation_id or "0",
|
|
94
|
+
result_complete=base_pb2.ExecutePlanResponse.ResultComplete(),
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _build_exec_plan_resp_stream_from_df_query_result(
|
|
99
|
+
request: base_pb2.ExecutePlanRequest,
|
|
100
|
+
session: Session,
|
|
101
|
+
query_result: envelope_pb2.DataframeQueryResult,
|
|
102
|
+
) -> Iterator[base_pb2.ExecutePlanResponse]:
|
|
103
|
+
query_id = query_result.result_job_uuid
|
|
104
|
+
arrow_schema = pa.ipc.read_schema(pa.BufferReader(query_result.arrow_schema))
|
|
105
|
+
spark_schema = types_pb2.DataType()
|
|
106
|
+
spark_schema.ParseFromString(query_result.spark_schema)
|
|
107
|
+
|
|
108
|
+
for row_count, arrow_batch_bytes in fetch_query_result_as_arrow_batches(
|
|
109
|
+
session, query_id, arrow_schema
|
|
110
|
+
):
|
|
111
|
+
yield base_pb2.ExecutePlanResponse(
|
|
112
|
+
session_id=request.session_id,
|
|
113
|
+
operation_id=request.operation_id or "0",
|
|
114
|
+
arrow_batch=base_pb2.ExecutePlanResponse.ArrowBatch(
|
|
115
|
+
row_count=row_count,
|
|
116
|
+
data=arrow_batch_bytes,
|
|
117
|
+
),
|
|
118
|
+
schema=spark_schema,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
yield _build_result_complete_response(request)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _build_exec_plan_resp_stream_from_resp_envelope(
|
|
125
|
+
request: base_pb2.ExecutePlanRequest,
|
|
126
|
+
session: Session,
|
|
127
|
+
resp_envelope: envelope_pb2.ResponseEnvelope,
|
|
128
|
+
) -> Iterator[base_pb2.ExecutePlanResponse]:
|
|
129
|
+
"""Build execution plan response stream from response envelope."""
|
|
130
|
+
resp_type = resp_envelope.WhichOneof("response_type")
|
|
131
|
+
|
|
132
|
+
if resp_type == "dataframe_query_result":
|
|
133
|
+
query_result = resp_envelope.dataframe_query_result
|
|
134
|
+
yield from _build_exec_plan_resp_stream_from_df_query_result(
|
|
135
|
+
request, session, query_result
|
|
136
|
+
)
|
|
137
|
+
elif resp_type == "execute_plan_response":
|
|
138
|
+
yield resp_envelope.execute_plan_response
|
|
139
|
+
yield _build_result_complete_response(request)
|
|
140
|
+
elif resp_type == "status":
|
|
141
|
+
raise GrpcErrorStatusException(resp_envelope.status)
|
|
142
|
+
else:
|
|
143
|
+
logger.warning(f"Unexpected response type: {resp_type}")
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
|
|
147
|
+
# Configs frequently read by PySpark client but not supported in SAS.
|
|
148
|
+
# We return the request's default value directly without calling the backend.
|
|
149
|
+
#
|
|
150
|
+
# Why this matters:
|
|
151
|
+
# - These configs are read via get_config_with_defaults() on every show()/toPandas()
|
|
152
|
+
# - In SAS, show() internally calls toPandas() (see pyspark/sql/connect/dataframe.py _show_string)
|
|
153
|
+
# - Without this optimization, every show() would make a backend config request
|
|
154
|
+
#
|
|
155
|
+
# Source: OSS Spark python/pyspark/sql/connect/client/core.py to_pandas() method
|
|
156
|
+
_UNSUPPORTED_CONFIGS: frozenset[str] = frozenset(
|
|
157
|
+
{
|
|
158
|
+
# Read on every show() and toPandas() call - controls Arrow memory optimization
|
|
159
|
+
"spark.sql.execution.arrow.pyspark.selfDestruct.enabled",
|
|
160
|
+
# Read on toPandas() when DataFrame contains StructType fields
|
|
161
|
+
"spark.sql.execution.pandas.structHandlingMode",
|
|
162
|
+
}
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def __init__(self, snowpark_session: Session) -> None:
|
|
166
|
+
self.snowpark_session = snowpark_session
|
|
167
|
+
self._config_cache: SynchronizedDict[str, str] = SynchronizedDict()
|
|
168
|
+
|
|
169
|
+
def _get_spark_resource(self) -> SparkConnectResource:
|
|
170
|
+
return SparkConnectResource(self.snowpark_session)
|
|
171
|
+
|
|
172
|
+
def _parse_response_envelope(
|
|
173
|
+
self, response_bytes: bytes | bytearray, expected_resp_type: str = None
|
|
174
|
+
) -> envelope_pb2.ResponseEnvelope:
|
|
175
|
+
"""Parse and validate response envelope from GS backend."""
|
|
176
|
+
|
|
177
|
+
resp_envelope = envelope_pb2.ResponseEnvelope()
|
|
178
|
+
if isinstance(response_bytes, bytearray):
|
|
179
|
+
response_bytes = bytes(response_bytes)
|
|
180
|
+
resp_envelope.ParseFromString(response_bytes)
|
|
181
|
+
|
|
182
|
+
resp_type = resp_envelope.WhichOneof("response_type")
|
|
183
|
+
if resp_type == "status":
|
|
184
|
+
raise GrpcErrorStatusException(resp_envelope.status)
|
|
185
|
+
|
|
186
|
+
if not resp_envelope.query_id and not resp_type == "dataframe_query_result":
|
|
187
|
+
_validate_response_type(resp_envelope, expected_resp_type)
|
|
188
|
+
|
|
189
|
+
return resp_envelope
|
|
190
|
+
|
|
191
|
+
def ExecutePlan(
|
|
192
|
+
self, request: base_pb2.ExecutePlanRequest, context: grpc.ServicerContext
|
|
193
|
+
) -> Iterator[base_pb2.ExecutePlanResponse]:
|
|
194
|
+
"""Execute a Spark plan by forwarding to GS backend."""
|
|
195
|
+
logger.debug("Received Execute Plan request")
|
|
196
|
+
query_id = None
|
|
197
|
+
|
|
198
|
+
try:
|
|
199
|
+
spark_resource = self._get_spark_resource()
|
|
200
|
+
response_bytes = spark_resource.execute_plan(request.SerializeToString())
|
|
201
|
+
resp_envelope = self._parse_response_envelope(
|
|
202
|
+
response_bytes, "execute_plan_response"
|
|
203
|
+
)
|
|
204
|
+
query_id = resp_envelope.query_id
|
|
205
|
+
|
|
206
|
+
if query_id:
|
|
207
|
+
job_res_envelope = fetch_query_result_as_protobuf(
|
|
208
|
+
self.snowpark_session, resp_envelope.query_id
|
|
209
|
+
)
|
|
210
|
+
yield from _build_exec_plan_resp_stream_from_resp_envelope(
|
|
211
|
+
request, self.snowpark_session, job_res_envelope
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
yield from _build_exec_plan_resp_stream_from_resp_envelope(
|
|
215
|
+
request, self.snowpark_session, resp_envelope
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
except GrpcErrorStatusException as e:
|
|
219
|
+
context.abort_with_status(rpc_status.to_status(e.status))
|
|
220
|
+
except Exception as e:
|
|
221
|
+
logger.error(f"Error in ExecutePlan, query id {query_id}", exc_info=True)
|
|
222
|
+
return _log_and_return_error(
|
|
223
|
+
"Error in ExecutePlan call", e, grpc.StatusCode.INTERNAL, context
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def _call_backend_config(
|
|
227
|
+
self, request: base_pb2.ConfigRequest
|
|
228
|
+
) -> base_pb2.ConfigResponse:
|
|
229
|
+
"""Forward config request to GS and return response."""
|
|
230
|
+
spark_resource = self._get_spark_resource()
|
|
231
|
+
response_bytes = spark_resource.config(request.SerializeToString())
|
|
232
|
+
resp_envelope = self._parse_response_envelope(response_bytes, "config_response")
|
|
233
|
+
|
|
234
|
+
query_id = resp_envelope.query_id
|
|
235
|
+
if query_id:
|
|
236
|
+
resp_envelope = fetch_query_result_as_protobuf(
|
|
237
|
+
self.snowpark_session, query_id
|
|
238
|
+
)
|
|
239
|
+
assert resp_envelope.WhichOneof("response_type") == "config_response"
|
|
240
|
+
|
|
241
|
+
return resp_envelope.config_response
|
|
242
|
+
|
|
243
|
+
def _handle_get_cached_config_request(
|
|
244
|
+
self,
|
|
245
|
+
request: base_pb2.ConfigRequest,
|
|
246
|
+
items: list[tuple[str, Optional[str]]],
|
|
247
|
+
op_name: str,
|
|
248
|
+
update_cache_on_miss: bool = True,
|
|
249
|
+
) -> base_pb2.ConfigResponse:
|
|
250
|
+
"""
|
|
251
|
+
Handle config requests with caching and unsupported config checks.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
request: The original ConfigRequest.
|
|
255
|
+
items: List of (key, default_value) tuples. default_value is None for get/get_option.
|
|
256
|
+
op_name: Name of the operation for logging (e.g., "get", "get_with_default").
|
|
257
|
+
update_cache_on_miss: Whether to update the cache with values returned from backend.
|
|
258
|
+
"""
|
|
259
|
+
keys = [k for k, _ in items]
|
|
260
|
+
|
|
261
|
+
# 1. Unsupported Configs Check
|
|
262
|
+
# If all keys are unsupported, return defaults (if any) or empty response without calling backend
|
|
263
|
+
if all(key in self._UNSUPPORTED_CONFIGS for key in keys):
|
|
264
|
+
response = base_pb2.ConfigResponse(session_id=request.session_id)
|
|
265
|
+
for key, default_val in items:
|
|
266
|
+
resp_pair = response.pairs.add()
|
|
267
|
+
resp_pair.key = key
|
|
268
|
+
if default_val is not None:
|
|
269
|
+
resp_pair.value = default_val
|
|
270
|
+
logger.debug(f"Config {op_name} returning defaults for unsupported: {keys}")
|
|
271
|
+
return response
|
|
272
|
+
|
|
273
|
+
# 2. Cache Check
|
|
274
|
+
# Check if all keys are in cache
|
|
275
|
+
cached_values = {key: self._config_cache.get(key) for key in keys}
|
|
276
|
+
if cached_values and all(value is not None for value in cached_values.values()):
|
|
277
|
+
response = base_pb2.ConfigResponse(session_id=request.session_id)
|
|
278
|
+
for key in keys:
|
|
279
|
+
resp_pair = response.pairs.add()
|
|
280
|
+
resp_pair.key = key
|
|
281
|
+
resp_pair.value = cached_values[key]
|
|
282
|
+
logger.debug(f"Config {op_name} served from cache: {keys}")
|
|
283
|
+
return response
|
|
284
|
+
|
|
285
|
+
# 3. Cache Miss - Call Backend
|
|
286
|
+
config_response = self._call_backend_config(request)
|
|
287
|
+
|
|
288
|
+
if update_cache_on_miss:
|
|
289
|
+
for pair in config_response.pairs:
|
|
290
|
+
if pair.HasField("value"):
|
|
291
|
+
self._config_cache[pair.key] = pair.value
|
|
292
|
+
logger.debug(f"Config {op_name} cached from backend: {keys}")
|
|
293
|
+
else:
|
|
294
|
+
logger.debug(f"Config {op_name} from backend (not cached): {keys}")
|
|
295
|
+
|
|
296
|
+
return config_response
|
|
297
|
+
|
|
298
|
+
def Config(
|
|
299
|
+
self, request: base_pb2.ConfigRequest, context: grpc.ServicerContext
|
|
300
|
+
) -> base_pb2.ConfigResponse:
|
|
301
|
+
logger.debug("Received Config request")
|
|
302
|
+
|
|
303
|
+
try:
|
|
304
|
+
op = request.operation
|
|
305
|
+
op_type = op.WhichOneof("op_type")
|
|
306
|
+
|
|
307
|
+
match op_type:
|
|
308
|
+
case "get_with_default":
|
|
309
|
+
pairs = op.get_with_default.pairs
|
|
310
|
+
items = [
|
|
311
|
+
(p.key, p.value if p.HasField("value") else None) for p in pairs
|
|
312
|
+
]
|
|
313
|
+
return self._handle_get_cached_config_request(
|
|
314
|
+
request, items, "get_with_default", update_cache_on_miss=False
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
case "get":
|
|
318
|
+
keys = op.get.keys
|
|
319
|
+
items = [(k, None) for k in keys]
|
|
320
|
+
return self._handle_get_cached_config_request(request, items, "get")
|
|
321
|
+
|
|
322
|
+
case "set":
|
|
323
|
+
config_response = self._call_backend_config(request)
|
|
324
|
+
|
|
325
|
+
for pair in op.set.pairs:
|
|
326
|
+
if pair.HasField("value"):
|
|
327
|
+
self._config_cache[pair.key] = pair.value
|
|
328
|
+
logger.debug(
|
|
329
|
+
f"Config set updated cache: {[p.key for p in op.set.pairs]}"
|
|
330
|
+
)
|
|
331
|
+
return config_response
|
|
332
|
+
|
|
333
|
+
case "unset":
|
|
334
|
+
config_response = self._call_backend_config(request)
|
|
335
|
+
|
|
336
|
+
for key in op.unset.keys:
|
|
337
|
+
self._config_cache.remove(key)
|
|
338
|
+
logger.debug(f"Config unset updated cache: {list(op.unset.keys)}")
|
|
339
|
+
return config_response
|
|
340
|
+
|
|
341
|
+
case "get_option":
|
|
342
|
+
keys = op.get_option.keys
|
|
343
|
+
items = [(k, None) for k in keys]
|
|
344
|
+
return self._handle_get_cached_config_request(
|
|
345
|
+
request, items, "get_option"
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
case "get_all":
|
|
349
|
+
# Always call backend since this is a prefix-based search and we
|
|
350
|
+
# can't know if all matching keys are in cache. Cache the results.
|
|
351
|
+
config_response = self._call_backend_config(request)
|
|
352
|
+
|
|
353
|
+
# Cache all returned values
|
|
354
|
+
for pair in config_response.pairs:
|
|
355
|
+
if pair.HasField("value"):
|
|
356
|
+
self._config_cache[pair.key] = pair.value
|
|
357
|
+
prefix = (
|
|
358
|
+
op.get_all.prefix if op.get_all.HasField("prefix") else "all"
|
|
359
|
+
)
|
|
360
|
+
logger.debug(
|
|
361
|
+
f"Config get_all cached {len(config_response.pairs)} items (prefix={prefix})"
|
|
362
|
+
)
|
|
363
|
+
return config_response
|
|
364
|
+
|
|
365
|
+
case _:
|
|
366
|
+
# Forward other operations to backend (no caching)
|
|
367
|
+
logger.debug(
|
|
368
|
+
f"Forwarding unknown config request of type {op_type} to the backend"
|
|
369
|
+
)
|
|
370
|
+
return self._call_backend_config(request)
|
|
371
|
+
|
|
372
|
+
except GrpcErrorStatusException as e:
|
|
373
|
+
context.abort_with_status(rpc_status.to_status(e.status))
|
|
374
|
+
except Exception as e:
|
|
375
|
+
logger.error("Error in Config", exc_info=True)
|
|
376
|
+
return _log_and_return_error(
|
|
377
|
+
"Error in Config call", e, grpc.StatusCode.INTERNAL, context
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
def AnalyzePlan(
|
|
381
|
+
self, request: base_pb2.AnalyzePlanRequest, context: grpc.ServicerContext
|
|
382
|
+
) -> base_pb2.AnalyzePlanResponse:
|
|
383
|
+
logger.debug("Received Analyze Plan request")
|
|
384
|
+
query_id = None
|
|
385
|
+
|
|
386
|
+
try:
|
|
387
|
+
spark_resource = self._get_spark_resource()
|
|
388
|
+
response_bytes = spark_resource.analyze_plan(request.SerializeToString())
|
|
389
|
+
resp_envelope = self._parse_response_envelope(
|
|
390
|
+
response_bytes, "analyze_plan_response"
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
query_id = resp_envelope.query_id
|
|
394
|
+
|
|
395
|
+
if query_id:
|
|
396
|
+
resp_envelope = fetch_query_result_as_protobuf(
|
|
397
|
+
self.snowpark_session, query_id
|
|
398
|
+
)
|
|
399
|
+
assert (
|
|
400
|
+
resp_envelope.WhichOneof("response_type") == "analyze_plan_response"
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
return resp_envelope.analyze_plan_response
|
|
404
|
+
|
|
405
|
+
except GrpcErrorStatusException as e:
|
|
406
|
+
context.abort_with_status(rpc_status.to_status(e.status))
|
|
407
|
+
except Exception as e:
|
|
408
|
+
logger.error(f"Error in AnalyzePlan, query id {query_id}", exc_info=True)
|
|
409
|
+
return _log_and_return_error(
|
|
410
|
+
"Error in AnalyzePlan call", e, grpc.StatusCode.INTERNAL, context
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
def AddArtifacts(
|
|
414
|
+
self,
|
|
415
|
+
request_iterator: Iterator[base_pb2.AddArtifactsRequest],
|
|
416
|
+
context: grpc.ServicerContext,
|
|
417
|
+
) -> base_pb2.AddArtifactsResponse:
|
|
418
|
+
logger.debug("Received AddArtifacts request")
|
|
419
|
+
add_artifacts_response = None
|
|
420
|
+
|
|
421
|
+
spark_resource = self._get_spark_resource()
|
|
422
|
+
|
|
423
|
+
for request in request_iterator:
|
|
424
|
+
query_id = None
|
|
425
|
+
try:
|
|
426
|
+
response_bytes = spark_resource.add_artifacts(
|
|
427
|
+
request.SerializeToString()
|
|
428
|
+
)
|
|
429
|
+
resp_envelope = self._parse_response_envelope(
|
|
430
|
+
response_bytes, "add_artifacts_response"
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
query_id = resp_envelope.query_id
|
|
434
|
+
|
|
435
|
+
if query_id:
|
|
436
|
+
resp_envelope = fetch_query_result_as_protobuf(
|
|
437
|
+
self.snowpark_session, query_id
|
|
438
|
+
)
|
|
439
|
+
assert (
|
|
440
|
+
resp_envelope.WhichOneof("response_type")
|
|
441
|
+
== "add_artifacts_response"
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
add_artifacts_response = resp_envelope.add_artifacts_response
|
|
445
|
+
|
|
446
|
+
except GrpcErrorStatusException as e:
|
|
447
|
+
context.abort_with_status(rpc_status.to_status(e.status))
|
|
448
|
+
except Exception as e:
|
|
449
|
+
logger.error(
|
|
450
|
+
f"Error in AddArtifacts, query id {query_id}", exc_info=True
|
|
451
|
+
)
|
|
452
|
+
return _log_and_return_error(
|
|
453
|
+
"Error in AddArtifacts call", e, grpc.StatusCode.INTERNAL, context
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
if add_artifacts_response is None:
|
|
457
|
+
raise ValueError("AddArtifacts received empty request_iterator")
|
|
458
|
+
|
|
459
|
+
return add_artifacts_response
|
|
460
|
+
|
|
461
|
+
def ArtifactStatus(
|
|
462
|
+
self, request: base_pb2.ArtifactStatusesRequest, context: grpc.ServicerContext
|
|
463
|
+
) -> base_pb2.ArtifactStatusesResponse:
|
|
464
|
+
"""Check statuses of artifacts in the session and returns them in a [[ArtifactStatusesResponse]]"""
|
|
465
|
+
logger.debug("Received ArtifactStatus request")
|
|
466
|
+
query_id = None
|
|
467
|
+
|
|
468
|
+
try:
|
|
469
|
+
spark_resource = self._get_spark_resource()
|
|
470
|
+
response_bytes = spark_resource.artifact_status(request.SerializeToString())
|
|
471
|
+
resp_envelope = self._parse_response_envelope(
|
|
472
|
+
response_bytes, "artifact_status_response"
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
query_id = resp_envelope.query_id
|
|
476
|
+
|
|
477
|
+
if query_id:
|
|
478
|
+
resp_envelope = fetch_query_result_as_protobuf(
|
|
479
|
+
self.snowpark_session, query_id
|
|
480
|
+
)
|
|
481
|
+
assert (
|
|
482
|
+
resp_envelope.WhichOneof("response_type")
|
|
483
|
+
== "artifact_status_response"
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
return resp_envelope.artifact_status_response
|
|
487
|
+
except GrpcErrorStatusException as e:
|
|
488
|
+
context.abort_with_status(rpc_status.to_status(e.status))
|
|
489
|
+
except Exception as e:
|
|
490
|
+
logger.error(f"Error in ArtifactStatus, query id {query_id}", exc_info=True)
|
|
491
|
+
return _log_and_return_error(
|
|
492
|
+
"Error in ArtifactStatus call", e, grpc.StatusCode.INTERNAL, context
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
def Interrupt(
|
|
496
|
+
self, request: base_pb2.InterruptRequest, context: grpc.ServicerContext
|
|
497
|
+
) -> base_pb2.InterruptResponse:
|
|
498
|
+
"""Interrupt running executions."""
|
|
499
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
500
|
+
context.set_details("Method Interrupt not implemented!")
|
|
501
|
+
raise NotImplementedError("Method Interrupt not implemented!")
|
|
502
|
+
|
|
503
|
+
def ReleaseExecute(
|
|
504
|
+
self, request: base_pb2.ReleaseExecuteRequest, context: grpc.ServicerContext
|
|
505
|
+
) -> base_pb2.ReleaseExecuteResponse:
|
|
506
|
+
"""Release an execution."""
|
|
507
|
+
logger.debug("Received Release Execute request")
|
|
508
|
+
try:
|
|
509
|
+
return base_pb2.ReleaseExecuteResponse(
|
|
510
|
+
session_id=request.session_id,
|
|
511
|
+
operation_id=request.operation_id or str(uuid.uuid4()),
|
|
512
|
+
)
|
|
513
|
+
except Exception as e:
|
|
514
|
+
logger.error("Error in ReleaseExecute", exc_info=True)
|
|
515
|
+
return _log_and_return_error(
|
|
516
|
+
"Error in ReleaseExecute call", e, grpc.StatusCode.INTERNAL, context
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
def ReattachExecute(
|
|
520
|
+
self, request: base_pb2.ReattachExecuteRequest, context: grpc.ServicerContext
|
|
521
|
+
) -> Iterator[base_pb2.ExecutePlanResponse]:
|
|
522
|
+
"""Reattach to an existing reattachable execution.
|
|
523
|
+
The ExecutePlan must have been started with ReattachOptions.reattachable=true.
|
|
524
|
+
If the ExecutePlanResponse stream ends without a ResultComplete message, there is more to
|
|
525
|
+
continue. If there is a ResultComplete, the client should use ReleaseExecute with
|
|
526
|
+
"""
|
|
527
|
+
from google.rpc import status_pb2
|
|
528
|
+
|
|
529
|
+
status = status_pb2.Status(
|
|
530
|
+
code=code_pb2.UNIMPLEMENTED,
|
|
531
|
+
message="Method ReattachExecute not implemented! INVALID_HANDLE.OPERATION_NOT_FOUND",
|
|
532
|
+
)
|
|
533
|
+
context.abort_with_status(rpc_status.to_status(status))
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def _serve(
|
|
537
|
+
stop_event: Optional[threading.Event] = None,
|
|
538
|
+
session: Optional[snowpark.Session] = None,
|
|
539
|
+
) -> None:
|
|
540
|
+
server_running = get_server_running()
|
|
541
|
+
try:
|
|
542
|
+
if session is None:
|
|
543
|
+
session = get_or_create_snowpark_session()
|
|
544
|
+
|
|
545
|
+
server_options = _get_default_grpc_options()
|
|
546
|
+
max_workers = get_int_from_env("SPARK_CONNECT_CLIENT_GRPC_MAX_WORKERS", 10)
|
|
547
|
+
|
|
548
|
+
server = grpc.server(
|
|
549
|
+
futures.ThreadPoolExecutor(max_workers=max_workers),
|
|
550
|
+
options=server_options,
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
base_pb2_grpc.add_SparkConnectServiceServicer_to_server(
|
|
554
|
+
SnowflakeConnectClientServicer(session),
|
|
555
|
+
server,
|
|
556
|
+
)
|
|
557
|
+
server_url = get_server_url()
|
|
558
|
+
server.add_insecure_port(server_url)
|
|
559
|
+
logger.info(f"Starting Snowpark Connect server on {server_url}...")
|
|
560
|
+
server.start()
|
|
561
|
+
server_running.set()
|
|
562
|
+
logger.info("Snowpark Connect server started!")
|
|
563
|
+
|
|
564
|
+
if stop_event is not None:
|
|
565
|
+
# start a background thread to listen for stop event and terminate the server
|
|
566
|
+
threading.Thread(
|
|
567
|
+
target=_stop_server, args=(stop_event, server), daemon=True
|
|
568
|
+
).start()
|
|
569
|
+
|
|
570
|
+
server.wait_for_termination()
|
|
571
|
+
except Exception as e:
|
|
572
|
+
set_server_error(True)
|
|
573
|
+
server_running.set() # unblock any client sessions
|
|
574
|
+
if "Invalid connection_name 'spark-connect', known ones are " in str(e):
|
|
575
|
+
logger.error(
|
|
576
|
+
"Ensure 'spark-connect' connection config has been set correctly in connections.toml."
|
|
577
|
+
)
|
|
578
|
+
else:
|
|
579
|
+
logger.error("Error starting up Snowpark Connect server", exc_info=True)
|
|
580
|
+
attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
|
|
581
|
+
raise e
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def start_session(
|
|
585
|
+
is_daemon: bool = True,
|
|
586
|
+
remote_url: Optional[str] = None,
|
|
587
|
+
tcp_port: Optional[int] = None,
|
|
588
|
+
unix_domain_socket: Optional[str] = None,
|
|
589
|
+
stop_event: threading.Event = None,
|
|
590
|
+
snowpark_session: Optional[snowpark.Session] = None,
|
|
591
|
+
connection_parameters: Optional[Dict[str, str]] = None,
|
|
592
|
+
max_grpc_message_size: int = None,
|
|
593
|
+
_add_signal_handler: bool = False,
|
|
594
|
+
) -> threading.Thread | None:
|
|
595
|
+
"""
|
|
596
|
+
Starts Spark Connect server connected to Snowflake. No-op if the Server is already running.
|
|
597
|
+
|
|
598
|
+
Parameters:
|
|
599
|
+
is_daemon (bool): Should run the server as daemon or not. use True to automatically shut the Spark connect
|
|
600
|
+
server down when the main program (or test) finishes. use False to start the server in a
|
|
601
|
+
stand-alone, long-running mode.
|
|
602
|
+
remote_url (Optional[str]): sc:// URL on which to start the Spark Connect server. This option is incompatible with the tcp_port
|
|
603
|
+
and unix_domain_socket parameters.
|
|
604
|
+
tcp_port (Optional[int]): TCP port on which to start the Spark Connect server. This option is incompatible with
|
|
605
|
+
the remote_url and unix_domain_socket parameters.
|
|
606
|
+
unix_domain_socket (Optional[str]): Path to the unix domain socket on which to start the Spark Connect server.
|
|
607
|
+
This option is incompatible with the remote_url and tcp_port parameters.
|
|
608
|
+
stop_event (Optional[threading.Event]): Stop the SAS server when stop_event.set() is called.
|
|
609
|
+
Only works when is_daemon=True.
|
|
610
|
+
snowpark_session: A Snowpark session to use for this connection; currently the only applicable use of this is to
|
|
611
|
+
pass in the session created by the stored proc environment.
|
|
612
|
+
connection_parameters: A dictionary of connection parameters to use to create the Snowpark session. If this is
|
|
613
|
+
provided, the `snowpark_session` parameter must be None.
|
|
614
|
+
"""
|
|
615
|
+
|
|
616
|
+
try:
|
|
617
|
+
# Set max grpc message size if provided
|
|
618
|
+
if max_grpc_message_size is not None:
|
|
619
|
+
set_grpc_max_message_size(max_grpc_message_size)
|
|
620
|
+
|
|
621
|
+
# Validate startup parameters
|
|
622
|
+
snowpark_session = validate_startup_parameters(
|
|
623
|
+
snowpark_session, connection_parameters
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
server_running = get_server_running()
|
|
627
|
+
if server_running.is_set():
|
|
628
|
+
url = get_client_url()
|
|
629
|
+
logger.warning(f"Snowpark Connect session is already running at {url}")
|
|
630
|
+
return
|
|
631
|
+
|
|
632
|
+
# Configure server URL
|
|
633
|
+
configure_server_url(remote_url, tcp_port, unix_domain_socket)
|
|
634
|
+
|
|
635
|
+
_disable_protobuf_recursion_limit()
|
|
636
|
+
|
|
637
|
+
if _add_signal_handler:
|
|
638
|
+
setup_signal_handlers(stop_event)
|
|
639
|
+
|
|
640
|
+
if is_daemon:
|
|
641
|
+
arguments = (stop_event, snowpark_session)
|
|
642
|
+
server_thread = threading.Thread(target=_serve, args=arguments, daemon=True)
|
|
643
|
+
server_thread.start()
|
|
644
|
+
server_running.wait()
|
|
645
|
+
if get_server_error():
|
|
646
|
+
exception = RuntimeError("Snowpark Connect session failed to start")
|
|
647
|
+
attach_custom_error_code(
|
|
648
|
+
exception, ErrorCodes.STARTUP_CONNECTION_FAILED
|
|
649
|
+
)
|
|
650
|
+
raise exception
|
|
651
|
+
|
|
652
|
+
return server_thread
|
|
653
|
+
else:
|
|
654
|
+
# Launch in the foreground with stop_event
|
|
655
|
+
_serve(stop_event=stop_event, session=snowpark_session)
|
|
656
|
+
except Exception as e:
|
|
657
|
+
_reset_server_run_state()
|
|
658
|
+
logger.error(e, exc_info=True)
|
|
659
|
+
attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
|
|
660
|
+
raise e
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def init_spark_session(conf: SparkConf = None) -> SparkSession:
|
|
664
|
+
"""
|
|
665
|
+
Initialize and return a Spark session.
|
|
666
|
+
|
|
667
|
+
Parameters:
|
|
668
|
+
conf (SparkConf): Optional Spark configuration.
|
|
669
|
+
|
|
670
|
+
Returns:
|
|
671
|
+
A new SparkSession connected to the Snowpark Connect thin Client server.
|
|
672
|
+
"""
|
|
673
|
+
_setup_spark_environment(False)
|
|
674
|
+
from snowflake.snowpark_connect.client.utils.session import (
|
|
675
|
+
_get_current_snowpark_session,
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
snowpark_session = _get_current_snowpark_session()
|
|
679
|
+
start_session(snowpark_session=snowpark_session)
|
|
680
|
+
return get_session(conf=conf)
|