snowpark-connect 0.22.1__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (46) hide show
  1. snowflake/snowpark_connect/config.py +0 -11
  2. snowflake/snowpark_connect/error/error_utils.py +7 -0
  3. snowflake/snowpark_connect/error/exceptions.py +4 -0
  4. snowflake/snowpark_connect/expression/function_defaults.py +207 -0
  5. snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
  6. snowflake/snowpark_connect/expression/literal.py +14 -12
  7. snowflake/snowpark_connect/expression/map_cast.py +20 -4
  8. snowflake/snowpark_connect/expression/map_expression.py +18 -2
  9. snowflake/snowpark_connect/expression/map_extension.py +12 -2
  10. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
  11. snowflake/snowpark_connect/expression/map_unresolved_function.py +69 -10
  12. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
  13. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
  14. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
  15. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
  16. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
  17. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
  18. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
  19. snowflake/snowpark_connect/relation/map_aggregate.py +57 -5
  20. snowflake/snowpark_connect/relation/map_column_ops.py +6 -5
  21. snowflake/snowpark_connect/relation/map_extension.py +65 -31
  22. snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
  23. snowflake/snowpark_connect/relation/map_row_ops.py +2 -0
  24. snowflake/snowpark_connect/relation/map_sql.py +22 -5
  25. snowflake/snowpark_connect/relation/read/map_read.py +2 -1
  26. snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
  27. snowflake/snowpark_connect/relation/read/reader_config.py +9 -0
  28. snowflake/snowpark_connect/relation/write/map_write.py +243 -68
  29. snowflake/snowpark_connect/server.py +25 -5
  30. snowflake/snowpark_connect/type_mapping.py +2 -2
  31. snowflake/snowpark_connect/utils/env_utils.py +55 -0
  32. snowflake/snowpark_connect/utils/session.py +21 -0
  33. snowflake/snowpark_connect/version.py +1 -1
  34. snowflake/snowpark_decoder/spark_decoder.py +1 -1
  35. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/METADATA +2 -2
  36. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/RECORD +44 -39
  37. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
  38. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
  39. {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-connect +0 -0
  40. {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-session +0 -0
  41. {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-submit +0 -0
  42. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/WHEEL +0 -0
  43. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE-binary +0 -0
  44. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE.txt +0 -0
  45. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/NOTICE-binary +0 -0
  46. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@
5
5
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
6
6
 
7
7
  import snowflake.snowpark.functions as snowpark_fn
8
+ from snowflake.snowpark._internal.analyzer.expression import Literal
8
9
  from snowflake.snowpark.types import ArrayType, MapType, StructType, _IntegralType
9
10
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
10
11
  from snowflake.snowpark_connect.config import global_config
@@ -57,7 +58,8 @@ def map_unresolved_extract_value(
57
58
  extract_fn = snowpark_fn.get_ignore_case
58
59
  # Set index to a dummy value before we use it later in the ansi mode check.
59
60
  index = snowpark_fn.lit(1)
60
- if _check_if_array_type(extract_typed_column, child_typed_column):
61
+ is_array = _check_if_array_type(extract_typed_column, child_typed_column)
62
+ if is_array:
61
63
  # Set all non-valid array indices to NULL.
62
64
  # This is done because both conditions of a CASE WHEN statement are executed regardless of if the condition is true or not.
63
65
  # Getting a negative index in Snowflake throws an error; thus, we convert all non-valid array indices to NULL before getting the index.
@@ -74,12 +76,37 @@ def map_unresolved_extract_value(
74
76
 
75
77
  spark_sql_ansi_enabled = global_config.spark_sql_ansi_enabled
76
78
 
77
- if spark_sql_ansi_enabled and _check_if_array_type(
78
- extract_typed_column, child_typed_column
79
- ):
79
+ if spark_sql_ansi_enabled and is_array:
80
80
  result_exp = snowpark_fn.when(
81
81
  index.isNull(),
82
82
  child_typed_column.col.getItem("[snowpark_connect::INVALID_ARRAY_INDEX]"),
83
83
  ).otherwise(result_exp)
84
84
 
85
- return spark_function_name, TypedColumn(result_exp, lambda: typer.type(result_exp))
85
+ def _get_extracted_value_type():
86
+ if is_array:
87
+ return [child_typed_column.typ.element_type]
88
+ elif isinstance(child_typed_column.typ, MapType):
89
+ return [child_typed_column.typ.value_type]
90
+ elif (
91
+ isinstance(child_typed_column.typ, StructType)
92
+ and isinstance(extract_typed_column.col._expr1, Literal)
93
+ and isinstance(extract_typed_column.col._expr1.value, str)
94
+ ):
95
+ struct = dict(
96
+ {
97
+ (
98
+ f.name
99
+ if global_config.spark_sql_caseSensitive
100
+ else f.name.lower(),
101
+ f.datatype,
102
+ )
103
+ for f in child_typed_column.typ.fields
104
+ }
105
+ )
106
+ key = extract_typed_column.col._expr1.value
107
+ key = key if global_config.spark_sql_caseSensitive else key.lower()
108
+
109
+ return [struct[key]] if key in struct else typer.type(result_exp)
110
+ return typer.type(result_exp)
111
+
112
+ return spark_function_name, TypedColumn(result_exp, _get_extracted_value_type)
@@ -80,6 +80,9 @@ from snowflake.snowpark_connect.constants import (
80
80
  SPARK_TZ_ABBREVIATIONS_OVERRIDES,
81
81
  STRUCTURED_TYPES_ENABLED,
82
82
  )
83
+ from snowflake.snowpark_connect.expression.function_defaults import (
84
+ inject_function_defaults,
85
+ )
83
86
  from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
84
87
  from snowflake.snowpark_connect.expression.map_cast import (
85
88
  CAST_FUNCTIONS,
@@ -299,6 +302,9 @@ def map_unresolved_function(
299
302
  function_name = exp.unresolved_function.function_name.lower()
300
303
  is_udtf_call = function_name in session._udtfs
301
304
 
305
+ # Inject default parameters for functions that need them (especially for Scala clients)
306
+ inject_function_defaults(exp.unresolved_function)
307
+
302
308
  def _resolve_args_expressions(exp: expressions_proto.Expression):
303
309
  def _resolve_fn_arg(exp):
304
310
  with resolving_fun_args():
@@ -1107,7 +1113,7 @@ def map_unresolved_function(
1107
1113
  result_exp = TypedColumn(
1108
1114
  result_exp, lambda: [ArrayType(snowpark_typed_args[0].typ)]
1109
1115
  )
1110
- case "array_size" | "cardinality":
1116
+ case "array_size":
1111
1117
  array_type = snowpark_typed_args[0].typ
1112
1118
  if not isinstance(array_type, ArrayType):
1113
1119
  raise AnalysisException(
@@ -1116,6 +1122,16 @@ def map_unresolved_function(
1116
1122
  result_exp = TypedColumn(
1117
1123
  snowpark_fn.array_size(*snowpark_args), lambda: [LongType()]
1118
1124
  )
1125
+ case "cardinality":
1126
+ arg_type = snowpark_typed_args[0].typ
1127
+ if isinstance(arg_type, (ArrayType, MapType)):
1128
+ result_exp = TypedColumn(
1129
+ snowpark_fn.size(*snowpark_args), lambda: [LongType()]
1130
+ )
1131
+ else:
1132
+ raise AnalysisException(
1133
+ f"Expected argument '{snowpark_arg_names[0]}' to have an ArrayType or MapType, but got {arg_type.simpleString()}."
1134
+ )
1119
1135
  case "array_sort":
1120
1136
  result_exp = TypedColumn(
1121
1137
  snowpark_fn.array_sort(*snowpark_args),
@@ -1295,10 +1311,35 @@ def map_unresolved_function(
1295
1311
  )
1296
1312
  result_exp = TypedColumn(result_exp, lambda: [LongType()])
1297
1313
  case "bit_get" | "getbit":
1298
- bit_get_function = snowpark_fn.function("GETBIT")
1299
- result_exp = TypedColumn(
1300
- bit_get_function(*snowpark_args), lambda: [LongType()]
1314
+ snowflake_compat = get_boolean_session_config_param(
1315
+ "enable_snowflake_extension_behavior"
1301
1316
  )
1317
+ col, pos = snowpark_args
1318
+ if snowflake_compat:
1319
+ bit_get_function = snowpark_fn.function("GETBIT")(col, pos)
1320
+ else:
1321
+ raise_error = _raise_error_helper(LongType())
1322
+ bit_get_function = snowpark_fn.when(
1323
+ (snowpark_fn.lit(0) <= pos) & (pos <= snowpark_fn.lit(63))
1324
+ | snowpark_fn.is_null(pos),
1325
+ snowpark_fn.function("GETBIT")(col, pos),
1326
+ ).otherwise(
1327
+ raise_error(
1328
+ snowpark_fn.concat(
1329
+ snowpark_fn.lit(
1330
+ "Invalid bit position: ",
1331
+ ),
1332
+ snowpark_fn.cast(
1333
+ pos,
1334
+ StringType(),
1335
+ ),
1336
+ snowpark_fn.lit(
1337
+ " exceeds the bit upper limit",
1338
+ ),
1339
+ )
1340
+ )
1341
+ )
1342
+ result_exp = TypedColumn(bit_get_function, lambda: [LongType()])
1302
1343
  case "bit_length":
1303
1344
  bit_length_function = snowpark_fn.function("bit_length")
1304
1345
  result_exp = TypedColumn(
@@ -3726,7 +3767,7 @@ def map_unresolved_function(
3726
3767
  snowpark_args[1] <= 0, snowpark_fn.lit("")
3727
3768
  ).otherwise(snowpark_fn.left(*snowpark_args))
3728
3769
  result_type = StringType()
3729
- case "length" | "char_length" | "character_length":
3770
+ case "length" | "char_length" | "character_length" | "len":
3730
3771
  if exp.unresolved_function.arguments[0].HasField("literal"):
3731
3772
  # Only update the name if it has the literal field.
3732
3773
  # If it doesn't, it means it's binary data.
@@ -4636,7 +4677,7 @@ def map_unresolved_function(
4636
4677
  snowpark_args[0],
4637
4678
  )
4638
4679
  result_type = DateType()
4639
- case "not":
4680
+ case "not" | "!":
4640
4681
  spark_function_name = f"(NOT {snowpark_arg_names[0]})"
4641
4682
  result_exp = ~snowpark_args[0]
4642
4683
  result_type = BooleanType()
@@ -5212,9 +5253,8 @@ def map_unresolved_function(
5212
5253
  # TODO: Seems like more validation of the arguments is appropriate.
5213
5254
  args = exp.unresolved_function.arguments
5214
5255
  if len(args) > 0:
5215
- if not (
5216
- isinstance(snowpark_typed_args[0].typ, IntegerType)
5217
- or isinstance(snowpark_typed_args[0].typ, NullType)
5256
+ if not isinstance(
5257
+ snowpark_typed_args[0].typ, (IntegerType, LongType, NullType)
5218
5258
  ):
5219
5259
  raise AnalysisException(
5220
5260
  f"""[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{spark_function_name}" due to data type mismatch: Parameter 1 requires the ("INT" or "BIGINT") type, however {snowpark_arg_names[0]} has the type "{snowpark_typed_args[0].typ}"""
@@ -5496,9 +5536,27 @@ def map_unresolved_function(
5496
5536
  ):
5497
5537
  result_exp = snowpark_fn.lit(None)
5498
5538
  else:
5539
+ right_expr = snowpark_fn.right(*snowpark_args)
5540
+ if isinstance(snowpark_typed_args[0].typ, TimestampType):
5541
+ # Spark format is always displayed as YYY-MM-DD HH:mm:ss.FF6
5542
+ # When microseconds are equal to 0 .FF6 part is removed
5543
+ # When microseconds are equal to 0 at the end, they are removed i.e. .123000 -> .123 when displayed
5544
+
5545
+ formated_timestamp = snowpark_fn.to_varchar(
5546
+ snowpark_args[0], "YYYY-MM-DD HH:MI:SS.FF6"
5547
+ )
5548
+ right_expr = snowpark_fn.right(
5549
+ snowpark_fn.regexp_replace(
5550
+ snowpark_fn.regexp_replace(formated_timestamp, "0+$", ""),
5551
+ "\\.$",
5552
+ "",
5553
+ ),
5554
+ snowpark_args[1],
5555
+ )
5556
+
5499
5557
  result_exp = snowpark_fn.when(
5500
5558
  snowpark_args[1] <= 0, snowpark_fn.lit("")
5501
- ).otherwise(snowpark_fn.right(*snowpark_args))
5559
+ ).otherwise(right_expr)
5502
5560
  result_type = StringType()
5503
5561
  case "rint":
5504
5562
  result_exp = snowpark_fn.cast(
@@ -6729,6 +6787,7 @@ def map_unresolved_function(
6729
6787
  if value == "" or any(
6730
6788
  c in value for c in [",", "\n", "\r", '"', "'"]
6731
6789
  ):
6790
+ value = value.replace("\\", "\\\\").replace('"', '\\"')
6732
6791
  result.append(f'"{value}"')
6733
6792
  else:
6734
6793
  result.append(value)
@@ -0,0 +1,16 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #