snowpark-connect 0.27.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  3. snowflake/snowpark_connect/client/__init__.py +15 -0
  4. snowflake/snowpark_connect/client/error_utils.py +30 -0
  5. snowflake/snowpark_connect/client/exceptions.py +36 -0
  6. snowflake/snowpark_connect/client/query_results.py +90 -0
  7. snowflake/snowpark_connect/client/server.py +717 -0
  8. snowflake/snowpark_connect/client/utils/__init__.py +10 -0
  9. snowflake/snowpark_connect/client/utils/session.py +85 -0
  10. snowflake/snowpark_connect/column_name_handler.py +404 -243
  11. snowflake/snowpark_connect/column_qualifier.py +43 -0
  12. snowflake/snowpark_connect/config.py +309 -26
  13. snowflake/snowpark_connect/constants.py +2 -0
  14. snowflake/snowpark_connect/dataframe_container.py +102 -8
  15. snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
  16. snowflake/snowpark_connect/error/error_codes.py +50 -0
  17. snowflake/snowpark_connect/error/error_utils.py +172 -23
  18. snowflake/snowpark_connect/error/exceptions.py +13 -4
  19. snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
  20. snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
  21. snowflake/snowpark_connect/execute_plan/utils.py +5 -1
  22. snowflake/snowpark_connect/expression/error_utils.py +28 -0
  23. snowflake/snowpark_connect/expression/function_defaults.py +9 -2
  24. snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
  25. snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
  26. snowflake/snowpark_connect/expression/literal.py +37 -13
  27. snowflake/snowpark_connect/expression/map_cast.py +224 -15
  28. snowflake/snowpark_connect/expression/map_expression.py +80 -27
  29. snowflake/snowpark_connect/expression/map_extension.py +322 -12
  30. snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
  31. snowflake/snowpark_connect/expression/map_udf.py +86 -20
  32. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
  33. snowflake/snowpark_connect/expression/map_unresolved_function.py +2964 -829
  34. snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
  35. snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
  36. snowflake/snowpark_connect/expression/map_window_function.py +18 -3
  37. snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
  39. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/{spark-connect-client-jvm_2.12-3.5.6.jar → spark-connect-client-jvm_2.13-3.5.6.jar} +0 -0
  43. snowflake/snowpark_connect/includes/jars/{spark-sql_2.12-3.5.6.jar → spark-sql_2.13-3.5.6.jar} +0 -0
  44. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
  45. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
  46. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  47. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  48. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
  49. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
  50. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
  51. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
  52. snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
  53. snowflake/snowpark_connect/relation/io_utils.py +110 -10
  54. snowflake/snowpark_connect/relation/map_aggregate.py +239 -256
  55. snowflake/snowpark_connect/relation/map_catalog.py +5 -1
  56. snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
  57. snowflake/snowpark_connect/relation/map_extension.py +263 -29
  58. snowflake/snowpark_connect/relation/map_join.py +683 -442
  59. snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
  60. snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
  61. snowflake/snowpark_connect/relation/map_relation.py +48 -19
  62. snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
  63. snowflake/snowpark_connect/relation/map_show_string.py +13 -6
  64. snowflake/snowpark_connect/relation/map_sql.py +1233 -222
  65. snowflake/snowpark_connect/relation/map_stats.py +48 -9
  66. snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
  67. snowflake/snowpark_connect/relation/map_udtf.py +14 -4
  68. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
  69. snowflake/snowpark_connect/relation/read/map_read.py +134 -43
  70. snowflake/snowpark_connect/relation/read/map_read_csv.py +326 -47
  71. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +21 -6
  72. snowflake/snowpark_connect/relation/read/map_read_json.py +324 -86
  73. snowflake/snowpark_connect/relation/read/map_read_parquet.py +146 -28
  74. snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
  75. snowflake/snowpark_connect/relation/read/map_read_socket.py +15 -3
  76. snowflake/snowpark_connect/relation/read/map_read_table.py +86 -6
  77. snowflake/snowpark_connect/relation/read/map_read_text.py +22 -4
  78. snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
  79. snowflake/snowpark_connect/relation/read/reader_config.py +42 -3
  80. snowflake/snowpark_connect/relation/read/utils.py +50 -5
  81. snowflake/snowpark_connect/relation/stage_locator.py +91 -55
  82. snowflake/snowpark_connect/relation/utils.py +128 -5
  83. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
  84. snowflake/snowpark_connect/relation/write/map_write.py +929 -319
  85. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
  86. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  87. snowflake/snowpark_connect/resources_initializer.py +171 -48
  88. snowflake/snowpark_connect/server.py +528 -473
  89. snowflake/snowpark_connect/server_common/__init__.py +503 -0
  90. snowflake/snowpark_connect/snowflake_session.py +65 -0
  91. snowflake/snowpark_connect/start_server.py +53 -5
  92. snowflake/snowpark_connect/type_mapping.py +349 -27
  93. snowflake/snowpark_connect/type_support.py +130 -0
  94. snowflake/snowpark_connect/typed_column.py +9 -7
  95. snowflake/snowpark_connect/utils/artifacts.py +9 -8
  96. snowflake/snowpark_connect/utils/cache.py +49 -27
  97. snowflake/snowpark_connect/utils/concurrent.py +36 -1
  98. snowflake/snowpark_connect/utils/context.py +195 -37
  99. snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
  100. snowflake/snowpark_connect/utils/env_utils.py +5 -1
  101. snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
  102. snowflake/snowpark_connect/utils/identifiers.py +137 -3
  103. snowflake/snowpark_connect/utils/io_utils.py +57 -1
  104. snowflake/snowpark_connect/utils/java_stored_procedure.py +151 -0
  105. snowflake/snowpark_connect/utils/java_udaf_utils.py +321 -0
  106. snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
  107. snowflake/snowpark_connect/utils/jvm_udf_utils.py +281 -0
  108. snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
  109. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
  110. snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
  111. snowflake/snowpark_connect/utils/profiling.py +25 -8
  112. snowflake/snowpark_connect/utils/scala_udf_utils.py +185 -340
  113. snowflake/snowpark_connect/utils/sequence.py +21 -0
  114. snowflake/snowpark_connect/utils/session.py +64 -28
  115. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
  116. snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
  117. snowflake/snowpark_connect/utils/telemetry.py +192 -40
  118. snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
  119. snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
  120. snowflake/snowpark_connect/utils/udf_cache.py +117 -41
  121. snowflake/snowpark_connect/utils/udf_helper.py +39 -37
  122. snowflake/snowpark_connect/utils/udf_utils.py +133 -14
  123. snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
  124. snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
  125. snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
  126. snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
  127. snowflake/snowpark_connect/version.py +1 -1
  128. snowflake/snowpark_decoder/dp_session.py +6 -2
  129. snowflake/snowpark_decoder/spark_decoder.py +12 -0
  130. {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +14 -4
  131. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +16 -7
  132. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +139 -168
  133. snowflake/snowpark_connect/hidden_column.py +0 -39
  134. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  135. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  136. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  137. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  138. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  139. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  140. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  141. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  142. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  143. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  144. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  145. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  146. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  147. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  148. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  149. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  150. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  151. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  152. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  153. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  154. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  155. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  156. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  157. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  158. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  159. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  160. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  161. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  162. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  163. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  164. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  165. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  166. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  167. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  168. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  169. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  170. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  171. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  172. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  173. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  174. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  175. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  176. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  177. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  178. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  179. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  180. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  181. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  182. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  183. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  184. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  185. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  186. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  187. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  188. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  189. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  190. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  191. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  192. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  193. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  194. {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
  195. {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
  196. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
  197. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
  198. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  199. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
  200. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,717 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ """
6
+ Remote Spark Connect Server for Snowpark Connect.
7
+
8
+ Lightweight servicer that forwards Spark Connect requests to Snowflake backend
9
+ via REST API using SparkConnectResource SDK.
10
+ """
11
+
12
+ import threading
13
+ import uuid
14
+ from concurrent import futures
15
+ from typing import Dict, Iterator, Optional
16
+
17
+ import grpc
18
+ import pyarrow as pa
19
+ from google.rpc import code_pb2
20
+ from grpc_status import rpc_status
21
+ from pyspark.conf import SparkConf
22
+ from pyspark.sql.connect.proto import base_pb2, base_pb2_grpc, types_pb2
23
+ from pyspark.sql.connect.session import SparkSession
24
+ from snowflake.core.spark_connect._spark_connect import SparkConnectResource
25
+
26
+ from snowflake import snowpark
27
+ from snowflake.snowpark import Session
28
+ from snowflake.snowpark_connect.client.error_utils import attach_custom_error_code
29
+ from snowflake.snowpark_connect.client.exceptions import (
30
+ GrpcErrorStatusException,
31
+ UnexpectedResponseException,
32
+ )
33
+ from snowflake.snowpark_connect.client.query_results import (
34
+ fetch_query_result_as_arrow_batches,
35
+ fetch_query_result_as_protobuf,
36
+ )
37
+ from snowflake.snowpark_connect.client.utils.session import (
38
+ get_or_create_snowpark_session,
39
+ )
40
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
41
+ from snowflake.snowpark_connect.server_common import ( # noqa: F401 - re-exported for public API
42
+ _disable_protobuf_recursion_limit,
43
+ _get_default_grpc_options,
44
+ _reset_server_run_state,
45
+ _setup_spark_environment,
46
+ _stop_server,
47
+ configure_server_url,
48
+ get_client_url,
49
+ get_server_error,
50
+ get_server_running,
51
+ get_server_url,
52
+ get_session,
53
+ set_grpc_max_message_size,
54
+ set_server_error,
55
+ setup_signal_handlers,
56
+ validate_startup_parameters,
57
+ )
58
+ from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
59
+ from snowflake.snowpark_connect.utils.env_utils import get_int_from_env
60
+ from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
61
+ from snowflake.snowpark_connect.utils.telemetry import telemetry
62
+ from spark.connect import envelope_pb2
63
+
64
+
65
+ def _log_and_return_error(
66
+ err_mesg: str,
67
+ error: Exception,
68
+ status_code: grpc.StatusCode,
69
+ context: grpc.ServicerContext,
70
+ ) -> None:
71
+ """Log error and set gRPC context."""
72
+ context.set_details(str(error))
73
+ context.set_code(status_code)
74
+ logger.error(f"{err_mesg} status code: {status_code}", exc_info=True)
75
+ return None
76
+
77
+
78
+ def _validate_response_type(
79
+ resp_envelope: envelope_pb2.ResponseEnvelope, expected_field: str
80
+ ) -> None:
81
+ """Validate that response envelope has expected type."""
82
+ field_name = resp_envelope.WhichOneof("response_type")
83
+ if field_name != expected_field:
84
+ raise UnexpectedResponseException(
85
+ f"Expected response type {expected_field}, got {field_name}"
86
+ )
87
+
88
+
89
+ def _build_result_complete_response(
90
+ request: base_pb2.ExecutePlanRequest,
91
+ ) -> base_pb2.ExecutePlanResponse:
92
+ return base_pb2.ExecutePlanResponse(
93
+ session_id=request.session_id,
94
+ operation_id=request.operation_id or "0",
95
+ result_complete=base_pb2.ExecutePlanResponse.ResultComplete(),
96
+ )
97
+
98
+
99
+ def _build_exec_plan_resp_stream_from_df_query_result(
100
+ request: base_pb2.ExecutePlanRequest,
101
+ session: Session,
102
+ query_result: envelope_pb2.DataframeQueryResult,
103
+ ) -> Iterator[base_pb2.ExecutePlanResponse]:
104
+ query_id = query_result.result_job_uuid
105
+ arrow_schema = pa.ipc.read_schema(pa.BufferReader(query_result.arrow_schema))
106
+ spark_schema = types_pb2.DataType()
107
+ spark_schema.ParseFromString(query_result.spark_schema)
108
+
109
+ for row_count, arrow_batch_bytes in fetch_query_result_as_arrow_batches(
110
+ session, query_id, arrow_schema
111
+ ):
112
+ yield base_pb2.ExecutePlanResponse(
113
+ session_id=request.session_id,
114
+ operation_id=request.operation_id or "0",
115
+ arrow_batch=base_pb2.ExecutePlanResponse.ArrowBatch(
116
+ row_count=row_count,
117
+ data=arrow_batch_bytes,
118
+ ),
119
+ schema=spark_schema,
120
+ )
121
+
122
+ yield _build_result_complete_response(request)
123
+
124
+
125
+ def _build_exec_plan_resp_stream_from_resp_envelope(
126
+ request: base_pb2.ExecutePlanRequest,
127
+ session: Session,
128
+ resp_envelope: envelope_pb2.ResponseEnvelope,
129
+ ) -> Iterator[base_pb2.ExecutePlanResponse]:
130
+ """Build execution plan response stream from response envelope."""
131
+ resp_type = resp_envelope.WhichOneof("response_type")
132
+
133
+ if resp_type == "dataframe_query_result":
134
+ query_result = resp_envelope.dataframe_query_result
135
+ yield from _build_exec_plan_resp_stream_from_df_query_result(
136
+ request, session, query_result
137
+ )
138
+ elif resp_type == "execute_plan_response":
139
+ yield resp_envelope.execute_plan_response
140
+ yield _build_result_complete_response(request)
141
+ elif resp_type == "status":
142
+ raise GrpcErrorStatusException(resp_envelope.status)
143
+ else:
144
+ logger.warning(f"Unexpected response type: {resp_type}")
145
+
146
+
147
+ class SnowflakeConnectClientServicer(base_pb2_grpc.SparkConnectServiceServicer):
148
+ # Configs frequently read by PySpark client but not supported in SAS.
149
+ # We return the request's default value directly without calling the backend.
150
+ #
151
+ # Why this matters:
152
+ # - These configs are read via get_config_with_defaults() on every show()/toPandas()
153
+ # - In SAS, show() internally calls toPandas() (see pyspark/sql/connect/dataframe.py _show_string)
154
+ # - Without this optimization, every show() would make a backend config request
155
+ #
156
+ # Source: OSS Spark python/pyspark/sql/connect/client/core.py to_pandas() method
157
+ _UNSUPPORTED_CONFIGS: frozenset[str] = frozenset(
158
+ {
159
+ # Read on every show() and toPandas() call - controls Arrow memory optimization
160
+ "spark.sql.execution.arrow.pyspark.selfDestruct.enabled",
161
+ # Read on toPandas() when DataFrame contains StructType fields
162
+ "spark.sql.execution.pandas.structHandlingMode",
163
+ }
164
+ )
165
+
166
+ def __init__(self, snowpark_session: Session) -> None:
167
+ self.snowpark_session = snowpark_session
168
+ self._config_cache: SynchronizedDict[str, str] = SynchronizedDict()
169
+
170
+ def _get_spark_resource(self) -> SparkConnectResource:
171
+ return SparkConnectResource(self.snowpark_session)
172
+
173
+ def _parse_response_envelope(
174
+ self, response_bytes: bytes | bytearray, expected_resp_type: str = None
175
+ ) -> envelope_pb2.ResponseEnvelope:
176
+ """Parse and validate response envelope from GS backend."""
177
+
178
+ resp_envelope = envelope_pb2.ResponseEnvelope()
179
+ if isinstance(response_bytes, bytearray):
180
+ response_bytes = bytes(response_bytes)
181
+ resp_envelope.ParseFromString(response_bytes)
182
+
183
+ resp_type = resp_envelope.WhichOneof("response_type")
184
+ if resp_type == "status":
185
+ raise GrpcErrorStatusException(resp_envelope.status)
186
+
187
+ if not resp_envelope.query_id and not resp_type == "dataframe_query_result":
188
+ _validate_response_type(resp_envelope, expected_resp_type)
189
+
190
+ return resp_envelope
191
+
192
+ def ExecutePlan(
193
+ self, request: base_pb2.ExecutePlanRequest, context: grpc.ServicerContext
194
+ ) -> Iterator[base_pb2.ExecutePlanResponse]:
195
+ """Execute a Spark plan by forwarding to GS backend."""
196
+ logger.debug("Received Execute Plan request")
197
+ query_id = None
198
+ telemetry.initialize_request_summary(request)
199
+
200
+ try:
201
+ spark_resource = self._get_spark_resource()
202
+ response_bytes = spark_resource.execute_plan(request.SerializeToString())
203
+ resp_envelope = self._parse_response_envelope(
204
+ response_bytes, "execute_plan_response"
205
+ )
206
+ query_id = resp_envelope.query_id
207
+
208
+ if query_id:
209
+ job_res_envelope = fetch_query_result_as_protobuf(
210
+ self.snowpark_session, resp_envelope.query_id
211
+ )
212
+ yield from _build_exec_plan_resp_stream_from_resp_envelope(
213
+ request, self.snowpark_session, job_res_envelope
214
+ )
215
+ else:
216
+ yield from _build_exec_plan_resp_stream_from_resp_envelope(
217
+ request, self.snowpark_session, resp_envelope
218
+ )
219
+
220
+ except GrpcErrorStatusException as e:
221
+ telemetry.report_request_failure(e)
222
+ context.abort_with_status(rpc_status.to_status(e.status))
223
+ except Exception as e:
224
+ telemetry.report_request_failure(e)
225
+ logger.error(f"Error in ExecutePlan, query id {query_id}", exc_info=True)
226
+ return _log_and_return_error(
227
+ "Error in ExecutePlan call", e, grpc.StatusCode.INTERNAL, context
228
+ )
229
+ finally:
230
+ telemetry.send_request_summary_telemetry()
231
+
232
+ def _call_backend_config(
233
+ self, request: base_pb2.ConfigRequest
234
+ ) -> base_pb2.ConfigResponse:
235
+ """Forward config request to GS and return response."""
236
+ spark_resource = self._get_spark_resource()
237
+ response_bytes = spark_resource.config(request.SerializeToString())
238
+ resp_envelope = self._parse_response_envelope(response_bytes, "config_response")
239
+
240
+ query_id = resp_envelope.query_id
241
+ if query_id:
242
+ resp_envelope = fetch_query_result_as_protobuf(
243
+ self.snowpark_session, query_id
244
+ )
245
+ assert resp_envelope.WhichOneof("response_type") == "config_response"
246
+
247
+ return resp_envelope.config_response
248
+
249
+ def _handle_get_cached_config_request(
250
+ self,
251
+ request: base_pb2.ConfigRequest,
252
+ items: list[tuple[str, Optional[str]]],
253
+ op_name: str,
254
+ update_cache_on_miss: bool = True,
255
+ ) -> base_pb2.ConfigResponse:
256
+ """
257
+ Handle config requests with caching and unsupported config checks.
258
+
259
+ Args:
260
+ request: The original ConfigRequest.
261
+ items: List of (key, default_value) tuples. default_value is None for get/get_option.
262
+ op_name: Name of the operation for logging (e.g., "get", "get_with_default").
263
+ update_cache_on_miss: Whether to update the cache with values returned from backend.
264
+ """
265
+ keys = [k for k, _ in items]
266
+
267
+ # 1. Unsupported Configs Check
268
+ # If all keys are unsupported, return defaults (if any) or empty response without calling backend
269
+ if all(key in self._UNSUPPORTED_CONFIGS for key in keys):
270
+ response = base_pb2.ConfigResponse(session_id=request.session_id)
271
+ for key, default_val in items:
272
+ resp_pair = response.pairs.add()
273
+ resp_pair.key = key
274
+ if default_val is not None:
275
+ resp_pair.value = default_val
276
+ logger.debug(f"Config {op_name} returning defaults for unsupported: {keys}")
277
+ return response
278
+
279
+ # 2. Cache Check
280
+ # Check if all keys are in cache
281
+ cached_values = {key: self._config_cache.get(key) for key in keys}
282
+ if cached_values and all(value is not None for value in cached_values.values()):
283
+ response = base_pb2.ConfigResponse(session_id=request.session_id)
284
+ for key in keys:
285
+ resp_pair = response.pairs.add()
286
+ resp_pair.key = key
287
+ resp_pair.value = cached_values[key]
288
+ logger.debug(f"Config {op_name} served from cache: {keys}")
289
+ return response
290
+
291
+ # 3. Cache Miss - Call Backend
292
+ config_response = self._call_backend_config(request)
293
+
294
+ if update_cache_on_miss:
295
+ for pair in config_response.pairs:
296
+ if pair.HasField("value"):
297
+ self._config_cache[pair.key] = pair.value
298
+ logger.debug(f"Config {op_name} cached from backend: {keys}")
299
+ else:
300
+ logger.debug(f"Config {op_name} from backend (not cached): {keys}")
301
+
302
+ return config_response
303
+
304
+ def Config(
305
+ self, request: base_pb2.ConfigRequest, context: grpc.ServicerContext
306
+ ) -> base_pb2.ConfigResponse:
307
+ logger.debug("Received Config request")
308
+ telemetry.initialize_request_summary(request)
309
+
310
+ try:
311
+ op = request.operation
312
+ op_type = op.WhichOneof("op_type")
313
+
314
+ match op_type:
315
+ case "get_with_default":
316
+ pairs = op.get_with_default.pairs
317
+ items = [
318
+ (p.key, p.value if p.HasField("value") else None) for p in pairs
319
+ ]
320
+ return self._handle_get_cached_config_request(
321
+ request, items, "get_with_default", update_cache_on_miss=False
322
+ )
323
+
324
+ case "get":
325
+ keys = op.get.keys
326
+ items = [(k, None) for k in keys]
327
+ return self._handle_get_cached_config_request(request, items, "get")
328
+
329
+ case "set":
330
+ config_response = self._call_backend_config(request)
331
+
332
+ for pair in op.set.pairs:
333
+ if pair.HasField("value"):
334
+ self._config_cache[pair.key] = pair.value
335
+ logger.debug(
336
+ f"Config set updated cache: {[p.key for p in op.set.pairs]}"
337
+ )
338
+ return config_response
339
+
340
+ case "unset":
341
+ config_response = self._call_backend_config(request)
342
+
343
+ for key in op.unset.keys:
344
+ self._config_cache.remove(key)
345
+ logger.debug(f"Config unset updated cache: {list(op.unset.keys)}")
346
+ return config_response
347
+
348
+ case "get_option":
349
+ keys = op.get_option.keys
350
+ items = [(k, None) for k in keys]
351
+ return self._handle_get_cached_config_request(
352
+ request, items, "get_option"
353
+ )
354
+
355
+ case "get_all":
356
+ # Always call backend since this is a prefix-based search and we
357
+ # can't know if all matching keys are in cache. Cache the results.
358
+ config_response = self._call_backend_config(request)
359
+
360
+ # Cache all returned values
361
+ for pair in config_response.pairs:
362
+ if pair.HasField("value"):
363
+ self._config_cache[pair.key] = pair.value
364
+ prefix = (
365
+ op.get_all.prefix if op.get_all.HasField("prefix") else "all"
366
+ )
367
+ logger.debug(
368
+ f"Config get_all cached {len(config_response.pairs)} items (prefix={prefix})"
369
+ )
370
+ return config_response
371
+
372
+ case _:
373
+ # Forward other operations to backend (no caching)
374
+ logger.debug(
375
+ f"Forwarding unknown config request of type {op_type} to the backend"
376
+ )
377
+ return self._call_backend_config(request)
378
+
379
+ except GrpcErrorStatusException as e:
380
+ telemetry.report_request_failure(e)
381
+ context.abort_with_status(rpc_status.to_status(e.status))
382
+ except Exception as e:
383
+ telemetry.report_request_failure(e)
384
+ logger.error("Error in Config", exc_info=True)
385
+ return _log_and_return_error(
386
+ "Error in Config call", e, grpc.StatusCode.INTERNAL, context
387
+ )
388
+ finally:
389
+ telemetry.send_request_summary_telemetry()
390
+
391
+ def AnalyzePlan(
392
+ self, request: base_pb2.AnalyzePlanRequest, context: grpc.ServicerContext
393
+ ) -> base_pb2.AnalyzePlanResponse:
394
+ logger.debug("Received Analyze Plan request")
395
+ query_id = None
396
+ telemetry.initialize_request_summary(request)
397
+
398
+ try:
399
+ spark_resource = self._get_spark_resource()
400
+ response_bytes = spark_resource.analyze_plan(request.SerializeToString())
401
+ resp_envelope = self._parse_response_envelope(
402
+ response_bytes, "analyze_plan_response"
403
+ )
404
+
405
+ query_id = resp_envelope.query_id
406
+
407
+ if query_id:
408
+ resp_envelope = fetch_query_result_as_protobuf(
409
+ self.snowpark_session, query_id
410
+ )
411
+ assert (
412
+ resp_envelope.WhichOneof("response_type") == "analyze_plan_response"
413
+ )
414
+
415
+ return resp_envelope.analyze_plan_response
416
+
417
+ except GrpcErrorStatusException as e:
418
+ telemetry.report_request_failure(e)
419
+ context.abort_with_status(rpc_status.to_status(e.status))
420
+ except Exception as e:
421
+ telemetry.report_request_failure(e)
422
+ logger.error(f"Error in AnalyzePlan, query id {query_id}", exc_info=True)
423
+ return _log_and_return_error(
424
+ "Error in AnalyzePlan call", e, grpc.StatusCode.INTERNAL, context
425
+ )
426
+ finally:
427
+ telemetry.send_request_summary_telemetry()
428
+
429
+ def AddArtifacts(
430
+ self,
431
+ request_iterator: Iterator[base_pb2.AddArtifactsRequest],
432
+ context: grpc.ServicerContext,
433
+ ) -> base_pb2.AddArtifactsResponse:
434
+ logger.debug("Received AddArtifacts request")
435
+ add_artifacts_response = None
436
+
437
+ spark_resource = self._get_spark_resource()
438
+
439
+ for request in request_iterator:
440
+ query_id = None
441
+ telemetry.initialize_request_summary(request)
442
+ try:
443
+ response_bytes = spark_resource.add_artifacts(
444
+ request.SerializeToString()
445
+ )
446
+ resp_envelope = self._parse_response_envelope(
447
+ response_bytes, "add_artifacts_response"
448
+ )
449
+
450
+ query_id = resp_envelope.query_id
451
+
452
+ if query_id:
453
+ resp_envelope = fetch_query_result_as_protobuf(
454
+ self.snowpark_session, query_id
455
+ )
456
+ assert (
457
+ resp_envelope.WhichOneof("response_type")
458
+ == "add_artifacts_response"
459
+ )
460
+
461
+ add_artifacts_response = resp_envelope.add_artifacts_response
462
+
463
+ except GrpcErrorStatusException as e:
464
+ telemetry.report_request_failure(e)
465
+ context.abort_with_status(rpc_status.to_status(e.status))
466
+ except Exception as e:
467
+ telemetry.report_request_failure(e)
468
+ logger.error(
469
+ f"Error in AddArtifacts, query id {query_id}", exc_info=True
470
+ )
471
+ return _log_and_return_error(
472
+ "Error in AddArtifacts call", e, grpc.StatusCode.INTERNAL, context
473
+ )
474
+ finally:
475
+ telemetry.send_request_summary_telemetry()
476
+
477
+ if add_artifacts_response is None:
478
+ raise ValueError("AddArtifacts received empty request_iterator")
479
+
480
+ return add_artifacts_response
481
+
482
+ def ArtifactStatus(
483
+ self, request: base_pb2.ArtifactStatusesRequest, context: grpc.ServicerContext
484
+ ) -> base_pb2.ArtifactStatusesResponse:
485
+ """Check statuses of artifacts in the session and returns them in a [[ArtifactStatusesResponse]]"""
486
+ logger.debug("Received ArtifactStatus request")
487
+ query_id = None
488
+ telemetry.initialize_request_summary(request)
489
+
490
+ try:
491
+ spark_resource = self._get_spark_resource()
492
+ response_bytes = spark_resource.artifact_status(request.SerializeToString())
493
+ resp_envelope = self._parse_response_envelope(
494
+ response_bytes, "artifact_status_response"
495
+ )
496
+
497
+ query_id = resp_envelope.query_id
498
+
499
+ if query_id:
500
+ resp_envelope = fetch_query_result_as_protobuf(
501
+ self.snowpark_session, query_id
502
+ )
503
+ assert (
504
+ resp_envelope.WhichOneof("response_type")
505
+ == "artifact_status_response"
506
+ )
507
+
508
+ return resp_envelope.artifact_status_response
509
+ except GrpcErrorStatusException as e:
510
+ telemetry.report_request_failure(e)
511
+ context.abort_with_status(rpc_status.to_status(e.status))
512
+ except Exception as e:
513
+ telemetry.report_request_failure(e)
514
+ logger.error(f"Error in ArtifactStatus, query id {query_id}", exc_info=True)
515
+ return _log_and_return_error(
516
+ "Error in ArtifactStatus call", e, grpc.StatusCode.INTERNAL, context
517
+ )
518
+ finally:
519
+ telemetry.send_request_summary_telemetry()
520
+
521
+ def Interrupt(
522
+ self, request: base_pb2.InterruptRequest, context: grpc.ServicerContext
523
+ ) -> base_pb2.InterruptResponse:
524
+ """Interrupt running executions."""
525
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
526
+ context.set_details("Method Interrupt not implemented!")
527
+ raise NotImplementedError("Method Interrupt not implemented!")
528
+
529
+ def ReleaseExecute(
530
+ self, request: base_pb2.ReleaseExecuteRequest, context: grpc.ServicerContext
531
+ ) -> base_pb2.ReleaseExecuteResponse:
532
+ """Release an execution."""
533
+ logger.debug("Received Release Execute request")
534
+ telemetry.initialize_request_summary(request)
535
+ try:
536
+ return base_pb2.ReleaseExecuteResponse(
537
+ session_id=request.session_id,
538
+ operation_id=request.operation_id or str(uuid.uuid4()),
539
+ )
540
+ except Exception as e:
541
+ telemetry.report_request_failure(e)
542
+ logger.error("Error in ReleaseExecute", exc_info=True)
543
+ return _log_and_return_error(
544
+ "Error in ReleaseExecute call", e, grpc.StatusCode.INTERNAL, context
545
+ )
546
+ finally:
547
+ telemetry.send_request_summary_telemetry()
548
+
549
+ def ReattachExecute(
550
+ self, request: base_pb2.ReattachExecuteRequest, context: grpc.ServicerContext
551
+ ) -> Iterator[base_pb2.ExecutePlanResponse]:
552
+ """Reattach to an existing reattachable execution.
553
+ The ExecutePlan must have been started with ReattachOptions.reattachable=true.
554
+ If the ExecutePlanResponse stream ends without a ResultComplete message, there is more to
555
+ continue. If there is a ResultComplete, the client should use ReleaseExecute with
556
+ """
557
+ from google.rpc import status_pb2
558
+
559
+ status = status_pb2.Status(
560
+ code=code_pb2.UNIMPLEMENTED,
561
+ message="Method ReattachExecute not implemented! INVALID_HANDLE.OPERATION_NOT_FOUND",
562
+ )
563
+ context.abort_with_status(rpc_status.to_status(status))
564
+
565
+
566
+ def _serve(
567
+ stop_event: Optional[threading.Event] = None,
568
+ session: Optional[snowpark.Session] = None,
569
+ ) -> None:
570
+ server_running = get_server_running()
571
+ try:
572
+ if session is None:
573
+ session = get_or_create_snowpark_session()
574
+
575
+ # Initialize telemetry with session and thin client source identifier
576
+ telemetry.initialize(session, source="SparkConnectLightWeightClient")
577
+
578
+ server_options = _get_default_grpc_options()
579
+ max_workers = get_int_from_env("SPARK_CONNECT_CLIENT_GRPC_MAX_WORKERS", 10)
580
+
581
+ server = grpc.server(
582
+ futures.ThreadPoolExecutor(max_workers=max_workers),
583
+ options=server_options,
584
+ )
585
+
586
+ base_pb2_grpc.add_SparkConnectServiceServicer_to_server(
587
+ SnowflakeConnectClientServicer(session),
588
+ server,
589
+ )
590
+ server_url = get_server_url()
591
+ server.add_insecure_port(server_url)
592
+ logger.info(f"Starting Snowpark Connect server on {server_url}...")
593
+ server.start()
594
+ server_running.set()
595
+ logger.info("Snowpark Connect server started!")
596
+ telemetry.send_server_started_telemetry()
597
+
598
+ if stop_event is not None:
599
+ # start a background thread to listen for stop event and terminate the server
600
+ threading.Thread(
601
+ target=_stop_server, args=(stop_event, server), daemon=True
602
+ ).start()
603
+
604
+ server.wait_for_termination()
605
+ except Exception as e:
606
+ set_server_error(True)
607
+ server_running.set() # unblock any client sessions
608
+ if "Invalid connection_name 'spark-connect', known ones are " in str(e):
609
+ logger.error(
610
+ "Ensure 'spark-connect' connection config has been set correctly in connections.toml."
611
+ )
612
+ else:
613
+ logger.error("Error starting up Snowpark Connect server", exc_info=True)
614
+ attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
615
+ raise e
616
+ finally:
617
+ # Flush the telemetry queue if possible
618
+ telemetry.shutdown()
619
+
620
+
621
+ def start_session(
622
+ is_daemon: bool = True,
623
+ remote_url: Optional[str] = None,
624
+ tcp_port: Optional[int] = None,
625
+ unix_domain_socket: Optional[str] = None,
626
+ stop_event: threading.Event = None,
627
+ snowpark_session: Optional[snowpark.Session] = None,
628
+ connection_parameters: Optional[Dict[str, str]] = None,
629
+ max_grpc_message_size: int = None,
630
+ _add_signal_handler: bool = False,
631
+ ) -> threading.Thread | None:
632
+ """
633
+ Starts Spark Connect server connected to Snowflake. No-op if the Server is already running.
634
+
635
+ Parameters:
636
+ is_daemon (bool): Should run the server as daemon or not. use True to automatically shut the Spark connect
637
+ server down when the main program (or test) finishes. use False to start the server in a
638
+ stand-alone, long-running mode.
639
+ remote_url (Optional[str]): sc:// URL on which to start the Spark Connect server. This option is incompatible with the tcp_port
640
+ and unix_domain_socket parameters.
641
+ tcp_port (Optional[int]): TCP port on which to start the Spark Connect server. This option is incompatible with
642
+ the remote_url and unix_domain_socket parameters.
643
+ unix_domain_socket (Optional[str]): Path to the unix domain socket on which to start the Spark Connect server.
644
+ This option is incompatible with the remote_url and tcp_port parameters.
645
+ stop_event (Optional[threading.Event]): Stop the SAS server when stop_event.set() is called.
646
+ Only works when is_daemon=True.
647
+ snowpark_session: A Snowpark session to use for this connection; currently the only applicable use of this is to
648
+ pass in the session created by the stored proc environment.
649
+ connection_parameters: A dictionary of connection parameters to use to create the Snowpark session. If this is
650
+ provided, the `snowpark_session` parameter must be None.
651
+ """
652
+
653
+ try:
654
+ # Set max grpc message size if provided
655
+ if max_grpc_message_size is not None:
656
+ set_grpc_max_message_size(max_grpc_message_size)
657
+
658
+ # Validate startup parameters
659
+ snowpark_session = validate_startup_parameters(
660
+ snowpark_session, connection_parameters
661
+ )
662
+
663
+ server_running = get_server_running()
664
+ if server_running.is_set():
665
+ url = get_client_url()
666
+ logger.warning(f"Snowpark Connect session is already running at {url}")
667
+ return
668
+
669
+ # Configure server URL
670
+ configure_server_url(remote_url, tcp_port, unix_domain_socket)
671
+
672
+ _disable_protobuf_recursion_limit()
673
+
674
+ if _add_signal_handler:
675
+ setup_signal_handlers(stop_event)
676
+
677
+ if is_daemon:
678
+ arguments = (stop_event, snowpark_session)
679
+ server_thread = threading.Thread(target=_serve, args=arguments, daemon=True)
680
+ server_thread.start()
681
+ server_running.wait()
682
+ if get_server_error():
683
+ exception = RuntimeError("Snowpark Connect session failed to start")
684
+ attach_custom_error_code(
685
+ exception, ErrorCodes.STARTUP_CONNECTION_FAILED
686
+ )
687
+ raise exception
688
+
689
+ return server_thread
690
+ else:
691
+ # Launch in the foreground with stop_event
692
+ _serve(stop_event=stop_event, session=snowpark_session)
693
+ except Exception as e:
694
+ _reset_server_run_state()
695
+ logger.error(e, exc_info=True)
696
+ attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
697
+ raise e
698
+
699
+
700
+ def init_spark_session(conf: SparkConf = None) -> SparkSession:
701
+ """
702
+ Initialize and return a Spark session.
703
+
704
+ Parameters:
705
+ conf (SparkConf): Optional Spark configuration.
706
+
707
+ Returns:
708
+ A new SparkSession connected to the Snowpark Connect thin Client server.
709
+ """
710
+ _setup_spark_environment(False)
711
+ from snowflake.snowpark_connect.client.utils.session import (
712
+ _get_current_snowpark_session,
713
+ )
714
+
715
+ snowpark_session = _get_current_snowpark_session()
716
+ start_session(snowpark_session=snowpark_session)
717
+ return get_session(conf=conf)