snowpark-connect 0.30.0__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (81) hide show
  1. snowflake/snowpark_connect/column_name_handler.py +150 -25
  2. snowflake/snowpark_connect/config.py +54 -16
  3. snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
  4. snowflake/snowpark_connect/error/error_codes.py +50 -0
  5. snowflake/snowpark_connect/error/error_utils.py +142 -22
  6. snowflake/snowpark_connect/error/exceptions.py +13 -4
  7. snowflake/snowpark_connect/execute_plan/map_execution_command.py +5 -1
  8. snowflake/snowpark_connect/execute_plan/map_execution_root.py +5 -1
  9. snowflake/snowpark_connect/execute_plan/utils.py +5 -1
  10. snowflake/snowpark_connect/expression/function_defaults.py +9 -2
  11. snowflake/snowpark_connect/expression/literal.py +7 -1
  12. snowflake/snowpark_connect/expression/map_cast.py +17 -5
  13. snowflake/snowpark_connect/expression/map_expression.py +48 -4
  14. snowflake/snowpark_connect/expression/map_extension.py +25 -5
  15. snowflake/snowpark_connect/expression/map_sql_expression.py +65 -30
  16. snowflake/snowpark_connect/expression/map_udf.py +10 -2
  17. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +33 -9
  18. snowflake/snowpark_connect/expression/map_unresolved_function.py +627 -205
  19. snowflake/snowpark_connect/expression/map_unresolved_star.py +5 -1
  20. snowflake/snowpark_connect/expression/map_update_fields.py +14 -4
  21. snowflake/snowpark_connect/expression/map_window_function.py +18 -3
  22. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
  23. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
  24. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
  25. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +34 -12
  26. snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
  27. snowflake/snowpark_connect/relation/io_utils.py +66 -4
  28. snowflake/snowpark_connect/relation/map_catalog.py +5 -1
  29. snowflake/snowpark_connect/relation/map_column_ops.py +88 -56
  30. snowflake/snowpark_connect/relation/map_extension.py +28 -8
  31. snowflake/snowpark_connect/relation/map_join.py +21 -10
  32. snowflake/snowpark_connect/relation/map_local_relation.py +5 -1
  33. snowflake/snowpark_connect/relation/map_relation.py +33 -7
  34. snowflake/snowpark_connect/relation/map_row_ops.py +36 -9
  35. snowflake/snowpark_connect/relation/map_sql.py +91 -24
  36. snowflake/snowpark_connect/relation/map_stats.py +25 -6
  37. snowflake/snowpark_connect/relation/map_udtf.py +14 -4
  38. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +49 -13
  39. snowflake/snowpark_connect/relation/read/map_read.py +24 -3
  40. snowflake/snowpark_connect/relation/read/map_read_csv.py +11 -3
  41. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
  42. snowflake/snowpark_connect/relation/read/map_read_json.py +8 -2
  43. snowflake/snowpark_connect/relation/read/map_read_parquet.py +13 -3
  44. snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
  45. snowflake/snowpark_connect/relation/read/map_read_table.py +15 -5
  46. snowflake/snowpark_connect/relation/read/map_read_text.py +5 -1
  47. snowflake/snowpark_connect/relation/read/metadata_utils.py +5 -1
  48. snowflake/snowpark_connect/relation/stage_locator.py +5 -1
  49. snowflake/snowpark_connect/relation/utils.py +19 -2
  50. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
  51. snowflake/snowpark_connect/relation/write/map_write.py +146 -63
  52. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
  53. snowflake/snowpark_connect/resources_initializer.py +5 -1
  54. snowflake/snowpark_connect/server.py +72 -19
  55. snowflake/snowpark_connect/type_mapping.py +54 -17
  56. snowflake/snowpark_connect/utils/context.py +42 -1
  57. snowflake/snowpark_connect/utils/describe_query_cache.py +3 -0
  58. snowflake/snowpark_connect/utils/env_utils.py +5 -1
  59. snowflake/snowpark_connect/utils/identifiers.py +11 -3
  60. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
  61. snowflake/snowpark_connect/utils/profiling.py +25 -8
  62. snowflake/snowpark_connect/utils/scala_udf_utils.py +11 -3
  63. snowflake/snowpark_connect/utils/session.py +5 -2
  64. snowflake/snowpark_connect/utils/telemetry.py +81 -18
  65. snowflake/snowpark_connect/utils/temporary_view_cache.py +5 -1
  66. snowflake/snowpark_connect/utils/udf_cache.py +5 -3
  67. snowflake/snowpark_connect/utils/udf_helper.py +20 -6
  68. snowflake/snowpark_connect/utils/udf_utils.py +4 -4
  69. snowflake/snowpark_connect/utils/udtf_helper.py +5 -1
  70. snowflake/snowpark_connect/utils/udtf_utils.py +34 -26
  71. snowflake/snowpark_connect/version.py +1 -1
  72. {snowpark_connect-0.30.0.dist-info → snowpark_connect-0.31.0.dist-info}/METADATA +3 -2
  73. {snowpark_connect-0.30.0.dist-info → snowpark_connect-0.31.0.dist-info}/RECORD +81 -78
  74. {snowpark_connect-0.30.0.data → snowpark_connect-0.31.0.data}/scripts/snowpark-connect +0 -0
  75. {snowpark_connect-0.30.0.data → snowpark_connect-0.31.0.data}/scripts/snowpark-session +0 -0
  76. {snowpark_connect-0.30.0.data → snowpark_connect-0.31.0.data}/scripts/snowpark-submit +0 -0
  77. {snowpark_connect-0.30.0.dist-info → snowpark_connect-0.31.0.dist-info}/WHEEL +0 -0
  78. {snowpark_connect-0.30.0.dist-info → snowpark_connect-0.31.0.dist-info}/licenses/LICENSE-binary +0 -0
  79. {snowpark_connect-0.30.0.dist-info → snowpark_connect-0.31.0.dist-info}/licenses/LICENSE.txt +0 -0
  80. {snowpark_connect-0.30.0.dist-info → snowpark_connect-0.31.0.dist-info}/licenses/NOTICE-binary +0 -0
  81. {snowpark_connect-0.30.0.dist-info → snowpark_connect-0.31.0.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,11 @@ from snowflake.snowpark_connect.column_name_handler import (
33
33
  )
34
34
  from snowflake.snowpark_connect.config import global_config
35
35
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
36
- from snowflake.snowpark_connect.error.error_utils import SparkException
36
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
37
+ from snowflake.snowpark_connect.error.error_utils import (
38
+ SparkException,
39
+ attach_custom_error_code,
40
+ )
37
41
  from snowflake.snowpark_connect.expression.map_expression import (
38
42
  map_alias,
39
43
  map_expression,
@@ -369,56 +373,64 @@ def map_sort(
369
373
  for col in input_container.column_map.get_spark_columns()
370
374
  ]
371
375
 
372
- for so in sort_order:
373
- if so.child.HasField("literal"):
374
- column_index = unwrap_literal(so.child)
375
- try:
376
- if column_index <= 0:
377
- raise IndexError
378
- col = input_df[column_index - 1]
379
- except IndexError:
380
- raise AnalysisException(
381
- f"""[ORDER_BY_POS_OUT_OF_RANGE] ORDER BY position {column_index} is not in select list (valid range is [1, {len(input_df.columns)})])."""
376
+ # Process ORDER BY expressions with a context flag to enable column reuse optimization
377
+ from snowflake.snowpark_connect.utils.context import push_processing_order_by_scope
378
+
379
+ with push_processing_order_by_scope():
380
+ for so in sort_order:
381
+ if so.child.HasField("literal"):
382
+ column_index = unwrap_literal(so.child)
383
+ try:
384
+ if column_index <= 0:
385
+ exception = IndexError()
386
+ attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
387
+ raise exception
388
+ col = input_df[column_index - 1]
389
+ except IndexError:
390
+ exception = AnalysisException(
391
+ f"""[ORDER_BY_POS_OUT_OF_RANGE] ORDER BY position {column_index} is not in select list (valid range is [1, {len(input_df.columns)})])."""
392
+ )
393
+ attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
394
+ raise exception
395
+ else:
396
+ _, typed_column = map_single_column_expression(
397
+ so.child, input_container.column_map, typer
382
398
  )
383
- else:
384
- _, typed_column = map_single_column_expression(
385
- so.child, input_container.column_map, typer
386
- )
387
- col = typed_column.col
399
+ col = typed_column.col
388
400
 
389
- match (so.direction, so.null_ordering):
390
- case (
391
- expressions_proto.Expression.SortOrder.SORT_DIRECTION_ASCENDING,
392
- expressions_proto.Expression.SortOrder.SORT_NULLS_FIRST,
393
- ):
394
- col = col.asc_nulls_first()
395
- case (
396
- expressions_proto.Expression.SortOrder.SORT_DIRECTION_ASCENDING,
397
- expressions_proto.Expression.SortOrder.SORT_NULLS_LAST,
398
- ):
399
- col = col.asc_nulls_last()
400
- case (
401
- expressions_proto.Expression.SortOrder.SORT_DIRECTION_DESCENDING,
402
- expressions_proto.Expression.SortOrder.SORT_NULLS_FIRST,
403
- ):
404
- col = col.desc_nulls_first()
405
- case (
406
- expressions_proto.Expression.SortOrder.SORT_DIRECTION_DESCENDING,
407
- expressions_proto.Expression.SortOrder.SORT_NULLS_LAST,
408
- ):
409
- col = col.desc_nulls_last()
401
+ match (so.direction, so.null_ordering):
402
+ case (
403
+ expressions_proto.Expression.SortOrder.SORT_DIRECTION_ASCENDING,
404
+ expressions_proto.Expression.SortOrder.SORT_NULLS_FIRST,
405
+ ):
406
+ col = col.asc_nulls_first()
407
+ case (
408
+ expressions_proto.Expression.SortOrder.SORT_DIRECTION_ASCENDING,
409
+ expressions_proto.Expression.SortOrder.SORT_NULLS_LAST,
410
+ ):
411
+ col = col.asc_nulls_last()
412
+ case (
413
+ expressions_proto.Expression.SortOrder.SORT_DIRECTION_DESCENDING,
414
+ expressions_proto.Expression.SortOrder.SORT_NULLS_FIRST,
415
+ ):
416
+ col = col.desc_nulls_first()
417
+ case (
418
+ expressions_proto.Expression.SortOrder.SORT_DIRECTION_DESCENDING,
419
+ expressions_proto.Expression.SortOrder.SORT_NULLS_LAST,
420
+ ):
421
+ col = col.desc_nulls_last()
410
422
 
411
- cols.append(col)
423
+ cols.append(col)
412
424
 
413
- ascending.append(
414
- so.direction
415
- == expressions_proto.Expression.SortOrder.SORT_DIRECTION_ASCENDING
416
- )
417
- if (
418
- so.direction
419
- != expressions_proto.Expression.SortOrder.SORT_DIRECTION_UNSPECIFIED
420
- ):
421
- order_specified = True
425
+ ascending.append(
426
+ so.direction
427
+ == expressions_proto.Expression.SortOrder.SORT_DIRECTION_ASCENDING
428
+ )
429
+ if (
430
+ so.direction
431
+ != expressions_proto.Expression.SortOrder.SORT_DIRECTION_UNSPECIFIED
432
+ ):
433
+ order_specified = True
422
434
 
423
435
  # TODO: sort.isglobal.
424
436
  if not order_specified:
@@ -446,9 +458,11 @@ def map_to_df(
446
458
  new_column_names = list(rel.to_df.column_names)
447
459
  if len(new_column_names) != len(input_container.column_map.columns):
448
460
  # TODO: Check error type here
449
- raise ValueError(
461
+ exception = ValueError(
450
462
  "Number of column names must match number of columns in DataFrame"
451
463
  )
464
+ attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
465
+ raise exception
452
466
  snowpark_new_column_names = make_column_names_snowpark_compatible(
453
467
  new_column_names, rel.common.plan_id
454
468
  )
@@ -507,9 +521,11 @@ def map_to_schema(
507
521
  for field in rel.to_schema.schema.struct.fields:
508
522
  if field.name in already_existing_columns:
509
523
  if count_case_insensitive_column_names[field.name.lower()] > 1:
510
- raise AnalysisException(
524
+ exception = AnalysisException(
511
525
  f"[AMBIGUOUS_COLUMN_OR_FIELD] Column or field `{field.name}` is ambiguous and has {len(input_container.column_map.spark_to_col[field.name])} matches."
512
526
  )
527
+ attach_custom_error_code(exception, ErrorCodes.AMBIGUOUS_COLUMN_NAME)
528
+ raise exception
513
529
  snowpark_name = None
514
530
  for name in input_container.column_map.spark_to_col:
515
531
  if name.lower() == field.name.lower():
@@ -526,17 +542,23 @@ def map_to_schema(
526
542
  and snowpark_field.nullable
527
543
  and not isinstance(snowpark_field.datatype, StructType)
528
544
  ):
529
- raise AnalysisException(
545
+ exception = AnalysisException(
530
546
  f"[NULLABLE_COLUMN_OR_FIELD] Column or field `{field.name}` is nullable while it's required to be non-nullable."
531
547
  )
548
+ attach_custom_error_code(
549
+ exception, ErrorCodes.INVALID_OPERATION
550
+ )
551
+ raise exception
532
552
 
533
553
  # Check type casting validation
534
554
  if not _can_cast_column_in_schema(
535
555
  snowpark_field.datatype, proto_to_snowpark_type(field.data_type)
536
556
  ):
537
- raise AnalysisException(
557
+ exception = AnalysisException(
538
558
  f"""[INVALID_COLUMN_OR_FIELD_DATA_TYPE] Column or field `{field.name}` is of type "{map_snowpark_to_pyspark_types(proto_to_snowpark_type(field.data_type))}" while it's required to be "{map_snowpark_to_pyspark_types(snowpark_field.datatype)}"."""
539
559
  )
560
+ attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
561
+ raise exception
540
562
  if len(already_existing_columns) == len(new_column_names):
541
563
  # All columns already exist, we're doing a simple update.
542
564
  snowpark_new_column_names = []
@@ -761,9 +783,11 @@ def map_with_columns(
761
783
  name = names_list[0]
762
784
  name_normalized = input_container.column_map._normalized_spark_name(name)
763
785
  if name_normalized in seen_columns:
764
- raise ValueError(
786
+ exception = ValueError(
765
787
  f"[COLUMN_ALREADY_EXISTS] The column `{name}` already exists."
766
788
  )
789
+ attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
790
+ raise exception
767
791
  seen_columns.add(name_normalized)
768
792
  # If the column name is already in the DataFrame, we replace it, so we use the
769
793
  # mapping to get the correct column name.
@@ -772,7 +796,9 @@ def map_with_columns(
772
796
  [name]
773
797
  )
774
798
  if len(all_instances_of_spark_column_name) == 0:
775
- raise KeyError(f"Spark column name {name} does not exist")
799
+ exception = KeyError(f"Spark column name {name} does not exist")
800
+ attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
801
+ raise exception
776
802
  with_columns_names.extend(all_instances_of_spark_column_name)
777
803
  with_columns_exprs.extend(
778
804
  [expr.col] * len(all_instances_of_spark_column_name)
@@ -852,7 +878,9 @@ def map_unpivot(
852
878
  # Spark API: df.unpivot([id_columns], [unpivot_columns], var_column, val_column)
853
879
  # Snowpark API: df.unpivot(val_column, var_column, [unpivot_columns])
854
880
  if rel.unpivot.HasField("values") and len(rel.unpivot.values.values) == 0:
855
- raise SparkException.unpivot_requires_value_columns()
881
+ exception = SparkException.unpivot_requires_value_columns()
882
+ attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
883
+ raise exception
856
884
 
857
885
  input_container = map_relation(rel.unpivot.input)
858
886
  input_df = input_container.dataframe
@@ -893,7 +921,7 @@ def map_unpivot(
893
921
  )
894
922
  if not get_lease_common_ancestor_classes(type_list):
895
923
  # TODO: match exactly how spark shows mismatched columns
896
- raise SparkException.unpivot_value_data_type_mismatch(
924
+ exception = SparkException.unpivot_value_data_type_mismatch(
897
925
  ", ".join(
898
926
  [
899
927
  f"{dtype} {column_name}"
@@ -901,6 +929,8 @@ def map_unpivot(
901
929
  ]
902
930
  )
903
931
  )
932
+ attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
933
+ raise exception
904
934
  return not is_same_type and contains_numeric_type
905
935
 
906
936
  def get_column_names(
@@ -1097,7 +1127,9 @@ def map_group_map(
1097
1127
  snowpark_grouping_expressions.append(snowpark_column.col)
1098
1128
  group_name_list.append(new_name)
1099
1129
  if rel.group_map.func.python_udf is None:
1100
- raise ValueError("group_map relation without python udf is not supported")
1130
+ exception = ValueError("group_map relation without python udf is not supported")
1131
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
1132
+ raise exception
1101
1133
 
1102
1134
  python_major, python_minor = rel.group_map.func.python_udf.python_ver.split(".")
1103
1135
  is_compatible_python = sys.version_info.major == int(
@@ -17,6 +17,8 @@ from snowflake.snowpark_connect.column_name_handler import (
17
17
  )
18
18
  from snowflake.snowpark_connect.config import get_boolean_session_config_param
19
19
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
20
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
21
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
20
22
  from snowflake.snowpark_connect.expression.map_expression import map_expression
21
23
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
22
24
  from snowflake.snowpark_connect.relation.map_relation import map_relation
@@ -84,11 +86,13 @@ def map_extension(
84
86
  input_df = result.dataframe
85
87
  snowpark_col_names = result.column_map.get_snowpark_columns()
86
88
  if len(subquery_aliases.aliases) != len(snowpark_col_names):
87
- raise AnalysisException(
89
+ exception = AnalysisException(
88
90
  "Number of column aliases does not match number of columns. "
89
91
  f"Number of column aliases: {len(subquery_aliases.aliases)}; "
90
92
  f"number of columns: {len(snowpark_col_names)}."
91
93
  )
94
+ attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
95
+ raise exception
92
96
  return DataFrameContainer.create_with_column_mapping(
93
97
  dataframe=input_df,
94
98
  spark_column_names=subquery_aliases.aliases,
@@ -108,18 +112,22 @@ def map_extension(
108
112
 
109
113
  left_queries = left_df.queries["queries"]
110
114
  if len(left_queries) != 1:
111
- raise SnowparkConnectNotImplementedError(
115
+ exception = SnowparkConnectNotImplementedError(
112
116
  f"Unexpected number of queries: {len(left_queries)}"
113
117
  )
118
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
119
+ raise exception
114
120
  left_query = left_queries[0]
115
121
  with push_outer_dataframe(left_result):
116
122
  right_result = map_relation(lateral_join.right)
117
123
  right_df = right_result.dataframe
118
124
  right_queries = right_df.queries["queries"]
119
125
  if len(right_queries) != 1:
120
- raise SnowparkConnectNotImplementedError(
126
+ exception = SnowparkConnectNotImplementedError(
121
127
  f"Unexpected number of queries: {len(right_queries)}"
122
128
  )
129
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
130
+ raise exception
123
131
  right_query = right_queries[0]
124
132
  input_df_sql = f"WITH __left AS ({left_query}) SELECT * FROM __left INNER JOIN LATERAL ({right_query})"
125
133
  session = snowpark.Session.get_active_session()
@@ -139,7 +147,11 @@ def map_extension(
139
147
  case "aggregate":
140
148
  return map_aggregate(extension.aggregate, rel.common.plan_id)
141
149
  case other:
142
- raise SnowparkConnectNotImplementedError(f"Unexpected extension {other}")
150
+ exception = SnowparkConnectNotImplementedError(
151
+ f"Unexpected extension {other}"
152
+ )
153
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
154
+ raise exception
143
155
 
144
156
 
145
157
  def get_udtf_project(relation: relation_proto.Relation) -> bool:
@@ -174,7 +186,9 @@ def handle_udtf_with_table_arguments(
174
186
  session = snowpark.Session.get_active_session()
175
187
  udtf_name_lower = udtf_info.function_name.lower()
176
188
  if udtf_name_lower not in session._udtfs:
177
- raise ValueError(f"UDTF '{udtf_info.function_name}' not found.")
189
+ exception = ValueError(f"UDTF '{udtf_info.function_name}' not found.")
190
+ attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
191
+ raise exception
178
192
  _udtf_obj, udtf_spark_output_names = session._udtfs[udtf_name_lower]
179
193
 
180
194
  table_containers = []
@@ -188,10 +202,12 @@ def handle_udtf_with_table_arguments(
188
202
  if not get_boolean_session_config_param(
189
203
  "spark.sql.tvf.allowMultipleTableArguments.enabled"
190
204
  ):
191
- raise AnalysisException(
205
+ exception = AnalysisException(
192
206
  "[TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS] Multiple table arguments are not enabled. "
193
207
  "Please set `spark.sql.tvf.allowMultipleTableArguments.enabled` to `true`"
194
208
  )
209
+ attach_custom_error_code(exception, ErrorCodes.CONFIG_NOT_ENABLED)
210
+ raise exception
195
211
 
196
212
  base_df = table_containers[0][0].dataframe
197
213
  first_table_col_count = len(base_df.columns)
@@ -339,9 +355,11 @@ def map_aggregate(
339
355
  exp, input_container.column_map, typer
340
356
  )
341
357
  if len(new_names) != 1:
342
- raise SnowparkConnectNotImplementedError(
358
+ exception = SnowparkConnectNotImplementedError(
343
359
  "Multi-column aggregate expressions are not supported"
344
360
  )
361
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
362
+ raise exception
345
363
  return new_names[0], snowpark_column
346
364
 
347
365
  raw_groupings: list[tuple[str, TypedColumn]] = []
@@ -474,9 +492,11 @@ def map_aggregate(
474
492
  snowpark.GroupingSets(*sets_mapped)
475
493
  )
476
494
  case other:
477
- raise SnowparkConnectNotImplementedError(
495
+ exception = SnowparkConnectNotImplementedError(
478
496
  f"Unsupported GROUP BY type: {other}"
479
497
  )
498
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
499
+ raise exception
480
500
 
481
501
  result = result.agg(*aggregations, exclude_grouping_columns=True)
482
502
 
@@ -5,6 +5,7 @@
5
5
  from functools import reduce
6
6
 
7
7
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
8
+ from pyspark.errors import AnalysisException
8
9
 
9
10
  import snowflake.snowpark.functions as snowpark_fn
10
11
  from snowflake import snowpark
@@ -12,7 +13,11 @@ from snowflake.snowpark_connect.column_name_handler import JoinColumnNameMap
12
13
  from snowflake.snowpark_connect.config import global_config
13
14
  from snowflake.snowpark_connect.constants import COLUMN_METADATA_COLLISION_KEY
14
15
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
15
- from snowflake.snowpark_connect.error.error_utils import SparkException
16
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
17
+ from snowflake.snowpark_connect.error.error_utils import (
18
+ SparkException,
19
+ attach_custom_error_code,
20
+ )
16
21
  from snowflake.snowpark_connect.expression.map_expression import (
17
22
  map_single_column_expression,
18
23
  )
@@ -62,7 +67,9 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
62
67
  match rel.join.join_type:
63
68
  case relation_proto.Join.JOIN_TYPE_UNSPECIFIED:
64
69
  # TODO: Understand what UNSPECIFIED Join type is
65
- raise SnowparkConnectNotImplementedError("Unspecified Join Type")
70
+ exception = SnowparkConnectNotImplementedError("Unspecified Join Type")
71
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
72
+ raise exception
66
73
  case relation_proto.Join.JOIN_TYPE_INNER:
67
74
  join_type = "inner"
68
75
  case relation_proto.Join.JOIN_TYPE_FULL_OUTER:
@@ -78,7 +85,9 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
78
85
  case relation_proto.Join.JOIN_TYPE_CROSS:
79
86
  join_type = "cross"
80
87
  case other:
81
- raise SnowparkConnectNotImplementedError(f"Other Join Type: {other}")
88
+ exception = SnowparkConnectNotImplementedError(f"Other Join Type: {other}")
89
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
90
+ raise exception
82
91
 
83
92
  # This handles case sensitivity for using_columns
84
93
  case_corrected_right_columns: list[str] = []
@@ -124,9 +133,7 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
124
133
  is None
125
134
  for c in using_columns
126
135
  ):
127
- import pyspark
128
-
129
- raise pyspark.errors.AnalysisException(
136
+ exception = AnalysisException(
130
137
  USING_COLUMN_NOT_FOUND_ERROR.format(
131
138
  next(
132
139
  c
@@ -140,6 +147,8 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
140
147
  left_container.column_map.get_spark_columns(),
141
148
  )
142
149
  )
150
+ attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
151
+ raise exception
143
152
  if any(
144
153
  right_container.column_map.get_snowpark_column_name_from_spark_column_name(
145
154
  c, allow_non_exists=True, return_first=True
@@ -147,9 +156,7 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
147
156
  is None
148
157
  for c in using_columns
149
158
  ):
150
- import pyspark
151
-
152
- raise pyspark.errors.AnalysisException(
159
+ exception = AnalysisException(
153
160
  USING_COLUMN_NOT_FOUND_ERROR.format(
154
161
  next(
155
162
  c
@@ -163,6 +170,8 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
163
170
  right_container.column_map.get_spark_columns(),
164
171
  )
165
172
  )
173
+ attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
174
+ raise exception
166
175
 
167
176
  # Round trip the using columns through the column map to get the correct names
168
177
  # in order to support case sensitivity.
@@ -227,7 +236,9 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
227
236
  result = joined_df.drop(*(right for _, right in snowpark_using_columns))
228
237
  else:
229
238
  if join_type != "cross" and not global_config.spark_sql_crossJoin_enabled:
230
- raise SparkException.implicit_cartesian_product("inner")
239
+ exception = SparkException.implicit_cartesian_product("inner")
240
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
241
+ raise exception
231
242
  result: snowpark.DataFrame = left_input.join(
232
243
  right=right_input,
233
244
  how=join_type,
@@ -19,6 +19,8 @@ from snowflake.snowpark_connect.column_name_handler import (
19
19
  make_column_names_snowpark_compatible,
20
20
  )
21
21
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
22
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
23
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
22
24
  from snowflake.snowpark_connect.type_mapping import (
23
25
  get_python_sql_utils_class,
24
26
  map_json_schema_to_snowpark,
@@ -327,9 +329,11 @@ def map_local_relation(
327
329
  column_metadata=column_metadata,
328
330
  )
329
331
  else:
330
- raise SnowparkConnectNotImplementedError(
332
+ exception = SnowparkConnectNotImplementedError(
331
333
  "LocalRelation without data & schema is not supported"
332
334
  )
335
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
336
+ raise exception
333
337
 
334
338
 
335
339
  def map_range(
@@ -8,6 +8,8 @@ import pandas
8
8
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
9
9
 
10
10
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
11
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
12
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
11
13
  from snowflake.snowpark_connect.utils.cache import (
12
14
  df_cache_map_get,
13
15
  df_cache_map_put_if_absent,
@@ -103,7 +105,9 @@ def map_relation(
103
105
  else:
104
106
  # This happens when the relation is empty, usually because the incoming message
105
107
  # type was incorrectly routed here.
106
- raise SnowparkConnectNotImplementedError("No Relation Type")
108
+ exception = SnowparkConnectNotImplementedError("No Relation Type")
109
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
110
+ raise exception
107
111
 
108
112
  result: DataFrameContainer | pandas.DataFrame
109
113
  operation = rel.WhichOneof("rel_type")
@@ -121,11 +125,19 @@ def map_relation(
121
125
  case relation_proto.Aggregate.GroupType.GROUP_TYPE_PIVOT:
122
126
  result = map_aggregate.map_pivot_aggregate(rel)
123
127
  case other:
124
- raise SnowparkConnectNotImplementedError(f"AGGREGATE {other}")
128
+ exception = SnowparkConnectNotImplementedError(
129
+ f"AGGREGATE {other}"
130
+ )
131
+ attach_custom_error_code(
132
+ exception, ErrorCodes.UNSUPPORTED_OPERATION
133
+ )
134
+ raise exception
125
135
  case "approx_quantile":
126
136
  result = map_stats.map_approx_quantile(rel)
127
137
  case "as_of_join":
128
- raise SnowparkConnectNotImplementedError("AS_OF_JOIN")
138
+ exception = SnowparkConnectNotImplementedError("AS_OF_JOIN")
139
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
140
+ raise exception
129
141
  case "catalog": # TODO: order these alphabetically
130
142
  result = map_catalog.map_catalog(rel.catalog)
131
143
  case "collect_metrics":
@@ -179,9 +191,11 @@ def map_relation(
179
191
  (get_session_id(), rel.cached_local_relation.hash)
180
192
  )
181
193
  if cached_df is None:
182
- raise ValueError(
194
+ exception = ValueError(
183
195
  f"Local relation with hash {rel.cached_local_relation.hash} not found in cache."
184
196
  )
197
+ attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
198
+ raise exception
185
199
  return cached_df
186
200
  case "map_partitions":
187
201
  result = map_map_partitions.map_map_partitions(rel)
@@ -235,7 +249,13 @@ def map_relation(
235
249
  case relation_proto.SetOperation.SetOpType.SET_OP_TYPE_EXCEPT:
236
250
  result = map_row_ops.map_except(rel)
237
251
  case other:
238
- raise SnowparkConnectNotImplementedError(f"SET_OP {other}")
252
+ exception = SnowparkConnectNotImplementedError(
253
+ f"SET_OP {other}"
254
+ )
255
+ attach_custom_error_code(
256
+ exception, ErrorCodes.UNSUPPORTED_OPERATION
257
+ )
258
+ raise exception
239
259
  case "show_string":
240
260
  result = map_show_string.map_show_string(rel)
241
261
  case "sort":
@@ -261,11 +281,17 @@ def map_relation(
261
281
  case "with_columns_renamed":
262
282
  result = map_column_ops.map_with_columns_renamed(rel)
263
283
  case "with_relations":
264
- raise SnowparkConnectNotImplementedError("WITH_RELATIONS")
284
+ exception = SnowparkConnectNotImplementedError("WITH_RELATIONS")
285
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
286
+ raise exception
265
287
  case "group_map":
266
288
  result = map_column_ops.map_group_map(rel)
267
289
  case other:
268
- raise SnowparkConnectNotImplementedError(f"Other Relation {other}")
290
+ exception = SnowparkConnectNotImplementedError(
291
+ f"Other Relation {other}"
292
+ )
293
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
294
+ raise exception
269
295
 
270
296
  # Store container in plan cache
271
297
  if isinstance(result, DataFrameContainer):
@@ -29,12 +29,17 @@ from snowflake.snowpark_connect.column_name_handler import (
29
29
  )
30
30
  from snowflake.snowpark_connect.config import global_config
31
31
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
32
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
33
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
32
34
  from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
33
35
  from snowflake.snowpark_connect.expression.map_expression import (
34
36
  map_single_column_expression,
35
37
  )
36
38
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
37
39
  from snowflake.snowpark_connect.relation.map_relation import map_relation
40
+ from snowflake.snowpark_connect.utils.identifiers import (
41
+ split_fully_qualified_spark_name,
42
+ )
38
43
  from snowflake.snowpark_connect.utils.telemetry import (
39
44
  SnowparkConnectNotImplementedError,
40
45
  )
@@ -55,9 +60,11 @@ def map_deduplicate(
55
60
  rel.deduplicate.HasField("within_watermark")
56
61
  and rel.deduplicate.within_watermark
57
62
  ):
58
- raise AnalysisException(
63
+ exception = AnalysisException(
59
64
  "dropDuplicatesWithinWatermark is not supported with batch DataFrames/DataSets"
60
65
  )
66
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
67
+ raise exception
61
68
 
62
69
  if (
63
70
  rel.deduplicate.HasField("all_columns_as_keys")
@@ -131,11 +138,19 @@ def map_fillna(
131
138
  input_df = input_container.dataframe
132
139
 
133
140
  if len(rel.fill_na.cols) > 0:
141
+ if rel.fill_na.cols == ["*"]:
142
+ # Expand "*" to all columns
143
+ spark_col_names = input_container.column_map.get_spark_columns()
144
+ else:
145
+ spark_col_names = list(rel.fill_na.cols)
146
+
147
+ # We don't validate the fully qualified spark name here as fillNa is no-op for structured type colums.
148
+ # It only works for scalar type columns like float, int, string or bool.
134
149
  columns: list[str] = [
135
150
  input_container.column_map.get_snowpark_column_name_from_spark_column_name(
136
- c
151
+ split_fully_qualified_spark_name(c)[0]
137
152
  )
138
- for c in rel.fill_na.cols
153
+ for c in spark_col_names
139
154
  ]
140
155
  values = [get_literal_field_and_name(v)[0] for v in rel.fill_na.values]
141
156
  if len(values) == 1:
@@ -212,7 +227,9 @@ def map_union(
212
227
  spark_sql_ansi_enabled = global_config.spark_sql_ansi_enabled
213
228
  if left_dtypes != right_dtypes and not rel.set_op.by_name:
214
229
  if len(left_dtypes) != len(right_dtypes):
215
- raise AnalysisException("UNION: the number of columns must match")
230
+ exception = AnalysisException("UNION: the number of columns must match")
231
+ attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
232
+ raise exception
216
233
  target_left_dtypes, target_right_dtypes = [], []
217
234
  for left_type, right_type in zip(left_dtypes, right_dtypes):
218
235
  match (left_type, right_type):
@@ -248,9 +265,11 @@ def map_union(
248
265
  not spark_sql_ansi_enabled
249
266
  or snowpark.types.StringType() not in [left_type, right_type]
250
267
  ): # In ansi mode , string type union boolean type is acceptable
251
- raise AnalysisException(
268
+ exception = AnalysisException(
252
269
  f"""[INCOMPATIBLE_COLUMN_TYPE] UNION can only be performed on tables with compatible column types. "{str(left_type)}" type which is not compatible with "{str(right_type)}". """
253
270
  )
271
+ attach_custom_error_code(exception, ErrorCodes.TYPE_MISMATCH)
272
+ raise exception
254
273
  target_left_dtypes.append(left_type)
255
274
  target_right_dtypes.append(right_type)
256
275
  case _:
@@ -776,7 +795,9 @@ def map_sample(
776
795
 
777
796
  frac = rel.sample.upper_bound - rel.sample.lower_bound
778
797
  if frac < 0 or frac > 1:
779
- raise IllegalArgumentException("Sample fraction must be between 0 and 1")
798
+ exception = IllegalArgumentException("Sample fraction must be between 0 and 1")
799
+ attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
800
+ raise exception
780
801
  # The seed argument is not supported here. There are a number of reasons that implementing
781
802
  # this will be complicated in Snowflake. Here is a list of complications:
782
803
  #
@@ -791,9 +812,11 @@ def map_sample(
791
812
  # these issues.
792
813
  if rel.sample.with_replacement:
793
814
  # TODO: Use a random number generator with ROW_NUMBER and SELECT.
794
- raise SnowparkConnectNotImplementedError(
815
+ exception = SnowparkConnectNotImplementedError(
795
816
  "Sample with replacement is not supported"
796
817
  )
818
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
819
+ raise exception
797
820
  else:
798
821
  result: snowpark.DataFrame = input_df.sample(frac=frac)
799
822
  return DataFrameContainer(
@@ -901,9 +924,13 @@ def _union_by_name_optimized(
901
924
  set_schema_getter(result, lambda: StructType(result_fields))
902
925
  return result
903
926
  else:
904
- raise SnowparkClientExceptionMessages.DF_CANNOT_RESOLVE_COLUMN_NAME_AMONG(
905
- missing_left, missing_right
927
+ exception = (
928
+ SnowparkClientExceptionMessages.DF_CANNOT_RESOLVE_COLUMN_NAME_AMONG(
929
+ missing_left, missing_right
930
+ )
906
931
  )
932
+ attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
933
+ raise exception
907
934
 
908
935
  result = left_df.unionAllByName(
909
936
  right_df, allow_missing_columns=allow_missing_columns