snowpark-connect 0.20.2__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (67) hide show
  1. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +3 -2
  2. snowflake/snowpark_connect/column_name_handler.py +6 -65
  3. snowflake/snowpark_connect/config.py +28 -14
  4. snowflake/snowpark_connect/dataframe_container.py +242 -0
  5. snowflake/snowpark_connect/execute_plan/map_execution_command.py +13 -23
  6. snowflake/snowpark_connect/execute_plan/map_execution_root.py +9 -5
  7. snowflake/snowpark_connect/expression/map_extension.py +2 -1
  8. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +8 -7
  9. snowflake/snowpark_connect/expression/map_unresolved_function.py +279 -43
  10. snowflake/snowpark_connect/expression/map_unresolved_star.py +8 -8
  11. snowflake/snowpark_connect/expression/map_update_fields.py +1 -1
  12. snowflake/snowpark_connect/expression/typer.py +6 -6
  13. snowflake/snowpark_connect/proto/control_pb2.py +17 -16
  14. snowflake/snowpark_connect/proto/control_pb2.pyi +17 -17
  15. snowflake/snowpark_connect/proto/control_pb2_grpc.py +12 -63
  16. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +15 -14
  17. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +19 -14
  18. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +27 -26
  19. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +74 -68
  20. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +5 -5
  21. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +25 -17
  22. snowflake/snowpark_connect/relation/map_aggregate.py +72 -47
  23. snowflake/snowpark_connect/relation/map_catalog.py +2 -2
  24. snowflake/snowpark_connect/relation/map_column_ops.py +207 -144
  25. snowflake/snowpark_connect/relation/map_crosstab.py +25 -6
  26. snowflake/snowpark_connect/relation/map_extension.py +81 -56
  27. snowflake/snowpark_connect/relation/map_join.py +72 -63
  28. snowflake/snowpark_connect/relation/map_local_relation.py +35 -20
  29. snowflake/snowpark_connect/relation/map_map_partitions.py +21 -16
  30. snowflake/snowpark_connect/relation/map_relation.py +22 -16
  31. snowflake/snowpark_connect/relation/map_row_ops.py +232 -146
  32. snowflake/snowpark_connect/relation/map_sample_by.py +15 -8
  33. snowflake/snowpark_connect/relation/map_show_string.py +42 -5
  34. snowflake/snowpark_connect/relation/map_sql.py +155 -78
  35. snowflake/snowpark_connect/relation/map_stats.py +88 -39
  36. snowflake/snowpark_connect/relation/map_subquery_alias.py +13 -14
  37. snowflake/snowpark_connect/relation/map_udtf.py +6 -9
  38. snowflake/snowpark_connect/relation/read/map_read.py +8 -3
  39. snowflake/snowpark_connect/relation/read/map_read_csv.py +7 -7
  40. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +7 -7
  41. snowflake/snowpark_connect/relation/read/map_read_json.py +7 -7
  42. snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -7
  43. snowflake/snowpark_connect/relation/read/map_read_socket.py +7 -3
  44. snowflake/snowpark_connect/relation/read/map_read_table.py +25 -16
  45. snowflake/snowpark_connect/relation/read/map_read_text.py +7 -7
  46. snowflake/snowpark_connect/relation/utils.py +11 -5
  47. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +15 -12
  48. snowflake/snowpark_connect/relation/write/map_write.py +199 -40
  49. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +3 -2
  50. snowflake/snowpark_connect/server.py +34 -4
  51. snowflake/snowpark_connect/type_mapping.py +2 -23
  52. snowflake/snowpark_connect/utils/cache.py +27 -22
  53. snowflake/snowpark_connect/utils/context.py +33 -17
  54. snowflake/snowpark_connect/utils/{attribute_handling.py → identifiers.py} +47 -0
  55. snowflake/snowpark_connect/utils/session.py +41 -34
  56. snowflake/snowpark_connect/utils/telemetry.py +1 -2
  57. snowflake/snowpark_connect/version.py +1 -1
  58. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/METADATA +5 -3
  59. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/RECORD +67 -64
  60. snowpark_connect-0.21.0.dist-info/licenses/LICENSE-binary +568 -0
  61. snowpark_connect-0.21.0.dist-info/licenses/NOTICE-binary +1533 -0
  62. {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-connect +0 -0
  63. {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-session +0 -0
  64. {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-submit +0 -0
  65. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/WHEEL +0 -0
  66. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
  67. {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/top_level.txt +0 -0
@@ -69,7 +69,10 @@ from snowflake.snowpark_connect.column_name_handler import (
69
69
  ColumnNameMap,
70
70
  set_schema_getter,
71
71
  )
72
- from snowflake.snowpark_connect.config import global_config
72
+ from snowflake.snowpark_connect.config import (
73
+ get_boolean_session_config_param,
74
+ global_config,
75
+ )
73
76
  from snowflake.snowpark_connect.constants import (
74
77
  DUPLICATE_KEY_FOUND_ERROR_TEMPLATE,
75
78
  SPARK_TZ_ABBREVIATIONS_OVERRIDES,
@@ -100,6 +103,7 @@ from snowflake.snowpark_connect.typed_column import (
100
103
  )
101
104
  from snowflake.snowpark_connect.utils.context import (
102
105
  add_sql_aggregate_function,
106
+ get_current_grouping_columns,
103
107
  get_is_aggregate_function,
104
108
  get_is_evaluating_sql,
105
109
  get_is_in_udtf_context,
@@ -341,6 +345,9 @@ def map_unresolved_function(
341
345
  )
342
346
  spark_col_names = []
343
347
  spark_sql_ansi_enabled = global_config.spark_sql_ansi_enabled
348
+ spark_sql_legacy_allow_hash_on_map_type = (
349
+ global_config.spark_sql_legacy_allowHashOnMapType
350
+ )
344
351
 
345
352
  function_name = exp.unresolved_function.function_name.lower()
346
353
  telemetry.report_function_usage(function_name)
@@ -867,14 +874,30 @@ def map_unresolved_function(
867
874
  )
868
875
  case "approx_percentile" | "percentile_approx":
869
876
  # SNOW-1955784: Support accuracy parameter
877
+ # Use percentile_disc to return actual values from dataset (matches PySpark behavior)
870
878
 
871
- # Even though the Spark function accepts a Column for percentage, it will fail unless it's a literal.
872
- # Therefore, we can do error checking right here.
873
- def _check_percentage(exp: expressions_proto.Expression) -> Column:
874
- perc = unwrap_literal(exp)
875
- if not 0.0 <= perc <= 1.0:
879
+ def _pyspark_approx_percentile(
880
+ column: Column, percentage: float, original_type: DataType
881
+ ) -> Column:
882
+ """
883
+ PySpark-compatible percentile that returns actual values from dataset.
884
+ - PySpark's approx_percentile returns the "smallest value in the ordered col values
885
+ such that no more than percentage of col values is less than or equal to that value"
886
+ - This means it MUST return an actual value from the original dataset
887
+ - Snowflake's approx_percentile() may interpolate between values, breaking compatibility
888
+ - percentile_disc() returns discrete values (actual dataset values), matching PySpark
889
+ """
890
+ # Even though the Spark function accepts a Column for percentage, it will fail unless it's a literal.
891
+ # Therefore, we can do error checking right here.
892
+ if not 0.0 <= percentage <= 1.0:
876
893
  raise AnalysisException("percentage must be between [0.0, 1.0]")
877
- return snowpark_fn.lit(perc)
894
+
895
+ result = snowpark_fn.function("percentile_disc")(
896
+ snowpark_fn.lit(percentage)
897
+ ).within_group(column)
898
+ return snowpark_fn.cast(result, original_type)
899
+
900
+ column_type = snowpark_typed_args[0].typ
878
901
 
879
902
  if isinstance(snowpark_typed_args[1].typ, ArrayType):
880
903
  # Snowpark doesn't accept a list of percentile values.
@@ -882,26 +905,26 @@ def map_unresolved_function(
882
905
  array_func = exp.unresolved_function.arguments[1].unresolved_function
883
906
  assert array_func.function_name == "array", array_func
884
907
 
885
- result_exp = snowpark_fn.array_construct(
886
- *[
887
- snowpark_fn.approx_percentile(
888
- snowpark_args[0], _check_percentage(arg)
889
- )
890
- for arg in array_func.arguments
891
- ]
892
- )
908
+ percentile_results = [
909
+ _pyspark_approx_percentile(
910
+ snowpark_args[0], unwrap_literal(arg), column_type
911
+ )
912
+ for arg in array_func.arguments
913
+ ]
914
+
915
+ result_type = ArrayType(element_type=column_type, contains_null=False)
893
916
  result_exp = snowpark_fn.cast(
894
- result_exp,
895
- ArrayType(element_type=DoubleType(), contains_null=False),
917
+ snowpark_fn.array_construct(*percentile_results),
918
+ result_type,
896
919
  )
897
- result_type = ArrayType(element_type=DoubleType(), contains_null=False)
898
920
  else:
921
+ # Handle single percentile
922
+ percentage = unwrap_literal(exp.unresolved_function.arguments[1])
899
923
  result_exp = TypedColumn(
900
- snowpark_fn.approx_percentile(
901
- snowpark_args[0],
902
- _check_percentage(exp.unresolved_function.arguments[1]),
924
+ _pyspark_approx_percentile(
925
+ snowpark_args[0], percentage, column_type
903
926
  ),
904
- lambda: [DoubleType()],
927
+ lambda: [column_type],
905
928
  )
906
929
  case "array":
907
930
  if len(snowpark_args) == 0:
@@ -2073,14 +2096,22 @@ def map_unresolved_function(
2073
2096
  assert (
2074
2097
  len(exp.unresolved_function.arguments) == 2
2075
2098
  ), "date_format takes 2 arguments"
2076
- result_exp = snowpark_fn.date_format(
2077
- snowpark_args[0],
2078
- snowpark_fn.lit(
2079
- map_spark_timestamp_format_expression(
2080
- exp.unresolved_function.arguments[1], snowpark_typed_args[0].typ
2081
- )
2082
- ),
2083
- )
2099
+
2100
+ # Check if format parameter is NULL
2101
+ format_literal = unwrap_literal(exp.unresolved_function.arguments[1])
2102
+ if format_literal is None:
2103
+ # If format is NULL, return NULL for all rows
2104
+ result_exp = snowpark_fn.lit(None)
2105
+ else:
2106
+ result_exp = snowpark_fn.date_format(
2107
+ snowpark_args[0],
2108
+ snowpark_fn.lit(
2109
+ map_spark_timestamp_format_expression(
2110
+ exp.unresolved_function.arguments[1],
2111
+ snowpark_typed_args[0].typ,
2112
+ )
2113
+ ),
2114
+ )
2084
2115
  result_exp = TypedColumn(result_exp, lambda: [StringType()])
2085
2116
  case "date_from_unix_date":
2086
2117
  result_exp = snowpark_fn.date_add(
@@ -2535,6 +2566,19 @@ def map_unresolved_function(
2535
2566
  input_types=[StringType(), StringType(), StructType()],
2536
2567
  )
2537
2568
  def _from_csv(csv_data: str, schema: str, options: Optional[dict]):
2569
+ if csv_data is None:
2570
+ return None
2571
+
2572
+ if csv_data == "":
2573
+ # Return dict with None values for empty string
2574
+ schemas = schema.split(",")
2575
+ results = {}
2576
+ for sc in schemas:
2577
+ parts = [i for i in sc.split(" ") if len(i) != 0]
2578
+ assert len(parts) == 2, f"{sc} is not a valid schema"
2579
+ results[parts[0]] = None
2580
+ return results
2581
+
2538
2582
  max_chars_per_column = -1
2539
2583
  sep = ","
2540
2584
 
@@ -2617,7 +2661,9 @@ def map_unresolved_function(
2617
2661
  case _:
2618
2662
  raise ValueError("Unrecognized from_csv parameters")
2619
2663
 
2620
- result_exp = snowpark_fn.cast(csv_result, ddl_schema)
2664
+ result_exp = snowpark_fn.when(
2665
+ snowpark_args[0].is_null(), snowpark_fn.lit(None)
2666
+ ).otherwise(snowpark_fn.cast(csv_result, ddl_schema))
2621
2667
  result_type = ddl_schema
2622
2668
  case "from_json":
2623
2669
  # TODO: support options.
@@ -2651,6 +2697,9 @@ def map_unresolved_function(
2651
2697
  # try to parse first, since spark returns null for invalid json
2652
2698
  result_exp = snowpark_fn.call_function("try_parse_json", snowpark_args[0])
2653
2699
 
2700
+ # Check if the original input is NULL - if so, return NULL for the entire result
2701
+ original_input_is_null = snowpark_args[0].is_null()
2702
+
2654
2703
  # helper function to make sure we have the expected array element type
2655
2704
  def _element_type_matches(
2656
2705
  array_exp: Column, element_type: DataType
@@ -2749,9 +2798,13 @@ def map_unresolved_function(
2749
2798
  else:
2750
2799
  return exp
2751
2800
 
2752
- result_exp = snowpark_fn.cast(
2753
- _coerce_to_type(result_exp, result_type), result_type
2754
- )
2801
+ # Apply the coercion to handle invalid JSON (creates struct with NULL fields)
2802
+ coerced_exp = _coerce_to_type(result_exp, result_type)
2803
+
2804
+ # If the original input was NULL, return NULL instead of a struct
2805
+ result_exp = snowpark_fn.when(
2806
+ original_input_is_null, snowpark_fn.lit(None)
2807
+ ).otherwise(snowpark_fn.cast(coerced_exp, result_type))
2755
2808
  case "from_unixtime":
2756
2809
 
2757
2810
  def raise_analysis_exception(
@@ -2896,10 +2949,53 @@ def map_unresolved_function(
2896
2949
  )
2897
2950
  case "grouping" | "grouping_id":
2898
2951
  # grouping_id is not an alias for grouping in PySpark, but Snowflake's implementation handles both
2899
- result_exp = snowpark_fn.grouping(*snowpark_args)
2952
+ current_grouping_cols = get_current_grouping_columns()
2953
+ if function_name == "grouping_id":
2954
+ if not snowpark_args:
2955
+ # grouping_id() with empty args means use all grouping columns
2956
+ spark_function_name = "grouping_id()"
2957
+ snowpark_args = [
2958
+ column_mapping.get_snowpark_column_name_from_spark_column_name(
2959
+ spark_col
2960
+ )
2961
+ for spark_col in current_grouping_cols
2962
+ ]
2963
+ else:
2964
+ # Verify that grouping arguments match current grouping columns
2965
+ spark_col_args = [
2966
+ column_mapping.get_spark_column_name_from_snowpark_column_name(
2967
+ sp_col.getName()
2968
+ )
2969
+ for sp_col in snowpark_args
2970
+ ]
2971
+ if current_grouping_cols != spark_col_args:
2972
+ raise AnalysisException(
2973
+ f"[GROUPING_ID_COLUMN_MISMATCH] Columns of grouping_id: {spark_col_args} doesnt match "
2974
+ f"Grouping columns: {current_grouping_cols}"
2975
+ )
2976
+ if function_name == "grouping_id":
2977
+ result_exp = snowpark_fn.grouping_id(*snowpark_args)
2978
+ else:
2979
+ result_exp = snowpark_fn.grouping(*snowpark_args)
2900
2980
  result_type = LongType()
2901
2981
  case "hash":
2902
2982
  # TODO: See the spark-compatibility-issues.md explanation, this is quite different from Spark.
2983
+ # MapType columns as input should raise an exception as they are not hashable.
2984
+ snowflake_compat = get_boolean_session_config_param(
2985
+ "enable_snowflake_extension_behavior"
2986
+ )
2987
+ # Snowflake's hash function does allow MAP types, but Spark does not. Therefore, if we have the expansion flag enabled
2988
+ # we want to let it pass through and hash MAP types.
2989
+ # Also allow if the legacy config spark.sql.legacy.allowHashOnMapType is set to true
2990
+ if not snowflake_compat and not spark_sql_legacy_allow_hash_on_map_type:
2991
+ for arg in snowpark_typed_args:
2992
+ if any(isinstance(t, MapType) for t in arg.types):
2993
+ raise AnalysisException(
2994
+ '[DATATYPE_MISMATCH.HASH_MAP_TYPE] Cannot resolve "hash(value)" due to data type mismatch: '
2995
+ 'Input to the function `hash` cannot contain elements of the "MAP" type. '
2996
+ 'In Spark, same maps may have different hashcode, thus hash expressions are prohibited on "MAP" elements. '
2997
+ 'To restore previous behavior set "spark.sql.legacy.allowHashOnMapType" to "true".'
2998
+ )
2903
2999
  result_exp = snowpark_fn.hash(*snowpark_args)
2904
3000
  result_type = LongType()
2905
3001
  case "hex":
@@ -2934,6 +3030,14 @@ def map_unresolved_function(
2934
3030
  result_type = StringType()
2935
3031
  case "histogram_numeric":
2936
3032
  aggregate_input_typ = snowpark_typed_args[0].typ
3033
+
3034
+ if isinstance(aggregate_input_typ, DecimalType):
3035
+ # mimic bug from Spark 3.5.3.
3036
+ # In 3.5.5 it's fixed and this exception shouldn't be thrown
3037
+ raise ValueError(
3038
+ "class org.apache.spark.sql.types.Decimal cannot be cast to class java.lang.Number (org.apache.spark.sql.types.Decimal is in unnamed module of loader 'app'; java.lang.Number is in module java.base of loader 'bootstrap')"
3039
+ )
3040
+
2937
3041
  histogram_return_type = ArrayType(
2938
3042
  StructType(
2939
3043
  [
@@ -3154,6 +3258,18 @@ def map_unresolved_function(
3154
3258
  )
3155
3259
  result_type = histogram_return_type
3156
3260
  case "hll_sketch_agg":
3261
+ # check if input type is correct
3262
+ if type(snowpark_typed_args[0].typ) not in [
3263
+ IntegerType,
3264
+ LongType,
3265
+ StringType,
3266
+ BinaryType,
3267
+ ]:
3268
+ type_str = snowpark_typed_args[0].typ.simpleString().upper()
3269
+ raise AnalysisException(
3270
+ f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: Parameter 1 requires the ("INT" or "BIGINT" or "STRING" or "BINARY") type, however "{snowpark_arg_names[0]}" has the type "{type_str}".'
3271
+ )
3272
+
3157
3273
  match snowpark_args:
3158
3274
  case [sketch]:
3159
3275
  spark_function_name = (
@@ -3796,12 +3912,47 @@ def map_unresolved_function(
3796
3912
  )
3797
3913
 
3798
3914
  result_type = StringType()
3799
- case "ltrim":
3915
+ case "ltrim" | "rtrim":
3916
+ function_name_argument = (
3917
+ "TRAILING" if function_name == "rtrim" else "LEADING"
3918
+ )
3800
3919
  if len(snowpark_args) == 2:
3801
3920
  # Only possible using SQL
3802
- spark_function_name = f"TRIM(LEADING {snowpark_arg_names[1]} FROM {snowpark_arg_names[0]})"
3921
+ spark_function_name = f"TRIM({function_name_argument} {snowpark_arg_names[1]} FROM {snowpark_arg_names[0]})"
3803
3922
  result_exp = snowpark_fn.ltrim(*snowpark_args)
3804
3923
  result_type = StringType()
3924
+ if isinstance(snowpark_typed_args[0].typ, BinaryType):
3925
+ argument_name = snowpark_arg_names[0]
3926
+ if exp.unresolved_function.arguments[0].HasField("literal"):
3927
+ argument_name = f"""X'{exp.unresolved_function.arguments[0].literal.binary.hex()}'"""
3928
+ if len(snowpark_args) == 1:
3929
+ spark_function_name = f"{function_name}({argument_name})"
3930
+ trim_value = snowpark_fn.lit(b"\x20")
3931
+ if len(snowpark_args) == 2:
3932
+ # Only possible using SQL
3933
+ trim_arg = snowpark_arg_names[1]
3934
+ if isinstance(
3935
+ snowpark_typed_args[1].typ, BinaryType
3936
+ ) and exp.unresolved_function.arguments[1].HasField("literal"):
3937
+ trim_arg = f"""X'{exp.unresolved_function.arguments[1].literal.binary.hex()}'"""
3938
+ trim_value = snowpark_args[1]
3939
+ else:
3940
+ trim_value = snowpark_fn.lit(None)
3941
+ function_name_argument = (
3942
+ "TRAILING" if function_name == "rtrim" else "LEADING"
3943
+ )
3944
+ spark_function_name = f"TRIM({function_name_argument} {trim_arg} FROM {argument_name})"
3945
+ result_exp = _trim_helper(
3946
+ snowpark_args[0], trim_value, snowpark_fn.lit(function_name)
3947
+ )
3948
+ result_type = BinaryType()
3949
+ else:
3950
+ if function_name == "ltrim":
3951
+ result_exp = snowpark_fn.ltrim(*snowpark_args)
3952
+ result_type = StringType()
3953
+ elif function_name == "rtrim":
3954
+ result_exp = snowpark_fn.rtrim(*snowpark_args)
3955
+ result_type = StringType()
3805
3956
  case "make_date":
3806
3957
  y = snowpark_args[0].cast(LongType())
3807
3958
  m = snowpark_args[1].cast(LongType())
@@ -4258,6 +4409,17 @@ def map_unresolved_function(
4258
4409
  lambda: snowpark_typed_args[0].types,
4259
4410
  )
4260
4411
  case "md5":
4412
+ snowflake_compat = get_boolean_session_config_param(
4413
+ "enable_snowflake_extension_behavior"
4414
+ )
4415
+
4416
+ # MD5 in Spark only accepts BinaryType or types that can be implicitly cast to it (StringType)
4417
+ if not snowflake_compat:
4418
+ if not isinstance(snowpark_typed_args[0].typ, (BinaryType, StringType)):
4419
+ raise AnalysisException(
4420
+ f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "md5({snowpark_arg_names[0]})" due to data type mismatch: '
4421
+ f'Parameter 1 requires the "BINARY" type, however "{snowpark_arg_names[0]}" has the type "{snowpark_typed_args[0].typ}".'
4422
+ )
4261
4423
  result_exp = snowpark_fn.md5(snowpark_args[0])
4262
4424
  result_type = StringType(32)
4263
4425
  case "median":
@@ -5466,13 +5628,28 @@ def map_unresolved_function(
5466
5628
  case "row_number":
5467
5629
  result_exp = snowpark_fn.row_number()
5468
5630
  result_exp = TypedColumn(result_exp, lambda: [LongType()])
5469
- case "rtrim":
5470
- if len(snowpark_args) == 2:
5471
- # Only possible using SQL
5472
- spark_function_name = f"TRIM(TRAILING {snowpark_arg_names[1]} FROM {snowpark_arg_names[0]})"
5473
- result_exp = snowpark_fn.rtrim(*snowpark_args)
5474
- result_type = StringType()
5475
5631
  case "schema_of_csv":
5632
+ # Validate that the input is a foldable STRING expression
5633
+ if (
5634
+ exp.unresolved_function.arguments[0].WhichOneof("expr_type")
5635
+ != "literal"
5636
+ ):
5637
+ raise AnalysisException(
5638
+ "[DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "
5639
+ f'"schema_of_csv({snowpark_arg_names[0]})" due to data type mismatch: '
5640
+ 'the input csv should be a foldable "STRING" expression; however, '
5641
+ f'got "{snowpark_arg_names[0]}".'
5642
+ )
5643
+
5644
+ if isinstance(snowpark_typed_args[0].typ, StringType):
5645
+ if exp.unresolved_function.arguments[0].literal.string == "":
5646
+ raise AnalysisException(
5647
+ "[DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "
5648
+ f'"schema_of_csv({snowpark_arg_names[0]})" due to data type mismatch: '
5649
+ 'the input csv should be a foldable "STRING" expression; however, '
5650
+ f'got "{snowpark_arg_names[0]}".'
5651
+ )
5652
+
5476
5653
  snowpark_args = [
5477
5654
  typed_arg.column(to_semi_structure=True)
5478
5655
  for typed_arg in snowpark_typed_args
@@ -5689,6 +5866,16 @@ def map_unresolved_function(
5689
5866
  )
5690
5867
  result_type = ArrayType(ArrayType(StringType()))
5691
5868
  case "sequence":
5869
+ if snowpark_typed_args[0].typ != snowpark_typed_args[1].typ or (
5870
+ not isinstance(snowpark_typed_args[0].typ, _IntegralType)
5871
+ or not isinstance(snowpark_typed_args[1].typ, _IntegralType)
5872
+ ):
5873
+ raise AnalysisException(
5874
+ f"""[DATATYPE_MISMATCH.SEQUENCE_WRONG_INPUT_TYPES] Cannot resolve "sequence({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: `sequence` uses the wrong parameter type. The parameter type must conform to:
5875
+ 1. The start and stop expressions must resolve to the same type.
5876
+ 2. Otherwise, if start and stop expressions resolve to the "INTEGRAL" type, then the step expression must resolve to the same type.
5877
+ """
5878
+ )
5692
5879
  result_exp = snowpark_fn.cast(
5693
5880
  snowpark_fn.sequence(*snowpark_args),
5694
5881
  ArrayType(LongType(), contains_null=False),
@@ -6274,6 +6461,10 @@ def map_unresolved_function(
6274
6461
  )
6275
6462
  result_type = TimestampType(snowpark.types.TimestampTimeZone.NTZ)
6276
6463
  case "timestamp_millis":
6464
+ if not isinstance(snowpark_typed_args[0].typ, _IntegralType):
6465
+ raise AnalysisException(
6466
+ f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "timestamp_millis({snowpark_arg_names[0]}" due to data type mismatch: Parameter 1 requires the "INTEGRAL" type, however "{snowpark_arg_names[0]}" has the type "{snowpark_typed_args[0].typ}".'
6467
+ )
6277
6468
  result_exp = snowpark_fn.cast(
6278
6469
  snowpark_fn.to_timestamp(snowpark_args[0] * 1_000, 6),
6279
6470
  TimestampType(snowpark.types.TimestampTimeZone.NTZ),
@@ -6283,6 +6474,10 @@ def map_unresolved_function(
6283
6474
  # Spark allows seconds to be fractional. Snowflake does not allow that
6284
6475
  # even though the documentation explicitly says that it does.
6285
6476
  # As a workaround, use integer milliseconds instead of fractional seconds.
6477
+ if not isinstance(snowpark_typed_args[0].typ, _NumericType):
6478
+ raise AnalysisException(
6479
+ f"""AnalysisException: [DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{function_name}({snowpark_arg_names[0]})" due to data type mismatch: Parameter 1 requires the "NUMERIC" type, however "{snowpark_arg_names[0]}" has the type "{snowpark_typed_args[0].typ}".;"""
6480
+ )
6286
6481
  result_exp = snowpark_fn.cast(
6287
6482
  snowpark_fn.to_timestamp(
6288
6483
  snowpark_fn.cast(snowpark_args[0] * 1_000_000, LongType()), 6
@@ -7116,6 +7311,12 @@ def map_unresolved_function(
7116
7311
  )
7117
7312
  )
7118
7313
  )
7314
+ raise_fn = _raise_error_udf_helper(BinaryType())
7315
+ result_exp = (
7316
+ snowpark_fn.when(unbase_arg.is_null(), snowpark_fn.lit(None))
7317
+ .when(result_exp.is_null(), raise_fn(snowpark_fn.lit("Invalid input")))
7318
+ .otherwise(result_exp)
7319
+ )
7119
7320
  result_type = BinaryType()
7120
7321
  case "unhex":
7121
7322
  # Non string columns, convert them to string type. This mimics pyspark behavior.
@@ -7316,6 +7517,15 @@ def map_unresolved_function(
7316
7517
  )
7317
7518
  result_type = LongType()
7318
7519
  case "when" | "if":
7520
+ # Validate that the condition is a boolean expression
7521
+ if len(snowpark_typed_args) > 0:
7522
+ condition_type = snowpark_typed_args[0].typ
7523
+ if not isinstance(condition_type, BooleanType):
7524
+ raise AnalysisException(
7525
+ f"[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve CASE WHEN condition due to data type mismatch: "
7526
+ f"Parameter 1 requires the 'BOOLEAN' type, however got '{condition_type}'"
7527
+ )
7528
+
7319
7529
  name_components = ["CASE"]
7320
7530
  name_components.append("WHEN")
7321
7531
  name_components.append(snowpark_arg_names[0])
@@ -7334,6 +7544,13 @@ def map_unresolved_function(
7334
7544
  name_components.append(snowpark_arg_names[i])
7335
7545
  name_components.append("THEN")
7336
7546
  name_components.append(snowpark_arg_names[i + 1])
7547
+ # Validate each WHEN condition
7548
+ condition_type = snowpark_typed_args[i].typ
7549
+ if not isinstance(condition_type, BooleanType):
7550
+ raise AnalysisException(
7551
+ f"[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve CASE WHEN condition due to data type mismatch: "
7552
+ f"Parameter {i + 1} requires the 'BOOLEAN' type, however got '{condition_type}'"
7553
+ )
7337
7554
  result_exp = result_exp.when(snowpark_args[i], snowpark_args[i + 1])
7338
7555
  result_type_indexes.append(i + 1)
7339
7556
  name_components.append("END")
@@ -9448,3 +9665,22 @@ def _validate_number_format_string(format_str: str) -> None:
9448
9665
  raise AnalysisException(
9449
9666
  f"[INVALID_FORMAT.WRONG_NUM_DIGIT] The format is invalid: '{format_str}'. The format string requires at least one number digit."
9450
9667
  )
9668
+
9669
+
9670
+ def _trim_helper(value: Column, trim_value: Column, trim_type: Column) -> Column:
9671
+ @cached_udf(
9672
+ return_type=BinaryType(),
9673
+ input_types=[BinaryType(), BinaryType(), StringType()],
9674
+ )
9675
+ def _binary_trim_udf(value: bytes, trim_value: bytes, trim_type: str) -> bytes:
9676
+ if value is None or trim_value is None:
9677
+ return value
9678
+ if trim_type in ("rtrim", "btrim", "trim"):
9679
+ while value.endswith(trim_value):
9680
+ value = value[: -len(trim_value)]
9681
+ if trim_type in ("ltrim", "btrim", "trim"):
9682
+ while value.startswith(trim_value):
9683
+ value = value[len(trim_value) :]
9684
+ return value
9685
+
9686
+ return _binary_trim_udf(value, trim_value, trim_type)
@@ -13,10 +13,10 @@ from snowflake.snowpark.types import StructType
13
13
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
14
14
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
15
15
  from snowflake.snowpark_connect.typed_column import TypedColumn
16
- from snowflake.snowpark_connect.utils.attribute_handling import (
16
+ from snowflake.snowpark_connect.utils.context import get_outer_dataframes
17
+ from snowflake.snowpark_connect.utils.identifiers import (
17
18
  split_fully_qualified_spark_name,
18
19
  )
19
- from snowflake.snowpark_connect.utils.context import get_outer_dataframes
20
20
 
21
21
 
22
22
  def check_struct_and_get_field_datatype(field_name, schema):
@@ -66,8 +66,8 @@ def map_unresolved_star(
66
66
  )
67
67
 
68
68
  if len(spark_names) == 0:
69
- for outer_df in get_outer_dataframes():
70
- column_mapping_for_outer_df = outer_df._column_map
69
+ for outer_df_container in get_outer_dataframes():
70
+ column_mapping_for_outer_df = outer_df_container.column_map
71
71
  (
72
72
  spark_names,
73
73
  snowpark_names,
@@ -106,8 +106,8 @@ def map_unresolved_star(
106
106
  )
107
107
  )
108
108
  if prefix_candidate is None:
109
- for outer_df in get_outer_dataframes():
110
- prefix_candidate = outer_df._column_map.get_snowpark_column_name_from_spark_column_name(
109
+ for outer_df_container in get_outer_dataframes():
110
+ prefix_candidate = outer_df_container.column_map.get_snowpark_column_name_from_spark_column_name(
111
111
  prefix_candidate_str, allow_non_exists=True
112
112
  )
113
113
  if prefix_candidate is not None:
@@ -184,8 +184,8 @@ def map_unresolved_star_struct(
184
184
  )
185
185
  )
186
186
  if prefix_candidate is None:
187
- for outer_df in get_outer_dataframes():
188
- prefix_candidate = outer_df._column_map.get_snowpark_column_name_from_spark_column_name(
187
+ for outer_df_container in get_outer_dataframes():
188
+ prefix_candidate = outer_df_container.column_map.get_snowpark_column_name_from_spark_column_name(
189
189
  prefix_candidate_str, allow_non_exists=True
190
190
  )
191
191
  if prefix_candidate is not None:
@@ -10,7 +10,7 @@ from snowflake.snowpark.types import DataType, StringType, StructField, StructTy
10
10
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
11
11
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
12
12
  from snowflake.snowpark_connect.typed_column import TypedColumn
13
- from snowflake.snowpark_connect.utils.attribute_handling import (
13
+ from snowflake.snowpark_connect.utils.identifiers import (
14
14
  split_fully_qualified_spark_name,
15
15
  )
16
16
 
@@ -29,7 +29,7 @@ class ExpressionTyper:
29
29
  types = self._try_to_type_attribute_or_literal(self.df, column)
30
30
  if not types and get_df_before_projection():
31
31
  types = self._try_to_type_attribute_or_literal(
32
- get_df_before_projection(), column
32
+ get_df_before_projection().dataframe, column
33
33
  )
34
34
  if not types:
35
35
  # df.select().schema results in DESCRIBE call to Snowflake, so avoid it if possible
@@ -42,17 +42,17 @@ class ExpressionTyper:
42
42
  try:
43
43
  return self._get_df_datatypes(df, column)
44
44
  except SnowparkClientException: # Fallback to the df before projection
45
- df = get_df_before_projection()
46
- if df is None:
45
+ df_container = get_df_before_projection()
46
+ if df_container is None:
47
47
  raise
48
48
 
49
- df = self._join_df_with_outer_dataframes(df)
49
+ df = self._join_df_with_outer_dataframes(df_container.dataframe)
50
50
  return self._get_df_datatypes(df, column)
51
51
 
52
52
  @staticmethod
53
53
  def _join_df_with_outer_dataframes(df: DataFrame) -> DataFrame:
54
- for outer_df in get_outer_dataframes():
55
- df = df.join(outer_df)
54
+ for outer_df_container in get_outer_dataframes():
55
+ df = df.join(outer_df_container.dataframe)
56
56
 
57
57
  return df
58
58
 
@@ -1,11 +1,12 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Generated by the protocol buffer compiler. DO NOT EDIT!
3
3
  # source: control.proto
4
+ # Protobuf Python Version: 4.25.1
4
5
  """Generated protocol buffer code."""
5
- from google.protobuf.internal import builder as _builder
6
6
  from google.protobuf import descriptor as _descriptor
7
7
  from google.protobuf import descriptor_pool as _descriptor_pool
8
8
  from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
9
10
  # @@protoc_insertion_point(imports)
10
11
 
11
12
  _sym_db = _symbol_database.Default()
@@ -15,21 +16,21 @@ _sym_db = _symbol_database.Default()
15
16
 
16
17
  DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rcontrol.proto\x12\rsnowflake.ses\"*\n\x06\x43onfig\x12\x14\n\x07log_ast\x18\x01 \x01(\x08H\x00\x88\x01\x01\x42\n\n\x08_log_ast\"\x1e\n\x0bPingRequest\x12\x0f\n\x07payload\x18\x01 \x01(\t\"\x1f\n\x0cPingResponse\x12\x0f\n\x07payload\x18\x01 \x01(\t\"+\n\x14GetRequestAstRequest\x12\x13\n\x0b\x66orce_flush\x18\x01 \x01(\x08\"M\n\x15GetRequestAstResponse\x12\x16\n\x0espark_requests\x18\x01 \x03(\x0c\x12\x1c\n\x14snowpark_ast_batches\x18\x02 \x03(\t2\xe8\x01\n\x0e\x43ontrolService\x12\x39\n\tConfigure\x12\x15.snowflake.ses.Config\x1a\x15.snowflake.ses.Config\x12?\n\x04Ping\x12\x1a.snowflake.ses.PingRequest\x1a\x1b.snowflake.ses.PingResponse\x12Z\n\rGetRequestAst\x12#.snowflake.ses.GetRequestAstRequest\x1a$.snowflake.ses.GetRequestAstResponseb\x06proto3')
17
18
 
18
- _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
19
- _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'control_pb2', globals())
19
+ _globals = globals()
20
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
21
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'control_pb2', _globals)
20
22
  if _descriptor._USE_C_DESCRIPTORS == False:
21
-
22
23
  DESCRIPTOR._options = None
23
- _CONFIG._serialized_start=32
24
- _CONFIG._serialized_end=74
25
- _PINGREQUEST._serialized_start=76
26
- _PINGREQUEST._serialized_end=106
27
- _PINGRESPONSE._serialized_start=108
28
- _PINGRESPONSE._serialized_end=139
29
- _GETREQUESTASTREQUEST._serialized_start=141
30
- _GETREQUESTASTREQUEST._serialized_end=184
31
- _GETREQUESTASTRESPONSE._serialized_start=186
32
- _GETREQUESTASTRESPONSE._serialized_end=263
33
- _CONTROLSERVICE._serialized_start=266
34
- _CONTROLSERVICE._serialized_end=498
24
+ _globals['_CONFIG']._serialized_start=32
25
+ _globals['_CONFIG']._serialized_end=74
26
+ _globals['_PINGREQUEST']._serialized_start=76
27
+ _globals['_PINGREQUEST']._serialized_end=106
28
+ _globals['_PINGRESPONSE']._serialized_start=108
29
+ _globals['_PINGRESPONSE']._serialized_end=139
30
+ _globals['_GETREQUESTASTREQUEST']._serialized_start=141
31
+ _globals['_GETREQUESTASTREQUEST']._serialized_end=184
32
+ _globals['_GETREQUESTASTRESPONSE']._serialized_start=186
33
+ _globals['_GETREQUESTASTRESPONSE']._serialized_end=263
34
+ _globals['_CONTROLSERVICE']._serialized_start=266
35
+ _globals['_CONTROLSERVICE']._serialized_end=498
35
36
  # @@protoc_insertion_point(module_scope)