snowpark-connect 0.23.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (25) hide show
  1. snowflake/snowpark_connect/expression/function_defaults.py +207 -0
  2. snowflake/snowpark_connect/expression/literal.py +5 -0
  3. snowflake/snowpark_connect/expression/map_expression.py +10 -1
  4. snowflake/snowpark_connect/expression/map_extension.py +12 -2
  5. snowflake/snowpark_connect/expression/map_unresolved_function.py +11 -12
  6. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
  7. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
  8. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
  9. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
  10. snowflake/snowpark_connect/relation/map_column_ops.py +1 -32
  11. snowflake/snowpark_connect/relation/map_extension.py +7 -7
  12. snowflake/snowpark_connect/relation/map_row_ops.py +2 -29
  13. snowflake/snowpark_connect/relation/read/utils.py +6 -7
  14. snowflake/snowpark_connect/relation/utils.py +1 -170
  15. snowflake/snowpark_connect/version.py +1 -1
  16. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/METADATA +1 -1
  17. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/RECORD +25 -20
  18. {snowpark_connect-0.23.0.data → snowpark_connect-0.24.0.data}/scripts/snowpark-connect +0 -0
  19. {snowpark_connect-0.23.0.data → snowpark_connect-0.24.0.data}/scripts/snowpark-session +0 -0
  20. {snowpark_connect-0.23.0.data → snowpark_connect-0.24.0.data}/scripts/snowpark-submit +0 -0
  21. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/WHEEL +0 -0
  22. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE-binary +0 -0
  23. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE.txt +0 -0
  24. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/NOTICE-binary +0 -0
  25. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,207 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import pyspark.sql.connect.proto.expressions_pb2 as expressions_pb2
