snowpark-connect 0.20.2__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +3 -2
- snowflake/snowpark_connect/column_name_handler.py +6 -65
- snowflake/snowpark_connect/config.py +28 -14
- snowflake/snowpark_connect/dataframe_container.py +242 -0
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +13 -23
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +9 -5
- snowflake/snowpark_connect/expression/map_extension.py +2 -1
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +8 -7
- snowflake/snowpark_connect/expression/map_unresolved_function.py +279 -43
- snowflake/snowpark_connect/expression/map_unresolved_star.py +8 -8
- snowflake/snowpark_connect/expression/map_update_fields.py +1 -1
- snowflake/snowpark_connect/expression/typer.py +6 -6
- snowflake/snowpark_connect/proto/control_pb2.py +17 -16
- snowflake/snowpark_connect/proto/control_pb2.pyi +17 -17
- snowflake/snowpark_connect/proto/control_pb2_grpc.py +12 -63
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +15 -14
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +19 -14
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +27 -26
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +74 -68
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +5 -5
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +25 -17
- snowflake/snowpark_connect/relation/map_aggregate.py +72 -47
- snowflake/snowpark_connect/relation/map_catalog.py +2 -2
- snowflake/snowpark_connect/relation/map_column_ops.py +207 -144
- snowflake/snowpark_connect/relation/map_crosstab.py +25 -6
- snowflake/snowpark_connect/relation/map_extension.py +81 -56
- snowflake/snowpark_connect/relation/map_join.py +72 -63
- snowflake/snowpark_connect/relation/map_local_relation.py +35 -20
- snowflake/snowpark_connect/relation/map_map_partitions.py +21 -16
- snowflake/snowpark_connect/relation/map_relation.py +22 -16
- snowflake/snowpark_connect/relation/map_row_ops.py +232 -146
- snowflake/snowpark_connect/relation/map_sample_by.py +15 -8
- snowflake/snowpark_connect/relation/map_show_string.py +42 -5
- snowflake/snowpark_connect/relation/map_sql.py +155 -78
- snowflake/snowpark_connect/relation/map_stats.py +88 -39
- snowflake/snowpark_connect/relation/map_subquery_alias.py +13 -14
- snowflake/snowpark_connect/relation/map_udtf.py +6 -9
- snowflake/snowpark_connect/relation/read/map_read.py +8 -3
- snowflake/snowpark_connect/relation/read/map_read_csv.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_json.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_socket.py +7 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +25 -16
- snowflake/snowpark_connect/relation/read/map_read_text.py +7 -7
- snowflake/snowpark_connect/relation/utils.py +11 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +15 -12
- snowflake/snowpark_connect/relation/write/map_write.py +199 -40
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +3 -2
- snowflake/snowpark_connect/server.py +34 -4
- snowflake/snowpark_connect/type_mapping.py +2 -23
- snowflake/snowpark_connect/utils/cache.py +27 -22
- snowflake/snowpark_connect/utils/context.py +33 -17
- snowflake/snowpark_connect/utils/{attribute_handling.py → identifiers.py} +47 -0
- snowflake/snowpark_connect/utils/session.py +41 -34
- snowflake/snowpark_connect/utils/telemetry.py +1 -2
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/METADATA +5 -3
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/RECORD +67 -64
- snowpark_connect-0.21.0.dist-info/licenses/LICENSE-binary +568 -0
- snowpark_connect-0.21.0.dist-info/licenses/NOTICE-binary +1533 -0
- {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.20.2.data → snowpark_connect-0.21.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.21.0.dist-info}/top_level.txt +0 -0
|
@@ -69,7 +69,10 @@ from snowflake.snowpark_connect.column_name_handler import (
|
|
|
69
69
|
ColumnNameMap,
|
|
70
70
|
set_schema_getter,
|
|
71
71
|
)
|
|
72
|
-
from snowflake.snowpark_connect.config import
|
|
72
|
+
from snowflake.snowpark_connect.config import (
|
|
73
|
+
get_boolean_session_config_param,
|
|
74
|
+
global_config,
|
|
75
|
+
)
|
|
73
76
|
from snowflake.snowpark_connect.constants import (
|
|
74
77
|
DUPLICATE_KEY_FOUND_ERROR_TEMPLATE,
|
|
75
78
|
SPARK_TZ_ABBREVIATIONS_OVERRIDES,
|
|
@@ -100,6 +103,7 @@ from snowflake.snowpark_connect.typed_column import (
|
|
|
100
103
|
)
|
|
101
104
|
from snowflake.snowpark_connect.utils.context import (
|
|
102
105
|
add_sql_aggregate_function,
|
|
106
|
+
get_current_grouping_columns,
|
|
103
107
|
get_is_aggregate_function,
|
|
104
108
|
get_is_evaluating_sql,
|
|
105
109
|
get_is_in_udtf_context,
|
|
@@ -341,6 +345,9 @@ def map_unresolved_function(
|
|
|
341
345
|
)
|
|
342
346
|
spark_col_names = []
|
|
343
347
|
spark_sql_ansi_enabled = global_config.spark_sql_ansi_enabled
|
|
348
|
+
spark_sql_legacy_allow_hash_on_map_type = (
|
|
349
|
+
global_config.spark_sql_legacy_allowHashOnMapType
|
|
350
|
+
)
|
|
344
351
|
|
|
345
352
|
function_name = exp.unresolved_function.function_name.lower()
|
|
346
353
|
telemetry.report_function_usage(function_name)
|
|
@@ -867,14 +874,30 @@ def map_unresolved_function(
|
|
|
867
874
|
)
|
|
868
875
|
case "approx_percentile" | "percentile_approx":
|
|
869
876
|
# SNOW-1955784: Support accuracy parameter
|
|
877
|
+
# Use percentile_disc to return actual values from dataset (matches PySpark behavior)
|
|
870
878
|
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
879
|
+
def _pyspark_approx_percentile(
|
|
880
|
+
column: Column, percentage: float, original_type: DataType
|
|
881
|
+
) -> Column:
|
|
882
|
+
"""
|
|
883
|
+
PySpark-compatible percentile that returns actual values from dataset.
|
|
884
|
+
- PySpark's approx_percentile returns the "smallest value in the ordered col values
|
|
885
|
+
such that no more than percentage of col values is less than or equal to that value"
|
|
886
|
+
- This means it MUST return an actual value from the original dataset
|
|
887
|
+
- Snowflake's approx_percentile() may interpolate between values, breaking compatibility
|
|
888
|
+
- percentile_disc() returns discrete values (actual dataset values), matching PySpark
|
|
889
|
+
"""
|
|
890
|
+
# Even though the Spark function accepts a Column for percentage, it will fail unless it's a literal.
|
|
891
|
+
# Therefore, we can do error checking right here.
|
|
892
|
+
if not 0.0 <= percentage <= 1.0:
|
|
876
893
|
raise AnalysisException("percentage must be between [0.0, 1.0]")
|
|
877
|
-
|
|
894
|
+
|
|
895
|
+
result = snowpark_fn.function("percentile_disc")(
|
|
896
|
+
snowpark_fn.lit(percentage)
|
|
897
|
+
).within_group(column)
|
|
898
|
+
return snowpark_fn.cast(result, original_type)
|
|
899
|
+
|
|
900
|
+
column_type = snowpark_typed_args[0].typ
|
|
878
901
|
|
|
879
902
|
if isinstance(snowpark_typed_args[1].typ, ArrayType):
|
|
880
903
|
# Snowpark doesn't accept a list of percentile values.
|
|
@@ -882,26 +905,26 @@ def map_unresolved_function(
|
|
|
882
905
|
array_func = exp.unresolved_function.arguments[1].unresolved_function
|
|
883
906
|
assert array_func.function_name == "array", array_func
|
|
884
907
|
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
)
|
|
908
|
+
percentile_results = [
|
|
909
|
+
_pyspark_approx_percentile(
|
|
910
|
+
snowpark_args[0], unwrap_literal(arg), column_type
|
|
911
|
+
)
|
|
912
|
+
for arg in array_func.arguments
|
|
913
|
+
]
|
|
914
|
+
|
|
915
|
+
result_type = ArrayType(element_type=column_type, contains_null=False)
|
|
893
916
|
result_exp = snowpark_fn.cast(
|
|
894
|
-
|
|
895
|
-
|
|
917
|
+
snowpark_fn.array_construct(*percentile_results),
|
|
918
|
+
result_type,
|
|
896
919
|
)
|
|
897
|
-
result_type = ArrayType(element_type=DoubleType(), contains_null=False)
|
|
898
920
|
else:
|
|
921
|
+
# Handle single percentile
|
|
922
|
+
percentage = unwrap_literal(exp.unresolved_function.arguments[1])
|
|
899
923
|
result_exp = TypedColumn(
|
|
900
|
-
|
|
901
|
-
snowpark_args[0],
|
|
902
|
-
_check_percentage(exp.unresolved_function.arguments[1]),
|
|
924
|
+
_pyspark_approx_percentile(
|
|
925
|
+
snowpark_args[0], percentage, column_type
|
|
903
926
|
),
|
|
904
|
-
lambda: [
|
|
927
|
+
lambda: [column_type],
|
|
905
928
|
)
|
|
906
929
|
case "array":
|
|
907
930
|
if len(snowpark_args) == 0:
|
|
@@ -2073,14 +2096,22 @@ def map_unresolved_function(
|
|
|
2073
2096
|
assert (
|
|
2074
2097
|
len(exp.unresolved_function.arguments) == 2
|
|
2075
2098
|
), "date_format takes 2 arguments"
|
|
2076
|
-
|
|
2077
|
-
|
|
2078
|
-
|
|
2079
|
-
|
|
2080
|
-
|
|
2081
|
-
|
|
2082
|
-
|
|
2083
|
-
|
|
2099
|
+
|
|
2100
|
+
# Check if format parameter is NULL
|
|
2101
|
+
format_literal = unwrap_literal(exp.unresolved_function.arguments[1])
|
|
2102
|
+
if format_literal is None:
|
|
2103
|
+
# If format is NULL, return NULL for all rows
|
|
2104
|
+
result_exp = snowpark_fn.lit(None)
|
|
2105
|
+
else:
|
|
2106
|
+
result_exp = snowpark_fn.date_format(
|
|
2107
|
+
snowpark_args[0],
|
|
2108
|
+
snowpark_fn.lit(
|
|
2109
|
+
map_spark_timestamp_format_expression(
|
|
2110
|
+
exp.unresolved_function.arguments[1],
|
|
2111
|
+
snowpark_typed_args[0].typ,
|
|
2112
|
+
)
|
|
2113
|
+
),
|
|
2114
|
+
)
|
|
2084
2115
|
result_exp = TypedColumn(result_exp, lambda: [StringType()])
|
|
2085
2116
|
case "date_from_unix_date":
|
|
2086
2117
|
result_exp = snowpark_fn.date_add(
|
|
@@ -2535,6 +2566,19 @@ def map_unresolved_function(
|
|
|
2535
2566
|
input_types=[StringType(), StringType(), StructType()],
|
|
2536
2567
|
)
|
|
2537
2568
|
def _from_csv(csv_data: str, schema: str, options: Optional[dict]):
|
|
2569
|
+
if csv_data is None:
|
|
2570
|
+
return None
|
|
2571
|
+
|
|
2572
|
+
if csv_data == "":
|
|
2573
|
+
# Return dict with None values for empty string
|
|
2574
|
+
schemas = schema.split(",")
|
|
2575
|
+
results = {}
|
|
2576
|
+
for sc in schemas:
|
|
2577
|
+
parts = [i for i in sc.split(" ") if len(i) != 0]
|
|
2578
|
+
assert len(parts) == 2, f"{sc} is not a valid schema"
|
|
2579
|
+
results[parts[0]] = None
|
|
2580
|
+
return results
|
|
2581
|
+
|
|
2538
2582
|
max_chars_per_column = -1
|
|
2539
2583
|
sep = ","
|
|
2540
2584
|
|
|
@@ -2617,7 +2661,9 @@ def map_unresolved_function(
|
|
|
2617
2661
|
case _:
|
|
2618
2662
|
raise ValueError("Unrecognized from_csv parameters")
|
|
2619
2663
|
|
|
2620
|
-
result_exp = snowpark_fn.
|
|
2664
|
+
result_exp = snowpark_fn.when(
|
|
2665
|
+
snowpark_args[0].is_null(), snowpark_fn.lit(None)
|
|
2666
|
+
).otherwise(snowpark_fn.cast(csv_result, ddl_schema))
|
|
2621
2667
|
result_type = ddl_schema
|
|
2622
2668
|
case "from_json":
|
|
2623
2669
|
# TODO: support options.
|
|
@@ -2651,6 +2697,9 @@ def map_unresolved_function(
|
|
|
2651
2697
|
# try to parse first, since spark returns null for invalid json
|
|
2652
2698
|
result_exp = snowpark_fn.call_function("try_parse_json", snowpark_args[0])
|
|
2653
2699
|
|
|
2700
|
+
# Check if the original input is NULL - if so, return NULL for the entire result
|
|
2701
|
+
original_input_is_null = snowpark_args[0].is_null()
|
|
2702
|
+
|
|
2654
2703
|
# helper function to make sure we have the expected array element type
|
|
2655
2704
|
def _element_type_matches(
|
|
2656
2705
|
array_exp: Column, element_type: DataType
|
|
@@ -2749,9 +2798,13 @@ def map_unresolved_function(
|
|
|
2749
2798
|
else:
|
|
2750
2799
|
return exp
|
|
2751
2800
|
|
|
2752
|
-
|
|
2753
|
-
|
|
2754
|
-
|
|
2801
|
+
# Apply the coercion to handle invalid JSON (creates struct with NULL fields)
|
|
2802
|
+
coerced_exp = _coerce_to_type(result_exp, result_type)
|
|
2803
|
+
|
|
2804
|
+
# If the original input was NULL, return NULL instead of a struct
|
|
2805
|
+
result_exp = snowpark_fn.when(
|
|
2806
|
+
original_input_is_null, snowpark_fn.lit(None)
|
|
2807
|
+
).otherwise(snowpark_fn.cast(coerced_exp, result_type))
|
|
2755
2808
|
case "from_unixtime":
|
|
2756
2809
|
|
|
2757
2810
|
def raise_analysis_exception(
|
|
@@ -2896,10 +2949,53 @@ def map_unresolved_function(
|
|
|
2896
2949
|
)
|
|
2897
2950
|
case "grouping" | "grouping_id":
|
|
2898
2951
|
# grouping_id is not an alias for grouping in PySpark, but Snowflake's implementation handles both
|
|
2899
|
-
|
|
2952
|
+
current_grouping_cols = get_current_grouping_columns()
|
|
2953
|
+
if function_name == "grouping_id":
|
|
2954
|
+
if not snowpark_args:
|
|
2955
|
+
# grouping_id() with empty args means use all grouping columns
|
|
2956
|
+
spark_function_name = "grouping_id()"
|
|
2957
|
+
snowpark_args = [
|
|
2958
|
+
column_mapping.get_snowpark_column_name_from_spark_column_name(
|
|
2959
|
+
spark_col
|
|
2960
|
+
)
|
|
2961
|
+
for spark_col in current_grouping_cols
|
|
2962
|
+
]
|
|
2963
|
+
else:
|
|
2964
|
+
# Verify that grouping arguments match current grouping columns
|
|
2965
|
+
spark_col_args = [
|
|
2966
|
+
column_mapping.get_spark_column_name_from_snowpark_column_name(
|
|
2967
|
+
sp_col.getName()
|
|
2968
|
+
)
|
|
2969
|
+
for sp_col in snowpark_args
|
|
2970
|
+
]
|
|
2971
|
+
if current_grouping_cols != spark_col_args:
|
|
2972
|
+
raise AnalysisException(
|
|
2973
|
+
f"[GROUPING_ID_COLUMN_MISMATCH] Columns of grouping_id: {spark_col_args} doesnt match "
|
|
2974
|
+
f"Grouping columns: {current_grouping_cols}"
|
|
2975
|
+
)
|
|
2976
|
+
if function_name == "grouping_id":
|
|
2977
|
+
result_exp = snowpark_fn.grouping_id(*snowpark_args)
|
|
2978
|
+
else:
|
|
2979
|
+
result_exp = snowpark_fn.grouping(*snowpark_args)
|
|
2900
2980
|
result_type = LongType()
|
|
2901
2981
|
case "hash":
|
|
2902
2982
|
# TODO: See the spark-compatibility-issues.md explanation, this is quite different from Spark.
|
|
2983
|
+
# MapType columns as input should raise an exception as they are not hashable.
|
|
2984
|
+
snowflake_compat = get_boolean_session_config_param(
|
|
2985
|
+
"enable_snowflake_extension_behavior"
|
|
2986
|
+
)
|
|
2987
|
+
# Snowflake's hash function does allow MAP types, but Spark does not. Therefore, if we have the expansion flag enabled
|
|
2988
|
+
# we want to let it pass through and hash MAP types.
|
|
2989
|
+
# Also allow if the legacy config spark.sql.legacy.allowHashOnMapType is set to true
|
|
2990
|
+
if not snowflake_compat and not spark_sql_legacy_allow_hash_on_map_type:
|
|
2991
|
+
for arg in snowpark_typed_args:
|
|
2992
|
+
if any(isinstance(t, MapType) for t in arg.types):
|
|
2993
|
+
raise AnalysisException(
|
|
2994
|
+
'[DATATYPE_MISMATCH.HASH_MAP_TYPE] Cannot resolve "hash(value)" due to data type mismatch: '
|
|
2995
|
+
'Input to the function `hash` cannot contain elements of the "MAP" type. '
|
|
2996
|
+
'In Spark, same maps may have different hashcode, thus hash expressions are prohibited on "MAP" elements. '
|
|
2997
|
+
'To restore previous behavior set "spark.sql.legacy.allowHashOnMapType" to "true".'
|
|
2998
|
+
)
|
|
2903
2999
|
result_exp = snowpark_fn.hash(*snowpark_args)
|
|
2904
3000
|
result_type = LongType()
|
|
2905
3001
|
case "hex":
|
|
@@ -2934,6 +3030,14 @@ def map_unresolved_function(
|
|
|
2934
3030
|
result_type = StringType()
|
|
2935
3031
|
case "histogram_numeric":
|
|
2936
3032
|
aggregate_input_typ = snowpark_typed_args[0].typ
|
|
3033
|
+
|
|
3034
|
+
if isinstance(aggregate_input_typ, DecimalType):
|
|
3035
|
+
# mimic bug from Spark 3.5.3.
|
|
3036
|
+
# In 3.5.5 it's fixed and this exception shouldn't be thrown
|
|
3037
|
+
raise ValueError(
|
|
3038
|
+
"class org.apache.spark.sql.types.Decimal cannot be cast to class java.lang.Number (org.apache.spark.sql.types.Decimal is in unnamed module of loader 'app'; java.lang.Number is in module java.base of loader 'bootstrap')"
|
|
3039
|
+
)
|
|
3040
|
+
|
|
2937
3041
|
histogram_return_type = ArrayType(
|
|
2938
3042
|
StructType(
|
|
2939
3043
|
[
|
|
@@ -3154,6 +3258,18 @@ def map_unresolved_function(
|
|
|
3154
3258
|
)
|
|
3155
3259
|
result_type = histogram_return_type
|
|
3156
3260
|
case "hll_sketch_agg":
|
|
3261
|
+
# check if input type is correct
|
|
3262
|
+
if type(snowpark_typed_args[0].typ) not in [
|
|
3263
|
+
IntegerType,
|
|
3264
|
+
LongType,
|
|
3265
|
+
StringType,
|
|
3266
|
+
BinaryType,
|
|
3267
|
+
]:
|
|
3268
|
+
type_str = snowpark_typed_args[0].typ.simpleString().upper()
|
|
3269
|
+
raise AnalysisException(
|
|
3270
|
+
f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: Parameter 1 requires the ("INT" or "BIGINT" or "STRING" or "BINARY") type, however "{snowpark_arg_names[0]}" has the type "{type_str}".'
|
|
3271
|
+
)
|
|
3272
|
+
|
|
3157
3273
|
match snowpark_args:
|
|
3158
3274
|
case [sketch]:
|
|
3159
3275
|
spark_function_name = (
|
|
@@ -3796,12 +3912,47 @@ def map_unresolved_function(
|
|
|
3796
3912
|
)
|
|
3797
3913
|
|
|
3798
3914
|
result_type = StringType()
|
|
3799
|
-
case "ltrim":
|
|
3915
|
+
case "ltrim" | "rtrim":
|
|
3916
|
+
function_name_argument = (
|
|
3917
|
+
"TRAILING" if function_name == "rtrim" else "LEADING"
|
|
3918
|
+
)
|
|
3800
3919
|
if len(snowpark_args) == 2:
|
|
3801
3920
|
# Only possible using SQL
|
|
3802
|
-
spark_function_name = f"TRIM(
|
|
3921
|
+
spark_function_name = f"TRIM({function_name_argument} {snowpark_arg_names[1]} FROM {snowpark_arg_names[0]})"
|
|
3803
3922
|
result_exp = snowpark_fn.ltrim(*snowpark_args)
|
|
3804
3923
|
result_type = StringType()
|
|
3924
|
+
if isinstance(snowpark_typed_args[0].typ, BinaryType):
|
|
3925
|
+
argument_name = snowpark_arg_names[0]
|
|
3926
|
+
if exp.unresolved_function.arguments[0].HasField("literal"):
|
|
3927
|
+
argument_name = f"""X'{exp.unresolved_function.arguments[0].literal.binary.hex()}'"""
|
|
3928
|
+
if len(snowpark_args) == 1:
|
|
3929
|
+
spark_function_name = f"{function_name}({argument_name})"
|
|
3930
|
+
trim_value = snowpark_fn.lit(b"\x20")
|
|
3931
|
+
if len(snowpark_args) == 2:
|
|
3932
|
+
# Only possible using SQL
|
|
3933
|
+
trim_arg = snowpark_arg_names[1]
|
|
3934
|
+
if isinstance(
|
|
3935
|
+
snowpark_typed_args[1].typ, BinaryType
|
|
3936
|
+
) and exp.unresolved_function.arguments[1].HasField("literal"):
|
|
3937
|
+
trim_arg = f"""X'{exp.unresolved_function.arguments[1].literal.binary.hex()}'"""
|
|
3938
|
+
trim_value = snowpark_args[1]
|
|
3939
|
+
else:
|
|
3940
|
+
trim_value = snowpark_fn.lit(None)
|
|
3941
|
+
function_name_argument = (
|
|
3942
|
+
"TRAILING" if function_name == "rtrim" else "LEADING"
|
|
3943
|
+
)
|
|
3944
|
+
spark_function_name = f"TRIM({function_name_argument} {trim_arg} FROM {argument_name})"
|
|
3945
|
+
result_exp = _trim_helper(
|
|
3946
|
+
snowpark_args[0], trim_value, snowpark_fn.lit(function_name)
|
|
3947
|
+
)
|
|
3948
|
+
result_type = BinaryType()
|
|
3949
|
+
else:
|
|
3950
|
+
if function_name == "ltrim":
|
|
3951
|
+
result_exp = snowpark_fn.ltrim(*snowpark_args)
|
|
3952
|
+
result_type = StringType()
|
|
3953
|
+
elif function_name == "rtrim":
|
|
3954
|
+
result_exp = snowpark_fn.rtrim(*snowpark_args)
|
|
3955
|
+
result_type = StringType()
|
|
3805
3956
|
case "make_date":
|
|
3806
3957
|
y = snowpark_args[0].cast(LongType())
|
|
3807
3958
|
m = snowpark_args[1].cast(LongType())
|
|
@@ -4258,6 +4409,17 @@ def map_unresolved_function(
|
|
|
4258
4409
|
lambda: snowpark_typed_args[0].types,
|
|
4259
4410
|
)
|
|
4260
4411
|
case "md5":
|
|
4412
|
+
snowflake_compat = get_boolean_session_config_param(
|
|
4413
|
+
"enable_snowflake_extension_behavior"
|
|
4414
|
+
)
|
|
4415
|
+
|
|
4416
|
+
# MD5 in Spark only accepts BinaryType or types that can be implicitly cast to it (StringType)
|
|
4417
|
+
if not snowflake_compat:
|
|
4418
|
+
if not isinstance(snowpark_typed_args[0].typ, (BinaryType, StringType)):
|
|
4419
|
+
raise AnalysisException(
|
|
4420
|
+
f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "md5({snowpark_arg_names[0]})" due to data type mismatch: '
|
|
4421
|
+
f'Parameter 1 requires the "BINARY" type, however "{snowpark_arg_names[0]}" has the type "{snowpark_typed_args[0].typ}".'
|
|
4422
|
+
)
|
|
4261
4423
|
result_exp = snowpark_fn.md5(snowpark_args[0])
|
|
4262
4424
|
result_type = StringType(32)
|
|
4263
4425
|
case "median":
|
|
@@ -5466,13 +5628,28 @@ def map_unresolved_function(
|
|
|
5466
5628
|
case "row_number":
|
|
5467
5629
|
result_exp = snowpark_fn.row_number()
|
|
5468
5630
|
result_exp = TypedColumn(result_exp, lambda: [LongType()])
|
|
5469
|
-
case "rtrim":
|
|
5470
|
-
if len(snowpark_args) == 2:
|
|
5471
|
-
# Only possible using SQL
|
|
5472
|
-
spark_function_name = f"TRIM(TRAILING {snowpark_arg_names[1]} FROM {snowpark_arg_names[0]})"
|
|
5473
|
-
result_exp = snowpark_fn.rtrim(*snowpark_args)
|
|
5474
|
-
result_type = StringType()
|
|
5475
5631
|
case "schema_of_csv":
|
|
5632
|
+
# Validate that the input is a foldable STRING expression
|
|
5633
|
+
if (
|
|
5634
|
+
exp.unresolved_function.arguments[0].WhichOneof("expr_type")
|
|
5635
|
+
!= "literal"
|
|
5636
|
+
):
|
|
5637
|
+
raise AnalysisException(
|
|
5638
|
+
"[DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "
|
|
5639
|
+
f'"schema_of_csv({snowpark_arg_names[0]})" due to data type mismatch: '
|
|
5640
|
+
'the input csv should be a foldable "STRING" expression; however, '
|
|
5641
|
+
f'got "{snowpark_arg_names[0]}".'
|
|
5642
|
+
)
|
|
5643
|
+
|
|
5644
|
+
if isinstance(snowpark_typed_args[0].typ, StringType):
|
|
5645
|
+
if exp.unresolved_function.arguments[0].literal.string == "":
|
|
5646
|
+
raise AnalysisException(
|
|
5647
|
+
"[DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "
|
|
5648
|
+
f'"schema_of_csv({snowpark_arg_names[0]})" due to data type mismatch: '
|
|
5649
|
+
'the input csv should be a foldable "STRING" expression; however, '
|
|
5650
|
+
f'got "{snowpark_arg_names[0]}".'
|
|
5651
|
+
)
|
|
5652
|
+
|
|
5476
5653
|
snowpark_args = [
|
|
5477
5654
|
typed_arg.column(to_semi_structure=True)
|
|
5478
5655
|
for typed_arg in snowpark_typed_args
|
|
@@ -5689,6 +5866,16 @@ def map_unresolved_function(
|
|
|
5689
5866
|
)
|
|
5690
5867
|
result_type = ArrayType(ArrayType(StringType()))
|
|
5691
5868
|
case "sequence":
|
|
5869
|
+
if snowpark_typed_args[0].typ != snowpark_typed_args[1].typ or (
|
|
5870
|
+
not isinstance(snowpark_typed_args[0].typ, _IntegralType)
|
|
5871
|
+
or not isinstance(snowpark_typed_args[1].typ, _IntegralType)
|
|
5872
|
+
):
|
|
5873
|
+
raise AnalysisException(
|
|
5874
|
+
f"""[DATATYPE_MISMATCH.SEQUENCE_WRONG_INPUT_TYPES] Cannot resolve "sequence({snowpark_arg_names[0]}, {snowpark_arg_names[1]})" due to data type mismatch: `sequence` uses the wrong parameter type. The parameter type must conform to:
|
|
5875
|
+
1. The start and stop expressions must resolve to the same type.
|
|
5876
|
+
2. Otherwise, if start and stop expressions resolve to the "INTEGRAL" type, then the step expression must resolve to the same type.
|
|
5877
|
+
"""
|
|
5878
|
+
)
|
|
5692
5879
|
result_exp = snowpark_fn.cast(
|
|
5693
5880
|
snowpark_fn.sequence(*snowpark_args),
|
|
5694
5881
|
ArrayType(LongType(), contains_null=False),
|
|
@@ -6274,6 +6461,10 @@ def map_unresolved_function(
|
|
|
6274
6461
|
)
|
|
6275
6462
|
result_type = TimestampType(snowpark.types.TimestampTimeZone.NTZ)
|
|
6276
6463
|
case "timestamp_millis":
|
|
6464
|
+
if not isinstance(snowpark_typed_args[0].typ, _IntegralType):
|
|
6465
|
+
raise AnalysisException(
|
|
6466
|
+
f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "timestamp_millis({snowpark_arg_names[0]}" due to data type mismatch: Parameter 1 requires the "INTEGRAL" type, however "{snowpark_arg_names[0]}" has the type "{snowpark_typed_args[0].typ}".'
|
|
6467
|
+
)
|
|
6277
6468
|
result_exp = snowpark_fn.cast(
|
|
6278
6469
|
snowpark_fn.to_timestamp(snowpark_args[0] * 1_000, 6),
|
|
6279
6470
|
TimestampType(snowpark.types.TimestampTimeZone.NTZ),
|
|
@@ -6283,6 +6474,10 @@ def map_unresolved_function(
|
|
|
6283
6474
|
# Spark allows seconds to be fractional. Snowflake does not allow that
|
|
6284
6475
|
# even though the documentation explicitly says that it does.
|
|
6285
6476
|
# As a workaround, use integer milliseconds instead of fractional seconds.
|
|
6477
|
+
if not isinstance(snowpark_typed_args[0].typ, _NumericType):
|
|
6478
|
+
raise AnalysisException(
|
|
6479
|
+
f"""AnalysisException: [DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{function_name}({snowpark_arg_names[0]})" due to data type mismatch: Parameter 1 requires the "NUMERIC" type, however "{snowpark_arg_names[0]}" has the type "{snowpark_typed_args[0].typ}".;"""
|
|
6480
|
+
)
|
|
6286
6481
|
result_exp = snowpark_fn.cast(
|
|
6287
6482
|
snowpark_fn.to_timestamp(
|
|
6288
6483
|
snowpark_fn.cast(snowpark_args[0] * 1_000_000, LongType()), 6
|
|
@@ -7116,6 +7311,12 @@ def map_unresolved_function(
|
|
|
7116
7311
|
)
|
|
7117
7312
|
)
|
|
7118
7313
|
)
|
|
7314
|
+
raise_fn = _raise_error_udf_helper(BinaryType())
|
|
7315
|
+
result_exp = (
|
|
7316
|
+
snowpark_fn.when(unbase_arg.is_null(), snowpark_fn.lit(None))
|
|
7317
|
+
.when(result_exp.is_null(), raise_fn(snowpark_fn.lit("Invalid input")))
|
|
7318
|
+
.otherwise(result_exp)
|
|
7319
|
+
)
|
|
7119
7320
|
result_type = BinaryType()
|
|
7120
7321
|
case "unhex":
|
|
7121
7322
|
# Non string columns, convert them to string type. This mimics pyspark behavior.
|
|
@@ -7316,6 +7517,15 @@ def map_unresolved_function(
|
|
|
7316
7517
|
)
|
|
7317
7518
|
result_type = LongType()
|
|
7318
7519
|
case "when" | "if":
|
|
7520
|
+
# Validate that the condition is a boolean expression
|
|
7521
|
+
if len(snowpark_typed_args) > 0:
|
|
7522
|
+
condition_type = snowpark_typed_args[0].typ
|
|
7523
|
+
if not isinstance(condition_type, BooleanType):
|
|
7524
|
+
raise AnalysisException(
|
|
7525
|
+
f"[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve CASE WHEN condition due to data type mismatch: "
|
|
7526
|
+
f"Parameter 1 requires the 'BOOLEAN' type, however got '{condition_type}'"
|
|
7527
|
+
)
|
|
7528
|
+
|
|
7319
7529
|
name_components = ["CASE"]
|
|
7320
7530
|
name_components.append("WHEN")
|
|
7321
7531
|
name_components.append(snowpark_arg_names[0])
|
|
@@ -7334,6 +7544,13 @@ def map_unresolved_function(
|
|
|
7334
7544
|
name_components.append(snowpark_arg_names[i])
|
|
7335
7545
|
name_components.append("THEN")
|
|
7336
7546
|
name_components.append(snowpark_arg_names[i + 1])
|
|
7547
|
+
# Validate each WHEN condition
|
|
7548
|
+
condition_type = snowpark_typed_args[i].typ
|
|
7549
|
+
if not isinstance(condition_type, BooleanType):
|
|
7550
|
+
raise AnalysisException(
|
|
7551
|
+
f"[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve CASE WHEN condition due to data type mismatch: "
|
|
7552
|
+
f"Parameter {i + 1} requires the 'BOOLEAN' type, however got '{condition_type}'"
|
|
7553
|
+
)
|
|
7337
7554
|
result_exp = result_exp.when(snowpark_args[i], snowpark_args[i + 1])
|
|
7338
7555
|
result_type_indexes.append(i + 1)
|
|
7339
7556
|
name_components.append("END")
|
|
@@ -9448,3 +9665,22 @@ def _validate_number_format_string(format_str: str) -> None:
|
|
|
9448
9665
|
raise AnalysisException(
|
|
9449
9666
|
f"[INVALID_FORMAT.WRONG_NUM_DIGIT] The format is invalid: '{format_str}'. The format string requires at least one number digit."
|
|
9450
9667
|
)
|
|
9668
|
+
|
|
9669
|
+
|
|
9670
|
+
def _trim_helper(value: Column, trim_value: Column, trim_type: Column) -> Column:
|
|
9671
|
+
@cached_udf(
|
|
9672
|
+
return_type=BinaryType(),
|
|
9673
|
+
input_types=[BinaryType(), BinaryType(), StringType()],
|
|
9674
|
+
)
|
|
9675
|
+
def _binary_trim_udf(value: bytes, trim_value: bytes, trim_type: str) -> bytes:
|
|
9676
|
+
if value is None or trim_value is None:
|
|
9677
|
+
return value
|
|
9678
|
+
if trim_type in ("rtrim", "btrim", "trim"):
|
|
9679
|
+
while value.endswith(trim_value):
|
|
9680
|
+
value = value[: -len(trim_value)]
|
|
9681
|
+
if trim_type in ("ltrim", "btrim", "trim"):
|
|
9682
|
+
while value.startswith(trim_value):
|
|
9683
|
+
value = value[len(trim_value) :]
|
|
9684
|
+
return value
|
|
9685
|
+
|
|
9686
|
+
return _binary_trim_udf(value, trim_value, trim_type)
|
|
@@ -13,10 +13,10 @@ from snowflake.snowpark.types import StructType
|
|
|
13
13
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
14
14
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
15
15
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
16
|
-
from snowflake.snowpark_connect.utils.
|
|
16
|
+
from snowflake.snowpark_connect.utils.context import get_outer_dataframes
|
|
17
|
+
from snowflake.snowpark_connect.utils.identifiers import (
|
|
17
18
|
split_fully_qualified_spark_name,
|
|
18
19
|
)
|
|
19
|
-
from snowflake.snowpark_connect.utils.context import get_outer_dataframes
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def check_struct_and_get_field_datatype(field_name, schema):
|
|
@@ -66,8 +66,8 @@ def map_unresolved_star(
|
|
|
66
66
|
)
|
|
67
67
|
|
|
68
68
|
if len(spark_names) == 0:
|
|
69
|
-
for
|
|
70
|
-
column_mapping_for_outer_df =
|
|
69
|
+
for outer_df_container in get_outer_dataframes():
|
|
70
|
+
column_mapping_for_outer_df = outer_df_container.column_map
|
|
71
71
|
(
|
|
72
72
|
spark_names,
|
|
73
73
|
snowpark_names,
|
|
@@ -106,8 +106,8 @@ def map_unresolved_star(
|
|
|
106
106
|
)
|
|
107
107
|
)
|
|
108
108
|
if prefix_candidate is None:
|
|
109
|
-
for
|
|
110
|
-
prefix_candidate =
|
|
109
|
+
for outer_df_container in get_outer_dataframes():
|
|
110
|
+
prefix_candidate = outer_df_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
111
111
|
prefix_candidate_str, allow_non_exists=True
|
|
112
112
|
)
|
|
113
113
|
if prefix_candidate is not None:
|
|
@@ -184,8 +184,8 @@ def map_unresolved_star_struct(
|
|
|
184
184
|
)
|
|
185
185
|
)
|
|
186
186
|
if prefix_candidate is None:
|
|
187
|
-
for
|
|
188
|
-
prefix_candidate =
|
|
187
|
+
for outer_df_container in get_outer_dataframes():
|
|
188
|
+
prefix_candidate = outer_df_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
189
189
|
prefix_candidate_str, allow_non_exists=True
|
|
190
190
|
)
|
|
191
191
|
if prefix_candidate is not None:
|
|
@@ -10,7 +10,7 @@ from snowflake.snowpark.types import DataType, StringType, StructField, StructTy
|
|
|
10
10
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
11
11
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
12
12
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
13
|
-
from snowflake.snowpark_connect.utils.
|
|
13
|
+
from snowflake.snowpark_connect.utils.identifiers import (
|
|
14
14
|
split_fully_qualified_spark_name,
|
|
15
15
|
)
|
|
16
16
|
|
|
@@ -29,7 +29,7 @@ class ExpressionTyper:
|
|
|
29
29
|
types = self._try_to_type_attribute_or_literal(self.df, column)
|
|
30
30
|
if not types and get_df_before_projection():
|
|
31
31
|
types = self._try_to_type_attribute_or_literal(
|
|
32
|
-
get_df_before_projection(), column
|
|
32
|
+
get_df_before_projection().dataframe, column
|
|
33
33
|
)
|
|
34
34
|
if not types:
|
|
35
35
|
# df.select().schema results in DESCRIBE call to Snowflake, so avoid it if possible
|
|
@@ -42,17 +42,17 @@ class ExpressionTyper:
|
|
|
42
42
|
try:
|
|
43
43
|
return self._get_df_datatypes(df, column)
|
|
44
44
|
except SnowparkClientException: # Fallback to the df before projection
|
|
45
|
-
|
|
46
|
-
if
|
|
45
|
+
df_container = get_df_before_projection()
|
|
46
|
+
if df_container is None:
|
|
47
47
|
raise
|
|
48
48
|
|
|
49
|
-
df = self._join_df_with_outer_dataframes(
|
|
49
|
+
df = self._join_df_with_outer_dataframes(df_container.dataframe)
|
|
50
50
|
return self._get_df_datatypes(df, column)
|
|
51
51
|
|
|
52
52
|
@staticmethod
|
|
53
53
|
def _join_df_with_outer_dataframes(df: DataFrame) -> DataFrame:
|
|
54
|
-
for
|
|
55
|
-
df = df.join(
|
|
54
|
+
for outer_df_container in get_outer_dataframes():
|
|
55
|
+
df = df.join(outer_df_container.dataframe)
|
|
56
56
|
|
|
57
57
|
return df
|
|
58
58
|
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
3
|
# source: control.proto
|
|
4
|
+
# Protobuf Python Version: 4.25.1
|
|
4
5
|
"""Generated protocol buffer code."""
|
|
5
|
-
from google.protobuf.internal import builder as _builder
|
|
6
6
|
from google.protobuf import descriptor as _descriptor
|
|
7
7
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
8
8
|
from google.protobuf import symbol_database as _symbol_database
|
|
9
|
+
from google.protobuf.internal import builder as _builder
|
|
9
10
|
# @@protoc_insertion_point(imports)
|
|
10
11
|
|
|
11
12
|
_sym_db = _symbol_database.Default()
|
|
@@ -15,21 +16,21 @@ _sym_db = _symbol_database.Default()
|
|
|
15
16
|
|
|
16
17
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rcontrol.proto\x12\rsnowflake.ses\"*\n\x06\x43onfig\x12\x14\n\x07log_ast\x18\x01 \x01(\x08H\x00\x88\x01\x01\x42\n\n\x08_log_ast\"\x1e\n\x0bPingRequest\x12\x0f\n\x07payload\x18\x01 \x01(\t\"\x1f\n\x0cPingResponse\x12\x0f\n\x07payload\x18\x01 \x01(\t\"+\n\x14GetRequestAstRequest\x12\x13\n\x0b\x66orce_flush\x18\x01 \x01(\x08\"M\n\x15GetRequestAstResponse\x12\x16\n\x0espark_requests\x18\x01 \x03(\x0c\x12\x1c\n\x14snowpark_ast_batches\x18\x02 \x03(\t2\xe8\x01\n\x0e\x43ontrolService\x12\x39\n\tConfigure\x12\x15.snowflake.ses.Config\x1a\x15.snowflake.ses.Config\x12?\n\x04Ping\x12\x1a.snowflake.ses.PingRequest\x1a\x1b.snowflake.ses.PingResponse\x12Z\n\rGetRequestAst\x12#.snowflake.ses.GetRequestAstRequest\x1a$.snowflake.ses.GetRequestAstResponseb\x06proto3')
|
|
17
18
|
|
|
18
|
-
|
|
19
|
-
_builder.
|
|
19
|
+
_globals = globals()
|
|
20
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
21
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'control_pb2', _globals)
|
|
20
22
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
|
21
|
-
|
|
22
23
|
DESCRIPTOR._options = None
|
|
23
|
-
_CONFIG._serialized_start=32
|
|
24
|
-
_CONFIG._serialized_end=74
|
|
25
|
-
_PINGREQUEST._serialized_start=76
|
|
26
|
-
_PINGREQUEST._serialized_end=106
|
|
27
|
-
_PINGRESPONSE._serialized_start=108
|
|
28
|
-
_PINGRESPONSE._serialized_end=139
|
|
29
|
-
_GETREQUESTASTREQUEST._serialized_start=141
|
|
30
|
-
_GETREQUESTASTREQUEST._serialized_end=184
|
|
31
|
-
_GETREQUESTASTRESPONSE._serialized_start=186
|
|
32
|
-
_GETREQUESTASTRESPONSE._serialized_end=263
|
|
33
|
-
_CONTROLSERVICE._serialized_start=266
|
|
34
|
-
_CONTROLSERVICE._serialized_end=498
|
|
24
|
+
_globals['_CONFIG']._serialized_start=32
|
|
25
|
+
_globals['_CONFIG']._serialized_end=74
|
|
26
|
+
_globals['_PINGREQUEST']._serialized_start=76
|
|
27
|
+
_globals['_PINGREQUEST']._serialized_end=106
|
|
28
|
+
_globals['_PINGRESPONSE']._serialized_start=108
|
|
29
|
+
_globals['_PINGRESPONSE']._serialized_end=139
|
|
30
|
+
_globals['_GETREQUESTASTREQUEST']._serialized_start=141
|
|
31
|
+
_globals['_GETREQUESTASTREQUEST']._serialized_end=184
|
|
32
|
+
_globals['_GETREQUESTASTRESPONSE']._serialized_start=186
|
|
33
|
+
_globals['_GETREQUESTASTRESPONSE']._serialized_end=263
|
|
34
|
+
_globals['_CONTROLSERVICE']._serialized_start=266
|
|
35
|
+
_globals['_CONTROLSERVICE']._serialized_end=498
|
|
35
36
|
# @@protoc_insertion_point(module_scope)
|