snowpark-connect 0.27.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  3. snowflake/snowpark_connect/client/__init__.py +15 -0
  4. snowflake/snowpark_connect/client/error_utils.py +30 -0
  5. snowflake/snowpark_connect/client/exceptions.py +36 -0
  6. snowflake/snowpark_connect/client/query_results.py +90 -0
  7. snowflake/snowpark_connect/client/server.py +680 -0
  8. snowflake/snowpark_connect/client/utils/__init__.py +10 -0
  9. snowflake/snowpark_connect/client/utils/session.py +85 -0
  10. snowflake/snowpark_connect/column_name_handler.py +404 -243
  11. snowflake/snowpark_connect/column_qualifier.py +43 -0
  12. snowflake/snowpark_connect/config.py +237 -23
  13. snowflake/snowpark_connect/constants.py +2 -0
  14. snowflake/snowpark_connect/dataframe_container.py +102 -8
  15. snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
  16. snowflake/snowpark_connect/error/error_codes.py +50 -0
  17. snowflake/snowpark_connect/error/error_utils.py +172 -23
  18. snowflake/snowpark_connect/error/exceptions.py +13 -4
  19. snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
  20. snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
  21. snowflake/snowpark_connect/execute_plan/utils.py +5 -1
  22. snowflake/snowpark_connect/expression/function_defaults.py +9 -2
  23. snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
  24. snowflake/snowpark_connect/expression/literal.py +37 -13
  25. snowflake/snowpark_connect/expression/map_cast.py +123 -5
  26. snowflake/snowpark_connect/expression/map_expression.py +80 -27
  27. snowflake/snowpark_connect/expression/map_extension.py +322 -12
  28. snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
  29. snowflake/snowpark_connect/expression/map_udf.py +85 -20
  30. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
  31. snowflake/snowpark_connect/expression/map_unresolved_function.py +2748 -746
  32. snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
  33. snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
  34. snowflake/snowpark_connect/expression/map_window_function.py +18 -3
  35. snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
  36. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
  37. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
  38. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  39. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  40. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
  41. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
  42. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
  43. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
  44. snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
  45. snowflake/snowpark_connect/relation/io_utils.py +110 -10
  46. snowflake/snowpark_connect/relation/map_aggregate.py +196 -255
  47. snowflake/snowpark_connect/relation/map_catalog.py +5 -1
  48. snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
  49. snowflake/snowpark_connect/relation/map_extension.py +263 -29
  50. snowflake/snowpark_connect/relation/map_join.py +683 -442
  51. snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
  52. snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
  53. snowflake/snowpark_connect/relation/map_relation.py +48 -19
  54. snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
  55. snowflake/snowpark_connect/relation/map_show_string.py +13 -6
  56. snowflake/snowpark_connect/relation/map_sql.py +1233 -222
  57. snowflake/snowpark_connect/relation/map_stats.py +48 -9
  58. snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
  59. snowflake/snowpark_connect/relation/map_udtf.py +14 -4
  60. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
  61. snowflake/snowpark_connect/relation/read/map_read.py +134 -43
  62. snowflake/snowpark_connect/relation/read/map_read_csv.py +255 -45
  63. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
  64. snowflake/snowpark_connect/relation/read/map_read_json.py +320 -85
  65. snowflake/snowpark_connect/relation/read/map_read_parquet.py +142 -27
  66. snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
  67. snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
  68. snowflake/snowpark_connect/relation/read/map_read_table.py +82 -5
  69. snowflake/snowpark_connect/relation/read/map_read_text.py +18 -3
  70. snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
  71. snowflake/snowpark_connect/relation/read/reader_config.py +36 -3
  72. snowflake/snowpark_connect/relation/read/utils.py +50 -5
  73. snowflake/snowpark_connect/relation/stage_locator.py +91 -55
  74. snowflake/snowpark_connect/relation/utils.py +128 -5
  75. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
  76. snowflake/snowpark_connect/relation/write/map_write.py +929 -319
  77. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
  78. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  79. snowflake/snowpark_connect/resources_initializer.py +110 -48
  80. snowflake/snowpark_connect/server.py +546 -456
  81. snowflake/snowpark_connect/server_common/__init__.py +500 -0
  82. snowflake/snowpark_connect/snowflake_session.py +65 -0
  83. snowflake/snowpark_connect/start_server.py +53 -5
  84. snowflake/snowpark_connect/type_mapping.py +349 -27
  85. snowflake/snowpark_connect/typed_column.py +9 -7
  86. snowflake/snowpark_connect/utils/artifacts.py +9 -8
  87. snowflake/snowpark_connect/utils/cache.py +49 -27
  88. snowflake/snowpark_connect/utils/concurrent.py +36 -1
  89. snowflake/snowpark_connect/utils/context.py +187 -37
  90. snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
  91. snowflake/snowpark_connect/utils/env_utils.py +5 -1
  92. snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
  93. snowflake/snowpark_connect/utils/identifiers.py +137 -3
  94. snowflake/snowpark_connect/utils/io_utils.py +57 -1
  95. snowflake/snowpark_connect/utils/java_stored_procedure.py +125 -0
  96. snowflake/snowpark_connect/utils/java_udaf_utils.py +303 -0
  97. snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
  98. snowflake/snowpark_connect/utils/jvm_udf_utils.py +248 -0
  99. snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
  100. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
  101. snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
  102. snowflake/snowpark_connect/utils/profiling.py +25 -8
  103. snowflake/snowpark_connect/utils/scala_udf_utils.py +101 -332
  104. snowflake/snowpark_connect/utils/sequence.py +21 -0
  105. snowflake/snowpark_connect/utils/session.py +64 -28
  106. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
  107. snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
  108. snowflake/snowpark_connect/utils/telemetry.py +163 -22
  109. snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
  110. snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
  111. snowflake/snowpark_connect/utils/udf_cache.py +117 -41
  112. snowflake/snowpark_connect/utils/udf_helper.py +39 -37
  113. snowflake/snowpark_connect/utils/udf_utils.py +133 -14
  114. snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
  115. snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
  116. snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
  117. snowflake/snowpark_connect/version.py +1 -1
  118. snowflake/snowpark_decoder/dp_session.py +6 -2
  119. snowflake/snowpark_decoder/spark_decoder.py +12 -0
  120. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-submit +2 -2
  121. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/METADATA +14 -7
  122. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/RECORD +129 -167
  123. snowflake/snowpark_connect/hidden_column.py +0 -39
  124. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  125. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  126. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  127. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  128. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  129. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  130. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  131. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  132. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  133. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  134. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  135. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  136. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  137. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  138. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  139. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  140. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  141. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  142. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  143. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  144. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  145. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  146. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  147. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  148. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  149. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  150. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  151. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  152. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  153. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  154. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  155. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  156. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  157. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  158. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  159. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  160. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  161. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  162. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  163. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  164. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  165. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  166. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  167. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  168. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
  169. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  170. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  171. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  172. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  173. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  174. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  175. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  176. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  177. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  178. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  179. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  180. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  181. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  182. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  183. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  184. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  185. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  186. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-connect +0 -0
  187. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-session +0 -0
  188. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/WHEEL +0 -0
  189. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE-binary +0 -0
  190. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
  191. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/NOTICE-binary +0 -0
  192. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/top_level.txt +0 -0
@@ -21,17 +21,13 @@
21
21
  # limitations under the License.
22
22
  #
23
23
 
24
- import atexit
25
- import logging
24
+
26
25
  import os
27
- import pathlib
28
- import socket
26
+ import sys
29
27
  import tempfile
30
28
  import threading
31
- import urllib.parse
32
- import zipfile
33
29
  from concurrent import futures
34
- from typing import Any, Callable, Dict, List, Optional, Tuple
30
+ from typing import Callable, Dict, List, Optional
35
31
 
36
32
  import grpc
37
33
  import jpype
@@ -41,14 +37,10 @@ import pyspark.sql.connect.proto.base_pb2_grpc as proto_base_grpc
41
37
  import pyspark.sql.connect.proto.common_pb2 as common_proto
42
38
  import pyspark.sql.connect.proto.relations_pb2 as relations_proto
43
39
  import pyspark.sql.connect.proto.types_pb2 as types_proto
44
- from packaging import version
45
40
  from pyspark import StorageLevel
46
41
  from pyspark.conf import SparkConf