8
+ import pyspark.sql.connect.proto.types_pb2 as types_pb2
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class DefaultParameter:
13
+ """Represents a single default parameter for a function."""
14
+
15
+ name: str
16
+ value: Any
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class FunctionDefaults:
21
+ """Represents default parameter configuration for a function."""
22
+
23
+ total_args: int
24
+ defaults: list[DefaultParameter]
25
+
26
+
27
+ # FUNCTION_DEFAULTS dictionary to hold operation name with default values.
28
+ # This is required as non pyspark clients such as scala or sql won't send all the parameters.
29
+ # We use this dict to inject the missing parameters before processing the unresolved function.
30
+ FUNCTION_DEFAULTS: dict[str, FunctionDefaults] = {
31
+ "aes_decrypt": FunctionDefaults(
32
+ total_args=5,
33
+ defaults=[
34
+ DefaultParameter("mode", "GCM"), # Spark SQL default: GCM
35
+ DefaultParameter("padding", "NONE"), # Spark SQL default: NONE for GCM mode
36
+ DefaultParameter("aad", ""), # Spark SQL default: empty string
37
+ ],
38
+ ),
39
+ "aes_encrypt": FunctionDefaults(
40
+ total_args=6,
41
+ defaults=[
42
+ DefaultParameter("mode", "GCM"), # Spark SQL default: GCM
43
+ DefaultParameter("padding", "NONE"), # Spark SQL default: NONE for GCM mode
44
+ DefaultParameter(
45
+ "iv", ""
46
+ ), # Spark SQL default: empty string (random generated if not provided)
47
+ DefaultParameter("aad", ""), # Spark SQL default: empty string
48
+ ],
49
+ ),
50
+ "approx_percentile": FunctionDefaults(
51
+ total_args=3,
52
+ defaults=[DefaultParameter("accuracy", 10000)],
53
+ ),
54
+ "bround": FunctionDefaults(
55
+ total_args=2,
56
+ defaults=[DefaultParameter("scale", 0)],
57
+ ),
58
+ "first": FunctionDefaults(
59
+ total_args=2,
60
+ defaults=[DefaultParameter("ignorenulls", False)],
61
+ ),
62
+ "lag": FunctionDefaults(
63
+ total_args=2,
64
+ defaults=[
65
+ DefaultParameter("offset", 1),
66
+ ],
67
+ ),
68
+ "last": FunctionDefaults(
69
+ total_args=2,
70
+ defaults=[DefaultParameter("ignorenulls", False)],
71
+ ),
72
+ "lead": FunctionDefaults(
73
+ total_args=3,
74
+ defaults=[DefaultParameter("offset", 1), DefaultParameter("default", None)],
75
+ ),
76
+ "locate": FunctionDefaults(
77
+ total_args=3,
78
+ defaults=[DefaultParameter("pos", 1)],
79
+ ),
80
+ "months_between": FunctionDefaults(
81
+ total_args=3,
82
+ defaults=[DefaultParameter("roundOff", True)],
83
+ ),
84
+ "nth_value": FunctionDefaults(
85
+ total_args=3,
86
+ defaults=[DefaultParameter("ignoreNulls", False)],
87
+ ),
88
+ "overlay": FunctionDefaults(
89
+ total_args=4,
90
+ defaults=[DefaultParameter("len", -1)],
91
+ ),
92
+ "percentile": FunctionDefaults(
93
+ total_args=3,
94
+ defaults=[DefaultParameter("frequency", 1)],
95
+ ),
96
+ "percentile_approx": FunctionDefaults(
97
+ total_args=3,
98
+ defaults=[DefaultParameter("accuracy", 10000)],
99
+ ),
100
+ "round": FunctionDefaults(
101
+ total_args=2,
102
+ defaults=[DefaultParameter("scale", 0)],
103
+ ),
104
+ "sentences": FunctionDefaults(
105
+ total_args=3,
106
+ defaults=[
107
+ DefaultParameter("language", ""),
108
+ DefaultParameter("country", ""),
109
+ ],
110
+ ),
111
+ "sort_array": FunctionDefaults(
112
+ total_args=2,
113
+ defaults=[DefaultParameter("asc", True)],
114
+ ),
115
+ "split": FunctionDefaults(
116
+ total_args=3,
117
+ defaults=[DefaultParameter("limit", -1)],
118
+ ),
119
+ "str_to_map": FunctionDefaults(
120
+ total_args=3,
121
+ defaults=[
122
+ DefaultParameter(
123
+ "pairDelim", ","
124
+ ), # Spark SQL default: comma for splitting pairs
125
+ DefaultParameter(
126
+ "keyValueDelim", ":"
127
+ ), # Spark SQL default: colon for splitting key/value
128
+ ],
129
+ ),
130
+ "try_aes_decrypt": FunctionDefaults(
131
+ total_args=5,
132
+ defaults=[
133
+ DefaultParameter("mode", "GCM"), # Spark SQL default: GCM
134
+ DefaultParameter("padding", "NONE"), # Spark SQL default: NONE for GCM mode
135
+ DefaultParameter("aad", ""), # Spark SQL default: empty string
136
+ ],
137
+ ),
138
+ }
139
+
140
+
141
+ def _create_literal_expression(value: Any) -> expressions_pb2.Expression:
142
+ """Create a literal expression for the given value."""
143
+ expr = expressions_pb2.Expression()
144
+ if isinstance(value, bool):
145
+ expr.literal.boolean = value
146
+ elif isinstance(value, int):
147
+ expr.literal.integer = value
148
+ elif isinstance(value, str):
149
+ expr.literal.string = value
150
+ elif isinstance(value, float):
151
+ expr.literal.double = value
152
+ elif value is None:
153
+ null_type = types_pb2.DataType()
154
+ null_type.null.SetInParent()
155
+ expr.literal.null.CopyFrom(null_type)
156
+ else:
157
+ raise ValueError(f"Unsupported literal type: {value}")
158
+
159
+ return expr
160
+
161
+
162
+ def inject_function_defaults(
163
+ unresolved_function: expressions_pb2.Expression.UnresolvedFunction,
164
+ ) -> bool:
165
+ """
166
+ Inject missing default parameters into an UnresolvedFunction protobuf.
167
+
168
+ Args:
169
+ unresolved_function: The protobuf UnresolvedFunction to modify
170
+
171
+ Returns:
172
+ bool: True if any defaults were injected, False otherwise
173
+ """
174
+ function_name = unresolved_function.function_name.lower()
175
+
176
+ if function_name not in FUNCTION_DEFAULTS:
177
+ return False
178
+
179
+ func_config = FUNCTION_DEFAULTS[function_name]
180
+ current_arg_count = len(unresolved_function.arguments)
181
+ total_args = func_config.total_args
182
+ defaults = func_config.defaults
183
+
184
+ if not defaults or current_arg_count >= total_args:
185
+ return False
186
+
187
+ # Calculate how many defaults to append
188
+ missing_arg_count = total_args - current_arg_count
189
+
190
+ # Check if any required params are missing.
191
+ if missing_arg_count > len(defaults):
192
+ raise ValueError(
193
+ f"Function '{function_name}' is missing required arguments. "
194
+ f"Expected {total_args} args, got {current_arg_count}, "
195
+ f"but only {len(defaults)} defaults are defined."
196
+ )
197
+
198
+ defaults_to_append = defaults[-missing_arg_count:]
199
+ injected = False
200
+
201
+ # Simply append the needed default values
202
+ for default_param in defaults_to_append:
203
+ default_expr = _create_literal_expression(default_param.value)
204
+ unresolved_function.arguments.append(default_expr)
205
+ injected = True
206
+
207
+ return injected
@@ -81,6 +81,11 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
81
81
  case "decimal":
