snowpark-connect 0.27.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +680 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +237 -23
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +123 -5
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +85 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2748 -746
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +196 -255
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +255 -45
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
- snowflake/snowpark_connect/relation/read/map_read_json.py +320 -85
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +142 -27
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +82 -5
- snowflake/snowpark_connect/relation/read/map_read_text.py +18 -3
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +36 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +110 -48
- snowflake/snowpark_connect/server.py +546 -456
- snowflake/snowpark_connect/server_common/__init__.py +500 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +187 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +125 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +303 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +248 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +101 -332
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +163 -22
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-submit +2 -2
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/METADATA +14 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/RECORD +129 -167
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,500 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Common server utilities shared between SAS server and client server.
|
|
7
|
+
|
|
8
|
+
This module contains shared constants, global state management, URL handling,
|
|
9
|
+
gRPC configuration, and session management code that is used by both
|
|
10
|
+
the main SAS server (server.py) and the thin client's server (client/server.py).
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import atexit
|
|
14
|
+
import os
|
|
15
|
+
import signal
|
|
16
|
+
import socket
|
|
17
|
+
import tempfile
|
|
18
|
+
import threading
|
|
19
|
+
import urllib.parse
|
|
20
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
21
|
+
|
|
22
|
+
import grpc
|
|
23
|
+
import pyspark
|
|
24
|
+
from pyspark.conf import SparkConf
|
|
25
|
+
from pyspark.errors import PySparkValueError
|
|
26
|
+
from pyspark.sql.connect.client.core import ChannelBuilder
|
|
27
|
+
from pyspark.sql.connect.session import SparkSession
|
|
28
|
+
|
|
29
|
+
from packaging import version
|
|
30
|
+
from snowflake import snowpark
|
|
31
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
32
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
33
|
+
from snowflake.snowpark_connect.utils.env_utils import get_int_from_env
|
|
34
|
+
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
35
|
+
|
|
36
|
+
DEFAULT_PORT = 15002
|
|
37
|
+
|
|
38
|
+
# https://github.com/apache/spark/blob/v3.5.3/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala#L21
|
|
39
|
+
_SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE = 128 * 1024 * 1024
|
|
40
|
+
# TODO: Verify if we want to configure it via env variables.
|
|
41
|
+
_SPARK_CONNECT_GRPC_MAX_METADATA_SIZE = 64 * 1024 # 64kb
|
|
42
|
+
|
|
43
|
+
# Thread-local storage for client telemetry context
|
|
44
|
+
_client_telemetry_context = threading.local()
|
|
45
|
+
|
|
46
|
+
_server_running: threading.Event = threading.Event()
|
|
47
|
+
_server_error: bool = False
|
|
48
|
+
_server_url: Optional[str] = None
|
|
49
|
+
_client_url: Optional[str] = None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_server_running() -> threading.Event:
|
|
53
|
+
"""Get the server running event."""
|
|
54
|
+
return _server_running
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_server_error() -> bool:
|
|
58
|
+
"""Get the server error flag."""
|
|
59
|
+
return _server_error
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def set_server_error(error: bool) -> None:
|
|
63
|
+
"""Set the server error flag."""
|
|
64
|
+
global _server_error
|
|
65
|
+
_server_error = error
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _reset_server_run_state() -> None:
|
|
69
|
+
"""
|
|
70
|
+
Reset server global state to the initial blank slate state.
|
|
71
|
+
Called after the startup error is caught and handled/logged.
|
|
72
|
+
"""
|
|
73
|
+
global _server_running, _server_error, _server_url, _client_url
|
|
74
|
+
_server_running.clear()
|
|
75
|
+
_server_error = False
|
|
76
|
+
_server_url = None
|
|
77
|
+
_client_url = None
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _stop_server(stop_event: threading.Event, server: grpc.Server) -> None:
|
|
81
|
+
"""Wait for stop event and then stop the server."""
|
|
82
|
+
stop_event.wait()
|
|
83
|
+
server.stop(0)
|
|
84
|
+
_reset_server_run_state()
|
|
85
|
+
logger.info("server stop sent")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _get_default_grpc_options() -> List[Tuple[str, Any]]:
|
|
89
|
+
"""Get default gRPC server options."""
|
|
90
|
+
grpc_max_msg_size = get_int_from_env(
|
|
91
|
+
"SNOWFLAKE_GRPC_MAX_MESSAGE_SIZE",
|
|
92
|
+
_SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE,
|
|
93
|
+
)
|
|
94
|
+
grpc_max_metadata_size = get_int_from_env(
|
|
95
|
+
"SNOWFLAKE_GRPC_MAX_METADATA_SIZE",
|
|
96
|
+
_SPARK_CONNECT_GRPC_MAX_METADATA_SIZE,
|
|
97
|
+
)
|
|
98
|
+
server_options = [
|
|
99
|
+
(
|
|
100
|
+
"grpc.max_send_message_length",
|
|
101
|
+
grpc_max_msg_size,
|
|
102
|
+
),
|
|
103
|
+
(
|
|
104
|
+
"grpc.max_receive_message_length",
|
|
105
|
+
grpc_max_msg_size,
|
|
106
|
+
),
|
|
107
|
+
(
|
|
108
|
+
"grpc.max_metadata_size",
|
|
109
|
+
grpc_max_metadata_size,
|
|
110
|
+
),
|
|
111
|
+
(
|
|
112
|
+
"grpc.absolute_max_metadata_size",
|
|
113
|
+
grpc_max_metadata_size * 2,
|
|
114
|
+
),
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
# try to adjust max message size for clients in the same process
|
|
118
|
+
from pyspark.sql.connect.client import ChannelBuilder
|
|
119
|
+
|
|
120
|
+
ChannelBuilder.MAX_MESSAGE_LENGTH = grpc_max_msg_size
|
|
121
|
+
|
|
122
|
+
return server_options
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def get_grpc_max_message_size() -> int:
|
|
126
|
+
"""Get the current gRPC max message size."""
|
|
127
|
+
return _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def set_grpc_max_message_size(size: int) -> None:
|
|
131
|
+
"""Set the gRPC max message size."""
|
|
132
|
+
global _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE
|
|
133
|
+
_SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE = size
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def get_server_url() -> str:
|
|
137
|
+
"""Get the server URL."""
|
|
138
|
+
global _server_url
|
|
139
|
+
if not _server_url:
|
|
140
|
+
exception = RuntimeError("Server URL not set")
|
|
141
|
+
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
142
|
+
raise exception
|
|
143
|
+
return _server_url
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def get_client_url() -> str:
|
|
147
|
+
"""Get the client URL."""
|
|
148
|
+
global _client_url
|
|
149
|
+
if not _client_url:
|
|
150
|
+
exception = RuntimeError("Client URL not set")
|
|
151
|
+
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
152
|
+
raise exception
|
|
153
|
+
return _client_url
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _check_port_is_free(port: int) -> None:
|
|
157
|
+
"""Check if a TCP port is available."""
|
|
158
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
159
|
+
s.settimeout(1)
|
|
160
|
+
if s.connect_ex(("127.0.0.1", port)) == 0:
|
|
161
|
+
exception = RuntimeError(f"TCP port {port} is already in use")
|
|
162
|
+
attach_custom_error_code(exception, ErrorCodes.TCP_PORT_ALREADY_IN_USE)
|
|
163
|
+
raise exception
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _set_remote_url(remote_url: str):
|
|
167
|
+
"""Set server and client URLs from a remote URL string."""
|
|
168
|
+
global _server_url, _client_url
|
|
169
|
+
_client_url = remote_url
|
|
170
|
+
parsed_url = urllib.parse.urlparse(remote_url)
|
|
171
|
+
if parsed_url.scheme == "sc":
|
|
172
|
+
_server_url = parsed_url.netloc
|
|
173
|
+
server_port = parsed_url.port or DEFAULT_PORT
|
|
174
|
+
_check_port_is_free(server_port)
|
|
175
|
+
elif parsed_url.scheme == "unix":
|
|
176
|
+
_server_url = remote_url.split("/;")[0]
|
|
177
|
+
else:
|
|
178
|
+
exception = RuntimeError(f"Invalid Snowpark Connect URL: {remote_url}")
|
|
179
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_SPARK_CONNECT_URL)
|
|
180
|
+
raise exception
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _set_server_tcp_port(server_port: int):
|
|
184
|
+
"""Set server and client URLs from a TCP port."""
|
|
185
|
+
global _server_url, _client_url
|
|
186
|
+
_check_port_is_free(server_port)
|
|
187
|
+
_server_url = f"[::]:{server_port}"
|
|
188
|
+
_client_url = f"sc://127.0.0.1:{server_port}"
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _set_server_unix_domain_socket(path: str):
|
|
192
|
+
"""Set server and client URLs from a Unix domain socket path."""
|
|
193
|
+
global _server_url, _client_url
|
|
194
|
+
_server_url = f"unix:{path}"
|
|
195
|
+
_client_url = f"unix:{path}"
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _make_unix_domain_socket() -> str:
|
|
199
|
+
"""Create a unique Unix domain socket path."""
|
|
200
|
+
parent_dir = tempfile.mkdtemp()
|
|
201
|
+
server_path = os.path.join(parent_dir, "snowflake_sas_grpc.sock")
|
|
202
|
+
atexit.register(_cleanup_unix_domain_socket, server_path)
|
|
203
|
+
return server_path
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _cleanup_unix_domain_socket(server_path: str) -> None:
|
|
207
|
+
"""Clean up a Unix domain socket and its parent directory."""
|
|
208
|
+
parent_dir = os.path.dirname(server_path)
|
|
209
|
+
if os.path.exists(server_path):
|
|
210
|
+
os.remove(server_path)
|
|
211
|
+
if os.path.exists(parent_dir):
|
|
212
|
+
os.rmdir(parent_dir)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class UnixDomainSocketChannelBuilder(ChannelBuilder):
|
|
216
|
+
"""
|
|
217
|
+
Spark Connect gRPC channel builder for Unix domain sockets.
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
def __init__(
|
|
221
|
+
self, url: str = None, channelOptions: Optional[List[Tuple[str, Any]]] = None
|
|
222
|
+
) -> None:
|
|
223
|
+
if url is None:
|
|
224
|
+
url = get_client_url()
|
|
225
|
+
if url[:6] != "unix:/" or len(url) < 7:
|
|
226
|
+
exception = PySparkValueError(
|
|
227
|
+
error_class="INVALID_CONNECT_URL",
|
|
228
|
+
message_parameters={
|
|
229
|
+
"detail": "The URL must start with 'unix://'. Please update the URL to follow the correct format, e.g., 'unix://unix_domain_socket_path'.",
|
|
230
|
+
},
|
|
231
|
+
)
|
|
232
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_SPARK_CONNECT_URL)
|
|
233
|
+
raise exception
|
|
234
|
+
|
|
235
|
+
# Rewrite the URL to use http as the scheme so that we can leverage
|
|
236
|
+
# Python's built-in parser to parse URL parameters
|
|
237
|
+
fake_url = "http://" + url[6:]
|
|
238
|
+
self.url = urllib.parse.urlparse(fake_url)
|
|
239
|
+
self.params: Dict[str, str] = {}
|
|
240
|
+
self._extract_attributes()
|
|
241
|
+
|
|
242
|
+
# Now parse the real unix domain socket URL
|
|
243
|
+
self.url = urllib.parse.urlparse(url)
|
|
244
|
+
|
|
245
|
+
GRPC_DEFAULT_OPTIONS = _get_default_grpc_options()
|
|
246
|
+
|
|
247
|
+
if channelOptions is None:
|
|
248
|
+
self._channel_options = GRPC_DEFAULT_OPTIONS
|
|
249
|
+
else:
|
|
250
|
+
for option in channelOptions:
|
|
251
|
+
if (
|
|
252
|
+
option[0] == "grpc.max_send_message_length"
|
|
253
|
+
or option[0] == "grpc.max_receive_message_length"
|
|
254
|
+
):
|
|
255
|
+
# try to adjust max message size for clients in the same process
|
|
256
|
+
from pyspark.sql.connect.client import ChannelBuilder
|
|
257
|
+
|
|
258
|
+
ChannelBuilder.MAX_MESSAGE_LENGTH = max(
|
|
259
|
+
ChannelBuilder.MAX_MESSAGE_LENGTH, option[1]
|
|
260
|
+
)
|
|
261
|
+
self._channel_options = GRPC_DEFAULT_OPTIONS + channelOptions
|
|
262
|
+
# For Spark 4.0 support, but also backwards compatible.
|
|
263
|
+
self._params = self.params
|
|
264
|
+
|
|
265
|
+
def _extract_attributes(self) -> None:
|
|
266
|
+
"""Extract attributes from parameters.
|
|
267
|
+
|
|
268
|
+
This method was copied from
|
|
269
|
+
https://github.com/apache/spark/blob/branch-3.5/python/pyspark/sql/connect/client/core.py
|
|
270
|
+
|
|
271
|
+
This is required for Spark 4.0 support, since it is dropped in favor of moving
|
|
272
|
+
the extraction logic into the constructor.
|
|
273
|
+
"""
|
|
274
|
+
if len(self.url.params) > 0:
|
|
275
|
+
parts = self.url.params.split(";")
|
|
276
|
+
for p in parts:
|
|
277
|
+
kv = p.split("=")
|
|
278
|
+
if len(kv) != 2:
|
|
279
|
+
exception = PySparkValueError(
|
|
280
|
+
error_class="INVALID_CONNECT_URL",
|
|
281
|
+
message_parameters={
|
|
282
|
+
"detail": f"Parameter '{p}' should be provided as a "
|
|
283
|
+
f"key-value pair separated by an equal sign (=). Please update "
|
|
284
|
+
f"the parameter to follow the correct format, e.g., 'key=value'.",
|
|
285
|
+
},
|
|
286
|
+
)
|
|
287
|
+
attach_custom_error_code(
|
|
288
|
+
exception, ErrorCodes.INVALID_SPARK_CONNECT_URL
|
|
289
|
+
)
|
|
290
|
+
raise exception
|
|
291
|
+
self.params[kv[0]] = urllib.parse.unquote(kv[1])
|
|
292
|
+
|
|
293
|
+
netloc = self.url.netloc.split(":")
|
|
294
|
+
if len(netloc) == 1:
|
|
295
|
+
self.host = netloc[0]
|
|
296
|
+
if version.parse(pyspark.__version__) >= version.parse("4.0.0"):
|
|
297
|
+
from pyspark.sql.connect.client.core import DefaultChannelBuilder
|
|
298
|
+
|
|
299
|
+
self.port = DefaultChannelBuilder.default_port()
|
|
300
|
+
else:
|
|
301
|
+
self.port = ChannelBuilder.default_port()
|
|
302
|
+
elif len(netloc) == 2:
|
|
303
|
+
self.host = netloc[0]
|
|
304
|
+
self.port = int(netloc[1])
|
|
305
|
+
else:
|
|
306
|
+
exception = PySparkValueError(
|
|
307
|
+
error_class="INVALID_CONNECT_URL",
|
|
308
|
+
message_parameters={
|
|
309
|
+
"detail": f"Target destination '{self.url.netloc}' should match the "
|
|
310
|
+
f"'<host>:<port>' pattern. Please update the destination to follow "
|
|
311
|
+
f"the correct format, e.g., 'hostname:port'.",
|
|
312
|
+
},
|
|
313
|
+
)
|
|
314
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_SPARK_CONNECT_URL)
|
|
315
|
+
raise exception
|
|
316
|
+
|
|
317
|
+
# We override this to enable compatibility with Spark 4.0
|
|
318
|
+
host = None
|
|
319
|
+
|
|
320
|
+
@property
|
|
321
|
+
def endpoint(self) -> str:
|
|
322
|
+
return f"{self.url.scheme}:{self.url.path}"
|
|
323
|
+
|
|
324
|
+
def toChannel(self) -> grpc.Channel:
|
|
325
|
+
return grpc.insecure_channel(self.endpoint, options=self._channel_options)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def get_session(url: Optional[str] = None, conf: SparkConf = None) -> SparkSession:
|
|
329
|
+
"""
|
|
330
|
+
Returns spark connect session.
|
|
331
|
+
|
|
332
|
+
Parameters:
|
|
333
|
+
url (Optional[str]): Spark connect server URL. Uses default server URL if none is provided.
|
|
334
|
+
conf (SparkConf): Optional Spark configuration.
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
A new spark connect session.
|
|
338
|
+
|
|
339
|
+
Raises:
|
|
340
|
+
RuntimeError: If Spark Connect server is not started.
|
|
341
|
+
"""
|
|
342
|
+
try:
|
|
343
|
+
if not url:
|
|
344
|
+
url = get_client_url()
|
|
345
|
+
|
|
346
|
+
if url.startswith("unix:/"):
|
|
347
|
+
b = SparkSession.builder.channelBuilder(UnixDomainSocketChannelBuilder())
|
|
348
|
+
else:
|
|
349
|
+
b = SparkSession.builder.remote(url)
|
|
350
|
+
|
|
351
|
+
if conf is not None:
|
|
352
|
+
for k, v in conf.getAll():
|
|
353
|
+
b.config(k, v)
|
|
354
|
+
|
|
355
|
+
return b.getOrCreate()
|
|
356
|
+
except Exception as e:
|
|
357
|
+
_reset_server_run_state()
|
|
358
|
+
logger.error(e, exc_info=True)
|
|
359
|
+
attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
|
|
360
|
+
raise e
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def _setup_spark_environment(setup_java_home: bool = True) -> None:
|
|
364
|
+
"""
|
|
365
|
+
Set up environment variables required for Spark Connect.
|
|
366
|
+
|
|
367
|
+
Parameters:
|
|
368
|
+
setup_java_home: If True, configures JAVA_HOME. Set to False for
|
|
369
|
+
lightweight client servers that don't need JVM.
|
|
370
|
+
"""
|
|
371
|
+
if setup_java_home:
|
|
372
|
+
if os.environ.get("JAVA_HOME") is None:
|
|
373
|
+
try:
|
|
374
|
+
# For Notebooks on SPCS
|
|
375
|
+
from jdk4py import JAVA_HOME
|
|
376
|
+
|
|
377
|
+
os.environ["JAVA_HOME"] = str(JAVA_HOME)
|
|
378
|
+
except ModuleNotFoundError:
|
|
379
|
+
# For notebooks on Warehouse
|
|
380
|
+
conda_prefix = os.environ.get("CONDA_PREFIX")
|
|
381
|
+
if conda_prefix is not None:
|
|
382
|
+
os.environ["JAVA_HOME"] = conda_prefix
|
|
383
|
+
os.environ["JAVA_LD_LIBRARY_PATH"] = os.path.join(
|
|
384
|
+
conda_prefix, "lib", "server"
|
|
385
|
+
)
|
|
386
|
+
logger.info("JAVA_HOME=%s", os.environ.get("JAVA_HOME", "Not defined"))
|
|
387
|
+
|
|
388
|
+
os.environ["SPARK_LOCAL_HOSTNAME"] = "127.0.0.1"
|
|
389
|
+
os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def _disable_protobuf_recursion_limit() -> None:
|
|
393
|
+
"""
|
|
394
|
+
Disable protobuf recursion limit.
|
|
395
|
+
|
|
396
|
+
https://github.com/protocolbuffers/protobuf/blob/960e79087b332583c80537c949621108a85aa442/src/google/protobuf/io/coded_stream.h#L616
|
|
397
|
+
Disable protobuf recursion limit (default 100) because Spark workloads often
|
|
398
|
+
produce deeply nested execution plans. For example:
|
|
399
|
+
- Queries with many unions
|
|
400
|
+
- Complex expressions with multiple levels of nesting
|
|
401
|
+
Without this, legitimate Spark queries would fail with
|
|
402
|
+
`(DecodeError) Error parsing message with type 'spark.connect.Relation'` error.
|
|
403
|
+
See test_sql_resulting_in_nested_protobuf
|
|
404
|
+
"""
|
|
405
|
+
from google.protobuf.pyext import cpp_message
|
|
406
|
+
|
|
407
|
+
cpp_message._message.SetAllowOversizeProtos(True)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def setup_signal_handlers(stop_event: threading.Event) -> None:
|
|
411
|
+
"""Set up signal handlers for graceful shutdown."""
|
|
412
|
+
|
|
413
|
+
def make_signal_handler(stop_event):
|
|
414
|
+
def signal_handler(signum, frame):
|
|
415
|
+
logger.info(f"Received signal {signum}, stopping server gracefully...")
|
|
416
|
+
stop_event.set()
|
|
417
|
+
|
|
418
|
+
return signal_handler
|
|
419
|
+
|
|
420
|
+
try:
|
|
421
|
+
signal_handler = make_signal_handler(stop_event)
|
|
422
|
+
signal.signal(signal.SIGTERM, signal_handler) # kill <pid>
|
|
423
|
+
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
|
424
|
+
if hasattr(signal, "SIGHUP"):
|
|
425
|
+
signal.signal(signal.SIGHUP, signal_handler) # Terminal hangup
|
|
426
|
+
logger.info("Signal handlers registered for graceful shutdown")
|
|
427
|
+
except Exception as e:
|
|
428
|
+
logger.warning(f"Failed to register signal handlers: {e}")
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def configure_server_url(
|
|
432
|
+
remote_url: Optional[str] = None,
|
|
433
|
+
tcp_port: Optional[int] = None,
|
|
434
|
+
unix_domain_socket: Optional[str] = None,
|
|
435
|
+
) -> Optional[str]:
|
|
436
|
+
"""
|
|
437
|
+
Configure server URL based on provided parameters or environment.
|
|
438
|
+
|
|
439
|
+
Returns the unix_domain_socket path if one was created, None otherwise.
|
|
440
|
+
"""
|
|
441
|
+
if len(list(filter(None, [remote_url, tcp_port, unix_domain_socket]))) > 1:
|
|
442
|
+
exception = RuntimeError(
|
|
443
|
+
"Can only set at most one of remote_url, tcp_port, and unix_domain_socket"
|
|
444
|
+
)
|
|
445
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_STARTUP_INPUT)
|
|
446
|
+
raise exception
|
|
447
|
+
|
|
448
|
+
url_from_env = os.environ.get("SPARK_REMOTE", None)
|
|
449
|
+
created_socket = None
|
|
450
|
+
|
|
451
|
+
if remote_url:
|
|
452
|
+
_set_remote_url(remote_url)
|
|
453
|
+
elif tcp_port:
|
|
454
|
+
_set_server_tcp_port(tcp_port)
|
|
455
|
+
elif unix_domain_socket:
|
|
456
|
+
_set_server_unix_domain_socket(unix_domain_socket)
|
|
457
|
+
elif url_from_env:
|
|
458
|
+
# Spark clients use environment variable SPARK_REMOTE to figure out Spark Connect URL
|
|
459
|
+
_set_remote_url(url_from_env)
|
|
460
|
+
else:
|
|
461
|
+
# No connection properties can be found - use Unix Domain Socket as fallback
|
|
462
|
+
if os.name == "nt":
|
|
463
|
+
# Windows does not support unix domain sockets
|
|
464
|
+
_set_server_tcp_port(DEFAULT_PORT)
|
|
465
|
+
else:
|
|
466
|
+
# Generate unique, random UDS port
|
|
467
|
+
created_socket = _make_unix_domain_socket()
|
|
468
|
+
_set_server_unix_domain_socket(created_socket)
|
|
469
|
+
|
|
470
|
+
return created_socket
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def validate_startup_parameters(
|
|
474
|
+
snowpark_session: Optional[snowpark.Session],
|
|
475
|
+
connection_parameters: Optional[Dict[str, str]],
|
|
476
|
+
) -> Optional[snowpark.Session]:
|
|
477
|
+
"""
|
|
478
|
+
Validate startup parameters and create snowpark session if needed.
|
|
479
|
+
|
|
480
|
+
Returns the snowpark session to use.
|
|
481
|
+
"""
|
|
482
|
+
if os.environ.get("SPARK_ENV_LOADED"):
|
|
483
|
+
exception = RuntimeError(
|
|
484
|
+
"Snowpark Connect cannot be run inside of a Spark environment"
|
|
485
|
+
)
|
|
486
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_STARTUP_OPERATION)
|
|
487
|
+
raise exception
|
|
488
|
+
|
|
489
|
+
if connection_parameters is not None:
|
|
490
|
+
if snowpark_session is not None:
|
|
491
|
+
exception = ValueError(
|
|
492
|
+
"Only specify one of snowpark_session and connection_parameters"
|
|
493
|
+
)
|
|
494
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_STARTUP_INPUT)
|
|
495
|
+
raise exception
|
|
496
|
+
snowpark_session = snowpark.Session.builder.configs(
|
|
497
|
+
connection_parameters
|
|
498
|
+
).create()
|
|
499
|
+
|
|
500
|
+
return snowpark_session
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
import zlib
|
|
6
|
+
|
|
7
|
+
from pyspark.sql import DataFrame, SparkSession
|
|
8
|
+
|
|
9
|
+
SQL_PASS_THROUGH_MARKER = "PRIVATE-SNOWFLAKE-SQL"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def calculate_checksum(data: str) -> str:
|
|
13
|
+
checksum = zlib.crc32(data.encode("utf-8"))
|
|
14
|
+
return format(checksum, "08X")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SnowflakeSession:
|
|
18
|
+
"""
|
|
19
|
+
Provides a wrapper around SparkSession to enable Snowflake SQL pass-through functionality.
|
|
20
|
+
Also provides helper methods to switch to different database, schema, role, warehouse, etc.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, spark_session: SparkSession) -> None:
|
|
24
|
+
self.spark_session = spark_session
|
|
25
|
+
|
|
26
|
+
def sql(self, sql_stmt: str) -> DataFrame:
|
|
27
|
+
"""
|
|
28
|
+
Execute Snowflake specific SQL directly against Snowflake.
|
|
29
|
+
"""
|
|
30
|
+
checksum = calculate_checksum(sql_stmt)
|
|
31
|
+
return self.spark_session.sql(
|
|
32
|
+
f"{SQL_PASS_THROUGH_MARKER} {checksum} {sql_stmt}"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def use_database(self, database: str, preserve_case: bool = False) -> DataFrame:
|
|
36
|
+
"""
|
|
37
|
+
Switch to the database specified by `database`.
|
|
38
|
+
"""
|
|
39
|
+
if preserve_case:
|
|
40
|
+
database = f'"{database}"'
|
|
41
|
+
return self.sql(f"USE DATABASE {database}")
|
|
42
|
+
|
|
43
|
+
def use_schema(self, schema: str, preserve_case: bool = False) -> DataFrame:
|
|
44
|
+
"""
|
|
45
|
+
Switch to the schema specified by `schema`.
|
|
46
|
+
"""
|
|
47
|
+
if preserve_case:
|
|
48
|
+
schema = f'"{schema}"'
|
|
49
|
+
return self.sql(f"USE SCHEMA {schema}")
|
|
50
|
+
|
|
51
|
+
def use_role(self, role: str, preserve_case: bool = False) -> DataFrame:
|
|
52
|
+
"""
|
|
53
|
+
Switch to the role specified by `role`.
|
|
54
|
+
"""
|
|
55
|
+
if preserve_case:
|
|
56
|
+
role = f'"{role}"'
|
|
57
|
+
return self.sql(f"USE ROLE {role}")
|
|
58
|
+
|
|
59
|
+
def use_warehouse(self, warehouse: str, preserve_case: bool = False) -> DataFrame:
|
|
60
|
+
"""
|
|
61
|
+
Switch to the warehouse specified by `warehouse`.
|
|
62
|
+
"""
|
|
63
|
+
if preserve_case:
|
|
64
|
+
warehouse = f'"{warehouse}"'
|
|
65
|
+
return self.sql(f"USE WAREHOUSE {warehouse}")
|
|
@@ -5,15 +5,34 @@
|
|
|
5
5
|
|
|
6
6
|
import argparse
|
|
7
7
|
import logging
|
|
8
|
+
import threading
|
|
9
|
+
|
|
10
|
+
from snowflake.snowpark_connect.utils.spcs_logger import setup_spcs_logger
|
|
8
11
|
|
|
9
12
|
if __name__ == "__main__":
|
|
10
13
|
from snowflake.snowpark_connect.server import start_session
|
|
14
|
+
from snowflake.snowpark_connect.utils.snowpark_connect_logging import (
|
|
15
|
+
ensure_logger_has_handler,
|
|
16
|
+
)
|
|
11
17
|
|
|
12
18
|
parser = argparse.ArgumentParser()
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
19
|
+
# Connection options are mutually exclusive
|
|
20
|
+
connection_group = parser.add_mutually_exclusive_group()
|
|
21
|
+
connection_group.add_argument("--tcp-port", type=int)
|
|
22
|
+
connection_group.add_argument("--unix-domain-socket", type=str)
|
|
23
|
+
|
|
24
|
+
# Logging options are independent
|
|
25
|
+
parser.add_argument("--verbose", action="store_true")
|
|
26
|
+
parser.add_argument(
|
|
27
|
+
"--disable-spcs-log-format",
|
|
28
|
+
action="store_true",
|
|
29
|
+
help="Disable SPCS (Snowpark Container Services) log format",
|
|
30
|
+
)
|
|
31
|
+
parser.add_argument(
|
|
32
|
+
"--disable-signal-handlers",
|
|
33
|
+
action="store_true",
|
|
34
|
+
help="Disable signal handlers (SIGTERM, SIGINT, SIGHUP) for graceful shutdown",
|
|
35
|
+
)
|
|
17
36
|
|
|
18
37
|
args = parser.parse_args()
|
|
19
38
|
unix_domain_socket = args.unix_domain_socket
|
|
@@ -21,12 +40,41 @@ if __name__ == "__main__":
|
|
|
21
40
|
if not unix_domain_socket and not tcp_port:
|
|
22
41
|
tcp_port = 15002 # default spark connect server port
|
|
23
42
|
|
|
43
|
+
log_level = logging.INFO
|
|
24
44
|
if args.verbose:
|
|
45
|
+
log_level = logging.DEBUG
|
|
46
|
+
|
|
47
|
+
# Configure other loggers - clear handlers first for clean setup
|
|
48
|
+
loggers_to_configure = [
|
|
49
|
+
"snowflake.snowpark",
|
|
50
|
+
"snowflake.connector",
|
|
51
|
+
"snowflake.connector.connection",
|
|
52
|
+
"snowflake_connect_server",
|
|
53
|
+
]
|
|
54
|
+
# Set up the logger based on environment
|
|
55
|
+
if not args.disable_spcs_log_format:
|
|
56
|
+
# Initialize SPCS log format when running in Snowpark Container Services (default)
|
|
57
|
+
logger = setup_spcs_logger(
|
|
58
|
+
log_level=log_level,
|
|
59
|
+
enable_console_output=False, # Shows human-readable logs to stderr
|
|
60
|
+
)
|
|
61
|
+
else:
|
|
62
|
+
for logger_name in loggers_to_configure:
|
|
63
|
+
target_logger = logging.getLogger(logger_name)
|
|
64
|
+
target_logger.handlers.clear()
|
|
65
|
+
configured_logger = ensure_logger_has_handler(
|
|
66
|
+
logger_name, log_level, force_level=True
|
|
67
|
+
)
|
|
68
|
+
# Get the logger for use in signal handlers
|
|
25
69
|
logger = logging.getLogger("snowflake_connect_server")
|
|
26
|
-
|
|
70
|
+
|
|
71
|
+
# Create stop_event and optionally set up signal handlers in start_server
|
|
72
|
+
stop_event = threading.Event()
|
|
27
73
|
|
|
28
74
|
start_session(
|
|
29
75
|
is_daemon=False,
|
|
30
76
|
tcp_port=tcp_port,
|
|
31
77
|
unix_domain_socket=unix_domain_socket,
|
|
78
|
+
stop_event=stop_event,
|
|
79
|
+
_add_signal_handler=(not args.disable_signal_handlers),
|
|
32
80
|
)
|