snowpark-connect 0.30.1__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (87) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/column_name_handler.py +200 -102
  3. snowflake/snowpark_connect/column_qualifier.py +47 -0
  4. snowflake/snowpark_connect/config.py +51 -16
  5. snowflake/snowpark_connect/dataframe_container.py +3 -2
  6. snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
  7. snowflake/snowpark_connect/error/error_codes.py +50 -0
  8. snowflake/snowpark_connect/error/error_utils.py +142 -22
  9. snowflake/snowpark_connect/error/exceptions.py +13 -4
  10. snowflake/snowpark_connect/execute_plan/map_execution_command.py +9 -3
  11. snowflake/snowpark_connect/execute_plan/map_execution_root.py +5 -1
  12. snowflake/snowpark_connect/execute_plan/utils.py +5 -1
  13. snowflake/snowpark_connect/expression/function_defaults.py +9 -2
  14. snowflake/snowpark_connect/expression/literal.py +7 -1
  15. snowflake/snowpark_connect/expression/map_cast.py +17 -5
  16. snowflake/snowpark_connect/expression/map_expression.py +53 -8
  17. snowflake/snowpark_connect/expression/map_extension.py +37 -11
  18. snowflake/snowpark_connect/expression/map_sql_expression.py +102 -32
  19. snowflake/snowpark_connect/expression/map_udf.py +10 -2
  20. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +38 -14
  21. snowflake/snowpark_connect/expression/map_unresolved_function.py +1476 -292
  22. snowflake/snowpark_connect/expression/map_unresolved_star.py +14 -8
  23. snowflake/snowpark_connect/expression/map_update_fields.py +14 -4
  24. snowflake/snowpark_connect/expression/map_window_function.py +18 -3
  25. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
  26. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +38 -13
  27. snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
  28. snowflake/snowpark_connect/relation/io_utils.py +6 -1
  29. snowflake/snowpark_connect/relation/map_aggregate.py +8 -5
  30. snowflake/snowpark_connect/relation/map_catalog.py +5 -1
  31. snowflake/snowpark_connect/relation/map_column_ops.py +92 -59
  32. snowflake/snowpark_connect/relation/map_extension.py +38 -17
  33. snowflake/snowpark_connect/relation/map_join.py +26 -12
  34. snowflake/snowpark_connect/relation/map_local_relation.py +5 -1
  35. snowflake/snowpark_connect/relation/map_relation.py +33 -7
  36. snowflake/snowpark_connect/relation/map_row_ops.py +23 -7
  37. snowflake/snowpark_connect/relation/map_sql.py +124 -25
  38. snowflake/snowpark_connect/relation/map_stats.py +5 -1
  39. snowflake/snowpark_connect/relation/map_subquery_alias.py +4 -1
  40. snowflake/snowpark_connect/relation/map_udtf.py +14 -4
  41. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +49 -13
  42. snowflake/snowpark_connect/relation/read/map_read.py +15 -3
  43. snowflake/snowpark_connect/relation/read/map_read_csv.py +11 -3
  44. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
  45. snowflake/snowpark_connect/relation/read/map_read_json.py +8 -2
  46. snowflake/snowpark_connect/relation/read/map_read_parquet.py +13 -3
  47. snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
  48. snowflake/snowpark_connect/relation/read/map_read_table.py +21 -8
  49. snowflake/snowpark_connect/relation/read/map_read_text.py +5 -1
  50. snowflake/snowpark_connect/relation/read/metadata_utils.py +5 -1
  51. snowflake/snowpark_connect/relation/stage_locator.py +5 -1
  52. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
  53. snowflake/snowpark_connect/relation/write/map_write.py +160 -48
  54. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
  55. snowflake/snowpark_connect/resources_initializer.py +5 -1
  56. snowflake/snowpark_connect/server.py +73 -21
  57. snowflake/snowpark_connect/type_mapping.py +90 -20
  58. snowflake/snowpark_connect/typed_column.py +8 -6
  59. snowflake/snowpark_connect/utils/context.py +42 -1
  60. snowflake/snowpark_connect/utils/describe_query_cache.py +3 -0
  61. snowflake/snowpark_connect/utils/env_utils.py +5 -1
  62. snowflake/snowpark_connect/utils/identifiers.py +11 -3
  63. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
  64. snowflake/snowpark_connect/utils/profiling.py +25 -8
  65. snowflake/snowpark_connect/utils/scala_udf_utils.py +11 -3
  66. snowflake/snowpark_connect/utils/session.py +24 -4
  67. snowflake/snowpark_connect/utils/telemetry.py +6 -0
  68. snowflake/snowpark_connect/utils/temporary_view_cache.py +5 -1
  69. snowflake/snowpark_connect/utils/udf_cache.py +5 -3
  70. snowflake/snowpark_connect/utils/udf_helper.py +20 -6
  71. snowflake/snowpark_connect/utils/udf_utils.py +4 -4
  72. snowflake/snowpark_connect/utils/udtf_helper.py +5 -1
  73. snowflake/snowpark_connect/utils/udtf_utils.py +34 -26
  74. snowflake/snowpark_connect/version.py +1 -1
  75. snowflake/snowpark_decoder/dp_session.py +1 -1
  76. {snowpark_connect-0.30.1.dist-info → snowpark_connect-0.32.0.dist-info}/METADATA +7 -3
  77. {snowpark_connect-0.30.1.dist-info → snowpark_connect-0.32.0.dist-info}/RECORD +85 -85
  78. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
  79. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
  80. {snowpark_connect-0.30.1.data → snowpark_connect-0.32.0.data}/scripts/snowpark-connect +0 -0
  81. {snowpark_connect-0.30.1.data → snowpark_connect-0.32.0.data}/scripts/snowpark-session +0 -0
  82. {snowpark_connect-0.30.1.data → snowpark_connect-0.32.0.data}/scripts/snowpark-submit +0 -0
  83. {snowpark_connect-0.30.1.dist-info → snowpark_connect-0.32.0.dist-info}/WHEEL +0 -0
  84. {snowpark_connect-0.30.1.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/LICENSE-binary +0 -0
  85. {snowpark_connect-0.30.1.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/LICENSE.txt +0 -0
  86. {snowpark_connect-0.30.1.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/NOTICE-binary +0 -0
  87. {snowpark_connect-0.30.1.dist-info → snowpark_connect-0.32.0.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ https://github.com/apache/spark/blob/master/common/utils/src/main/resources/erro
12
12
  import json
13
13
  import pathlib
14
14
  import re
15
+ import threading
15
16
  import traceback
16
17
 
17
18
  import jpype
@@ -35,9 +36,12 @@ from snowflake.core.exceptions import NotFoundError
35
36
 
36
37
  from snowflake.connector.errors import ProgrammingError
37
38
  from snowflake.snowpark.exceptions import SnowparkClientException, SnowparkSQLException
38
- from snowflake.snowpark_connect.config import global_config
39
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
39
40
  from snowflake.snowpark_connect.error.error_mapping import ERROR_MAPPINGS_JSON
40
41
 
42
+ # Thread-local storage for custom error codes when we can't attach them directly to exceptions
43
+ _thread_local = threading.local()
44
+
41
45
  # The JSON string in error_mapping.py is a copy of https://github.com/apache/spark/blob/master/common/utils/src/main/resources/error/error-conditions.json.
42
46
  # The file doesn't have to be synced with spark latest main. Just update it when required.
43
47
  current_dir = pathlib.Path(__file__).parent.resolve()
@@ -81,6 +85,21 @@ invalid_bit_pattern = re.compile(
81
85
  )
82
86
 
83
87
 
88
+ def attach_custom_error_code(exception: Exception, custom_error_code: int) -> Exception:
89
+ """
90
+ Attach a custom error code to any exception instance.
91
+ This allows us to add custom error codes to existing PySpark exceptions.
92
+ """
93
+ if not hasattr(exception, "custom_error_code"):
94
+ try:
95
+ exception.custom_error_code = custom_error_code
96
+ except (AttributeError, TypeError):
97
+ # Some exception types (like Java exceptions) don't allow setting custom attributes
98
+ # Store the error code in thread-local storage for later retrieval
99
+ _thread_local.pending_error_code = custom_error_code
100
+ return exception
101
+
102
+
84
103
  def contains_udtf_select(sql_string):
85
104
  # This function tries to detect if the SQL string contains a UDTF (User Defined Table Function) call.
86
105
  # Looks for select FROM TABLE(...) or FROM ( TABLE(...) )
@@ -100,20 +119,29 @@ def _get_converted_known_sql_or_custom_exception(
100
119
 
101
120
  # custom exception
102
121
  if "[snowpark_connect::invalid_array_index]" in msg:
103
- return ArrayIndexOutOfBoundsException(
122
+ exception = ArrayIndexOutOfBoundsException(
104
123
  message='The index <indexValue> is out of bounds. The array has <arraySize> elements. Use the SQL function `get()` to tolerate accessing element at invalid index and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.'
105
124
  )
125
+ attach_custom_error_code(exception, ErrorCodes.ARRAY_INDEX_OUT_OF_BOUNDS)
126
+ return exception
106
127
  if "[snowpark_connect::invalid_index_of_zero]" in msg:
107
- return SparkRuntimeException(
128
+ exception = SparkRuntimeException(
108
129
  message="[INVALID_INDEX_OF_ZERO] The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1)."
109
130
  )
131
+ attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
132
+ return exception
110
133
  if "[snowpark_connect::invalid_index_of_zero_in_slice]" in msg:
111
- return SparkRuntimeException(
134
+ exception = SparkRuntimeException(
112
135
  message="Unexpected value for start in function slice: SQL array indices start at 1."
113
136
  )
137
+ attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
138
+ return exception
139
+
114
140
  invalid_bit = invalid_bit_pattern.search(msg)
115
141
  if invalid_bit:
116
- return IllegalArgumentException(message=invalid_bit.group(0))
142
+ exception = IllegalArgumentException(message=invalid_bit.group(0))
143
+ attach_custom_error_code(exception, ErrorCodes.INVALID_FUNCTION_ARGUMENT)
144
+ return exception
117
145
  match = snowpark_connect_exception_pattern.search(
118
146
  ex.message if hasattr(ex, "message") else str(ex)
119
147
  )
@@ -125,71 +153,136 @@ def _get_converted_known_sql_or_custom_exception(
125
153
  if class_name
126
154
  else SparkConnectGrpcException
127
155
  )
128
- return exception_class(message=message)
156
+ exception = exception_class(message=message)
157
+ attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
158
+ return exception
129
159
 
130
160
  if "select with no columns" in msg and contains_udtf_select(query):
131
161
  # We try our best to detect if the SQL string contains a UDTF call and the output schema is empty.
132
- return PythonException(message=f"[UDTF_RETURN_SCHEMA_MISMATCH] {ex.message}")
162
+ exception = PythonException(
163
+ message=f"[UDTF_RETURN_SCHEMA_MISMATCH] {ex.message}"
164
+ )
165
+ attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
166
+ return exception
133
167
 
134
168
  # known sql exception
135
169
  if ex.sql_error_code not in (100038, 100037, 100035, 100357):
136
170
  return None
137
171
 
138
172
  if "(22018): numeric value" in msg:
139
- return NumberFormatException(
173
+ exception = NumberFormatException(
140
174
  message='[CAST_INVALID_INPUT] Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary setting "spark.sql.ansi.enabled" to "false" may bypass this error.'
141
175
  )
176
+ attach_custom_error_code(exception, ErrorCodes.INVALID_CAST)
177
+ return exception
142
178
  if "(22018): boolean value" in msg:
143
- return SparkRuntimeException(
179
+ exception = SparkRuntimeException(
144
180
  message='[CAST_INVALID_INPUT] Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary setting "spark.sql.ansi.enabled" to "false" may bypass this error.'
145
181
  )
182
+ attach_custom_error_code(exception, ErrorCodes.INVALID_CAST)
183
+ return exception
146
184
  if "(22007): timestamp" in msg:
147
- return AnalysisException(
185
+ exception = AnalysisException(
148
186
  "[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Data type mismatch"
149
187
  )
188
+ attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
189
+ return exception
150
190
 
151
191
  if getattr(ex, "sql_error_code", None) == 100357:
152
192
  if re.search(init_multi_args_exception_pattern, msg):
153
- return PythonException(
193
+ exception = PythonException(
154
194
  message=f"[UDTF_EXEC_ERROR] User defined table function encountered an error in the init method {ex.message}"
155
195
  )
196
+ attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
197
+ return exception
156
198
  if re.search(terminate_multi_args_exception_pattern, msg):
157
- return PythonException(
199
+ exception = PythonException(
158
200
  message=f"[UDTF_EXEC_ERROR] User defined table function encountered an error in the terminate method: {ex.message}"
159
201
  )
202
+ attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
203
+ return exception
160
204
 
161
205
  if "failed to split string, provided pattern:" in msg:
162
- return IllegalArgumentException(
206
+ exception = IllegalArgumentException(
163
207
  message=f"Failed to split string using provided pattern. {ex.message}"
164
208
  )
209
+ attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
210
+ return exception
165
211
 
166
212
  if "100357" in msg and "wrong tuple size for returned value" in msg:
167
- return PythonException(
213
+ exception = PythonException(
168
214
  message=f"[UDTF_RETURN_SCHEMA_MISMATCH] The number of columns in the result does not match the specified schema. {ex.message}"
169
215
  )
216
+ attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
217
+ return exception
170
218
 
171
219
  if "100357 (p0000): python interpreter error:" in msg:
172
220
  if "in eval" in msg:
173
- return PythonException(
221
+ exception = PythonException(
174
222
  message=f"[UDTF_EXEC_ERROR] User defined table function encountered an error in the 'eval' method: error. {ex.message}"
175
223
  )
224
+ attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
225
+ return exception
176
226
 
177
227
  if "in terminate" in msg:
178
- return PythonException(
228
+ exception = PythonException(
179
229
  message=f"[UDTF_EXEC_ERROR] User defined table function encountered an error in the 'terminate' method: terminate error. {ex.message}"
180
230
  )
231
+ attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
232
+ return exception
181
233
 
182
234
  if "object is not iterable" in msg and contains_udtf_select(query):
183
- return PythonException(
235
+ exception = PythonException(
184
236
  message=f"[UDTF_RETURN_NOT_ITERABLE] {ex.message}"
185
237
  )
238
+ attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
239
+ return exception
186
240
 
187
- return PythonException(message=f"{ex.message}")
241
+ exception = PythonException(message=f"{ex.message}")
242
+ attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
243
+ return exception
188
244
 
189
245
  return None
190
246
 
191
247
 
248
+ def _sanitize_custom_error_message(msg):
249
+ if "[snowpark_connect::unsupported_operation]" in msg:
250
+ return (
251
+ msg.replace("[snowpark_connect::unsupported_operation] ", ""),
252
+ ErrorCodes.UNSUPPORTED_OPERATION,
253
+ )
254
+ if "[snowpark_connect::internal_error]" in msg:
255
+ return (
256
+ msg.replace("[snowpark_connect::internal_error] ", ""),
257
+ ErrorCodes.INTERNAL_ERROR,
258
+ )
259
+ if "[snowpark_connect::invalid_operation]" in msg:
260
+ return (
261
+ msg.replace("[snowpark_connect::invalid_operation] ", ""),
262
+ ErrorCodes.INVALID_OPERATION,
263
+ )
264
+ if "[snowpark_connect::type_mismatch]" in msg:
265
+ return (
266
+ msg.replace("[snowpark_connect::type_mismatch] ", ""),
267
+ ErrorCodes.TYPE_MISMATCH,
268
+ )
269
+ if "[snowpark_connect::invalid_input]" in msg:
270
+ return (
271
+ msg.replace("[snowpark_connect::invalid_input] ", ""),
272
+ ErrorCodes.INVALID_INPUT,
273
+ )
274
+ if "[snowpark_connect::unsupported_type]" in msg:
275
+ return (
276
+ msg.replace("[snowpark_connect::unsupported_type] ", ""),
277
+ ErrorCodes.UNSUPPORTED_TYPE,
278
+ )
279
+ return msg, None
280
+
281
+
192
282
  def build_grpc_error_response(ex: Exception) -> status_pb2.Status:
283
+ # Lazy import to avoid circular dependency
284
+ from snowflake.snowpark_connect.config import global_config
285
+
193
286
  include_stack_trace = (
194
287
  global_config.get("spark.sql.pyspark.jvmStacktrace.enabled")
195
288
  if hasattr(global_config, "spark.sql.pyspark.jvmStacktrace.enabled")
@@ -211,6 +304,7 @@ def build_grpc_error_response(ex: Exception) -> status_pb2.Status:
211
304
  error_class="DIVIDE_BY_ZERO",
212
305
  message_parameters={"config": '"spark.sql.ansi.enabled"'},
213
306
  )
307
+ attach_custom_error_code(ex, ErrorCodes.DIVISION_BY_ZERO)
214
308
  elif ex.sql_error_code in (100096, 100040):
215
309
  # Spark seems to want the Java base class instead of org.apache.spark.sql.SparkDateTimeException
216
310
  # which is what should really be thrown
@@ -299,14 +393,40 @@ def build_grpc_error_response(ex: Exception) -> status_pb2.Status:
299
393
  domain="snowflake.sas",
300
394
  )
301
395
 
302
- detail = any_pb2.Any()
303
- detail.Pack(error_info)
304
-
305
396
  if message is None:
306
397
  message = str(ex)
307
398
 
399
+ custom_error_code = None
400
+
401
+ # attach error code using visa exception message
402
+ message, custom_error_code_from_msg = _sanitize_custom_error_message(message)
403
+
404
+ # Check if exception already has a custom error code, if not add INTERNAL_ERROR as default
405
+ if not hasattr(ex, "custom_error_code") or ex.custom_error_code is None:
406
+ attach_custom_error_code(
407
+ ex,
408
+ ErrorCodes.INTERNAL_ERROR
409
+ if custom_error_code_from_msg is None
410
+ else custom_error_code_from_msg,
411
+ )
412
+
413
+ # Get the custom error code from the exception or thread-local storage
414
+ custom_error_code = getattr(ex, "custom_error_code", None) or getattr(
415
+ _thread_local, "pending_error_code", None
416
+ )
417
+
418
+ # Clear thread-local storage after retrieving the error code
419
+ if hasattr(_thread_local, "pending_error_code"):
420
+ delattr(_thread_local, "pending_error_code")
421
+
422
+ separator = "==========================================="
423
+ error_code_added_message = f"\n{separator}\nSNOWPARK CONNECT ERROR CODE: {custom_error_code}\n{separator}\n{message}"
424
+
425
+ detail = any_pb2.Any()
426
+ detail.Pack(error_info)
427
+
308
428
  rich_status = status_pb2.Status(
309
- code=code_pb2.INTERNAL, message=message, details=[detail]
429
+ code=code_pb2.INTERNAL, message=error_code_added_message, details=[detail]
310
430
  )
311
431
  return rich_status
312
432
 
@@ -2,27 +2,36 @@
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
4
 
5
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
6
+
5
7
 
6
8
  class SnowparkConnectException(Exception):
7
9
  """Parent class to all SnowparkConnect related exceptions."""
8
10
 
9
- def __init__(self, *args, **kwargs) -> None:
11
+ def __init__(self, *args, custom_error_code=None, **kwargs) -> None:
10
12
  super().__init__(*args, **kwargs)
13
+ self.custom_error_code = custom_error_code
11
14
 
12
15
 
13
16
  class MissingDatabase(SnowparkConnectException):
14
- def __init__(self) -> None:
17
+ def __init__(self, custom_error_code=None) -> None:
15
18
  super().__init__(
16
19
  "No default database found in session",
20
+ custom_error_code=custom_error_code or ErrorCodes.MISSING_DATABASE,
17
21
  )
18
22
 
19
23
 
20
24
  class MissingSchema(SnowparkConnectException):
21
- def __init__(self) -> None:
25
+ def __init__(self, custom_error_code=None) -> None:
22
26
  super().__init__(
23
27
  "No default schema found in session",
28
+ custom_error_code=custom_error_code or ErrorCodes.MISSING_SCHEMA,
24
29
  )
25
30
 
26
31
 
27
32
  class MaxRetryExceeded(SnowparkConnectException):
28
- ...
33
+ def __init__(
34
+ self,
35
+ message="Maximum retry attempts exceeded",
36
+ ) -> None:
37
+ super().__init__(message)
@@ -11,6 +11,8 @@ from snowflake.snowpark_connect.column_name_handler import ColumnNames
11
11
  from snowflake.snowpark_connect.config import global_config, sessions_config
12
12
  from snowflake.snowpark_connect.constants import SERVER_SIDE_SESSION_ID
13
13
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
14
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
15
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
14
16
  from snowflake.snowpark_connect.execute_plan.utils import pandas_to_arrow_batches_bytes
15
17
  from snowflake.snowpark_connect.expression import map_udf
16
18
  from snowflake.snowpark_connect.relation import map_udtf
@@ -52,9 +54,11 @@ def _create_column_rename_map(
52
54
  new_column_name = (
53
55
  f"{new_column_name}_DEDUP_{column_counts[normalized_name] - 1}"
54
56
  )
55
- renamed_cols.append(ColumnNames(new_column_name, col.snowpark_name, []))
57
+ renamed_cols.append(ColumnNames(new_column_name, col.snowpark_name, set()))
56
58
  else:
57
- not_renamed_cols.append(ColumnNames(new_column_name, col.snowpark_name, []))
59
+ not_renamed_cols.append(
60
+ ColumnNames(new_column_name, col.snowpark_name, set())
61
+ )
58
62
 
59
63
  if len(renamed_cols) == 0:
60
64
  return {
@@ -207,6 +211,8 @@ def map_execution_command(
207
211
  map_udtf.register_udtf(request.plan.command.register_table_function)
208
212
 
209
213
  case other:
210
- raise SnowparkConnectNotImplementedError(
214
+ exception = SnowparkConnectNotImplementedError(
211
215
  f"Command type {other} not implemented"
212
216
  )
217
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
218
+ raise exception
@@ -21,6 +21,8 @@ from snowflake.snowpark._internal.utils import (
21
21
  )
22
22
  from snowflake.snowpark_connect.constants import SERVER_SIDE_SESSION_ID
23
23
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
24
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
25
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
24
26
  from snowflake.snowpark_connect.execute_plan.utils import (
25
27
  arrow_table_to_arrow_bytes,
26
28
  pandas_to_arrow_batches_bytes,
@@ -56,7 +58,9 @@ def sproc_connector_fetch_arrow_batches_fix(self) -> Iterator[Table]:
56
58
  if self._prefetch_hook is not None:
57
59
  self._prefetch_hook()
58
60
  if self._query_result_format != "arrow":
59
- raise NotSupportedError
61
+ exception = NotSupportedError()
62
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
63
+ raise exception
60
64
  return self._result_set._fetch_arrow_batches()
61
65
 
62
66
 
@@ -8,6 +8,8 @@ import pyspark.sql.connect.proto.relations_pb2 as relation_proto
8
8
  from pyspark.sql.pandas.types import _dedup_names
9
9
 
10
10
  from snowflake.snowpark import types as sf_types
11
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
12
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
11
13
  from snowflake.snowpark_connect.type_mapping import map_snowpark_types_to_pyarrow_types
12
14
  from snowflake.snowpark_connect.utils.telemetry import (
13
15
  SnowparkConnectNotImplementedError,
@@ -88,9 +90,11 @@ def is_streaming(rel: relation_proto.Relation) -> bool:
88
90
  case "html_string":
89
91
  return is_streaming(rel.html_string.input)
90
92
  case "cached_remote_relation":
91
- raise SnowparkConnectNotImplementedError(
93
+ exception = SnowparkConnectNotImplementedError(
92
94
  "Cached remote relation not implemented"
93
95
  )
96
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
97
+ raise exception
94
98
  case "common_inline_user_defined_table_function":
95
99
  return is_streaming(rel.common_inline_user_defined_table_function.input)
96
100
  case "fill_na":
@@ -7,6 +7,9 @@ from typing import Any
7
7
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_pb2
8
8
  import pyspark.sql.connect.proto.types_pb2 as types_pb2
9
9
 
10
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
11
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
12
+
10
13
 
11
14
  @dataclass(frozen=True)
12
15
  class DefaultParameter:
@@ -154,7 +157,9 @@ def _create_literal_expression(value: Any) -> expressions_pb2.Expression:
154
157
  null_type.null.SetInParent()
155
158
  expr.literal.null.CopyFrom(null_type)
156
159
  else:
157
- raise ValueError(f"Unsupported literal type: {value}")
160
+ exception = ValueError(f"Unsupported literal type: {value}")
161
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
162
+ raise exception
158
163
 
159
164
  return expr
160
165
 
@@ -189,11 +194,13 @@ def inject_function_defaults(
189
194
 
190
195
  # Check if any required params are missing.
191
196
  if missing_arg_count > len(defaults):
192
- raise ValueError(
197
+ exception = ValueError(
193
198
  f"Function '{function_name}' is missing required arguments. "
194
199
  f"Expected {total_args} args, got {current_arg_count}, "
195
200
  f"but only {len(defaults)} defaults are defined."
196
201
  )
202
+ attach_custom_error_code(exception, ErrorCodes.INVALID_FUNCTION_ARGUMENT)
203
+ raise exception
197
204
 
198
205
  defaults_to_append = defaults[-missing_arg_count:]
199
206
  injected = False
@@ -10,6 +10,8 @@ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
10
10
  from tzlocal import get_localzone
11
11
 
12
12
  from snowflake.snowpark_connect.config import global_config
13
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
14
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
13
15
  from snowflake.snowpark_connect.utils.context import get_is_evaluating_sql
14
16
  from snowflake.snowpark_connect.utils.telemetry import (
15
17
  SnowparkConnectNotImplementedError,
@@ -100,4 +102,8 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
100
102
  case "null" | None:
101
103
  return None, "NULL"
102
104
  case other:
103
- raise SnowparkConnectNotImplementedError(f"Other Literal Type {other}")
105
+ exception = SnowparkConnectNotImplementedError(
106
+ f"Other Literal Type {other}"
107
+ )
108
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
109
+ raise exception
@@ -31,6 +31,8 @@ from snowflake.snowpark.types import (
31
31
  )
32
32
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
33
33
  from snowflake.snowpark_connect.config import global_config
34
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
35
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
34
36
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
35
37
  from snowflake.snowpark_connect.type_mapping import (
36
38
  map_type_string_to_snowpark_type,
@@ -87,7 +89,9 @@ def map_cast(
87
89
  to_type = map_type_string_to_snowpark_type(exp.cast.type_str)
88
90
  to_type_str = exp.cast.type_str.upper()
89
91
  case _:
90
- raise ValueError("No type to cast to")
92
+ exception = ValueError("No type to cast to")
93
+ attach_custom_error_code(exception, ErrorCodes.INVALID_CAST)
94
+ raise exception
91
95
 
92
96
  from_exp = exp.cast.expr
93
97
  new_name, typed_column = map_single_column_expression(
@@ -300,9 +304,11 @@ def map_cast(
300
304
  else:
301
305
  result_exp = snowpark_fn.try_cast(col, to_type)
302
306
  case (StringType(), _):
303
- raise AnalysisException(
307
+ exception = AnalysisException(
304
308
  f"""[DATATYPE_MISMATCH.CAST_WITHOUT_SUGGESTION] Cannot resolve "{col_name}" due to data type mismatch: cannot cast "{snowpark_to_proto_type(from_type, column_mapping)}" to "{exp.cast.type_str.upper()}".;"""
305
309
  )
310
+ attach_custom_error_code(exception, ErrorCodes.INVALID_CAST)
311
+ raise exception
306
312
  case _:
307
313
  result_exp = snowpark_fn.cast(col, to_type)
308
314
 
@@ -317,9 +323,11 @@ def sanity_check(
317
323
  """
318
324
 
319
325
  if isinstance(from_type, LongType) and isinstance(to_type, BinaryType):
320
- raise NumberFormatException(
326
+ exception = NumberFormatException(
321
327
  f"""[DATATYPE_MISMATCH.CAST_WITH_CONF_SUGGESTION] Cannot resolve "CAST({value} AS BINARY)" due to data type mismatch: cannot cast "BIGINT" to "BINARY" with ANSI mode on."""
322
328
  )
329
+ attach_custom_error_code(exception, ErrorCodes.INVALID_CAST)
330
+ raise exception
323
331
 
324
332
  if (
325
333
  from_type_cast
@@ -329,9 +337,11 @@ def sanity_check(
329
337
  if value is not None:
330
338
  value = value.strip().lower()
331
339
  if value not in {"t", "true", "f", "false", "y", "yes", "n", "no", "0", "1"}:
332
- raise SparkRuntimeException(
340
+ exception = SparkRuntimeException(
333
341
  f"""[CAST_INVALID_INPUT] The value '{value}' of the type "STRING" cannot be cast to "BOOLEAN" because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error."""
334
342
  )
343
+ attach_custom_error_code(exception, ErrorCodes.INVALID_CAST)
344
+ raise exception
335
345
 
336
346
  raise_cast_failure_exception = False
337
347
  if isinstance(to_type, _IntegralType):
@@ -351,6 +361,8 @@ def sanity_check(
351
361
  except Exception:
352
362
  raise_cast_failure_exception = True
353
363
  if raise_cast_failure_exception:
354
- raise NumberFormatException(
364
+ exception = NumberFormatException(
355
365
  """[CAST_INVALID_INPUT] Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary setting "spark.sql.ansi.enabled" to "false" may bypass this error."""
356
366
  )
367
+ attach_custom_error_code(exception, ErrorCodes.INVALID_CAST)
368
+ raise exception
@@ -14,6 +14,8 @@ from snowflake.snowpark import Session
14
14
  from snowflake.snowpark._internal.analyzer.expression import UnresolvedAttribute
15
15
  from snowflake.snowpark.types import TimestampTimeZone, TimestampType
16
16
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
17
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
18
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
17
19
  from snowflake.snowpark_connect.expression import (
18
20
  map_extension,
19
21
  map_udf,
@@ -62,9 +64,11 @@ def map_alias(
62
64
  # Multi-column case: handle like explode("map").alias("key", "value")
63
65
  col_names, col = map_expression(alias.expr, column_mapping, typer)
64
66
  if len(col_names) != len(list(alias.name)):
65
- raise ValueError(
67
+ exception = ValueError(
66
68
  f"Found the unresolved operator: 'Project [{col_names} AS ({', '.join(list(alias.name))})]. Number of aliases ({len(list(alias.name))}) does not match number of columns ({len(col_names)})"
67
69
  )
70
+ attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
71
+ raise exception
68
72
  return list(alias.name), col
69
73
 
70
74
  name, col = map_single_column_expression(alias.expr, column_mapping, typer)
@@ -226,22 +230,27 @@ def map_expression(
226
230
  | exp.sort_order.SORT_DIRECTION_ASCENDING
227
231
  ):
228
232
  if exp.sort_order.null_ordering == exp.sort_order.SORT_NULLS_LAST:
229
- return [child_name], snowpark_fn.asc_nulls_last(child_column)
233
+ col = snowpark_fn.asc_nulls_last(child_column.col)
230
234
  else:
231
235
  # If nulls are not specified or null_ordering is FIRST in the sort order, Spark defaults to nulls
232
236
  # first in the case of ascending sort order.
233
- return [child_name], snowpark_fn.asc_nulls_first(child_column)
237
+ col = snowpark_fn.asc_nulls_first(child_column.col)
234
238
  case exp.sort_order.SORT_DIRECTION_DESCENDING:
235
239
  if exp.sort_order.null_ordering == exp.sort_order.SORT_NULLS_FIRST:
236
- return [child_name], snowpark_fn.desc_nulls_first(child_column)
240
+ col = snowpark_fn.desc_nulls_first(child_column.col)
237
241
  else:
238
242
  # If nulls are not specified or null_ordering is LAST in the sort order, Spark defaults to nulls
239
243
  # last in the case of descending sort order.
240
- return [child_name], snowpark_fn.desc_nulls_last(child_column)
244
+ col = snowpark_fn.desc_nulls_last(child_column.col)
241
245
  case _:
242
- raise ValueError(
246
+ exception = ValueError(
243
247
  f"Invalid sort direction {exp.sort_order.direction}"
244
248
  )
249
+ attach_custom_error_code(
250
+ exception, ErrorCodes.INVALID_FUNCTION_ARGUMENT
251
+ )
252
+ raise exception
253
+ return [child_name], TypedColumn(col, lambda: typer.type(col))
245
254
  case "unresolved_attribute":
246
255
  col_name, col = map_att.map_unresolved_attribute(exp, column_mapping, typer)
247
256
  # Check if this is a multi-column regex expansion
@@ -275,6 +284,36 @@ def map_expression(
275
284
  )
276
285
  return [col_name], col
277
286
  case "unresolved_function":
287
+ from snowflake.snowpark_connect.utils.context import (
288
+ get_is_processing_order_by,
289
+ )
290
+
291
+ is_order_by = get_is_processing_order_by()
292
+ if is_order_by:
293
+ # For expressions in an order by clause check if we can reuse already-computed column.
294
+ if exp.unresolved_function.function_name:
295
+ func_name = exp.unresolved_function.function_name
296
+ available_columns = column_mapping.get_spark_columns()
297
+
298
+ for col_name in available_columns:
299
+ if (
300
+ func_name.lower() in col_name.lower()
301
+ and "(" in col_name
302
+ and ")" in col_name
303
+ ):
304
+ # This looks like it might be an expression
305
+ snowpark_col_name = column_mapping.get_snowpark_column_name_from_spark_column_name(
306
+ col_name
307
+ )
308
+ if snowpark_col_name:
309
+ # Optimization applied - reusing already computed column
310
+ return [col_name], TypedColumn(
311
+ snowpark_fn.col(snowpark_col_name),
312
+ lambda col_name=snowpark_col_name: typer.type(
313
+ col_name
314
+ ),
315
+ )
316
+
278
317
  return map_func.map_unresolved_function(exp, column_mapping, typer)
279
318
  case "unresolved_named_lambda_variable":
280
319
  # Validate that this lambda variable is in scope
@@ -293,13 +332,17 @@ def map_expression(
293
332
  col, lambda: typer.type(col)
294
333
  )
295
334
  else:
296
- raise AnalysisException(
335
+ exception = AnalysisException(
297
336
  f"Cannot resolve variable '{var_name}' within lambda function. "
298
337
  f"Lambda functions can access their own parameters and parent dataframe columns. "
299
338
  f"Current lambda parameters: {current_params}. "
300
339
  f"If '{var_name}' is an outer scope lambda variable from a nested lambda, "
301
340
  f"that is an unsupported feature in Snowflake SQL."
302
341
  )
342
+ attach_custom_error_code(
343
+ exception, ErrorCodes.UNSUPPORTED_OPERATION
344
+ )
345
+ raise exception
303
346
 
304
347
  col = snowpark_fn.Column(
305
348
  UnresolvedAttribute(exp.unresolved_named_lambda_variable.name_parts[0])
@@ -334,6 +377,8 @@ def map_expression(
334
377
  case "update_fields":
335
378
  return map_update_fields.map_update_fields(exp, column_mapping, typer)
336
379
  case _:
337
- raise SnowparkConnectNotImplementedError(
380
+ exception = SnowparkConnectNotImplementedError(
338
381
  f"Unsupported expression type {expr_type}"
339
382
  )
383
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
384
+ raise exception