47
- from pyspark.errors import PySparkValueError
48
- from pyspark.sql.connect.client.core import ChannelBuilder
49
42
  from pyspark.sql.connect.session import SparkSession
50
43
 
51
- import snowflake.snowpark_connect
52
44
  import snowflake.snowpark_connect.proto.control_pb2_grpc as control_grpc
53
45
  import snowflake.snowpark_connect.tcm as tcm
54
46
  from snowflake import snowpark
@@ -56,7 +48,11 @@ from snowflake.snowpark_connect.analyze_plan.map_tree_string import map_tree_str
56
48
  from snowflake.snowpark_connect.config import route_config_proto
57
49
  from snowflake.snowpark_connect.constants import SERVER_SIDE_SESSION_ID
58
50
  from snowflake.snowpark_connect.control_server import ControlServicer
59
- from snowflake.snowpark_connect.error.error_utils import build_grpc_error_response
51
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
52
+ from snowflake.snowpark_connect.error.error_utils import (
53
+ attach_custom_error_code,
54
+ build_grpc_error_response,
55
+ )
60
56
  from snowflake.snowpark_connect.execute_plan.map_execution_command import (
61
57
  map_execution_command,
62
58
  )
@@ -66,7 +62,26 @@ from snowflake.snowpark_connect.execute_plan.map_execution_root import (
66
62
  from snowflake.snowpark_connect.relation.map_local_relation import map_local_relation
67
63
  from snowflake.snowpark_connect.relation.map_relation import map_relation
68
64
  from snowflake.snowpark_connect.relation.utils import get_semantic_string
69
- from snowflake.snowpark_connect.resources_initializer import initialize_resources_async
65
+ from snowflake.snowpark_connect.resources_initializer import initialize_resources
66
+ from snowflake.snowpark_connect.server_common import ( # noqa: F401 - re-exported for public API
67
+ _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE,
68
+ _client_telemetry_context,
69
+ _disable_protobuf_recursion_limit,
70
+ _get_default_grpc_options,
71
+ _reset_server_run_state,
72
+ _setup_spark_environment,
73
+ _stop_server,
74
+ configure_server_url,
75
+ get_client_url,
76
+ get_server_error,
77
+ get_server_running,
78
+ get_server_url,
79
+ get_session,
80
+ set_grpc_max_message_size,
81
+ set_server_error,
82
+ setup_signal_handlers,
83
+ validate_startup_parameters,
84
+ )
70
85
  from snowflake.snowpark_connect.type_mapping import (
71
86
  map_type_string_to_proto,
72
87
  snowpark_to_proto_type,
@@ -82,12 +97,13 @@ from snowflake.snowpark_connect.utils.cache import (
82
97
  df_cache_map_put_if_absent,
83
98
  )
84
99
  from snowflake.snowpark_connect.utils.context import (
100
+ clean_request_external_tables,
85
101
  clear_context_data,
86
- get_session_id,
87
- set_session_id,
102
+ get_request_external_tables,
103
+ get_spark_session_id,
104
+ set_spark_session_id,
88
105
  set_spark_version,
89
106
  )
90
- from snowflake.snowpark_connect.utils.env_utils import get_int_from_env
91
107
  from snowflake.snowpark_connect.utils.external_udxf_cache import (
92
108
  clear_external_udxf_cache,
93
109
  )
@@ -96,7 +112,25 @@ from snowflake.snowpark_connect.utils.interrupt import (
96
112
  interrupt_queries_with_tag,
97
113
  interrupt_query,
98
114
  )
99
- from snowflake.snowpark_connect.utils.profiling import profile_method
115
+ from snowflake.snowpark_connect.utils.java_stored_procedure import (
116
+ set_java_udf_creator_initialized_state,
117
+ )
118
+ from snowflake.snowpark_connect.utils.open_telemetry import (
119
+ is_telemetry_enabled,
120
+ otel_attach_context,
121
+ otel_create_context_wrapper,
122
+ otel_create_status,
123
+ otel_detach_context,
124
+ otel_end_root_span,
125
+ otel_flush_telemetry,
126
+ otel_get_current_span,
127
+ otel_get_root_span_context,
128
+ otel_get_status_code,
129
+ otel_get_tracer,
130
+ otel_initialize,
131
+ otel_start_span_as_current,
132
+ )
133
+ from snowflake.snowpark_connect.utils.profiling import PROFILING_ENABLED, profile_method
100
134
  from snowflake.snowpark_connect.utils.session import (
101
135
  configure_snowpark_session,
102
136
  get_or_create_snowpark_session,
@@ -112,29 +146,111 @@ from snowflake.snowpark_connect.utils.telemetry import (
112
146
  )
113
147
  from snowflake.snowpark_connect.utils.xxhash64 import xxhash64_string
114
148
 
115
- DEFAULT_PORT = 15002
116
149
 
117
- # https://github.com/apache/spark/blob/v3.5.3/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala#L21
118
- _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE = 128 * 1024 * 1024
119
- # TODO: Verify if we we want to configure it via env variables.
120
- _SPARK_CONNECT_GRPC_MAX_METADATA_SIZE = 64 * 1024 # 64kb
150
+ def _store_client_stack_trace(client_stack_info):
151
+ """Store client stack trace in thread-local storage"""
152
+
153
+ _client_telemetry_context.stack_trace = client_stack_info
154
+
155
+
156
+ def _clear_client_stack_trace():
157
+ """Clear client stack trace"""
158
+
159
+ _client_telemetry_context.stack_trace = None
160
+
161
+
162
+ def _get_client_stack_trace():
163
+ """Get current client stack trace"""
164
+
165
+ return getattr(_client_telemetry_context, "stack_trace", None)
166
+
121
167
 
168
+ def _add_client_stack_trace_to_span(span, client_stack):
169
+ """
170
+ Add formatted client stack trace to a specific span.
122
171
 
123
- def _sanitize_file_paths(text: str) -> str:
172
+ Args:
173
+ span: The OpenTelemetry span to add the stack trace attribute to
174
+ client_stack: The client stack trace data (list of frame dicts)
124
175
  """
125
- Sanitize file paths in error messages by replacing them with placeholders.
126
- Only matches actual file paths, not module names or class names.
176
+ if not client_stack or not span or not span.is_recording():
177
+ return
178
+
179
+ stack_frames = []
180
+ for frame in client_stack:
181
+ if frame.get("file_name") and frame.get("line_number"):
182
+ method = frame.get("method_name", "unknown")
183
+ location = f"{frame.get('file_name')}:{frame.get('line_number')}"
184
+ stack_frames.append(f"{method} at {location}")
185
+
186
+ if stack_frames:
187
+ span.set_attribute("client.stack_trace", " <- ".join(stack_frames))
188
+
189
+
190
+ def _process_and_store_client_stack_trace(request, add_to_span: bool = False):
127
191
  """
128
- import re
192
+ Extract, store, and optionally add client stack trace to the current span.
129
193
 
130
- # Pattern to match file paths in traceback "File" lines only
131
- # This targets the specific format: File "/path/to/file.py", line XX
132
- file_line_pattern = r'(File\s+["\'])([^"\']+)(["\'],\s+line\s+\d+)'
194
+ Args:
195
+ request: The gRPC request containing user context with stack trace
196
+ add_to_span: If True, format and add stack trace as span attribute to current span
133
197
 
134
- def replace_file_path(match):
135
- return f"{match.group(1)}<redacted_file_path>{match.group(3)}"
198
+ Returns:
199
+ The extracted client_stack (or None) for use in ExecutePlan
200
+ """
201
+ # Extract and store client stack trace information for telemetry
202
+ client_stack = _extract_and_log_user_stack_trace(request)
203
+ if client_stack:
204
+ _store_client_stack_trace(client_stack)
136
205
 
137
- return re.sub(file_line_pattern, replace_file_path, text)
206
+ # Set span attribute with formatted stack trace (if requested and available)
207
+ if add_to_span and client_stack:
208
+ root_span_otel_context = otel_get_root_span_context()
209
+ if root_span_otel_context is not None and is_telemetry_enabled():
210
+ current_span = otel_get_current_span()
211
+ if current_span and current_span.is_recording():
212
+ _add_client_stack_trace_to_span(current_span, client_stack)
213
+
214
+ return client_stack
215
+
216
+
217
+ def _extract_and_log_user_stack_trace(request):
218
+ """
219
+ Extract and log user stack trace information from request extensions.
220
+
221
+ Args:
222
+ request: The gRPC request containing user_context.extensions
223
+
224
+ Returns:
225
+ List of stack trace frames or None if no traces found
226
+ """
227
+ try:
228
+ from snowflake.snowpark_connect.utils.patch_spark_line_number import (
229
+ extract_stack_trace_from_extensions,
230
+ )
231
+
232
+ if hasattr(request, "user_context") and hasattr(
233
+ request.user_context, "extensions"
234
+ ):
235
+ stack_traces = extract_stack_trace_from_extensions(
236
+ request.user_context.extensions
237
+ )
238
+
239
+ if stack_traces:
240
+ logger.debug("User code stack trace:")
241
+ for i, frame in enumerate(stack_traces):
242
+ logger.debug(
243
+ f" Frame {i}: {frame.get('method_name', 'unknown')} "
244
+ f"at {frame.get('file_name', 'unknown')}:{frame.get('line_number', 'unknown')}"
245
+ )
246
+ return stack_traces # Return the stack traces for telemetry use
247
+ else:
248
+ logger.debug("No user stack trace information found in request")
249
+ return None
250
+ except Exception as e:
251
+ # Don't let stack trace extraction errors affect the main request
252
+ logger.debug(f"Failed to extract user stack trace: {e}")
253
+ return None
138
254
 
139
255
 
140
256
  def _handle_exception(context, e: Exception):
@@ -147,16 +263,15 @@ def _handle_exception(context, e: Exception):
147
263
  if show_traceback:
148
264
  # Show detailed traceback (includes error info naturally)
149
265
  error_traceback = traceback.format_exc()
150
- sanitized_traceback = _sanitize_file_paths(error_traceback)
151
- logger.error(sanitized_traceback)
266
+ logger.error(error_traceback)
152
267
  else:
153
268
  # Show only basic error information, no traceback
154
269
  logger.error("Error: %s - %s", type(e).__name__, str(e))
155
270
 
156
271
  telemetry.report_request_failure(e)
157
-
158
272
  if tcm.TCM_MODE:
159
- # TODO: SNOW-2009834 gracefully return error back in TCM
273
+ # spark decoder will catch the error and return it to GS gracefully
274
+ attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
160
275
  raise e
161
276
 
162
277
  from grpc_status import rpc_status
@@ -165,14 +280,71 @@ def _handle_exception(context, e: Exception):
165
280
  context.abort_with_status(rpc_status.to_status(rich_status))
166
281
 
167
282
 
283
+ # Decorator for creating method spans as children of root span
284
+ def _with_method_span(method_name):
285
+ """
286
+ Decorator to create a new span as child of root span for gRPC methods and provide it as parent to Snowpark operations.
287
+ """
288
+
289
+ def decorator(func):
290
+ def wrapper(*args, **kwargs):
291
+ # Get the root span context first
292
+ root_span_otel_context = otel_get_root_span_context()
293
+
294
+ # Only proceed if BOTH conditions are true
295
+ if root_span_otel_context is not None and is_telemetry_enabled():
296
+ # Attach the root context first, then create child span
297
+ context_token = otel_attach_context(root_span_otel_context)
298
+
299
+ try:
300
+ tracer = otel_get_tracer(__name__)
301
+ span_name = f"snowpark_connect.{method_name}"
302
+
303
+ # Create span as child of the root span context
304
+ span_context_mgr = otel_start_span_as_current(tracer, span_name)
305
+ if span_context_mgr:
306
+ with span_context_mgr as span:
307
+ try:
308
+ # Execute the method with the new span as current context
309
+ return func(*args, **kwargs)
310
+
311
+ except Exception as e:
312
+ # Record the exception in the span
313
+ span.record_exception(e)
314
+ StatusCode = otel_get_status_code()
315
+ if StatusCode:
316
+ status = otel_create_status(
317
+ StatusCode.ERROR, str(e)
318
+ )
319
+ if status:
320
+ span.set_status(status)
321
+ raise
322
+ else:
323
+ # No span created, just execute the function
324
+ return func(*args, **kwargs)
325
+
326
+ finally:
327
+ # Always detach the root context
328
+ if context_token is not None:
329
+ otel_detach_context(context_token)
330
+ else:
331
+ # No root context available or OTel not available, execute without span
332
+ return func(*args, **kwargs)
333
+
334
+ return wrapper
335
+
336
+ return decorator
337
+
338
+
339
+ # Snowflake Connect gRPC Service Implementation
168
340
  class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
169
341
  def __init__(
170
342
  self,
171
343
  log_request_fn: Optional[Callable[[bytearray], None]] = None,
172
344
  ) -> None:
173
345
  self.log_request_fn = log_request_fn
174
- # Trigger async initialization here, so that we reduce overhead for rpc calls.
175
- initialize_resources_async()
346
+ # Trigger synchronous initialization here, so that we reduce overhead for rpc calls.
347
+ initialize_resources()
176
348
 
177
349
  @profile_method
178
350
  def ExecutePlan(self, request: proto_base.ExecutePlanRequest, context):
@@ -181,20 +353,45 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
181
353
  It is guaranteed that there is at least one ARROW batch returned even if the result set is empty.
182
354
  """
183
355
  logger.info("ExecutePlan")
356
+
357
+ client_stack = _process_and_store_client_stack_trace(request, add_to_span=False)
358
+
184
359
  if self.log_request_fn is not None:
185
360
  self.log_request_fn(request.SerializeToString())
186
361
 
187
362
  # TODO: remove session id context when we host this in Snowflake server
188
363
  # set the thread-local context of session id
189
364
  clear_context_data()
190
- set_session_id(request.session_id)
365
+ set_spark_session_id(request.session_id)
191
366
  set_spark_version(request.client_type)
192
367
  telemetry.initialize_request_summary(request)
193
368
 
194
369
  set_query_tags(request.tags)
195
370
 
196
- result_iter = iter(())
371
+ # Additional context attachment for Snowpark DataFrame operations
372
+ snowpark_context_token = None
373
+ span = None
374
+ span_context_manager = None
197
375
  try:
376
+ root_span_otel_context = otel_get_root_span_context()
377
+
378
+ if root_span_otel_context is not None and is_telemetry_enabled():
379
+ snowpark_context_token = otel_attach_context(root_span_otel_context)
380
+
381
+ # Create span manually for generator function and make it current
382
+ tracer = otel_get_tracer(__name__)
383
+ span_context_manager = otel_start_span_as_current(
384
+ tracer, "snowpark_connect.ExecutePlan"
385
+ )
386
+ span = None
387
+ if span_context_manager:
388
+ span = (
389
+ span_context_manager.__enter__()
390
+ ) # Start the span context AND make it current
391
+ # Add stack trace to this manually created span
392
+ _add_client_stack_trace_to_span(span, client_stack)
393
+
394
+ result_iter = iter(())
198
395
  match request.plan.WhichOneof("op_type"):
199
396
  case "root":
200
397
  logger.info("ROOT")
@@ -212,32 +409,60 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
212
409
  result_complete=proto_base.ExecutePlanResponse.ResultComplete(),
213
410
  )
214
411
  except Exception as e:
412
+ if span:
413
+ span.record_exception(e)
414
+ StatusCode = otel_get_status_code()
415
+ if StatusCode:
416
+ status = otel_create_status(StatusCode.ERROR, str(e))
417
+ if status:
418
+ span.set_status(status)
215
419
  _handle_exception(context, e)
216
420
  finally:
421
+ if span_context_manager:
422
+ span_context_manager.__exit__(None, None, None) # End the span
423
+ if snowpark_context_token is not None:
424
+ otel_detach_context(snowpark_context_token)
425
+ # Clear client stack trace when request is done
426
+ _clear_client_stack_trace()
427
+ otel_flush_telemetry()
428
+ self._cleanup_external_tables()
217
429
  telemetry.send_request_summary_telemetry()
218
430
 
219
431
  @profile_method
432
+ @_with_method_span("AnalyzePlan")
220
433
  def AnalyzePlan(self, request: proto_base.AnalyzePlanRequest, context):
221
434
  """Analyzes a query and returns a [[AnalyzeResponse]] containing metadata about the query."""
222
435
  logger.info(f"AnalyzePlan: {request.WhichOneof('analyze')}")
436
+
437
+ _process_and_store_client_stack_trace(request, add_to_span=True)
438
+
223
439
  if self.log_request_fn is not None:
224
440
  self.log_request_fn(request.SerializeToString())
441
+
225
442
  try:
226
443
  # TODO: remove session id context when we host this in Snowflake server
227
444
  # set the thread-local context of session id
228
445
  clear_context_data()
229
- set_session_id(request.session_id)
446
+ set_spark_session_id(request.session_id)
230
447
  set_spark_version(request.client_type)
231
448
  telemetry.initialize_request_summary(request)
232
449
  match request.WhichOneof("analyze"):
233
450
  case "schema":
234
451
  result = map_relation(request.schema.plan.root)
235
- snowpark_df = result.dataframe
236
- snowpark_schema: snowpark.types.StructType = snowpark_df.schema
452
+
453
+ from snowflake.snowpark_connect.relation.read.metadata_utils import (
454
+ without_internal_columns,
455
+ )
456
+
457
+ filtered_result = without_internal_columns(result)
458
+ filtered_df = filtered_result.dataframe
459
+
237
460
  schema = proto_base.AnalyzePlanResponse.Schema(
238
461
  schema=types_proto.DataType(
239
462
  **snowpark_to_proto_type(
240
- snowpark_schema, result.column_map, snowpark_df
463
+ filtered_df.schema,
464
+ filtered_result.column_map,
465
+ filtered_df,
241
466
  )
242
467
  )
243
468
  )
@@ -274,10 +499,15 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
274
499
  plan_id = request.persist.relation.common.plan_id
275
500
  # cache the plan if it is not already in the map
276
501
 
502
+ from snowflake.snowpark_connect.relation.read.metadata_utils import (
503
+ without_internal_columns,
504
+ )
505
+
277
506
  df_cache_map_put_if_absent(
278
507
  (request.session_id, plan_id),
279
- lambda: map_relation(request.persist.relation),
280
- materialize=True,
508
+ lambda: without_internal_columns(
509
+ map_relation(request.persist.relation)
510
+ ),
281
511
  )
282
512
 
283
513
  storage_level = request.persist.storage_level
@@ -366,15 +596,24 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
366
596
  ),
367
597
  )
368
598
  case _:
369
- raise SnowparkConnectNotImplementedError(
599
+ exception = SnowparkConnectNotImplementedError(
370
600
  f"ANALYZE PLAN NOT IMPLEMENTED:\n{request}"
371
601
  )
602
+ attach_custom_error_code(
603
+ exception, ErrorCodes.UNSUPPORTED_OPERATION
604
+ )
605
+ raise exception
372
606
  except Exception as e:
373
607
  _handle_exception(context, e)
374
608
  finally:
609
+ # Clear client stack trace when request is done
610
+ _clear_client_stack_trace()
611
+ otel_flush_telemetry()
612
+ self._cleanup_external_tables()
375
613
  telemetry.send_request_summary_telemetry()
376
614
 
377
615
  @staticmethod
616
+ @_with_method_span("Config")
378
617
  def Config(
379
618
  request: proto_base.ConfigRequest,
380
619
  context,
@@ -389,12 +628,18 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
389
628
  ):
390
629
  """Update or fetch the configurations and returns a [[ConfigResponse]] containing the result."""
391
630
  logger.info("Config")
631
+
632
+ _process_and_store_client_stack_trace(request, add_to_span=True)
633
+
392
634
  try:
393
635
  telemetry.initialize_request_summary(request)
394
636
  return route_config_proto(request, get_or_create_snowpark_session())
395
637
  except Exception as e:
396
638
  _handle_exception(context, e)
397
639
  finally:
640
+ # Clear client stack trace when request is done
641
+ _clear_client_stack_trace()
642
+ otel_flush_telemetry()
398
643
  telemetry.send_request_summary_telemetry()
399
644
 
400
645
  def AddArtifacts(self, request_iterator, context):
@@ -402,11 +647,9 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
402
647
  the added artifacts.
403
648
  """
404
649
  logger.info("AddArtifacts")
650
+
405
651
  session: snowpark.Session = get_or_create_snowpark_session()
406
- filenames: dict[str, str] = {}
407
652
  response: dict[str, proto_base.AddArtifactsResponse.ArtifactSummary] = {}
408
- # Store accumulated data for local relation cache
409
- cache_data: dict[str, bytearray] = {}
410
653
 
411
654
  def _try_handle_local_relation(artifact_name: str, data: bytes):
412
655
  """
@@ -422,12 +665,14 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
422
665
  ) # heuristic to identify local relations
423
666
 
424
667
  def _handle_regular_artifact():
425
- filenames[artifact_name] = write_artifact(
668
+ artifact = write_artifact(
426
669
  session,
427
670
  artifact_name,
428
671
  data,
429
672
  overwrite=True,
430
673
  )
674
+ with session._filenames_lock:
675
+ session._filenames[get_spark_session_id()][artifact_name] = artifact
431
676
 
432
677
  if is_likely_local_relation:
433
678
  try:
@@ -435,9 +680,8 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
435
680
  l_relation.ParseFromString(data)
436
681
  relation = relations_proto.Relation(local_relation=l_relation)
437
682
  df_cache_map_put_if_absent(
438
- (get_session_id(), artifact_name.replace("cache/", "")),
683
+ (get_spark_session_id(), artifact_name.replace("cache/", "")),
439
684
  lambda: map_local_relation(relation), # noqa: B023
440
- materialize=True,
441
685
  )
442
686
  except Exception as e:
443
687
  logger.warning("Failed to put df into cache: %s", str(e))
@@ -458,29 +702,46 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
458
702
  # Batch artifacts are sent as a single "batch" message containing a list of
459
703
  # artifacts. We do not need to keep track of the name since it is included in
460
704
  # each artifact.
461
- current_name: str = ""
705
+
462
706
  for request in request_iterator:
463
707
  clear_context_data()
464
- set_session_id(request.session_id)
708
+ set_spark_session_id(request.session_id)
465
709
  set_spark_version(request.client_type)
710
+ with session._filenames_lock:
711
+ if request.session_id not in session._filenames:
712
+ session._filenames[request.session_id] = {}
713
+
466
714
  match request.WhichOneof("payload"):
467
715
  case "begin_chunk":
468
716
  current_name = request.begin_chunk.name
469
- assert (
470
- current_name not in filenames
471
- ), "Duplicate artifact name found."
717
+ current_chunk = {
718
+ "name": current_name,
719
+ "num_chunks": request.begin_chunk.num_chunks,
720
+ "current_chunk_index": 1,
721
+ }
722
+ with session._filenames_lock:
723
+ assert (
724
+ current_name not in session._filenames[request.session_id]
725
+ ), "Duplicate artifact name found."
472
726
 
473
727
  if current_name.startswith("cache/"):
474
- cache_data[current_name] = bytearray(
728
+ current_chunk["cache"] = bytearray(
475
729
  request.begin_chunk.initial_chunk.data
476
730
  )
477
731
  else:
478
- filenames[current_name] = write_artifact(
732
+ artifact = write_artifact(
479
733
  session,
480
734
  current_name,
481
735
  request.begin_chunk.initial_chunk.data,
482
736
  overwrite=True,
483
737
  )
738
+ with session._filenames_lock:
739
+ session._filenames[request.session_id][
740
+ current_name
741
+ ] = artifact
742
+ # cache current chunk
743
+ with session._current_chunk_lock:
744
+ session._current_chunk[request.session_id] = current_chunk
484
745
  response[
485
746
  current_name
486
747
  ] = proto_base.AddArtifactsResponse.ArtifactSummary(
@@ -491,18 +752,53 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
491
752
  ),
492
753
  )
493
754
  case "chunk":
755
+ # retrieve current chunk
756
+ with session._current_chunk_lock:
757
+ if request.session_id not in session._current_chunk:
758
+ exception = ValueError(
759
+ f"Received 'chunk' for session_id '{request.session_id}' without a prior 'begin_chunk'."
760
+ )
761
+ attach_custom_error_code(
762
+ exception, ErrorCodes.INTERNAL_ERROR
763
+ )
764
+ raise exception
765
+ current_chunk = session._current_chunk[request.session_id]
766
+
767
+ current_name = current_chunk["name"]
768
+ current_chunk["current_chunk_index"] += 1
494
769
  if current_name.startswith("cache/"):
495
- cache_data[current_name].extend(request.chunk.data)
770
+ current_chunk["cache"].extend(request.chunk.data)
496
771
  else:
497
- assert filenames[current_name] == write_artifact(
772
+ artifact = write_artifact(
498
773
  session, current_name, request.chunk.data
499
- ), "Artifact staging error."
774
+ )
775
+ with session._filenames_lock:
776
+ assert (
777
+ session._filenames[request.session_id][current_name]
778
+ == artifact
779
+ ), "Artifact staging error."
780
+
781
+ if (
782
+ current_chunk["current_chunk_index"]
783
+ == current_chunk["num_chunks"]
784
+ ):
785
+ # all chunks are ready
786
+ if current_name.startswith("cache/"):
787
+ _try_handle_local_relation(
788
+ current_name, bytes(current_chunk["cache"])
789
+ )
790
+ with session._current_chunk_lock:
791
+ # remove current chunk from session
792
+ del session._current_chunk[request.session_id]
500
793
 
501
794
  response[
502
795
  current_name
503
796
  ] = proto_base.AddArtifactsResponse.ArtifactSummary(
504
797
  name=current_name,
505
- is_crc_successful=response[current_name].is_crc_successful
798
+ is_crc_successful=(
799
+ current_name not in response
800
+ or response[current_name].is_crc_successful
801
+ )
506
802
  and check_checksum(request.chunk.data, request.chunk.crc),
507
803
  )
508
804
  case "batch":
@@ -519,62 +815,89 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
519
815
  ),
520
816
  )
521
817
  case _:
522
- raise ValueError(
818
+ exception = ValueError(
523
819
  f"Unexpected payload type in AddArtifacts: {request.WhichOneof('payload')}"
524
820
  )
525
-
526
- for name, data in cache_data.items():
527
- _try_handle_local_relation(name, bytes(data))
821
+ attach_custom_error_code(
822
+ exception, ErrorCodes.UNSUPPORTED_OPERATION
823
+ )
824
+ raise exception
825
+
826
+ # if current chunk is still not finished, just return here
827
+ # This should only happen in TCM since we have to send request via rest one by one so current chunk cannot be
828
+ # finished in one iteration
829
+ with session._current_chunk_lock:
830
+ if request.session_id in session._current_chunk:
831
+ return proto_base.AddArtifactsResponse(
832
+ artifacts=list(response.values())
833
+ )
528
834
 
529
835
  class_files: dict[str, str] = {}
530
- for (name, filepath) in filenames.items():
531
- if name.endswith(".class"):
532
- # name is <dir>/<package>/<class_name>
533
- # we don't need the dir name, but require the package, so only remove dir
534
- if os.name != "nt":
535
- class_files[name.split("/", 1)[-1]] = filepath
536
- else:
537
- class_files[name.split("\\", 1)[-1]] = filepath
538
- continue
539
- session.file.put(
540
- filepath,
541
- session.get_session_stage(),
542
- auto_compress=False,
543
- overwrite=True,
544
- source_compression="GZIP" if name.endswith(".gz") else "NONE",
545
- )
546
-
547
- if name.startswith("cache"):
548
- continue
549
-
550
- # Remove temporary stored files which are put on the stage
551
- os.remove(filepath)
552
-
553
- # Add only files marked to be used in user generated Python UDFs.
554
- cached_name = f"{session.get_session_stage()}/{filepath.split('/')[-1]}"
555
- if not name.startswith("pyfiles") and cached_name in session._python_files:
556
- session._python_files.remove(cached_name)
557
- elif name.startswith("pyfiles"):
558
- session._python_files.add(cached_name)
559
-
560
- if not name.startswith("pyfiles"):
561
- session._import_files.add(cached_name)
562
-
563
- if class_files:
564
- write_class_files_to_stage(session, class_files)
836
+ with session._filenames_lock:
837
+ for (name, filepath) in session._filenames[get_spark_session_id()].items():
838
+ if name.endswith(".class"):
839
+ # name is <dir>/<package>/<class_name>
840
+ # we don't need the dir name, but require the package, so only remove dir
841
+ if os.name != "nt":
842
+ class_files[name.split("/", 1)[-1]] = filepath
843
+ else:
844
+ class_files[name.split("\\", 1)[-1]] = filepath
845
+ continue
846
+ session.file.put(
847
+ filepath,
848
+ session.get_session_stage(),
849
+ auto_compress=False,
850
+ overwrite=True,
851
+ source_compression="GZIP" if name.endswith(".gz") else "NONE",
852
+ )
565
853
 
566
- if any(not name.startswith("cache") for name in filenames.keys()):
567
- clear_external_udxf_cache(session)
854
+ if name.startswith("cache"):
855
+ continue
856
+
857
+ # Add only files marked to be used in user generated Python UDFs.
858
+ cached_name = f"{session.get_session_stage()}/{filepath.split('/')[-1]}"
859
+ if (
860
+ not name.startswith("pyfiles")
861
+ and cached_name in session._python_files
862
+ ):
863
+ session._python_files.remove(cached_name)
864
+ elif name.startswith("pyfiles"):
865
+ session._python_files.add(cached_name)
866
+
867
+ if name.startswith("jars/"):
868
+ session._artifact_jars.add(cached_name)
869
+ # Recreate the Java procedure to reload jars
870
+ set_java_udf_creator_initialized_state(False)
871
+ elif not name.startswith("pyfiles"):
872
+ session._import_files.add(cached_name)
873
+
874
+ # Remove temporary stored files which are put on the stage
875
+ os.remove(filepath)
876
+
877
+ if class_files:
878
+ jar_name = write_class_files_to_stage(session, class_files)
879
+ session._artifact_jars.add(jar_name)
880
+
881
+ if any(
882
+ not name.startswith("cache")
883
+ for name in session._filenames[get_spark_session_id()].keys()
884
+ ):
885
+ clear_external_udxf_cache(session)
886
+
887
+ # clear filenames for this session
888
+ session._filenames[get_spark_session_id()] = {}
568
889
 
569
890
  return proto_base.AddArtifactsResponse(artifacts=list(response.values()))
570
891
 
571
892
  def ArtifactStatus(self, request, context):
572
893
  """Check statuses of artifacts in the session and returns them in a [[ArtifactStatusesResponse]]"""
573
894
  logger.info("ArtifactStatus")
895
+
574
896
  clear_context_data()
575
- set_session_id(request.session_id)
897
+ set_spark_session_id(request.session_id)
576
898
  set_spark_version(request.client_type)
577
899
  session: snowpark.Session = get_or_create_snowpark_session()
900
+
578
901
  if os.name != "nt":
579
902
  tmp_path = f"/tmp/sas-{session.session_id}/"
580
903
  else:
@@ -583,7 +906,7 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
583
906
  def _is_local_relation_cached(name: str) -> bool:
584
907
  if name.startswith("cache/"):
585
908
  hash = name.replace("cache/", "")
586
- cached_df = df_cache_map_get((get_session_id(), hash))
909
+ cached_df = df_cache_map_get((get_spark_session_id(), hash))
587
910
  return cached_df is not None
588
911
  return False
589
912
 
@@ -618,6 +941,7 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
618
941
  # instead of using operation ids, we're relying on Snowflake query ids here, meaning that:
619
942
  # - The list of returned interrupted_ids contains query ids of interrupted jobs, instead of their operation ids
620
943
  # - INTERRUPT_TYPE_OPERATION_ID interrupt type expects a Snowflake query id instead of an operation id
944
+
621
945
  try:
622
946
  match request.interrupt_type:
623
947
  case proto_base.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL:
@@ -627,9 +951,13 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
627
951
  case proto_base.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID:
628
952
  interrupted_ids = interrupt_query(request.operation_id)
629
953
  case _:
630
- raise SnowparkConnectNotImplementedError(
954
+ exception = SnowparkConnectNotImplementedError(
631
955
  f"INTERRUPT NOT IMPLEMENTED:\n{request}"
632
956
  )
957
+ attach_custom_error_code(
958
+ exception, ErrorCodes.UNSUPPORTED_OPERATION
959
+ )
960
+ raise exception
633
961
 
634
962
  return proto_base.InterruptResponse(
635
963
  session_id=request.session_id,
@@ -647,9 +975,12 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
647
975
  continue. If there is a ResultComplete, the client should use ReleaseExecute with
648
976
  """
649
977
  logger.info("ReattachExecute")
650
- raise SnowparkConnectNotImplementedError(
978
+
979
+ exception = SnowparkConnectNotImplementedError(
651
980
  "Spark client has detached, please resubmit request. In a future version, the server will be support the reattach."
652
981
  )
982
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
983
+ raise exception
653
984
 
654
985
  def ReleaseExecute(self, request: proto_base.ReleaseExecuteRequest, context):
655
986
  """Release an reattachable execution, or parts thereof.
@@ -666,6 +997,18 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
666
997
  except Exception as e:
667
998
  _handle_exception(context, e)
668
999
 
1000
+ def _cleanup_external_tables(self):
1001
+ external_tables = get_request_external_tables()
1002
+ if not external_tables:
1003
+ return
1004
+ session: snowpark.Session = get_or_create_snowpark_session()
1005
+ for table in external_tables:
1006
+ try:
1007
+ session.sql(f"DROP EXTERNAL TABLE IF EXISTS {table}").collect()
1008
+ except Exception as e:
1009
+ logger.warning(f"Failed to drop external table {table}: {e}")
1010
+ clean_request_external_tables()
1011
+
669
1012
  # TODO: These are required in Spark 4.x.
670
1013
  # def ReleaseSession(self, request, context):
671
1014
  # """Release a session.
@@ -682,39 +1025,16 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
682
1025
  # return super().FetchErrorDetails(request, context)
683
1026
 
684
1027
 
685
- # Global state related to server connection
686
- _server_running: threading.Event = threading.Event()
687
- _server_error: bool = False
688
- _server_url: Optional[str] = None
689
- _client_url: Optional[str] = None
690
-
691
-
692
- # Used to reset server global state to the initial blank slate state if error happens during server startup.
693
- # Called after the startup error is caught and handled / logged etc.
694
- def _reset_server_run_state():
695
- global _server_running, _server_error, _server_url, _client_url
696
- _server_running.clear()
697
- _server_error = False
698
- _server_url = None
699
- _client_url = None
700
-
701
-
702
- def _stop_server(stop_event: threading.Event, server: grpc.Server):
703
- stop_event.wait()
704
- server.stop(0)
705
- _reset_server_run_state()
706
- logger.info("server stop sent")
707
-
708
-
709
1028
  def _serve(
710
1029
  stop_event: Optional[threading.Event] = None,
711
1030
  session: Optional[snowpark.Session] = None,
712
1031
  ):
713
- global _server_running, _server_error
1032
+ server_running = get_server_running()
714
1033
  # TODO: factor out the Snowflake connection code.
715
1034
  server = None
716
1035
  try:
717
1036
  config_snowpark()
1037
+
718
1038
  if session is None:
719
1039
  session = get_or_create_snowpark_session()
720
1040
  else:
@@ -725,33 +1045,16 @@ def _serve(
725
1045
  # No need to start grpc server in TCM
726
1046
  return
727
1047
 
728
- server_options = [
729
- (
730
- "grpc.max_receive_message_length",
731
- get_int_from_env(
732
- "SNOWFLAKE_GRPC_MAX_MESSAGE_SIZE",
733
- _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE,
734
- ),
735
- ),
736
- (
737
- "grpc.max_metadata_size",
738
- get_int_from_env(
739
- "SNOWFLAKE_GRPC_MAX_METADATA_SIZE",
740
- _SPARK_CONNECT_GRPC_MAX_METADATA_SIZE,
741
- ),
742
- ),
743
- (
744
- "grpc.absolute_max_metadata_size",
745
- get_int_from_env(
746
- "SNOWFLAKE_GRPC_MAX_METADATA_SIZE",
747
- _SPARK_CONNECT_GRPC_MAX_METADATA_SIZE,
748
- )
749
- * 2,
750
- ),
751
- ]
1048
+ server_options = _get_default_grpc_options()
1049
+
1050
+ # cProfile doesn't work correctly with multiple threads
1051
+ max_workers = 1 if PROFILING_ENABLED else 10
1052
+
752
1053
  server = grpc.server(
753
- futures.ThreadPoolExecutor(max_workers=10), options=server_options
1054
+ futures.ThreadPoolExecutor(max_workers=max_workers),
1055
+ options=server_options,
754
1056
  )
1057
+
755
1058
  control_servicer = ControlServicer(session)
756
1059
  proto_base_grpc.add_SparkConnectServiceServicer_to_server(
757
1060
  SnowflakeConnectServicer(control_servicer.log_spark_connect_batch),
@@ -762,193 +1065,33 @@ def _serve(
762
1065
  server.add_insecure_port(server_url)
763
1066
  logger.info(f"Starting Snowpark Connect server on {server_url}...")
764
1067
  server.start()
765
- _server_running.set()
1068
+ server_running.set()
766
1069
  logger.info("Snowpark Connect server started!")
767
1070
  telemetry.send_server_started_telemetry()
1071
+
768
1072
  if stop_event is not None:
769
1073
  # start a background thread to listen for stop event and terminate the server
770
1074
  threading.Thread(
771
1075
  target=_stop_server, args=(stop_event, server), daemon=True
772
1076
  ).start()
1077
+
773
1078
  server.wait_for_termination()
774
1079
  except Exception as e:
775
- _server_error = True
776
- _server_running.set() # unblock any client sessions
1080
+ set_server_error(True)
1081
+ server_running.set() # unblock any client sessions
777
1082
  if "Invalid connection_name 'spark-connect', known ones are " in str(e):
778
1083
  logger.error(
779
1084
  "Ensure 'spark-connect' connection config has been set correctly in connections.toml."
780
1085
  )
781
1086
  else:
782
1087
  logger.error("Error starting up Snowpark Connect server", exc_info=True)
1088
+ attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
783
1089
  raise e
784
1090
  finally:
785
1091
  # flush the telemetry queue if possible
786
1092
  telemetry.shutdown()
787
-
788
-
789
- def _set_remote_url(remote_url: str):
790
- global _server_url, _client_url
791
- _client_url = remote_url
792
- parsed_url = urllib.parse.urlparse(remote_url)
793
- if parsed_url.scheme == "sc":
794
- _server_url = parsed_url.netloc
795
- server_port = parsed_url.port or DEFAULT_PORT
796
- _check_port_is_free(server_port)
797
- elif parsed_url.scheme == "unix":
798
- _server_url = remote_url.split("/;")[0]
799
- else:
800
- raise RuntimeError(f"Invalid Snowpark Connect URL: {remote_url}")
801
-
802
-
803
- def _set_server_tcp_port(server_port: int):
804
- global _server_url, _client_url
805
- _check_port_is_free(server_port)
806
- _server_url = f"[::]:{server_port}"
807
- _client_url = f"sc://127.0.0.1:{server_port}"
808
-
809
-
810
- def _check_port_is_free(port: int) -> None:
811
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
812
- s.settimeout(1)
813
- if s.connect_ex(("127.0.0.1", port)) == 0:
814
- raise RuntimeError(f"TCP port {port} is already in use")
815
-
816
-
817
- def _set_server_unix_domain_socket(path: str):
818
- global _server_url, _client_url
819
- _server_url = f"unix:{path}"
820
- _client_url = f"unix:{path}"
821
-
822
-
823
- def get_server_url() -> str:
824
- global _server_url
825
- if not _server_url:
826
- raise RuntimeError("Server URL not set")
827
- return _server_url
828
-
829
-
830
- def get_client_url() -> str:
831
- global _client_url
832
- if not _client_url:
833
- raise RuntimeError("Client URL not set")
834
- return _client_url
835
-
836
-
837
- def _make_unix_domain_socket() -> str:
838
- parent_dir = tempfile.mkdtemp()
839
- server_path = os.path.join(parent_dir, "snowflake_sas_grpc.sock")
840
- atexit.register(_cleanup_unix_domain_socket, server_path)
841
- return server_path
842
-
843
-
844
- def _cleanup_unix_domain_socket(server_path: str) -> None:
845
- parent_dir = os.path.dirname(server_path)
846
- if os.path.exists(server_path):
847
- os.remove(server_path)
848
- if os.path.exists(parent_dir):
849
- os.rmdir(parent_dir)
850
-
851
-
852
- class UnixDomainSocketChannelBuilder(ChannelBuilder):
853
- """
854
- Spark Connect gRPC channel builder for Unix domain sockets
855
- """
856
-
857
- def __init__(
858
- self, url: str = None, channelOptions: Optional[List[Tuple[str, Any]]] = None
859
- ) -> None:
860
- if url is None:
861
- url = get_client_url()
862
- if url[:6] != "unix:/" or len(url) < 7:
863
- raise PySparkValueError(
864
- error_class="INVALID_CONNECT_URL",
865
- message_parameters={
866
- "detail": "The URL must start with 'unix://'. Please update the URL to follow the correct format, e.g., 'unix://unix_domain_socket_path'.",
867
- },
868
- )
869
-
870
- # Rewrite the URL to use http as the scheme so that we can leverage
871
- # Python's built-in parser to parse URL parameters
872
- fake_url = "http://" + url[6:]
873
- self.url = urllib.parse.urlparse(fake_url)
874
- self.params: Dict[str, str] = {}
875
- self._extract_attributes()
876
-
877
- # Now parse the real unix domain socket URL
878
- self.url = urllib.parse.urlparse(url)
879
-
880
- GRPC_DEFAULT_OPTIONS = [
881
- ("grpc.max_send_message_length", _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE),
882
- ("grpc.max_receive_message_length", _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE),
883
- ("grpc.max_metadata_size", _SPARK_CONNECT_GRPC_MAX_METADATA_SIZE),
884
- (
885
- "grpc.absolute_max_metadata_size",
886
- 2 * _SPARK_CONNECT_GRPC_MAX_METADATA_SIZE,
887
- ),
888
- ]
889
-
890
- if channelOptions is None:
891
- self._channel_options = GRPC_DEFAULT_OPTIONS
892
- else:
893
- self._channel_options = GRPC_DEFAULT_OPTIONS + channelOptions
894
- # For Spark 4.0 support, but also backwards compatible.
895
- self._params = self.params
896
-
897
- def _extract_attributes(self) -> None:
898
- """Extract attributes from parameters.
899
-
900
- This method was copied from
901
- https://github.com/apache/spark/blob/branch-3.5/python/pyspark/sql/connect/client/core.py
902
-
903
- This is required for Spark 4.0 support, since it is dropped in favor of moving
904
- the extraction logic into the constructor.
905
- """
906
- if len(self.url.params) > 0:
907
- parts = self.url.params.split(";")
908
- for p in parts:
909
- kv = p.split("=")
910
- if len(kv) != 2:
911
- raise PySparkValueError(
912
- error_class="INVALID_CONNECT_URL",
913
- message_parameters={
914
- "detail": f"Parameter '{p}' should be provided as a "
915
- f"key-value pair separated by an equal sign (=). Please update "
916
- f"the parameter to follow the correct format, e.g., 'key=value'.",
917
- },
918
- )
919
- self.params[kv[0]] = urllib.parse.unquote(kv[1])
920
-
921
- netloc = self.url.netloc.split(":")
922
- if len(netloc) == 1:
923
- self.host = netloc[0]
924
- if version.parse(pyspark.__version__) >= version.parse("4.0.0"):
925
- from pyspark.sql.connect.client.core import DefaultChannelBuilder
926
-
927
- self.port = DefaultChannelBuilder.default_port()
928
- else:
929
- self.port = ChannelBuilder.default_port()
930
- elif len(netloc) == 2:
931
- self.host = netloc[0]
932
- self.port = int(netloc[1])
933
- else:
934
- raise PySparkValueError(
935
- error_class="INVALID_CONNECT_URL",
936
- message_parameters={
937
- "detail": f"Target destination '{self.url.netloc}' should match the "
938
- f"'<host>:<port>' pattern. Please update the destination to follow "
939
- f"the correct format, e.g., 'hostname:port'.",
940
- },
941
- )
942
-
943
- # We override this to enable compatibility with Spark 4.0
944
- host = None
945
-
946
- @property
947
- def endpoint(self) -> str:
948
- return f"{self.url.scheme}:{self.url.path}"
949
-
950
- def toChannel(self) -> grpc.Channel:
951
- return grpc.insecure_channel(self.endpoint, options=self._channel_options)
1093
+ # End the root span when server shuts down completely
1094
+ otel_end_root_span()
952
1095
 
953
1096
 
954
1097
  def config_snowpark() -> None:
@@ -977,12 +1120,24 @@ def start_jvm():
977
1120
  if tcm.TCM_MODE:
978
1121
  # No-op if JVM is already started in TCM mode
979
1122
  return
980
- raise RuntimeError(
1123
+ exception = RuntimeError(
981
1124
  "JVM must not be running when starting the Spark Connect server"
982
1125
  )
1126
+ attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
1127
+ raise exception
983
1128
 
1129
+ import pathlib
1130
+ import zipfile
1131
+
1132
+ import snowflake.snowpark_connect
1133
+
1134
+ # Import both JAR dependency packages
1135
+ import snowpark_connect_deps_1
1136
+ import snowpark_connect_deps_2
1137
+
1138
+ # First, add JARs from includes/jars directory
984
1139
  pyspark_jars = (
985
- pathlib.Path(snowflake.snowpark_connect.__file__).parent / "includes/jars"
1140
+ pathlib.Path(snowflake.snowpark_connect.__file__).parent / "includes" / "jars"
986
1141
  )
987
1142
 
988
1143
  if "dataframe_processor.zip" in str(pyspark_jars):
@@ -991,18 +1146,31 @@ def start_jvm():
991
1146
  snowflake.snowpark_connect.__file__
992
1147
  ).parent.parent.parent
993
1148
  temp_dir = tempfile.gettempdir()
994
-
995
1149
  extract_folder = "snowflake/snowpark_connect/includes/jars/" # Folder to extract (must end with '/')
996
1150
 
997
1151
  with zipfile.ZipFile(zip_path, "r") as zip_ref:
998
1152
  for member in zip_ref.namelist():
999
1153
  if member.startswith(extract_folder):
1000
1154
  zip_ref.extract(member, path=temp_dir)
1001
-
1002
1155
  pyspark_jars = pathlib.Path(temp_dir) / extract_folder
1003
1156
 
1004
- for path in pyspark_jars.glob("**/*.jar"):
1005
- jpype.addClassPath(path)
1157
+ included_jar_names = set()
1158
+
1159
+ if pyspark_jars.exists():
1160
+ for jar_path in pyspark_jars.glob(
1161
+ "**/*.jar"
1162
+ ): # Use **/*.jar to handle nested paths in TCM
1163
+ jpype.addClassPath(str(jar_path))
1164
+ included_jar_names.add(jar_path.name)
1165
+
1166
+ # Load jar files from both packages, skipping those already loaded from includes/jars
1167
+ jar_path_list = (
1168
+ snowpark_connect_deps_1.list_jars() + snowpark_connect_deps_2.list_jars()
1169
+ )
1170
+ for jar_path in jar_path_list:
1171
+ # Skip if this JAR was already loaded from includes/jars
1172
+ if jar_path.name not in included_jar_names:
1173
+ jpype.addClassPath(jar_path)
1006
1174
 
1007
1175
  # TODO: Should remove convertStrings, but it breaks the JDBC code.
1008
1176
  jvm_settings: list[str] = list(
@@ -1027,6 +1195,7 @@ def start_session(
1027
1195
  snowpark_session: Optional[snowpark.Session] = None,
1028
1196
  connection_parameters: Optional[Dict[str, str]] = None,
1029
1197
  max_grpc_message_size: int = _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE,
1198
+ _add_signal_handler: bool = False,
1030
1199
  ) -> threading.Thread | None:
1031
1200
  """
1032
1201
  Starts Spark Connect server connected to Snowflake. No-op if the Server is already running.
@@ -1048,147 +1217,80 @@ def start_session(
1048
1217
  connection_parameters: A dictionary of connection parameters to use to create the Snowpark session. If this is
1049
1218
  provided, the `snowpark_session` parameter must be None.
1050
1219
  """
1051
- try:
1052
- # Changing the value of our global variable based on the grpc message size provided by the user.
1053
- global _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE
1054
- _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE = max_grpc_message_size
1220
+ # Increase recursion limit to 1100 (1000 by default)
1221
+ # introduced due to Scala OSS Test: org.apache.spark.sql.ClientE2ETestSuite.spark deep recursion
1222
+ sys.setrecursionlimit(1100)
1055
1223
 
1056
- from pyspark.sql.connect.client import ChannelBuilder
1224
+ # Apply PySpark Connect client patching for enhanced debugging (only if telemetry is enabled)
1225
+ from snowflake.snowpark_connect.utils.patch_spark_line_number import (
1226
+ patch_pyspark_connect,
1227
+ )
1057
1228
 
1058
- ChannelBuilder.MAX_MESSAGE_LENGTH = max_grpc_message_size
1229
+ if is_telemetry_enabled():
1230
+ patch_pyspark_connect()
1059
1231
 
1060
- if os.environ.get("SPARK_ENV_LOADED"):
1061
- raise RuntimeError(
1062
- "Snowpark Connect cannot be run inside of a Spark environment"
1063
- )
1064
- if connection_parameters is not None:
1065
- if snowpark_session is not None:
1066
- raise ValueError(
1067
- "Only specify one of snowpark_session and connection_parameters"
1068
- )
1069
- snowpark_session = snowpark.Session.builder.configs(
1070
- connection_parameters
1071
- ).create()
1232
+ try:
1233
+ # Set max grpc message size if provided
1234
+ if max_grpc_message_size is not None:
1235
+ set_grpc_max_message_size(max_grpc_message_size)
1072
1236
 
1073
- global _server_running, _server_error
1074
- if _server_running.is_set():
1237
+ # Validate startup parameters
1238
+ snowpark_session = validate_startup_parameters(
1239
+ snowpark_session, connection_parameters
1240
+ )
1241
+
1242
+ server_running = get_server_running()
1243
+ if server_running.is_set():
1075
1244
  url = get_client_url()
1076
1245
  logger.warning(f"Snowpark Connect session is already running at {url}")
1077
1246
  return
1078
1247
 
1079
- if len(list(filter(None, [remote_url, tcp_port, unix_domain_socket]))) > 1:
1080
- raise RuntimeError(
1081
- "Can only set at most one of remote_url, tcp_port, and unix_domain_socket"
1082
- )
1083
-
1084
- url_from_env = os.environ.get("SPARK_REMOTE", None)
1085
- if remote_url:
1086
- _set_remote_url(remote_url)
1087
- elif tcp_port:
1088
- _set_server_tcp_port(tcp_port)
1089
- elif unix_domain_socket:
1090
- _set_server_unix_domain_socket(unix_domain_socket)
1091
- elif url_from_env:
1092
- # Spark clients use environment variable SPARK_REMOTE to figure out Spark Connect URL. If none of the
1093
- # connection properties (remote_url, tcp_port, unix_domain_socket) are explicitly passed in to this
1094
- # function then we should try and mimic clients' behavior
1095
- # i.e. read the server URL from the SPARK_REMOTE environment variable.
1096
- _set_remote_url(url_from_env)
1097
- else:
1098
- # No connection properties can be found at all - either as arguments to this function or int the environment
1099
- # variable. We use random, unique Unix Domain Socket as a last fallback. Client can connect to this randomly
1100
- # generated UDS port using snowpark_connect.get_session().
1101
- # Mostly used in stored procs and Notebooks to avoid port conflicts.
1102
- if os.name == "nt":
1103
- # Windows does not support unix domain sockets, so use default TCP port instead.
1104
- _set_server_tcp_port(DEFAULT_PORT)
1105
- else:
1106
- # Generate unique, random UDS port. Mostly useful in stored proc environment to avoid port conflicts.
1107
- unix_domain_socket = _make_unix_domain_socket()
1108
- _set_server_unix_domain_socket(unix_domain_socket)
1248
+ configure_server_url(remote_url, tcp_port, unix_domain_socket)
1109
1249
 
1110
1250
  start_jvm()
1111
1251
  _disable_protobuf_recursion_limit()
1252
+ otel_initialize()
1253
+
1254
+ if _add_signal_handler:
1255
+ setup_signal_handlers(stop_event)
1112
1256
 
1113
1257
  if is_daemon:
1114
1258
  arguments = (stop_event, snowpark_session)
1115
- # `daemon=True` ensures the server thread exits when script finishes.
1116
- server_thread = threading.Thread(target=_serve, args=arguments, daemon=True)
1117
- server_thread.start()
1118
- _server_running.wait()
1119
- if _server_error:
1120
- raise RuntimeError("Snowpark Connect session failed to start")
1121
- return server_thread
1122
- else:
1123
- # Launch in the foreground.
1124
- _serve(session=snowpark_session)
1125
- except Exception as e:
1126
- _reset_server_run_state()
1127
- logger.error(e, exc_info=True)
1128
- raise e
1129
1259
 
1260
+ target_func = otel_create_context_wrapper(_serve)
1130
1261
 
1131
- def get_session(url: Optional[str] = None, conf: SparkConf = None) -> SparkSession:
1132
- """
1133
- Returns spark connect session
1134
-
1135
- Parameters:
1136
- url (Optional[str]): Spark connect server URL. Uses default server URL if none is provided.
1137
-
1138
- Returns:
1139
- A new spark connect session
1140
-
1141
- Raises:
1142
- RuntimeError: If Spark Connect server is not started.
1143
- """
1144
- try:
1145
- if not url:
1146
- url = get_client_url()
1262
+ server_thread = threading.Thread(
1263
+ target=target_func, args=arguments, daemon=True
1264
+ )
1265
+ server_thread.start()
1266
+ server_running.wait()
1267
+ if get_server_error():
1268
+ exception = RuntimeError("Snowpark Connect session failed to start")
1269
+ attach_custom_error_code(
1270
+ exception, ErrorCodes.STARTUP_CONNECTION_FAILED
1271
+ )
1272
+ raise exception
1147
1273
 
1148
- if url.startswith("unix:/"):
1149
- b = SparkSession.builder.channelBuilder(UnixDomainSocketChannelBuilder())
1274
+ return server_thread
1150
1275
  else:
1151
- b = SparkSession.builder.remote(url)
1152
-
1153
- if conf is not None:
1154
- for k, v in conf.getAll():
1155
- b.config(k, v)
1156
-
1157
- return b.getOrCreate()
1276
+ # Launch in the foreground with stop_event
1277
+ _serve(stop_event=stop_event, session=snowpark_session)
1158
1278
  except Exception as e:
1159
1279
  _reset_server_run_state()
1160
1280
  logger.error(e, exc_info=True)
1281
+ attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
1161
1282
  raise e
1162
1283
 
1163
1284
 
1164
1285
  def init_spark_session(conf: SparkConf = None) -> SparkSession:
1165
- try:
1166
- # For Notebooks on SPCS
1167
- from jdk4py import JAVA_HOME
1168
-
1169
- os.environ["JAVA_HOME"] = str(JAVA_HOME)
1170
- except ModuleNotFoundError:
1171
- # For notebooks on Warehouse
1172
- os.environ["JAVA_HOME"] = os.environ["CONDA_PREFIX"]
1173
- os.environ["JAVA_LD_LIBRARY_PATH"] = os.path.join(
1174
- os.environ["CONDA_PREFIX"], "lib", "server"
1175
- )
1176
- logger.info("JAVA_HOME=%s", os.environ["JAVA_HOME"])
1286
+ _setup_spark_environment()
1287
+ from snowflake.snowpark_connect.utils.session import _get_current_snowpark_session
1177
1288
 
1178
- os.environ["SPARK_LOCAL_HOSTNAME"] = "127.0.0.1"
1179
- os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
1180
-
1181
- snowpark_session = snowpark.context.get_active_session()
1289
+ snowpark_session = _get_current_snowpark_session()
1182
1290
  start_session(snowpark_session=snowpark_session)
1183
1291
  return get_session(conf=conf)
1184
1292
 
1185
1293
 
1186
- def enable_debug_logging():
1187
- logger.setLevel(logging.DEBUG)
1188
- for handler in logger.handlers:
1189
- handler.setLevel(logging.DEBUG)
1190
-
1191
-
1192
1294
  def _get_files_metadata(data_source: relations_proto.Read.DataSource) -> List[str]:
1193
1295
  # TODO: Handle paths on the cloud
1194
1296
  paths = data_source.paths
@@ -1206,15 +1308,3 @@ def _get_files_metadata(data_source: relations_proto.Read.DataSource) -> List[st
1206
1308
  ]
1207
1309
  )
1208
1310
  return files
1209
-
1210
-
1211
- def _disable_protobuf_recursion_limit():
1212
- # https://github.com/protocolbuffers/protobuf/blob/960e79087b332583c80537c949621108a85aa442/src/google/protobuf/io/coded_stream.h#L616
1213
- # Disable protobuf recursion limit (default 100) because Spark workloads often produce deeply nested execution plans. For example:
1214
- # - Queries with many unions
1215
- # - Complex expressions with multiple levels of nesting
1216
- # Without this, legitimate Spark queries would fail with `(DecodeError) Error parsing message with type 'spark.connect.Relation'` error.
1217
- # see test_sql_resulting_in_nested_protobuf
1218
- from google.protobuf.pyext import cpp_message
1219
-
1220
- cpp_message._message.SetAllowOversizeProtos(True)