82
82
  # literal.decimal.precision & scale are ignored, as decimal.Decimal doesn't accept them
83
83
  return decimal.Decimal(literal.decimal.value), literal.decimal.value
84
+ case "array":
85
+ array_values, element_names = zip(
86
+ *(get_literal_field_and_name(e) for e in literal.array.elements)
87
+ )
88
+ return array_values, f"ARRAY({', '.join(element_names)})"
84
89
  case "null" | None:
85
90
  return None, "NULL"
86
91
  case other:
@@ -27,7 +27,10 @@ from snowflake.snowpark_connect.expression.literal import get_literal_field_and_
27
27
  from snowflake.snowpark_connect.expression.map_cast import map_cast
28
28
  from snowflake.snowpark_connect.expression.map_sql_expression import map_sql_expr
29
29
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
30
- from snowflake.snowpark_connect.type_mapping import map_simple_types
30
+ from snowflake.snowpark_connect.type_mapping import (
31
+ map_simple_types,
32
+ proto_to_snowpark_type,
33
+ )
31
34
  from snowflake.snowpark_connect.typed_column import TypedColumn
32
35
  from snowflake.snowpark_connect.utils.context import (
33
36
  gen_sql_plan_id,
@@ -166,6 +169,12 @@ def map_expression(
166
169
  lambda: [map_simple_types(lit_type_str)],
167
170
  )
168
171
 
172
+ if lit_type_str == "array":
173
+ result_exp = snowpark_fn.lit(lit_value)
174
+ element_types = proto_to_snowpark_type(exp.literal.array.element_type)
175
+ array_type = snowpark.types.ArrayType(element_types)
176
+ return [lit_name], TypedColumn(result_exp, lambda: [array_type])
177
+
169
178
  # Decimal needs further processing to get the precision and scale properly.
170
179
  if lit_type_str == "decimal":
171
180
  # Precision and scale are optional in the proto.
@@ -10,7 +10,10 @@ from snowflake.snowpark.types import BooleanType
10
10
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
11
11
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
12
12
  from snowflake.snowpark_connect.typed_column import TypedColumn
13
- from snowflake.snowpark_connect.utils.context import push_evaluating_sql_scope
13
+ from snowflake.snowpark_connect.utils.context import (
14
+ push_evaluating_sql_scope,
15
+ push_outer_dataframe,
16
+ )
14
17
  from snowflake.snowpark_connect.utils.telemetry import (
15
18
  SnowparkConnectNotImplementedError,
16
19
  )
@@ -52,12 +55,19 @@ def map_extension(
52
55
  return [name], typed_col
53
56
 
54
57
  case "subquery_expression":
58
+ from snowflake.snowpark_connect.dataframe_container import (
59
+ DataFrameContainer,
60
+ )
55
61
  from snowflake.snowpark_connect.expression.map_expression import (
56
62
  map_expression,
57
63
  )
58
64
  from snowflake.snowpark_connect.relation.map_relation import map_relation
59
65
 
60
- with push_evaluating_sql_scope():
66
+ current_outer_df = DataFrameContainer(
67
+ dataframe=typer.df, column_map=column_mapping
68
+ )
69
+
70
+ with push_evaluating_sql_scope(), push_outer_dataframe(current_outer_df):
61
71
  df_container = map_relation(extension.subquery_expression.input)
62
72
  df = df_container.dataframe
63
73
 
@@ -80,6 +80,9 @@ from snowflake.snowpark_connect.constants import (
80
80
  SPARK_TZ_ABBREVIATIONS_OVERRIDES,
81
81
  STRUCTURED_TYPES_ENABLED,
82
82
  )
83
+ from snowflake.snowpark_connect.expression.function_defaults import (
84
+ inject_function_defaults,
85
+ )
83
86
  from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
84
87
  from snowflake.snowpark_connect.expression.map_cast import (
85
88
  CAST_FUNCTIONS,
@@ -299,6 +302,9 @@ def map_unresolved_function(
299
302
  function_name = exp.unresolved_function.function_name.lower()
300
303
  is_udtf_call = function_name in session._udtfs
301
304
 
305
+ # Inject default parameters for functions that need them (especially for Scala clients)
306
+ inject_function_defaults(exp.unresolved_function)
307
+
302
308
  def _resolve_args_expressions(exp: expressions_proto.Expression):
303
309
  def _resolve_fn_arg(exp):
304
310
  with resolving_fun_args():
@@ -3761,7 +3767,7 @@ def map_unresolved_function(
3761
3767
  snowpark_args[1] <= 0, snowpark_fn.lit("")
3762
3768
  ).otherwise(snowpark_fn.left(*snowpark_args))
3763
3769
  result_type = StringType()
3764
- case "length" | "char_length" | "character_length":
3770
+ case "length" | "char_length" | "character_length" | "len":
3765
3771
  if exp.unresolved_function.arguments[0].HasField("literal"):
3766
3772
  # Only update the name if it has the literal field.
3767
3773
  # If it doesn't, it means it's binary data.
@@ -3822,13 +3828,7 @@ def map_unresolved_function(
3822
3828
  case "locate":
3823
3829
  substr = unwrap_literal(exp.unresolved_function.arguments[0])
3824
3830
  value = snowpark_args[1]
3825
- if len(exp.unresolved_function.arguments) == 3:
3826
- start_pos = unwrap_literal(exp.unresolved_function.arguments[2])
3827
- else:
3828
- # start_pos is an optional argument and if not provided we should default to 1.
3829
- # This path will only be reached by spark connect scala clients.
3830
- start_pos = 1
3831
- spark_function_name = f"locate({', '.join(snowpark_arg_names)}, 1)"
3831
+ start_pos = unwrap_literal(exp.unresolved_function.arguments[2])
3832
3832
 
3833
3833
  if start_pos > 0:
3834
3834
  result_exp = snowpark_fn.locate(substr, value, start_pos)
@@ -4677,7 +4677,7 @@ def map_unresolved_function(
4677
4677
  snowpark_args[0],
4678
4678
  )
4679
4679
  result_type = DateType()
4680
- case "not":
4680
+ case "not" | "!":
4681
4681
  spark_function_name = f"(NOT {snowpark_arg_names[0]})"
4682
4682
  result_exp = ~snowpark_args[0]
4683
4683
  result_type = BooleanType()
@@ -5253,9 +5253,8 @@ def map_unresolved_function(
5253
5253
  # TODO: Seems like more validation of the arguments is appropriate.
5254
5254
  args = exp.unresolved_function.arguments
5255
5255
  if len(args) > 0:
5256
- if not (
5257
- isinstance(snowpark_typed_args[0].typ, IntegerType)
5258
- or isinstance(snowpark_typed_args[0].typ, NullType)
5256
+ if not isinstance(
5257
+ snowpark_typed_args[0].typ, (IntegerType, LongType, NullType)
5259
5258
  ):
5260
5259
  raise AnalysisException(
5261
5260
  f"""[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: Parameter 1 requires the ("INT" or "BIGINT") type, however {snowpark_arg_names[0]} has the type "{snowpark_typed_args[0].typ}"""
@@ -0,0 +1,16 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #