snowpark-connect 0.27.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +680 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +237 -23
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +123 -5
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +85 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2748 -746
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +196 -255
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +255 -45
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
- snowflake/snowpark_connect/relation/read/map_read_json.py +320 -85
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +142 -27
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +82 -5
- snowflake/snowpark_connect/relation/read/map_read_text.py +18 -3
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +36 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +110 -48
- snowflake/snowpark_connect/server.py +546 -456
- snowflake/snowpark_connect/server_common/__init__.py +500 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +187 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +125 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +303 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +248 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +101 -332
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +163 -22
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-submit +2 -2
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/METADATA +14 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/RECORD +129 -167
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -21,17 +21,13 @@
|
|
|
21
21
|
# limitations under the License.
|
|
22
22
|
#
|
|
23
23
|
|
|
24
|
-
|
|
25
|
-
import logging
|
|
24
|
+
|
|
26
25
|
import os
|
|
27
|
-
import
|
|
28
|
-
import socket
|
|
26
|
+
import sys
|
|
29
27
|
import tempfile
|
|
30
28
|
import threading
|
|
31
|
-
import urllib.parse
|
|
32
|
-
import zipfile
|
|
33
29
|
from concurrent import futures
|
|
34
|
-
from typing import
|
|
30
|
+
from typing import Callable, Dict, List, Optional
|
|
35
31
|
|
|
36
32
|
import grpc
|
|
37
33
|
import jpype
|
|
@@ -41,14 +37,10 @@ import pyspark.sql.connect.proto.base_pb2_grpc as proto_base_grpc
|
|
|
41
37
|
import pyspark.sql.connect.proto.common_pb2 as common_proto
|
|
42
38
|
import pyspark.sql.connect.proto.relations_pb2 as relations_proto
|
|
43
39
|
import pyspark.sql.connect.proto.types_pb2 as types_proto
|
|
44
|
-
from packaging import version
|
|
45
40
|
from pyspark import StorageLevel
|
|
46
41
|
from pyspark.conf import SparkConf
|
|
47
|
-
from pyspark.errors import PySparkValueError
|
|
48
|
-
from pyspark.sql.connect.client.core import ChannelBuilder
|
|
49
42
|
from pyspark.sql.connect.session import SparkSession
|
|
50
43
|
|
|
51
|
-
import snowflake.snowpark_connect
|
|
52
44
|
import snowflake.snowpark_connect.proto.control_pb2_grpc as control_grpc
|
|
53
45
|
import snowflake.snowpark_connect.tcm as tcm
|
|
54
46
|
from snowflake import snowpark
|
|
@@ -56,7 +48,11 @@ from snowflake.snowpark_connect.analyze_plan.map_tree_string import map_tree_str
|
|
|
56
48
|
from snowflake.snowpark_connect.config import route_config_proto
|
|
57
49
|
from snowflake.snowpark_connect.constants import SERVER_SIDE_SESSION_ID
|
|
58
50
|
from snowflake.snowpark_connect.control_server import ControlServicer
|
|
59
|
-
from snowflake.snowpark_connect.error.
|
|
51
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
52
|
+
from snowflake.snowpark_connect.error.error_utils import (
|
|
53
|
+
attach_custom_error_code,
|
|
54
|
+
build_grpc_error_response,
|
|
55
|
+
)
|
|
60
56
|
from snowflake.snowpark_connect.execute_plan.map_execution_command import (
|
|
61
57
|
map_execution_command,
|
|
62
58
|
)
|
|
@@ -66,7 +62,26 @@ from snowflake.snowpark_connect.execute_plan.map_execution_root import (
|
|
|
66
62
|
from snowflake.snowpark_connect.relation.map_local_relation import map_local_relation
|
|
67
63
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
68
64
|
from snowflake.snowpark_connect.relation.utils import get_semantic_string
|
|
69
|
-
from snowflake.snowpark_connect.resources_initializer import
|
|
65
|
+
from snowflake.snowpark_connect.resources_initializer import initialize_resources
|
|
66
|
+
from snowflake.snowpark_connect.server_common import ( # noqa: F401 - re-exported for public API
|
|
67
|
+
_SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE,
|
|
68
|
+
_client_telemetry_context,
|
|
69
|
+
_disable_protobuf_recursion_limit,
|
|
70
|
+
_get_default_grpc_options,
|
|
71
|
+
_reset_server_run_state,
|
|
72
|
+
_setup_spark_environment,
|
|
73
|
+
_stop_server,
|
|
74
|
+
configure_server_url,
|
|
75
|
+
get_client_url,
|
|
76
|
+
get_server_error,
|
|
77
|
+
get_server_running,
|
|
78
|
+
get_server_url,
|
|
79
|
+
get_session,
|
|
80
|
+
set_grpc_max_message_size,
|
|
81
|
+
set_server_error,
|
|
82
|
+
setup_signal_handlers,
|
|
83
|
+
validate_startup_parameters,
|
|
84
|
+
)
|
|
70
85
|
from snowflake.snowpark_connect.type_mapping import (
|
|
71
86
|
map_type_string_to_proto,
|
|
72
87
|
snowpark_to_proto_type,
|
|
@@ -82,12 +97,13 @@ from snowflake.snowpark_connect.utils.cache import (
|
|
|
82
97
|
df_cache_map_put_if_absent,
|
|
83
98
|
)
|
|
84
99
|
from snowflake.snowpark_connect.utils.context import (
|
|
100
|
+
clean_request_external_tables,
|
|
85
101
|
clear_context_data,
|
|
86
|
-
|
|
87
|
-
|
|
102
|
+
get_request_external_tables,
|
|
103
|
+
get_spark_session_id,
|
|
104
|
+
set_spark_session_id,
|
|
88
105
|
set_spark_version,
|
|
89
106
|
)
|
|
90
|
-
from snowflake.snowpark_connect.utils.env_utils import get_int_from_env
|
|
91
107
|
from snowflake.snowpark_connect.utils.external_udxf_cache import (
|
|
92
108
|
clear_external_udxf_cache,
|
|
93
109
|
)
|
|
@@ -96,7 +112,25 @@ from snowflake.snowpark_connect.utils.interrupt import (
|
|
|
96
112
|
interrupt_queries_with_tag,
|
|
97
113
|
interrupt_query,
|
|
98
114
|
)
|
|
99
|
-
from snowflake.snowpark_connect.utils.
|
|
115
|
+
from snowflake.snowpark_connect.utils.java_stored_procedure import (
|
|
116
|
+
set_java_udf_creator_initialized_state,
|
|
117
|
+
)
|
|
118
|
+
from snowflake.snowpark_connect.utils.open_telemetry import (
|
|
119
|
+
is_telemetry_enabled,
|
|
120
|
+
otel_attach_context,
|
|
121
|
+
otel_create_context_wrapper,
|
|
122
|
+
otel_create_status,
|
|
123
|
+
otel_detach_context,
|
|
124
|
+
otel_end_root_span,
|
|
125
|
+
otel_flush_telemetry,
|
|
126
|
+
otel_get_current_span,
|
|
127
|
+
otel_get_root_span_context,
|
|
128
|
+
otel_get_status_code,
|
|
129
|
+
otel_get_tracer,
|
|
130
|
+
otel_initialize,
|
|
131
|
+
otel_start_span_as_current,
|
|
132
|
+
)
|
|
133
|
+
from snowflake.snowpark_connect.utils.profiling import PROFILING_ENABLED, profile_method
|
|
100
134
|
from snowflake.snowpark_connect.utils.session import (
|
|
101
135
|
configure_snowpark_session,
|
|
102
136
|
get_or_create_snowpark_session,
|
|
@@ -112,29 +146,111 @@ from snowflake.snowpark_connect.utils.telemetry import (
|
|
|
112
146
|
)
|
|
113
147
|
from snowflake.snowpark_connect.utils.xxhash64 import xxhash64_string
|
|
114
148
|
|
|
115
|
-
DEFAULT_PORT = 15002
|
|
116
149
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
150
|
+
def _store_client_stack_trace(client_stack_info):
|
|
151
|
+
"""Store client stack trace in thread-local storage"""
|
|
152
|
+
|
|
153
|
+
_client_telemetry_context.stack_trace = client_stack_info
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _clear_client_stack_trace():
|
|
157
|
+
"""Clear client stack trace"""
|
|
158
|
+
|
|
159
|
+
_client_telemetry_context.stack_trace = None
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _get_client_stack_trace():
|
|
163
|
+
"""Get current client stack trace"""
|
|
164
|
+
|
|
165
|
+
return getattr(_client_telemetry_context, "stack_trace", None)
|
|
166
|
+
|
|
121
167
|
|
|
168
|
+
def _add_client_stack_trace_to_span(span, client_stack):
|
|
169
|
+
"""
|
|
170
|
+
Add formatted client stack trace to a specific span.
|
|
122
171
|
|
|
123
|
-
|
|
172
|
+
Args:
|
|
173
|
+
span: The OpenTelemetry span to add the stack trace attribute to
|
|
174
|
+
client_stack: The client stack trace data (list of frame dicts)
|
|
124
175
|
"""
|
|
125
|
-
|
|
126
|
-
|
|
176
|
+
if not client_stack or not span or not span.is_recording():
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
stack_frames = []
|
|
180
|
+
for frame in client_stack:
|
|
181
|
+
if frame.get("file_name") and frame.get("line_number"):
|
|
182
|
+
method = frame.get("method_name", "unknown")
|
|
183
|
+
location = f"{frame.get('file_name')}:{frame.get('line_number')}"
|
|
184
|
+
stack_frames.append(f"{method} at {location}")
|
|
185
|
+
|
|
186
|
+
if stack_frames:
|
|
187
|
+
span.set_attribute("client.stack_trace", " <- ".join(stack_frames))
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _process_and_store_client_stack_trace(request, add_to_span: bool = False):
|
|
127
191
|
"""
|
|
128
|
-
|
|
192
|
+
Extract, store, and optionally add client stack trace to the current span.
|
|
129
193
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
194
|
+
Args:
|
|
195
|
+
request: The gRPC request containing user context with stack trace
|
|
196
|
+
add_to_span: If True, format and add stack trace as span attribute to current span
|
|
133
197
|
|
|
134
|
-
|
|
135
|
-
|
|
198
|
+
Returns:
|
|
199
|
+
The extracted client_stack (or None) for use in ExecutePlan
|
|
200
|
+
"""
|
|
201
|
+
# Extract and store client stack trace information for telemetry
|
|
202
|
+
client_stack = _extract_and_log_user_stack_trace(request)
|
|
203
|
+
if client_stack:
|
|
204
|
+
_store_client_stack_trace(client_stack)
|
|
136
205
|
|
|
137
|
-
|
|
206
|
+
# Set span attribute with formatted stack trace (if requested and available)
|
|
207
|
+
if add_to_span and client_stack:
|
|
208
|
+
root_span_otel_context = otel_get_root_span_context()
|
|
209
|
+
if root_span_otel_context is not None and is_telemetry_enabled():
|
|
210
|
+
current_span = otel_get_current_span()
|
|
211
|
+
if current_span and current_span.is_recording():
|
|
212
|
+
_add_client_stack_trace_to_span(current_span, client_stack)
|
|
213
|
+
|
|
214
|
+
return client_stack
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _extract_and_log_user_stack_trace(request):
|
|
218
|
+
"""
|
|
219
|
+
Extract and log user stack trace information from request extensions.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
request: The gRPC request containing user_context.extensions
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
List of stack trace frames or None if no traces found
|
|
226
|
+
"""
|
|
227
|
+
try:
|
|
228
|
+
from snowflake.snowpark_connect.utils.patch_spark_line_number import (
|
|
229
|
+
extract_stack_trace_from_extensions,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if hasattr(request, "user_context") and hasattr(
|
|
233
|
+
request.user_context, "extensions"
|
|
234
|
+
):
|
|
235
|
+
stack_traces = extract_stack_trace_from_extensions(
|
|
236
|
+
request.user_context.extensions
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if stack_traces:
|
|
240
|
+
logger.debug("User code stack trace:")
|
|
241
|
+
for i, frame in enumerate(stack_traces):
|
|
242
|
+
logger.debug(
|
|
243
|
+
f" Frame {i}: {frame.get('method_name', 'unknown')} "
|
|
244
|
+
f"at {frame.get('file_name', 'unknown')}:{frame.get('line_number', 'unknown')}"
|
|
245
|
+
)
|
|
246
|
+
return stack_traces # Return the stack traces for telemetry use
|
|
247
|
+
else:
|
|
248
|
+
logger.debug("No user stack trace information found in request")
|
|
249
|
+
return None
|
|
250
|
+
except Exception as e:
|
|
251
|
+
# Don't let stack trace extraction errors affect the main request
|
|
252
|
+
logger.debug(f"Failed to extract user stack trace: {e}")
|
|
253
|
+
return None
|
|
138
254
|
|
|
139
255
|
|
|
140
256
|
def _handle_exception(context, e: Exception):
|
|
@@ -147,16 +263,15 @@ def _handle_exception(context, e: Exception):
|
|
|
147
263
|
if show_traceback:
|
|
148
264
|
# Show detailed traceback (includes error info naturally)
|
|
149
265
|
error_traceback = traceback.format_exc()
|
|
150
|
-
|
|
151
|
-
logger.error(sanitized_traceback)
|
|
266
|
+
logger.error(error_traceback)
|
|
152
267
|
else:
|
|
153
268
|
# Show only basic error information, no traceback
|
|
154
269
|
logger.error("Error: %s - %s", type(e).__name__, str(e))
|
|
155
270
|
|
|
156
271
|
telemetry.report_request_failure(e)
|
|
157
|
-
|
|
158
272
|
if tcm.TCM_MODE:
|
|
159
|
-
#
|
|
273
|
+
# spark decoder will catch the error and return it to GS gracefully
|
|
274
|
+
attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
|
|
160
275
|
raise e
|
|
161
276
|
|
|
162
277
|
from grpc_status import rpc_status
|
|
@@ -165,14 +280,71 @@ def _handle_exception(context, e: Exception):
|
|
|
165
280
|
context.abort_with_status(rpc_status.to_status(rich_status))
|
|
166
281
|
|
|
167
282
|
|
|
283
|
+
# Decorator for creating method spans as children of root span
|
|
284
|
+
def _with_method_span(method_name):
|
|
285
|
+
"""
|
|
286
|
+
Decorator to create a new span as child of root span for gRPC methods and provide it as parent to Snowpark operations.
|
|
287
|
+
"""
|
|
288
|
+
|
|
289
|
+
def decorator(func):
|
|
290
|
+
def wrapper(*args, **kwargs):
|
|
291
|
+
# Get the root span context first
|
|
292
|
+
root_span_otel_context = otel_get_root_span_context()
|
|
293
|
+
|
|
294
|
+
# Only proceed if BOTH conditions are true
|
|
295
|
+
if root_span_otel_context is not None and is_telemetry_enabled():
|
|
296
|
+
# Attach the root context first, then create child span
|
|
297
|
+
context_token = otel_attach_context(root_span_otel_context)
|
|
298
|
+
|
|
299
|
+
try:
|
|
300
|
+
tracer = otel_get_tracer(__name__)
|
|
301
|
+
span_name = f"snowpark_connect.{method_name}"
|
|
302
|
+
|
|
303
|
+
# Create span as child of the root span context
|
|
304
|
+
span_context_mgr = otel_start_span_as_current(tracer, span_name)
|
|
305
|
+
if span_context_mgr:
|
|
306
|
+
with span_context_mgr as span:
|
|
307
|
+
try:
|
|
308
|
+
# Execute the method with the new span as current context
|
|
309
|
+
return func(*args, **kwargs)
|
|
310
|
+
|
|
311
|
+
except Exception as e:
|
|
312
|
+
# Record the exception in the span
|
|
313
|
+
span.record_exception(e)
|
|
314
|
+
StatusCode = otel_get_status_code()
|
|
315
|
+
if StatusCode:
|
|
316
|
+
status = otel_create_status(
|
|
317
|
+
StatusCode.ERROR, str(e)
|
|
318
|
+
)
|
|
319
|
+
if status:
|
|
320
|
+
span.set_status(status)
|
|
321
|
+
raise
|
|
322
|
+
else:
|
|
323
|
+
# No span created, just execute the function
|
|
324
|
+
return func(*args, **kwargs)
|
|
325
|
+
|
|
326
|
+
finally:
|
|
327
|
+
# Always detach the root context
|
|
328
|
+
if context_token is not None:
|
|
329
|
+
otel_detach_context(context_token)
|
|
330
|
+
else:
|
|
331
|
+
# No root context available or OTel not available, execute without span
|
|
332
|
+
return func(*args, **kwargs)
|
|
333
|
+
|
|
334
|
+
return wrapper
|
|
335
|
+
|
|
336
|
+
return decorator
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
# Snowflake Connect gRPC Service Implementation
|
|
168
340
|
class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
169
341
|
def __init__(
|
|
170
342
|
self,
|
|
171
343
|
log_request_fn: Optional[Callable[[bytearray], None]] = None,
|
|
172
344
|
) -> None:
|
|
173
345
|
self.log_request_fn = log_request_fn
|
|
174
|
-
# Trigger
|
|
175
|
-
|
|
346
|
+
# Trigger synchronous initialization here, so that we reduce overhead for rpc calls.
|
|
347
|
+
initialize_resources()
|
|
176
348
|
|
|
177
349
|
@profile_method
|
|
178
350
|
def ExecutePlan(self, request: proto_base.ExecutePlanRequest, context):
|
|
@@ -181,20 +353,45 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
181
353
|
It is guaranteed that there is at least one ARROW batch returned even if the result set is empty.
|
|
182
354
|
"""
|
|
183
355
|
logger.info("ExecutePlan")
|
|
356
|
+
|
|
357
|
+
client_stack = _process_and_store_client_stack_trace(request, add_to_span=False)
|
|
358
|
+
|
|
184
359
|
if self.log_request_fn is not None:
|
|
185
360
|
self.log_request_fn(request.SerializeToString())
|
|
186
361
|
|
|
187
362
|
# TODO: remove session id context when we host this in Snowflake server
|
|
188
363
|
# set the thread-local context of session id
|
|
189
364
|
clear_context_data()
|
|
190
|
-
|
|
365
|
+
set_spark_session_id(request.session_id)
|
|
191
366
|
set_spark_version(request.client_type)
|
|
192
367
|
telemetry.initialize_request_summary(request)
|
|
193
368
|
|
|
194
369
|
set_query_tags(request.tags)
|
|
195
370
|
|
|
196
|
-
|
|
371
|
+
# Additional context attachment for Snowpark DataFrame operations
|
|
372
|
+
snowpark_context_token = None
|
|
373
|
+
span = None
|
|
374
|
+
span_context_manager = None
|
|
197
375
|
try:
|
|
376
|
+
root_span_otel_context = otel_get_root_span_context()
|
|
377
|
+
|
|
378
|
+
if root_span_otel_context is not None and is_telemetry_enabled():
|
|
379
|
+
snowpark_context_token = otel_attach_context(root_span_otel_context)
|
|
380
|
+
|
|
381
|
+
# Create span manually for generator function and make it current
|
|
382
|
+
tracer = otel_get_tracer(__name__)
|
|
383
|
+
span_context_manager = otel_start_span_as_current(
|
|
384
|
+
tracer, "snowpark_connect.ExecutePlan"
|
|
385
|
+
)
|
|
386
|
+
span = None
|
|
387
|
+
if span_context_manager:
|
|
388
|
+
span = (
|
|
389
|
+
span_context_manager.__enter__()
|
|
390
|
+
) # Start the span context AND make it current
|
|
391
|
+
# Add stack trace to this manually created span
|
|
392
|
+
_add_client_stack_trace_to_span(span, client_stack)
|
|
393
|
+
|
|
394
|
+
result_iter = iter(())
|
|
198
395
|
match request.plan.WhichOneof("op_type"):
|
|
199
396
|
case "root":
|
|
200
397
|
logger.info("ROOT")
|
|
@@ -212,32 +409,60 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
212
409
|
result_complete=proto_base.ExecutePlanResponse.ResultComplete(),
|
|
213
410
|
)
|
|
214
411
|
except Exception as e:
|
|
412
|
+
if span:
|
|
413
|
+
span.record_exception(e)
|
|
414
|
+
StatusCode = otel_get_status_code()
|
|
415
|
+
if StatusCode:
|
|
416
|
+
status = otel_create_status(StatusCode.ERROR, str(e))
|
|
417
|
+
if status:
|
|
418
|
+
span.set_status(status)
|
|
215
419
|
_handle_exception(context, e)
|
|
216
420
|
finally:
|
|
421
|
+
if span_context_manager:
|
|
422
|
+
span_context_manager.__exit__(None, None, None) # End the span
|
|
423
|
+
if snowpark_context_token is not None:
|
|
424
|
+
otel_detach_context(snowpark_context_token)
|
|
425
|
+
# Clear client stack trace when request is done
|
|
426
|
+
_clear_client_stack_trace()
|
|
427
|
+
otel_flush_telemetry()
|
|
428
|
+
self._cleanup_external_tables()
|
|
217
429
|
telemetry.send_request_summary_telemetry()
|
|
218
430
|
|
|
219
431
|
@profile_method
|
|
432
|
+
@_with_method_span("AnalyzePlan")
|
|
220
433
|
def AnalyzePlan(self, request: proto_base.AnalyzePlanRequest, context):
|
|
221
434
|
"""Analyzes a query and returns a [[AnalyzeResponse]] containing metadata about the query."""
|
|
222
435
|
logger.info(f"AnalyzePlan: {request.WhichOneof('analyze')}")
|
|
436
|
+
|
|
437
|
+
_process_and_store_client_stack_trace(request, add_to_span=True)
|
|
438
|
+
|
|
223
439
|
if self.log_request_fn is not None:
|
|
224
440
|
self.log_request_fn(request.SerializeToString())
|
|
441
|
+
|
|
225
442
|
try:
|
|
226
443
|
# TODO: remove session id context when we host this in Snowflake server
|
|
227
444
|
# set the thread-local context of session id
|
|
228
445
|
clear_context_data()
|
|
229
|
-
|
|
446
|
+
set_spark_session_id(request.session_id)
|
|
230
447
|
set_spark_version(request.client_type)
|
|
231
448
|
telemetry.initialize_request_summary(request)
|
|
232
449
|
match request.WhichOneof("analyze"):
|
|
233
450
|
case "schema":
|
|
234
451
|
result = map_relation(request.schema.plan.root)
|
|
235
|
-
|
|
236
|
-
|
|
452
|
+
|
|
453
|
+
from snowflake.snowpark_connect.relation.read.metadata_utils import (
|
|
454
|
+
without_internal_columns,
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
filtered_result = without_internal_columns(result)
|
|
458
|
+
filtered_df = filtered_result.dataframe
|
|
459
|
+
|
|
237
460
|
schema = proto_base.AnalyzePlanResponse.Schema(
|
|
238
461
|
schema=types_proto.DataType(
|
|
239
462
|
**snowpark_to_proto_type(
|
|
240
|
-
|
|
463
|
+
filtered_df.schema,
|
|
464
|
+
filtered_result.column_map,
|
|
465
|
+
filtered_df,
|
|
241
466
|
)
|
|
242
467
|
)
|
|
243
468
|
)
|
|
@@ -274,10 +499,15 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
274
499
|
plan_id = request.persist.relation.common.plan_id
|
|
275
500
|
# cache the plan if it is not already in the map
|
|
276
501
|
|
|
502
|
+
from snowflake.snowpark_connect.relation.read.metadata_utils import (
|
|
503
|
+
without_internal_columns,
|
|
504
|
+
)
|
|
505
|
+
|
|
277
506
|
df_cache_map_put_if_absent(
|
|
278
507
|
(request.session_id, plan_id),
|
|
279
|
-
lambda:
|
|
280
|
-
|
|
508
|
+
lambda: without_internal_columns(
|
|
509
|
+
map_relation(request.persist.relation)
|
|
510
|
+
),
|
|
281
511
|
)
|
|
282
512
|
|
|
283
513
|
storage_level = request.persist.storage_level
|
|
@@ -366,15 +596,24 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
366
596
|
),
|
|
367
597
|
)
|
|
368
598
|
case _:
|
|
369
|
-
|
|
599
|
+
exception = SnowparkConnectNotImplementedError(
|
|
370
600
|
f"ANALYZE PLAN NOT IMPLEMENTED:\n{request}"
|
|
371
601
|
)
|
|
602
|
+
attach_custom_error_code(
|
|
603
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
604
|
+
)
|
|
605
|
+
raise exception
|
|
372
606
|
except Exception as e:
|
|
373
607
|
_handle_exception(context, e)
|
|
374
608
|
finally:
|
|
609
|
+
# Clear client stack trace when request is done
|
|
610
|
+
_clear_client_stack_trace()
|
|
611
|
+
otel_flush_telemetry()
|
|
612
|
+
self._cleanup_external_tables()
|
|
375
613
|
telemetry.send_request_summary_telemetry()
|
|
376
614
|
|
|
377
615
|
@staticmethod
|
|
616
|
+
@_with_method_span("Config")
|
|
378
617
|
def Config(
|
|
379
618
|
request: proto_base.ConfigRequest,
|
|
380
619
|
context,
|
|
@@ -389,12 +628,18 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
389
628
|
):
|
|
390
629
|
"""Update or fetch the configurations and returns a [[ConfigResponse]] containing the result."""
|
|
391
630
|
logger.info("Config")
|
|
631
|
+
|
|
632
|
+
_process_and_store_client_stack_trace(request, add_to_span=True)
|
|
633
|
+
|
|
392
634
|
try:
|
|
393
635
|
telemetry.initialize_request_summary(request)
|
|
394
636
|
return route_config_proto(request, get_or_create_snowpark_session())
|
|
395
637
|
except Exception as e:
|
|
396
638
|
_handle_exception(context, e)
|
|
397
639
|
finally:
|
|
640
|
+
# Clear client stack trace when request is done
|
|
641
|
+
_clear_client_stack_trace()
|
|
642
|
+
otel_flush_telemetry()
|
|
398
643
|
telemetry.send_request_summary_telemetry()
|
|
399
644
|
|
|
400
645
|
def AddArtifacts(self, request_iterator, context):
|
|
@@ -402,11 +647,9 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
402
647
|
the added artifacts.
|
|
403
648
|
"""
|
|
404
649
|
logger.info("AddArtifacts")
|
|
650
|
+
|
|
405
651
|
session: snowpark.Session = get_or_create_snowpark_session()
|
|
406
|
-
filenames: dict[str, str] = {}
|
|
407
652
|
response: dict[str, proto_base.AddArtifactsResponse.ArtifactSummary] = {}
|
|
408
|
-
# Store accumulated data for local relation cache
|
|
409
|
-
cache_data: dict[str, bytearray] = {}
|
|
410
653
|
|
|
411
654
|
def _try_handle_local_relation(artifact_name: str, data: bytes):
|
|
412
655
|
"""
|
|
@@ -422,12 +665,14 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
422
665
|
) # heuristic to identify local relations
|
|
423
666
|
|
|
424
667
|
def _handle_regular_artifact():
|
|
425
|
-
|
|
668
|
+
artifact = write_artifact(
|
|
426
669
|
session,
|
|
427
670
|
artifact_name,
|
|
428
671
|
data,
|
|
429
672
|
overwrite=True,
|
|
430
673
|
)
|
|
674
|
+
with session._filenames_lock:
|
|
675
|
+
session._filenames[get_spark_session_id()][artifact_name] = artifact
|
|
431
676
|
|
|
432
677
|
if is_likely_local_relation:
|
|
433
678
|
try:
|
|
@@ -435,9 +680,8 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
435
680
|
l_relation.ParseFromString(data)
|
|
436
681
|
relation = relations_proto.Relation(local_relation=l_relation)
|
|
437
682
|
df_cache_map_put_if_absent(
|
|
438
|
-
(
|
|
683
|
+
(get_spark_session_id(), artifact_name.replace("cache/", "")),
|
|
439
684
|
lambda: map_local_relation(relation), # noqa: B023
|
|
440
|
-
materialize=True,
|
|
441
685
|
)
|
|
442
686
|
except Exception as e:
|
|
443
687
|
logger.warning("Failed to put df into cache: %s", str(e))
|
|
@@ -458,29 +702,46 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
458
702
|
# Batch artifacts are sent as a single "batch" message containing a list of
|
|
459
703
|
# artifacts. We do not need to keep track of the name since it is included in
|
|
460
704
|
# each artifact.
|
|
461
|
-
|
|
705
|
+
|
|
462
706
|
for request in request_iterator:
|
|
463
707
|
clear_context_data()
|
|
464
|
-
|
|
708
|
+
set_spark_session_id(request.session_id)
|
|
465
709
|
set_spark_version(request.client_type)
|
|
710
|
+
with session._filenames_lock:
|
|
711
|
+
if request.session_id not in session._filenames:
|
|
712
|
+
session._filenames[request.session_id] = {}
|
|
713
|
+
|
|
466
714
|
match request.WhichOneof("payload"):
|
|
467
715
|
case "begin_chunk":
|
|
468
716
|
current_name = request.begin_chunk.name
|
|
469
|
-
|
|
470
|
-
current_name
|
|
471
|
-
|
|
717
|
+
current_chunk = {
|
|
718
|
+
"name": current_name,
|
|
719
|
+
"num_chunks": request.begin_chunk.num_chunks,
|
|
720
|
+
"current_chunk_index": 1,
|
|
721
|
+
}
|
|
722
|
+
with session._filenames_lock:
|
|
723
|
+
assert (
|
|
724
|
+
current_name not in session._filenames[request.session_id]
|
|
725
|
+
), "Duplicate artifact name found."
|
|
472
726
|
|
|
473
727
|
if current_name.startswith("cache/"):
|
|
474
|
-
|
|
728
|
+
current_chunk["cache"] = bytearray(
|
|
475
729
|
request.begin_chunk.initial_chunk.data
|
|
476
730
|
)
|
|
477
731
|
else:
|
|
478
|
-
|
|
732
|
+
artifact = write_artifact(
|
|
479
733
|
session,
|
|
480
734
|
current_name,
|
|
481
735
|
request.begin_chunk.initial_chunk.data,
|
|
482
736
|
overwrite=True,
|
|
483
737
|
)
|
|
738
|
+
with session._filenames_lock:
|
|
739
|
+
session._filenames[request.session_id][
|
|
740
|
+
current_name
|
|
741
|
+
] = artifact
|
|
742
|
+
# cache current chunk
|
|
743
|
+
with session._current_chunk_lock:
|
|
744
|
+
session._current_chunk[request.session_id] = current_chunk
|
|
484
745
|
response[
|
|
485
746
|
current_name
|
|
486
747
|
] = proto_base.AddArtifactsResponse.ArtifactSummary(
|
|
@@ -491,18 +752,53 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
491
752
|
),
|
|
492
753
|
)
|
|
493
754
|
case "chunk":
|
|
755
|
+
# retrieve current chunk
|
|
756
|
+
with session._current_chunk_lock:
|
|
757
|
+
if request.session_id not in session._current_chunk:
|
|
758
|
+
exception = ValueError(
|
|
759
|
+
f"Received 'chunk' for session_id '{request.session_id}' without a prior 'begin_chunk'."
|
|
760
|
+
)
|
|
761
|
+
attach_custom_error_code(
|
|
762
|
+
exception, ErrorCodes.INTERNAL_ERROR
|
|
763
|
+
)
|
|
764
|
+
raise exception
|
|
765
|
+
current_chunk = session._current_chunk[request.session_id]
|
|
766
|
+
|
|
767
|
+
current_name = current_chunk["name"]
|
|
768
|
+
current_chunk["current_chunk_index"] += 1
|
|
494
769
|
if current_name.startswith("cache/"):
|
|
495
|
-
|
|
770
|
+
current_chunk["cache"].extend(request.chunk.data)
|
|
496
771
|
else:
|
|
497
|
-
|
|
772
|
+
artifact = write_artifact(
|
|
498
773
|
session, current_name, request.chunk.data
|
|
499
|
-
)
|
|
774
|
+
)
|
|
775
|
+
with session._filenames_lock:
|
|
776
|
+
assert (
|
|
777
|
+
session._filenames[request.session_id][current_name]
|
|
778
|
+
== artifact
|
|
779
|
+
), "Artifact staging error."
|
|
780
|
+
|
|
781
|
+
if (
|
|
782
|
+
current_chunk["current_chunk_index"]
|
|
783
|
+
== current_chunk["num_chunks"]
|
|
784
|
+
):
|
|
785
|
+
# all chunks are ready
|
|
786
|
+
if current_name.startswith("cache/"):
|
|
787
|
+
_try_handle_local_relation(
|
|
788
|
+
current_name, bytes(current_chunk["cache"])
|
|
789
|
+
)
|
|
790
|
+
with session._current_chunk_lock:
|
|
791
|
+
# remove current chunk from session
|
|
792
|
+
del session._current_chunk[request.session_id]
|
|
500
793
|
|
|
501
794
|
response[
|
|
502
795
|
current_name
|
|
503
796
|
] = proto_base.AddArtifactsResponse.ArtifactSummary(
|
|
504
797
|
name=current_name,
|
|
505
|
-
is_crc_successful=
|
|
798
|
+
is_crc_successful=(
|
|
799
|
+
current_name not in response
|
|
800
|
+
or response[current_name].is_crc_successful
|
|
801
|
+
)
|
|
506
802
|
and check_checksum(request.chunk.data, request.chunk.crc),
|
|
507
803
|
)
|
|
508
804
|
case "batch":
|
|
@@ -519,62 +815,89 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
519
815
|
),
|
|
520
816
|
)
|
|
521
817
|
case _:
|
|
522
|
-
|
|
818
|
+
exception = ValueError(
|
|
523
819
|
f"Unexpected payload type in AddArtifacts: {request.WhichOneof('payload')}"
|
|
524
820
|
)
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
821
|
+
attach_custom_error_code(
|
|
822
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
823
|
+
)
|
|
824
|
+
raise exception
|
|
825
|
+
|
|
826
|
+
# if current chunk is still not finished, just return here
|
|
827
|
+
# This should only happen in TCM since we have to send request via rest one by one so current chunk cannot be
|
|
828
|
+
# finished in one iteration
|
|
829
|
+
with session._current_chunk_lock:
|
|
830
|
+
if request.session_id in session._current_chunk:
|
|
831
|
+
return proto_base.AddArtifactsResponse(
|
|
832
|
+
artifacts=list(response.values())
|
|
833
|
+
)
|
|
528
834
|
|
|
529
835
|
class_files: dict[str, str] = {}
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
if name.startswith("cache"):
|
|
548
|
-
continue
|
|
549
|
-
|
|
550
|
-
# Remove temporary stored files which are put on the stage
|
|
551
|
-
os.remove(filepath)
|
|
552
|
-
|
|
553
|
-
# Add only files marked to be used in user generated Python UDFs.
|
|
554
|
-
cached_name = f"{session.get_session_stage()}/{filepath.split('/')[-1]}"
|
|
555
|
-
if not name.startswith("pyfiles") and cached_name in session._python_files:
|
|
556
|
-
session._python_files.remove(cached_name)
|
|
557
|
-
elif name.startswith("pyfiles"):
|
|
558
|
-
session._python_files.add(cached_name)
|
|
559
|
-
|
|
560
|
-
if not name.startswith("pyfiles"):
|
|
561
|
-
session._import_files.add(cached_name)
|
|
562
|
-
|
|
563
|
-
if class_files:
|
|
564
|
-
write_class_files_to_stage(session, class_files)
|
|
836
|
+
with session._filenames_lock:
|
|
837
|
+
for (name, filepath) in session._filenames[get_spark_session_id()].items():
|
|
838
|
+
if name.endswith(".class"):
|
|
839
|
+
# name is <dir>/<package>/<class_name>
|
|
840
|
+
# we don't need the dir name, but require the package, so only remove dir
|
|
841
|
+
if os.name != "nt":
|
|
842
|
+
class_files[name.split("/", 1)[-1]] = filepath
|
|
843
|
+
else:
|
|
844
|
+
class_files[name.split("\\", 1)[-1]] = filepath
|
|
845
|
+
continue
|
|
846
|
+
session.file.put(
|
|
847
|
+
filepath,
|
|
848
|
+
session.get_session_stage(),
|
|
849
|
+
auto_compress=False,
|
|
850
|
+
overwrite=True,
|
|
851
|
+
source_compression="GZIP" if name.endswith(".gz") else "NONE",
|
|
852
|
+
)
|
|
565
853
|
|
|
566
|
-
|
|
567
|
-
|
|
854
|
+
if name.startswith("cache"):
|
|
855
|
+
continue
|
|
856
|
+
|
|
857
|
+
# Add only files marked to be used in user generated Python UDFs.
|
|
858
|
+
cached_name = f"{session.get_session_stage()}/{filepath.split('/')[-1]}"
|
|
859
|
+
if (
|
|
860
|
+
not name.startswith("pyfiles")
|
|
861
|
+
and cached_name in session._python_files
|
|
862
|
+
):
|
|
863
|
+
session._python_files.remove(cached_name)
|
|
864
|
+
elif name.startswith("pyfiles"):
|
|
865
|
+
session._python_files.add(cached_name)
|
|
866
|
+
|
|
867
|
+
if name.startswith("jars/"):
|
|
868
|
+
session._artifact_jars.add(cached_name)
|
|
869
|
+
# Recreate the Java procedure to reload jars
|
|
870
|
+
set_java_udf_creator_initialized_state(False)
|
|
871
|
+
elif not name.startswith("pyfiles"):
|
|
872
|
+
session._import_files.add(cached_name)
|
|
873
|
+
|
|
874
|
+
# Remove temporary stored files which are put on the stage
|
|
875
|
+
os.remove(filepath)
|
|
876
|
+
|
|
877
|
+
if class_files:
|
|
878
|
+
jar_name = write_class_files_to_stage(session, class_files)
|
|
879
|
+
session._artifact_jars.add(jar_name)
|
|
880
|
+
|
|
881
|
+
if any(
|
|
882
|
+
not name.startswith("cache")
|
|
883
|
+
for name in session._filenames[get_spark_session_id()].keys()
|
|
884
|
+
):
|
|
885
|
+
clear_external_udxf_cache(session)
|
|
886
|
+
|
|
887
|
+
# clear filenames for this session
|
|
888
|
+
session._filenames[get_spark_session_id()] = {}
|
|
568
889
|
|
|
569
890
|
return proto_base.AddArtifactsResponse(artifacts=list(response.values()))
|
|
570
891
|
|
|
571
892
|
def ArtifactStatus(self, request, context):
|
|
572
893
|
"""Check statuses of artifacts in the session and returns them in a [[ArtifactStatusesResponse]]"""
|
|
573
894
|
logger.info("ArtifactStatus")
|
|
895
|
+
|
|
574
896
|
clear_context_data()
|
|
575
|
-
|
|
897
|
+
set_spark_session_id(request.session_id)
|
|
576
898
|
set_spark_version(request.client_type)
|
|
577
899
|
session: snowpark.Session = get_or_create_snowpark_session()
|
|
900
|
+
|
|
578
901
|
if os.name != "nt":
|
|
579
902
|
tmp_path = f"/tmp/sas-{session.session_id}/"
|
|
580
903
|
else:
|
|
@@ -583,7 +906,7 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
583
906
|
def _is_local_relation_cached(name: str) -> bool:
|
|
584
907
|
if name.startswith("cache/"):
|
|
585
908
|
hash = name.replace("cache/", "")
|
|
586
|
-
cached_df = df_cache_map_get((
|
|
909
|
+
cached_df = df_cache_map_get((get_spark_session_id(), hash))
|
|
587
910
|
return cached_df is not None
|
|
588
911
|
return False
|
|
589
912
|
|
|
@@ -618,6 +941,7 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
618
941
|
# instead of using operation ids, we're relying on Snowflake query ids here, meaning that:
|
|
619
942
|
# - The list of returned interrupted_ids contains query ids of interrupted jobs, instead of their operation ids
|
|
620
943
|
# - INTERRUPT_TYPE_OPERATION_ID interrupt type expects a Snowflake query id instead of an operation id
|
|
944
|
+
|
|
621
945
|
try:
|
|
622
946
|
match request.interrupt_type:
|
|
623
947
|
case proto_base.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL:
|
|
@@ -627,9 +951,13 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
627
951
|
case proto_base.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID:
|
|
628
952
|
interrupted_ids = interrupt_query(request.operation_id)
|
|
629
953
|
case _:
|
|
630
|
-
|
|
954
|
+
exception = SnowparkConnectNotImplementedError(
|
|
631
955
|
f"INTERRUPT NOT IMPLEMENTED:\n{request}"
|
|
632
956
|
)
|
|
957
|
+
attach_custom_error_code(
|
|
958
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
959
|
+
)
|
|
960
|
+
raise exception
|
|
633
961
|
|
|
634
962
|
return proto_base.InterruptResponse(
|
|
635
963
|
session_id=request.session_id,
|
|
@@ -647,9 +975,12 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
647
975
|
continue. If there is a ResultComplete, the client should use ReleaseExecute with
|
|
648
976
|
"""
|
|
649
977
|
logger.info("ReattachExecute")
|
|
650
|
-
|
|
978
|
+
|
|
979
|
+
exception = SnowparkConnectNotImplementedError(
|
|
651
980
|
"Spark client has detached, please resubmit request. In a future version, the server will be support the reattach."
|
|
652
981
|
)
|
|
982
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
983
|
+
raise exception
|
|
653
984
|
|
|
654
985
|
def ReleaseExecute(self, request: proto_base.ReleaseExecuteRequest, context):
|
|
655
986
|
"""Release an reattachable execution, or parts thereof.
|
|
@@ -666,6 +997,18 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
666
997
|
except Exception as e:
|
|
667
998
|
_handle_exception(context, e)
|
|
668
999
|
|
|
1000
|
+
def _cleanup_external_tables(self):
|
|
1001
|
+
external_tables = get_request_external_tables()
|
|
1002
|
+
if not external_tables:
|
|
1003
|
+
return
|
|
1004
|
+
session: snowpark.Session = get_or_create_snowpark_session()
|
|
1005
|
+
for table in external_tables:
|
|
1006
|
+
try:
|
|
1007
|
+
session.sql(f"DROP EXTERNAL TABLE IF EXISTS {table}").collect()
|
|
1008
|
+
except Exception as e:
|
|
1009
|
+
logger.warning(f"Failed to drop external table {table}: {e}")
|
|
1010
|
+
clean_request_external_tables()
|
|
1011
|
+
|
|
669
1012
|
# TODO: These are required in Spark 4.x.
|
|
670
1013
|
# def ReleaseSession(self, request, context):
|
|
671
1014
|
# """Release a session.
|
|
@@ -682,39 +1025,16 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
682
1025
|
# return super().FetchErrorDetails(request, context)
|
|
683
1026
|
|
|
684
1027
|
|
|
685
|
-
# Global state related to server connection
|
|
686
|
-
_server_running: threading.Event = threading.Event()
|
|
687
|
-
_server_error: bool = False
|
|
688
|
-
_server_url: Optional[str] = None
|
|
689
|
-
_client_url: Optional[str] = None
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
# Used to reset server global state to the initial blank slate state if error happens during server startup.
|
|
693
|
-
# Called after the startup error is caught and handled / logged etc.
|
|
694
|
-
def _reset_server_run_state():
|
|
695
|
-
global _server_running, _server_error, _server_url, _client_url
|
|
696
|
-
_server_running.clear()
|
|
697
|
-
_server_error = False
|
|
698
|
-
_server_url = None
|
|
699
|
-
_client_url = None
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
def _stop_server(stop_event: threading.Event, server: grpc.Server):
|
|
703
|
-
stop_event.wait()
|
|
704
|
-
server.stop(0)
|
|
705
|
-
_reset_server_run_state()
|
|
706
|
-
logger.info("server stop sent")
|
|
707
|
-
|
|
708
|
-
|
|
709
1028
|
def _serve(
|
|
710
1029
|
stop_event: Optional[threading.Event] = None,
|
|
711
1030
|
session: Optional[snowpark.Session] = None,
|
|
712
1031
|
):
|
|
713
|
-
|
|
1032
|
+
server_running = get_server_running()
|
|
714
1033
|
# TODO: factor out the Snowflake connection code.
|
|
715
1034
|
server = None
|
|
716
1035
|
try:
|
|
717
1036
|
config_snowpark()
|
|
1037
|
+
|
|
718
1038
|
if session is None:
|
|
719
1039
|
session = get_or_create_snowpark_session()
|
|
720
1040
|
else:
|
|
@@ -725,33 +1045,16 @@ def _serve(
|
|
|
725
1045
|
# No need to start grpc server in TCM
|
|
726
1046
|
return
|
|
727
1047
|
|
|
728
|
-
server_options =
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
_SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE,
|
|
734
|
-
),
|
|
735
|
-
),
|
|
736
|
-
(
|
|
737
|
-
"grpc.max_metadata_size",
|
|
738
|
-
get_int_from_env(
|
|
739
|
-
"SNOWFLAKE_GRPC_MAX_METADATA_SIZE",
|
|
740
|
-
_SPARK_CONNECT_GRPC_MAX_METADATA_SIZE,
|
|
741
|
-
),
|
|
742
|
-
),
|
|
743
|
-
(
|
|
744
|
-
"grpc.absolute_max_metadata_size",
|
|
745
|
-
get_int_from_env(
|
|
746
|
-
"SNOWFLAKE_GRPC_MAX_METADATA_SIZE",
|
|
747
|
-
_SPARK_CONNECT_GRPC_MAX_METADATA_SIZE,
|
|
748
|
-
)
|
|
749
|
-
* 2,
|
|
750
|
-
),
|
|
751
|
-
]
|
|
1048
|
+
server_options = _get_default_grpc_options()
|
|
1049
|
+
|
|
1050
|
+
# cProfile doesn't work correctly with multiple threads
|
|
1051
|
+
max_workers = 1 if PROFILING_ENABLED else 10
|
|
1052
|
+
|
|
752
1053
|
server = grpc.server(
|
|
753
|
-
futures.ThreadPoolExecutor(max_workers=
|
|
1054
|
+
futures.ThreadPoolExecutor(max_workers=max_workers),
|
|
1055
|
+
options=server_options,
|
|
754
1056
|
)
|
|
1057
|
+
|
|
755
1058
|
control_servicer = ControlServicer(session)
|
|
756
1059
|
proto_base_grpc.add_SparkConnectServiceServicer_to_server(
|
|
757
1060
|
SnowflakeConnectServicer(control_servicer.log_spark_connect_batch),
|
|
@@ -762,193 +1065,33 @@ def _serve(
|
|
|
762
1065
|
server.add_insecure_port(server_url)
|
|
763
1066
|
logger.info(f"Starting Snowpark Connect server on {server_url}...")
|
|
764
1067
|
server.start()
|
|
765
|
-
|
|
1068
|
+
server_running.set()
|
|
766
1069
|
logger.info("Snowpark Connect server started!")
|
|
767
1070
|
telemetry.send_server_started_telemetry()
|
|
1071
|
+
|
|
768
1072
|
if stop_event is not None:
|
|
769
1073
|
# start a background thread to listen for stop event and terminate the server
|
|
770
1074
|
threading.Thread(
|
|
771
1075
|
target=_stop_server, args=(stop_event, server), daemon=True
|
|
772
1076
|
).start()
|
|
1077
|
+
|
|
773
1078
|
server.wait_for_termination()
|
|
774
1079
|
except Exception as e:
|
|
775
|
-
|
|
776
|
-
|
|
1080
|
+
set_server_error(True)
|
|
1081
|
+
server_running.set() # unblock any client sessions
|
|
777
1082
|
if "Invalid connection_name 'spark-connect', known ones are " in str(e):
|
|
778
1083
|
logger.error(
|
|
779
1084
|
"Ensure 'spark-connect' connection config has been set correctly in connections.toml."
|
|
780
1085
|
)
|
|
781
1086
|
else:
|
|
782
1087
|
logger.error("Error starting up Snowpark Connect server", exc_info=True)
|
|
1088
|
+
attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
|
|
783
1089
|
raise e
|
|
784
1090
|
finally:
|
|
785
1091
|
# flush the telemetry queue if possible
|
|
786
1092
|
telemetry.shutdown()
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
def _set_remote_url(remote_url: str):
|
|
790
|
-
global _server_url, _client_url
|
|
791
|
-
_client_url = remote_url
|
|
792
|
-
parsed_url = urllib.parse.urlparse(remote_url)
|
|
793
|
-
if parsed_url.scheme == "sc":
|
|
794
|
-
_server_url = parsed_url.netloc
|
|
795
|
-
server_port = parsed_url.port or DEFAULT_PORT
|
|
796
|
-
_check_port_is_free(server_port)
|
|
797
|
-
elif parsed_url.scheme == "unix":
|
|
798
|
-
_server_url = remote_url.split("/;")[0]
|
|
799
|
-
else:
|
|
800
|
-
raise RuntimeError(f"Invalid Snowpark Connect URL: {remote_url}")
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
def _set_server_tcp_port(server_port: int):
|
|
804
|
-
global _server_url, _client_url
|
|
805
|
-
_check_port_is_free(server_port)
|
|
806
|
-
_server_url = f"[::]:{server_port}"
|
|
807
|
-
_client_url = f"sc://127.0.0.1:{server_port}"
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
def _check_port_is_free(port: int) -> None:
|
|
811
|
-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
812
|
-
s.settimeout(1)
|
|
813
|
-
if s.connect_ex(("127.0.0.1", port)) == 0:
|
|
814
|
-
raise RuntimeError(f"TCP port {port} is already in use")
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
def _set_server_unix_domain_socket(path: str):
|
|
818
|
-
global _server_url, _client_url
|
|
819
|
-
_server_url = f"unix:{path}"
|
|
820
|
-
_client_url = f"unix:{path}"
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
def get_server_url() -> str:
|
|
824
|
-
global _server_url
|
|
825
|
-
if not _server_url:
|
|
826
|
-
raise RuntimeError("Server URL not set")
|
|
827
|
-
return _server_url
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
def get_client_url() -> str:
|
|
831
|
-
global _client_url
|
|
832
|
-
if not _client_url:
|
|
833
|
-
raise RuntimeError("Client URL not set")
|
|
834
|
-
return _client_url
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
def _make_unix_domain_socket() -> str:
|
|
838
|
-
parent_dir = tempfile.mkdtemp()
|
|
839
|
-
server_path = os.path.join(parent_dir, "snowflake_sas_grpc.sock")
|
|
840
|
-
atexit.register(_cleanup_unix_domain_socket, server_path)
|
|
841
|
-
return server_path
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
def _cleanup_unix_domain_socket(server_path: str) -> None:
|
|
845
|
-
parent_dir = os.path.dirname(server_path)
|
|
846
|
-
if os.path.exists(server_path):
|
|
847
|
-
os.remove(server_path)
|
|
848
|
-
if os.path.exists(parent_dir):
|
|
849
|
-
os.rmdir(parent_dir)
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
class UnixDomainSocketChannelBuilder(ChannelBuilder):
|
|
853
|
-
"""
|
|
854
|
-
Spark Connect gRPC channel builder for Unix domain sockets
|
|
855
|
-
"""
|
|
856
|
-
|
|
857
|
-
def __init__(
|
|
858
|
-
self, url: str = None, channelOptions: Optional[List[Tuple[str, Any]]] = None
|
|
859
|
-
) -> None:
|
|
860
|
-
if url is None:
|
|
861
|
-
url = get_client_url()
|
|
862
|
-
if url[:6] != "unix:/" or len(url) < 7:
|
|
863
|
-
raise PySparkValueError(
|
|
864
|
-
error_class="INVALID_CONNECT_URL",
|
|
865
|
-
message_parameters={
|
|
866
|
-
"detail": "The URL must start with 'unix://'. Please update the URL to follow the correct format, e.g., 'unix://unix_domain_socket_path'.",
|
|
867
|
-
},
|
|
868
|
-
)
|
|
869
|
-
|
|
870
|
-
# Rewrite the URL to use http as the scheme so that we can leverage
|
|
871
|
-
# Python's built-in parser to parse URL parameters
|
|
872
|
-
fake_url = "http://" + url[6:]
|
|
873
|
-
self.url = urllib.parse.urlparse(fake_url)
|
|
874
|
-
self.params: Dict[str, str] = {}
|
|
875
|
-
self._extract_attributes()
|
|
876
|
-
|
|
877
|
-
# Now parse the real unix domain socket URL
|
|
878
|
-
self.url = urllib.parse.urlparse(url)
|
|
879
|
-
|
|
880
|
-
GRPC_DEFAULT_OPTIONS = [
|
|
881
|
-
("grpc.max_send_message_length", _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE),
|
|
882
|
-
("grpc.max_receive_message_length", _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE),
|
|
883
|
-
("grpc.max_metadata_size", _SPARK_CONNECT_GRPC_MAX_METADATA_SIZE),
|
|
884
|
-
(
|
|
885
|
-
"grpc.absolute_max_metadata_size",
|
|
886
|
-
2 * _SPARK_CONNECT_GRPC_MAX_METADATA_SIZE,
|
|
887
|
-
),
|
|
888
|
-
]
|
|
889
|
-
|
|
890
|
-
if channelOptions is None:
|
|
891
|
-
self._channel_options = GRPC_DEFAULT_OPTIONS
|
|
892
|
-
else:
|
|
893
|
-
self._channel_options = GRPC_DEFAULT_OPTIONS + channelOptions
|
|
894
|
-
# For Spark 4.0 support, but also backwards compatible.
|
|
895
|
-
self._params = self.params
|
|
896
|
-
|
|
897
|
-
def _extract_attributes(self) -> None:
|
|
898
|
-
"""Extract attributes from parameters.
|
|
899
|
-
|
|
900
|
-
This method was copied from
|
|
901
|
-
https://github.com/apache/spark/blob/branch-3.5/python/pyspark/sql/connect/client/core.py
|
|
902
|
-
|
|
903
|
-
This is required for Spark 4.0 support, since it is dropped in favor of moving
|
|
904
|
-
the extraction logic into the constructor.
|
|
905
|
-
"""
|
|
906
|
-
if len(self.url.params) > 0:
|
|
907
|
-
parts = self.url.params.split(";")
|
|
908
|
-
for p in parts:
|
|
909
|
-
kv = p.split("=")
|
|
910
|
-
if len(kv) != 2:
|
|
911
|
-
raise PySparkValueError(
|
|
912
|
-
error_class="INVALID_CONNECT_URL",
|
|
913
|
-
message_parameters={
|
|
914
|
-
"detail": f"Parameter '{p}' should be provided as a "
|
|
915
|
-
f"key-value pair separated by an equal sign (=). Please update "
|
|
916
|
-
f"the parameter to follow the correct format, e.g., 'key=value'.",
|
|
917
|
-
},
|
|
918
|
-
)
|
|
919
|
-
self.params[kv[0]] = urllib.parse.unquote(kv[1])
|
|
920
|
-
|
|
921
|
-
netloc = self.url.netloc.split(":")
|
|
922
|
-
if len(netloc) == 1:
|
|
923
|
-
self.host = netloc[0]
|
|
924
|
-
if version.parse(pyspark.__version__) >= version.parse("4.0.0"):
|
|
925
|
-
from pyspark.sql.connect.client.core import DefaultChannelBuilder
|
|
926
|
-
|
|
927
|
-
self.port = DefaultChannelBuilder.default_port()
|
|
928
|
-
else:
|
|
929
|
-
self.port = ChannelBuilder.default_port()
|
|
930
|
-
elif len(netloc) == 2:
|
|
931
|
-
self.host = netloc[0]
|
|
932
|
-
self.port = int(netloc[1])
|
|
933
|
-
else:
|
|
934
|
-
raise PySparkValueError(
|
|
935
|
-
error_class="INVALID_CONNECT_URL",
|
|
936
|
-
message_parameters={
|
|
937
|
-
"detail": f"Target destination '{self.url.netloc}' should match the "
|
|
938
|
-
f"'<host>:<port>' pattern. Please update the destination to follow "
|
|
939
|
-
f"the correct format, e.g., 'hostname:port'.",
|
|
940
|
-
},
|
|
941
|
-
)
|
|
942
|
-
|
|
943
|
-
# We override this to enable compatibility with Spark 4.0
|
|
944
|
-
host = None
|
|
945
|
-
|
|
946
|
-
@property
|
|
947
|
-
def endpoint(self) -> str:
|
|
948
|
-
return f"{self.url.scheme}:{self.url.path}"
|
|
949
|
-
|
|
950
|
-
def toChannel(self) -> grpc.Channel:
|
|
951
|
-
return grpc.insecure_channel(self.endpoint, options=self._channel_options)
|
|
1093
|
+
# End the root span when server shuts down completely
|
|
1094
|
+
otel_end_root_span()
|
|
952
1095
|
|
|
953
1096
|
|
|
954
1097
|
def config_snowpark() -> None:
|
|
@@ -977,12 +1120,24 @@ def start_jvm():
|
|
|
977
1120
|
if tcm.TCM_MODE:
|
|
978
1121
|
# No-op if JVM is already started in TCM mode
|
|
979
1122
|
return
|
|
980
|
-
|
|
1123
|
+
exception = RuntimeError(
|
|
981
1124
|
"JVM must not be running when starting the Spark Connect server"
|
|
982
1125
|
)
|
|
1126
|
+
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
1127
|
+
raise exception
|
|
983
1128
|
|
|
1129
|
+
import pathlib
|
|
1130
|
+
import zipfile
|
|
1131
|
+
|
|
1132
|
+
import snowflake.snowpark_connect
|
|
1133
|
+
|
|
1134
|
+
# Import both JAR dependency packages
|
|
1135
|
+
import snowpark_connect_deps_1
|
|
1136
|
+
import snowpark_connect_deps_2
|
|
1137
|
+
|
|
1138
|
+
# First, add JARs from includes/jars directory
|
|
984
1139
|
pyspark_jars = (
|
|
985
|
-
pathlib.Path(snowflake.snowpark_connect.__file__).parent / "includes/jars"
|
|
1140
|
+
pathlib.Path(snowflake.snowpark_connect.__file__).parent / "includes" / "jars"
|
|
986
1141
|
)
|
|
987
1142
|
|
|
988
1143
|
if "dataframe_processor.zip" in str(pyspark_jars):
|
|
@@ -991,18 +1146,31 @@ def start_jvm():
|
|
|
991
1146
|
snowflake.snowpark_connect.__file__
|
|
992
1147
|
).parent.parent.parent
|
|
993
1148
|
temp_dir = tempfile.gettempdir()
|
|
994
|
-
|
|
995
1149
|
extract_folder = "snowflake/snowpark_connect/includes/jars/" # Folder to extract (must end with '/')
|
|
996
1150
|
|
|
997
1151
|
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
|
998
1152
|
for member in zip_ref.namelist():
|
|
999
1153
|
if member.startswith(extract_folder):
|
|
1000
1154
|
zip_ref.extract(member, path=temp_dir)
|
|
1001
|
-
|
|
1002
1155
|
pyspark_jars = pathlib.Path(temp_dir) / extract_folder
|
|
1003
1156
|
|
|
1004
|
-
|
|
1005
|
-
|
|
1157
|
+
included_jar_names = set()
|
|
1158
|
+
|
|
1159
|
+
if pyspark_jars.exists():
|
|
1160
|
+
for jar_path in pyspark_jars.glob(
|
|
1161
|
+
"**/*.jar"
|
|
1162
|
+
): # Use **/*.jar to handle nested paths in TCM
|
|
1163
|
+
jpype.addClassPath(str(jar_path))
|
|
1164
|
+
included_jar_names.add(jar_path.name)
|
|
1165
|
+
|
|
1166
|
+
# Load jar files from both packages, skipping those already loaded from includes/jars
|
|
1167
|
+
jar_path_list = (
|
|
1168
|
+
snowpark_connect_deps_1.list_jars() + snowpark_connect_deps_2.list_jars()
|
|
1169
|
+
)
|
|
1170
|
+
for jar_path in jar_path_list:
|
|
1171
|
+
# Skip if this JAR was already loaded from includes/jars
|
|
1172
|
+
if jar_path.name not in included_jar_names:
|
|
1173
|
+
jpype.addClassPath(jar_path)
|
|
1006
1174
|
|
|
1007
1175
|
# TODO: Should remove convertStrings, but it breaks the JDBC code.
|
|
1008
1176
|
jvm_settings: list[str] = list(
|
|
@@ -1027,6 +1195,7 @@ def start_session(
|
|
|
1027
1195
|
snowpark_session: Optional[snowpark.Session] = None,
|
|
1028
1196
|
connection_parameters: Optional[Dict[str, str]] = None,
|
|
1029
1197
|
max_grpc_message_size: int = _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE,
|
|
1198
|
+
_add_signal_handler: bool = False,
|
|
1030
1199
|
) -> threading.Thread | None:
|
|
1031
1200
|
"""
|
|
1032
1201
|
Starts Spark Connect server connected to Snowflake. No-op if the Server is already running.
|
|
@@ -1048,147 +1217,80 @@ def start_session(
|
|
|
1048
1217
|
connection_parameters: A dictionary of connection parameters to use to create the Snowpark session. If this is
|
|
1049
1218
|
provided, the `snowpark_session` parameter must be None.
|
|
1050
1219
|
"""
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
_SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE = max_grpc_message_size
|
|
1220
|
+
# Increase recursion limit to 1100 (1000 by default)
|
|
1221
|
+
# introduced due to Scala OSS Test: org.apache.spark.sql.ClientE2ETestSuite.spark deep recursion
|
|
1222
|
+
sys.setrecursionlimit(1100)
|
|
1055
1223
|
|
|
1056
|
-
|
|
1224
|
+
# Apply PySpark Connect client patching for enhanced debugging (only if telemetry is enabled)
|
|
1225
|
+
from snowflake.snowpark_connect.utils.patch_spark_line_number import (
|
|
1226
|
+
patch_pyspark_connect,
|
|
1227
|
+
)
|
|
1057
1228
|
|
|
1058
|
-
|
|
1229
|
+
if is_telemetry_enabled():
|
|
1230
|
+
patch_pyspark_connect()
|
|
1059
1231
|
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
)
|
|
1064
|
-
if connection_parameters is not None:
|
|
1065
|
-
if snowpark_session is not None:
|
|
1066
|
-
raise ValueError(
|
|
1067
|
-
"Only specify one of snowpark_session and connection_parameters"
|
|
1068
|
-
)
|
|
1069
|
-
snowpark_session = snowpark.Session.builder.configs(
|
|
1070
|
-
connection_parameters
|
|
1071
|
-
).create()
|
|
1232
|
+
try:
|
|
1233
|
+
# Set max grpc message size if provided
|
|
1234
|
+
if max_grpc_message_size is not None:
|
|
1235
|
+
set_grpc_max_message_size(max_grpc_message_size)
|
|
1072
1236
|
|
|
1073
|
-
|
|
1074
|
-
|
|
1237
|
+
# Validate startup parameters
|
|
1238
|
+
snowpark_session = validate_startup_parameters(
|
|
1239
|
+
snowpark_session, connection_parameters
|
|
1240
|
+
)
|
|
1241
|
+
|
|
1242
|
+
server_running = get_server_running()
|
|
1243
|
+
if server_running.is_set():
|
|
1075
1244
|
url = get_client_url()
|
|
1076
1245
|
logger.warning(f"Snowpark Connect session is already running at {url}")
|
|
1077
1246
|
return
|
|
1078
1247
|
|
|
1079
|
-
|
|
1080
|
-
raise RuntimeError(
|
|
1081
|
-
"Can only set at most one of remote_url, tcp_port, and unix_domain_socket"
|
|
1082
|
-
)
|
|
1083
|
-
|
|
1084
|
-
url_from_env = os.environ.get("SPARK_REMOTE", None)
|
|
1085
|
-
if remote_url:
|
|
1086
|
-
_set_remote_url(remote_url)
|
|
1087
|
-
elif tcp_port:
|
|
1088
|
-
_set_server_tcp_port(tcp_port)
|
|
1089
|
-
elif unix_domain_socket:
|
|
1090
|
-
_set_server_unix_domain_socket(unix_domain_socket)
|
|
1091
|
-
elif url_from_env:
|
|
1092
|
-
# Spark clients use environment variable SPARK_REMOTE to figure out Spark Connect URL. If none of the
|
|
1093
|
-
# connection properties (remote_url, tcp_port, unix_domain_socket) are explicitly passed in to this
|
|
1094
|
-
# function then we should try and mimic clients' behavior
|
|
1095
|
-
# i.e. read the server URL from the SPARK_REMOTE environment variable.
|
|
1096
|
-
_set_remote_url(url_from_env)
|
|
1097
|
-
else:
|
|
1098
|
-
# No connection properties can be found at all - either as arguments to this function or int the environment
|
|
1099
|
-
# variable. We use random, unique Unix Domain Socket as a last fallback. Client can connect to this randomly
|
|
1100
|
-
# generated UDS port using snowpark_connect.get_session().
|
|
1101
|
-
# Mostly used in stored procs and Notebooks to avoid port conflicts.
|
|
1102
|
-
if os.name == "nt":
|
|
1103
|
-
# Windows does not support unix domain sockets, so use default TCP port instead.
|
|
1104
|
-
_set_server_tcp_port(DEFAULT_PORT)
|
|
1105
|
-
else:
|
|
1106
|
-
# Generate unique, random UDS port. Mostly useful in stored proc environment to avoid port conflicts.
|
|
1107
|
-
unix_domain_socket = _make_unix_domain_socket()
|
|
1108
|
-
_set_server_unix_domain_socket(unix_domain_socket)
|
|
1248
|
+
configure_server_url(remote_url, tcp_port, unix_domain_socket)
|
|
1109
1249
|
|
|
1110
1250
|
start_jvm()
|
|
1111
1251
|
_disable_protobuf_recursion_limit()
|
|
1252
|
+
otel_initialize()
|
|
1253
|
+
|
|
1254
|
+
if _add_signal_handler:
|
|
1255
|
+
setup_signal_handlers(stop_event)
|
|
1112
1256
|
|
|
1113
1257
|
if is_daemon:
|
|
1114
1258
|
arguments = (stop_event, snowpark_session)
|
|
1115
|
-
# `daemon=True` ensures the server thread exits when script finishes.
|
|
1116
|
-
server_thread = threading.Thread(target=_serve, args=arguments, daemon=True)
|
|
1117
|
-
server_thread.start()
|
|
1118
|
-
_server_running.wait()
|
|
1119
|
-
if _server_error:
|
|
1120
|
-
raise RuntimeError("Snowpark Connect session failed to start")
|
|
1121
|
-
return server_thread
|
|
1122
|
-
else:
|
|
1123
|
-
# Launch in the foreground.
|
|
1124
|
-
_serve(session=snowpark_session)
|
|
1125
|
-
except Exception as e:
|
|
1126
|
-
_reset_server_run_state()
|
|
1127
|
-
logger.error(e, exc_info=True)
|
|
1128
|
-
raise e
|
|
1129
1259
|
|
|
1260
|
+
target_func = otel_create_context_wrapper(_serve)
|
|
1130
1261
|
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
RuntimeError: If Spark Connect server is not started.
|
|
1143
|
-
"""
|
|
1144
|
-
try:
|
|
1145
|
-
if not url:
|
|
1146
|
-
url = get_client_url()
|
|
1262
|
+
server_thread = threading.Thread(
|
|
1263
|
+
target=target_func, args=arguments, daemon=True
|
|
1264
|
+
)
|
|
1265
|
+
server_thread.start()
|
|
1266
|
+
server_running.wait()
|
|
1267
|
+
if get_server_error():
|
|
1268
|
+
exception = RuntimeError("Snowpark Connect session failed to start")
|
|
1269
|
+
attach_custom_error_code(
|
|
1270
|
+
exception, ErrorCodes.STARTUP_CONNECTION_FAILED
|
|
1271
|
+
)
|
|
1272
|
+
raise exception
|
|
1147
1273
|
|
|
1148
|
-
|
|
1149
|
-
b = SparkSession.builder.channelBuilder(UnixDomainSocketChannelBuilder())
|
|
1274
|
+
return server_thread
|
|
1150
1275
|
else:
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
if conf is not None:
|
|
1154
|
-
for k, v in conf.getAll():
|
|
1155
|
-
b.config(k, v)
|
|
1156
|
-
|
|
1157
|
-
return b.getOrCreate()
|
|
1276
|
+
# Launch in the foreground with stop_event
|
|
1277
|
+
_serve(stop_event=stop_event, session=snowpark_session)
|
|
1158
1278
|
except Exception as e:
|
|
1159
1279
|
_reset_server_run_state()
|
|
1160
1280
|
logger.error(e, exc_info=True)
|
|
1281
|
+
attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
|
|
1161
1282
|
raise e
|
|
1162
1283
|
|
|
1163
1284
|
|
|
1164
1285
|
def init_spark_session(conf: SparkConf = None) -> SparkSession:
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
from jdk4py import JAVA_HOME
|
|
1168
|
-
|
|
1169
|
-
os.environ["JAVA_HOME"] = str(JAVA_HOME)
|
|
1170
|
-
except ModuleNotFoundError:
|
|
1171
|
-
# For notebooks on Warehouse
|
|
1172
|
-
os.environ["JAVA_HOME"] = os.environ["CONDA_PREFIX"]
|
|
1173
|
-
os.environ["JAVA_LD_LIBRARY_PATH"] = os.path.join(
|
|
1174
|
-
os.environ["CONDA_PREFIX"], "lib", "server"
|
|
1175
|
-
)
|
|
1176
|
-
logger.info("JAVA_HOME=%s", os.environ["JAVA_HOME"])
|
|
1286
|
+
_setup_spark_environment()
|
|
1287
|
+
from snowflake.snowpark_connect.utils.session import _get_current_snowpark_session
|
|
1177
1288
|
|
|
1178
|
-
|
|
1179
|
-
os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
|
|
1180
|
-
|
|
1181
|
-
snowpark_session = snowpark.context.get_active_session()
|
|
1289
|
+
snowpark_session = _get_current_snowpark_session()
|
|
1182
1290
|
start_session(snowpark_session=snowpark_session)
|
|
1183
1291
|
return get_session(conf=conf)
|
|
1184
1292
|
|
|
1185
1293
|
|
|
1186
|
-
def enable_debug_logging():
|
|
1187
|
-
logger.setLevel(logging.DEBUG)
|
|
1188
|
-
for handler in logger.handlers:
|
|
1189
|
-
handler.setLevel(logging.DEBUG)
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
1294
|
def _get_files_metadata(data_source: relations_proto.Read.DataSource) -> List[str]:
|
|
1193
1295
|
# TODO: Handle paths on the cloud
|
|
1194
1296
|
paths = data_source.paths
|
|
@@ -1206,15 +1308,3 @@ def _get_files_metadata(data_source: relations_proto.Read.DataSource) -> List[st
|
|
|
1206
1308
|
]
|
|
1207
1309
|
)
|
|
1208
1310
|
return files
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
def _disable_protobuf_recursion_limit():
|
|
1212
|
-
# https://github.com/protocolbuffers/protobuf/blob/960e79087b332583c80537c949621108a85aa442/src/google/protobuf/io/coded_stream.h#L616
|
|
1213
|
-
# Disable protobuf recursion limit (default 100) because Spark workloads often produce deeply nested execution plans. For example:
|
|
1214
|
-
# - Queries with many unions
|
|
1215
|
-
# - Complex expressions with multiple levels of nesting
|
|
1216
|
-
# Without this, legitimate Spark queries would fail with `(DecodeError) Error parsing message with type 'spark.connect.Relation'` error.
|
|
1217
|
-
# see test_sql_resulting_in_nested_protobuf
|
|
1218
|
-
from google.protobuf.pyext import cpp_message
|
|
1219
|
-
|
|
1220
|
-
cpp_message._message.SetAllowOversizeProtos(True)
|