snowpark-connect 0.22.1__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (42) hide show
  1. snowflake/snowpark_connect/config.py +0 -11
  2. snowflake/snowpark_connect/error/error_utils.py +7 -0
  3. snowflake/snowpark_connect/error/exceptions.py +4 -0
  4. snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
  5. snowflake/snowpark_connect/expression/literal.py +9 -12
  6. snowflake/snowpark_connect/expression/map_cast.py +20 -4
  7. snowflake/snowpark_connect/expression/map_expression.py +8 -1
  8. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
  9. snowflake/snowpark_connect/expression/map_unresolved_function.py +66 -6
  10. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
  11. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
  12. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
  13. snowflake/snowpark_connect/relation/map_aggregate.py +57 -5
  14. snowflake/snowpark_connect/relation/map_column_ops.py +38 -6
  15. snowflake/snowpark_connect/relation/map_extension.py +58 -24
  16. snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
  17. snowflake/snowpark_connect/relation/map_row_ops.py +30 -1
  18. snowflake/snowpark_connect/relation/map_sql.py +22 -5
  19. snowflake/snowpark_connect/relation/read/map_read.py +2 -1
  20. snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
  21. snowflake/snowpark_connect/relation/read/reader_config.py +9 -0
  22. snowflake/snowpark_connect/relation/read/utils.py +7 -6
  23. snowflake/snowpark_connect/relation/utils.py +170 -1
  24. snowflake/snowpark_connect/relation/write/map_write.py +243 -68
  25. snowflake/snowpark_connect/server.py +25 -5
  26. snowflake/snowpark_connect/type_mapping.py +2 -2
  27. snowflake/snowpark_connect/utils/env_utils.py +55 -0
  28. snowflake/snowpark_connect/utils/session.py +21 -0
  29. snowflake/snowpark_connect/version.py +1 -1
  30. snowflake/snowpark_decoder/spark_decoder.py +1 -1
  31. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/METADATA +2 -2
  32. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/RECORD +40 -40
  33. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
  34. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
  35. {snowpark_connect-0.22.1.data → snowpark_connect-0.23.0.data}/scripts/snowpark-connect +0 -0
  36. {snowpark_connect-0.22.1.data → snowpark_connect-0.23.0.data}/scripts/snowpark-session +0 -0
  37. {snowpark_connect-0.22.1.data → snowpark_connect-0.23.0.data}/scripts/snowpark-submit +0 -0
  38. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/WHEEL +0 -0
  39. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE-binary +0 -0
  40. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE.txt +0 -0
  41. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/NOTICE-binary +0 -0
  42. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/top_level.txt +0 -0
@@ -4,10 +4,8 @@
4
4
 
5
5
  # Proto source for reference:
6
6
  # https://github.com/apache/spark/blob/branch-3.5/connector/connect/common/src/main/protobuf/spark/connect/base.proto#L420
7
- import os
8
7
  import re
9
8
  import sys
10
- import time
11
9
  from collections import defaultdict
12
10
  from copy import copy, deepcopy
13
11
  from typing import Any
@@ -533,17 +531,8 @@ def set_snowflake_parameters(
533
531
  snowpark_session.sql(
534
532
  f"ALTER SESSION SET TIMEZONE = '{value}'"
535
533
  ).collect()
536
- set_jvm_timezone(value)
537
- if hasattr(time, "tzset"):
538
- os.environ["TZ"] = value
539
- time.tzset()
540
534
  else:
541
535
  snowpark_session.sql("ALTER SESSION UNSET TIMEZONE").collect()
542
- reset_jvm_timezone_to_system_default()
543
- if hasattr(time, "tzset") and False:
544
- if "TZ" in os.environ:
545
- del os.environ["TZ"]
546
- time.tzset()
547
536
  case "spark.sql.globalTempDatabase":
548
537
  if not value:
549
538
  value = global_config.default_static_global_config.get(key)
@@ -75,6 +75,10 @@ terminate_multi_args_exception_pattern = (
75
75
  snowpark_connect_exception_pattern = re.compile(
76
76
  r"\[snowpark-connect-exception(?::(\w+))?\]\s*(.+?)'\s*is not recognized"
77
77
  )
78
+ invalid_bit_pattern = re.compile(
79
+ r"Invalid bit position: \d+ exceeds the bit (?:upper|lower) limit",
80
+ re.IGNORECASE,
81
+ )
78
82
 
79
83
 
80
84
  def contains_udtf_select(sql_string):
@@ -107,6 +111,9 @@ def _get_converted_known_sql_or_custom_exception(
107
111
  return SparkRuntimeException(
108
112
  message="Unexpected value for start in function slice: SQL array indices start at 1."
109
113
  )
114
+ invalid_bit = invalid_bit_pattern.search(msg)
115
+ if invalid_bit:
116
+ return IllegalArgumentException(message=invalid_bit.group(0))
110
117
  match = snowpark_connect_exception_pattern.search(
111
118
  ex.message if hasattr(ex, "message") else str(ex)
112
119
  )
@@ -22,3 +22,7 @@ class MissingSchema(SnowparkConnectException):
22
22
  super().__init__(
23
23
  "No default schema found in session",
24
24
  )
25
+
26
+
27
+ class MaxRetryExceeded(SnowparkConnectException):
28
+ ...
@@ -0,0 +1,192 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ """
6
+ Hybrid column mapping for HAVING clause resolution.
7
+
8
+ This module provides a special column mapping that can resolve expressions
9
+ in the context of both the input DataFrame (for base columns) and the
10
+ aggregated DataFrame (for aggregate expressions and aliases).
11
+ """
12
+
13
+ from typing import Dict, List
14
+
15
+ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
16
+
17
+ from snowflake import snowpark
18
+ from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
19
+ from snowflake.snowpark_connect.expression.typer import ExpressionTyper
20
+ from snowflake.snowpark_connect.typed_column import TypedColumn
21
+
22
+
23
+ class HybridColumnMap:
24
+ """
25
+ A column mapping that can resolve expressions in both input and aggregated contexts.
26
+
27
+ This is specifically designed for HAVING clause resolution where expressions may reference:
28
+ 1. Base columns from the input DataFrame (to build new aggregates)
29
+ 2. Existing aggregate expressions and their aliases
30
+ 3. Grouping columns
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ input_column_map: ColumnNameMap,
36
+ input_typer: ExpressionTyper,
37
+ aggregated_column_map: ColumnNameMap,
38
+ aggregated_typer: ExpressionTyper,
39
+ aggregate_expressions: List[expressions_proto.Expression],
40
+ grouping_expressions: List[expressions_proto.Expression],
41
+ aggregate_aliases: Dict[str, expressions_proto.Expression],
42
+ ) -> None:
43
+ self.input_column_map = input_column_map
44
+ self.input_typer = input_typer
45
+ self.aggregated_column_map = aggregated_column_map
46
+ self.aggregated_typer = aggregated_typer
47
+ self.aggregate_expressions = aggregate_expressions
48
+ self.grouping_expressions = grouping_expressions
49
+ self.aggregate_aliases = aggregate_aliases
50
+
51
+ def is_aggregate_function(self, exp: expressions_proto.Expression) -> bool:
52
+ """Check if an expression is an aggregate function."""
53
+ if exp.WhichOneof("expr_type") == "unresolved_function":
54
+ func_name = exp.unresolved_function.function_name.lower()
55
+ # Common aggregate functions - expand this list as needed
56
+ aggregate_functions = {
57
+ "avg",
58
+ "average",
59
+ "sum",
60
+ "count",
61
+ "min",
62
+ "max",
63
+ "stddev",
64
+ "stddev_pop",
65
+ "stddev_samp",
66
+ "variance",
67
+ "var_pop",
68
+ "var_samp",
69
+ "collect_list",
70
+ "collect_set",
71
+ "first",
72
+ "last",
73
+ "any_value",
74
+ "bool_and",
75
+ "bool_or",
76
+ "corr",
77
+ "covar_pop",
78
+ "covar_samp",
79
+ "kurtosis",
80
+ "skewness",
81
+ "percentile_cont",
82
+ "percentile_disc",
83
+ "approx_count_distinct",
84
+ }
85
+ return func_name in aggregate_functions
86
+ return False
87
+
88
+ def is_grouping_column(self, column_name: str) -> bool:
89
+ """Check if a column name refers to a grouping column."""
90
+ for group_exp in self.grouping_expressions:
91
+ if (
92
+ group_exp.WhichOneof("expr_type") == "unresolved_attribute"
93
+ and group_exp.unresolved_attribute.unparsed_identifier == column_name
94
+ ):
95
+ return True
96
+ return False
97
+
98
+ def resolve_expression(
99
+ self, exp: expressions_proto.Expression
100
+ ) -> tuple[list[str], TypedColumn]:
101
+ """
102
+ Resolve an expression in the hybrid context.
103
+
104
+ Strategy:
105
+ 1. If it's an aggregate function -> create new aggregate using input context
106
+ 2. If it's an alias to existing aggregate -> use aggregated context
107
+ 3. If it's a grouping column -> try aggregated context first, fall back to input context
108
+ (handles exclude_grouping_columns=True case)
109
+ 4. Otherwise -> try input context first, then aggregated context
110
+ """
111
+ from snowflake.snowpark_connect.expression.map_expression import map_expression
112
+
113
+ expr_type = exp.WhichOneof("expr_type")
114
+
115
+ # Handle aggregate functions - need to evaluate against input DataFrame
116
+ if self.is_aggregate_function(exp):
117
+ return map_expression(exp, self.input_column_map, self.input_typer)
118
+
119
+ # Handle column references
120
+ if expr_type == "unresolved_attribute":
121
+ column_name = exp.unresolved_attribute.unparsed_identifier
122
+
123
+ # Check if it's an alias to an existing aggregate expression
124
+ if column_name in self.aggregate_aliases:
125
+ # Use the aggregated context to get the alias
126
+ return map_expression(
127
+ exp, self.aggregated_column_map, self.aggregated_typer
128
+ )
129
+
130
+ # Check if it's a grouping column
131
+ if self.is_grouping_column(column_name):
132
+ # Try aggregated context first (for cases where grouping columns are included)
133
+ try:
134
+ return map_expression(
135
+ exp, self.aggregated_column_map, self.aggregated_typer
136
+ )
137
+ except Exception:
138
+ # Fall back to input context if grouping columns were excluded
139
+ # This handles the exclude_grouping_columns=True case
140
+ return map_expression(exp, self.input_column_map, self.input_typer)
141
+
142
+ # Try input context first (for base columns used in new aggregates)
143
+ try:
144
+ return map_expression(exp, self.input_column_map, self.input_typer)
145
+ except Exception:
146
+ # Fall back to aggregated context
147
+ return map_expression(
148
+ exp, self.aggregated_column_map, self.aggregated_typer
149
+ )
150
+
151
+ # For other expression types, try aggregated context first (likely references to computed values)
152
+ try:
153
+ return map_expression(
154
+ exp, self.aggregated_column_map, self.aggregated_typer
155
+ )
156
+ except Exception:
157
+ # Fall back to input context
158
+ return map_expression(exp, self.input_column_map, self.input_typer)
159
+
160
+
161
+ def create_hybrid_column_map_for_having(
162
+ input_df: snowpark.DataFrame,
163
+ input_column_map: ColumnNameMap,
164
+ aggregated_df: snowpark.DataFrame,
165
+ aggregated_column_map: ColumnNameMap,
166
+ aggregate_expressions: List[expressions_proto.Expression],
167
+ grouping_expressions: List[expressions_proto.Expression],
168
+ spark_columns: List[str],
169
+ raw_aggregations: List[tuple[str, TypedColumn]],
170
+ ) -> HybridColumnMap:
171
+ """
172
+ Create a HybridColumnMap instance for HAVING clause resolution.
173
+ """
174
+ # Create typers for both contexts
175
+ input_typer = ExpressionTyper(input_df)
176
+ aggregated_typer = ExpressionTyper(aggregated_df)
177
+
178
+ # Build alias mapping from spark column names to aggregate expressions
179
+ aggregate_aliases = {}
180
+ for i, (spark_name, _) in enumerate(raw_aggregations):
181
+ if i < len(aggregate_expressions):
182
+ aggregate_aliases[spark_name] = aggregate_expressions[i]
183
+
184
+ return HybridColumnMap(
185
+ input_column_map=input_column_map,
186
+ input_typer=input_typer,
187
+ aggregated_column_map=aggregated_column_map,
188
+ aggregated_typer=aggregated_typer,
189
+ aggregate_expressions=aggregate_expressions,
190
+ grouping_expressions=grouping_expressions,
191
+ aggregate_aliases=aggregate_aliases,
192
+ )
@@ -10,6 +10,7 @@ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
10
10
  from tzlocal import get_localzone
11
11
 
12
12
  from snowflake.snowpark_connect.config import global_config
13
+ from snowflake.snowpark_connect.utils.context import get_is_evaluating_sql
13
14
  from snowflake.snowpark_connect.utils.telemetry import (
14
15
  SnowparkConnectNotImplementedError,
15
16
  )
@@ -47,17 +48,7 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
47
48
  ).date()
48
49
  return date, f"DATE '{date}'"
49
50
  case "timestamp" | "timestamp_ntz" as t:
50
- # Note - Clients need to ensure local_timezone is the same as spark_sql_session_timeZone config.
51
- # No need to apply timezone for lit datetime, because we set the TIMEZONE parameter in snowpark session,
52
- # the snowflake backend would convert the lit datetime correctly. However, for returned column name, the
53
- # timezone needs to be added. Pyspark has a weird behavior that datetime.datetime always gets converted
54
- # to local timezone before printing according to spark_sql_session_timeZone setting. Haven't found
55
- # official doc about it, but this behavior is based on my testings.
56
- tz = (
57
- ZoneInfo(global_config.spark_sql_session_timeZone)
58
- if hasattr(global_config, "spark_sql_session_timeZone")
59
- else get_localzone()
60
- )
51
+ local_tz = get_localzone()
61
52
  if t == "timestamp":
62
53
  microseconds = literal.timestamp
63
54
  else:
@@ -66,11 +57,17 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
66
57
  microseconds // 1_000_000
67
58
  ) + datetime.timedelta(microseconds=microseconds % 1_000_000)
68
59
  tz_dt = datetime.datetime.fromtimestamp(
69
- microseconds // 1_000_000, tz=tz
60
+ microseconds // 1_000_000, tz=local_tz
70
61
  ) + datetime.timedelta(microseconds=microseconds % 1_000_000)
71
62
  if t == "timestamp_ntz":
72
63
  lit_dt = lit_dt.astimezone(datetime.timezone.utc)
73
64
  tz_dt = tz_dt.astimezone(datetime.timezone.utc)
65
+ elif not get_is_evaluating_sql():
66
+ config_tz = global_config.spark_sql_session_timeZone
67
+ config_tz = ZoneInfo(config_tz) if config_tz else local_tz
68
+ tz_dt = tz_dt.astimezone(config_tz)
69
+ lit_dt = lit_dt.astimezone(local_tz)
70
+
74
71
  return lit_dt, f"{t.upper()} '{tz_dt.strftime('%Y-%m-%d %H:%M:%S')}'"
75
72
  case "day_time_interval":
76
73
  # TODO(SNOW-1920942): Snowflake SQL is missing an "interval" type.
@@ -127,10 +127,11 @@ def map_cast(
127
127
  from_type = StringType()
128
128
  if isinstance(to_type, StringType):
129
129
  to_type = StringType()
130
+
131
+ # todo - verify if that's correct SNOW-2248680
130
132
  if isinstance(from_type, TimestampType):
131
133
  from_type = TimestampType()
132
- if isinstance(to_type, TimestampType):
133
- to_type = TimestampType()
134
+
134
135
  match (from_type, to_type):
135
136
  case (_, _) if (from_type == to_type):
136
137
  result_exp = col
@@ -185,6 +186,17 @@ def map_cast(
185
186
  case (DateType(), TimestampType()):
186
187
  result_exp = snowpark_fn.to_timestamp(col)
187
188
  result_exp = result_exp.cast(TimestampType(TimestampTimeZone.NTZ))
189
+ case (TimestampType() as f, TimestampType() as t) if f.tzinfo == t.tzinfo:
190
+ result_exp = col
191
+ case (
192
+ TimestampType(),
193
+ TimestampType() as t,
194
+ ) if t.tzinfo == TimestampTimeZone.NTZ:
195
+ zone = global_config.spark_sql_session_timeZone
196
+ result_exp = snowpark_fn.convert_timezone(snowpark_fn.lit(zone), col).cast(
197
+ TimestampType(TimestampTimeZone.NTZ)
198
+ )
199
+ # todo: verify if more support for LTZ and TZ is needed - SNOW-2248680
188
200
  case (TimestampType(), TimestampType()):
189
201
  result_exp = col
190
202
  case (_, TimestampType()) if isinstance(from_type, _NumericType):
@@ -259,8 +271,12 @@ def map_cast(
259
271
  case (_, _) if isinstance(from_type, _FractionalType) and isinstance(
260
272
  to_type, _IntegralType
261
273
  ):
262
- result_exp = snowpark_fn.when(col < 0, snowpark_fn.ceil(col)).otherwise(
263
- snowpark_fn.floor(col)
274
+ result_exp = (
275
+ snowpark_fn.when(
276
+ col == snowpark_fn.lit(float("nan")), snowpark_fn.lit(0)
277
+ )
278
+ .when(col < 0, snowpark_fn.ceil(col))
279
+ .otherwise(snowpark_fn.floor(col))
264
280
  )
265
281
  result_exp = result_exp.cast(to_type)
266
282
  case (StringType(), _) if (isinstance(to_type, _IntegralType)):
@@ -11,6 +11,7 @@ import snowflake.snowpark.functions as snowpark_fn
11
11
  from snowflake import snowpark
12
12
  from snowflake.snowpark import Session
13
13
  from snowflake.snowpark._internal.analyzer.expression import UnresolvedAttribute
14
+ from snowflake.snowpark.types import TimestampTimeZone, TimestampType
14
15
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
15
16
  from snowflake.snowpark_connect.expression import (
16
17
  map_extension,
@@ -190,9 +191,15 @@ def map_expression(
190
191
  return [lit_name], TypedColumn(
191
192
  snowpark_fn.lit(lit_value, return_type), lambda: [return_type]
192
193
  )
194
+ result_exp = snowpark_fn.lit(lit_value)
195
+
196
+ if lit_type_str == "timestamp_ntz" and isinstance(
197
+ lit_value, datetime.datetime
198
+ ):
199
+ result_exp = result_exp.cast(TimestampType(TimestampTimeZone.NTZ))
193
200
 
194
201
  return [lit_name], TypedColumn(
195
- snowpark_fn.lit(lit_value), lambda: [map_simple_types(lit_type_str)]
202
+ result_exp, lambda: [map_simple_types(lit_type_str)]
196
203
  )
197
204
  case "sort_order":
198
205
  child_name, child_column = map_single_column_expression(
@@ -5,6 +5,7 @@
5
5
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
6
6
 
7
7
  import snowflake.snowpark.functions as snowpark_fn
8
+ from snowflake.snowpark._internal.analyzer.expression import Literal
8
9
  from snowflake.snowpark.types import ArrayType, MapType, StructType, _IntegralType
9
10
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
10
11
  from snowflake.snowpark_connect.config import global_config
@@ -57,7 +58,8 @@ def map_unresolved_extract_value(
57
58
  extract_fn = snowpark_fn.get_ignore_case
58
59
  # Set index to a dummy value before we use it later in the ansi mode check.
59
60
  index = snowpark_fn.lit(1)
60
- if _check_if_array_type(extract_typed_column, child_typed_column):
61
+ is_array = _check_if_array_type(extract_typed_column, child_typed_column)
62
+ if is_array:
61
63
  # Set all non-valid array indices to NULL.
62
64
  # This is done because both conditions of a CASE WHEN statement are executed regardless of if the condition is true or not.
63
65
  # Getting a negative index in Snowflake throws an error; thus, we convert all non-valid array indices to NULL before getting the index.
@@ -74,12 +76,37 @@ def map_unresolved_extract_value(
74
76
 
75
77
  spark_sql_ansi_enabled = global_config.spark_sql_ansi_enabled
76
78
 
77
- if spark_sql_ansi_enabled and _check_if_array_type(
78
- extract_typed_column, child_typed_column
79
- ):
79
+ if spark_sql_ansi_enabled and is_array:
80
80
  result_exp = snowpark_fn.when(
81
81
  index.isNull(),
82
82
  child_typed_column.col.getItem("[snowpark_connect::INVALID_ARRAY_INDEX]"),
83
83
  ).otherwise(result_exp)
84
84
 
85
- return spark_function_name, TypedColumn(result_exp, lambda: typer.type(result_exp))
85
+ def _get_extracted_value_type():
86
+ if is_array:
87
+ return [child_typed_column.typ.element_type]
88
+ elif isinstance(child_typed_column.typ, MapType):
89
+ return [child_typed_column.typ.value_type]
90
+ elif (
91
+ isinstance(child_typed_column.typ, StructType)
92
+ and isinstance(extract_typed_column.col._expr1, Literal)
93
+ and isinstance(extract_typed_column.col._expr1.value, str)
94
+ ):
95
+ struct = dict(
96
+ {
97
+ (
98
+ f.name
99
+ if global_config.spark_sql_caseSensitive
100
+ else f.name.lower(),
101
+ f.datatype,
102
+ )
103
+ for f in child_typed_column.typ.fields
104
+ }
105
+ )
106
+ key = extract_typed_column.col._expr1.value
107
+ key = key if global_config.spark_sql_caseSensitive else key.lower()
108
+
109
+ return [struct[key]] if key in struct else typer.type(result_exp)
110
+ return typer.type(result_exp)
111
+
112
+ return spark_function_name, TypedColumn(result_exp, _get_extracted_value_type)
@@ -1107,7 +1107,7 @@ def map_unresolved_function(
1107
1107
  result_exp = TypedColumn(
1108
1108
  result_exp, lambda: [ArrayType(snowpark_typed_args[0].typ)]
1109
1109
  )
1110
- case "array_size" | "cardinality":
1110
+ case "array_size":
1111
1111
  array_type = snowpark_typed_args[0].typ
1112
1112
  if not isinstance(array_type, ArrayType):
1113
1113
  raise AnalysisException(
@@ -1116,6 +1116,16 @@ def map_unresolved_function(
1116
1116
  result_exp = TypedColumn(
1117
1117
  snowpark_fn.array_size(*snowpark_args), lambda: [LongType()]
1118
1118
  )
1119
+ case "cardinality":
1120
+ arg_type = snowpark_typed_args[0].typ
1121
+ if isinstance(arg_type, (ArrayType, MapType)):
1122
+ result_exp = TypedColumn(
1123
+ snowpark_fn.size(*snowpark_args), lambda: [LongType()]
1124
+ )
1125
+ else:
1126
+ raise AnalysisException(
1127
+ f"Expected argument '{snowpark_arg_names[0]}' to have an ArrayType or MapType, but got {arg_type.simpleString()}."
1128
+ )
1119
1129
  case "array_sort":
1120
1130
  result_exp = TypedColumn(
1121
1131
  snowpark_fn.array_sort(*snowpark_args),
@@ -1295,10 +1305,35 @@ def map_unresolved_function(
1295
1305
  )
1296
1306
  result_exp = TypedColumn(result_exp, lambda: [LongType()])
1297
1307
  case "bit_get" | "getbit":
1298
- bit_get_function = snowpark_fn.function("GETBIT")
1299
- result_exp = TypedColumn(
1300
- bit_get_function(*snowpark_args), lambda: [LongType()]
1308
+ snowflake_compat = get_boolean_session_config_param(
1309
+ "enable_snowflake_extension_behavior"
1301
1310
  )
1311
+ col, pos = snowpark_args
1312
+ if snowflake_compat:
1313
+ bit_get_function = snowpark_fn.function("GETBIT")(col, pos)
1314
+ else:
1315
+ raise_error = _raise_error_helper(LongType())
1316
+ bit_get_function = snowpark_fn.when(
1317
+ (snowpark_fn.lit(0) <= pos) & (pos <= snowpark_fn.lit(63))
1318
+ | snowpark_fn.is_null(pos),
1319
+ snowpark_fn.function("GETBIT")(col, pos),
1320
+ ).otherwise(
1321
+ raise_error(
1322
+ snowpark_fn.concat(
1323
+ snowpark_fn.lit(
1324
+ "Invalid bit position: ",
1325
+ ),
1326
+ snowpark_fn.cast(
1327
+ pos,
1328
+ StringType(),
1329
+ ),
1330
+ snowpark_fn.lit(
1331
+ " exceeds the bit upper limit",
1332
+ ),
1333
+ )
1334
+ )
1335
+ )
1336
+ result_exp = TypedColumn(bit_get_function, lambda: [LongType()])
1302
1337
  case "bit_length":
1303
1338
  bit_length_function = snowpark_fn.function("bit_length")
1304
1339
  result_exp = TypedColumn(
@@ -3787,7 +3822,13 @@ def map_unresolved_function(
3787
3822
  case "locate":
3788
3823
  substr = unwrap_literal(exp.unresolved_function.arguments[0])
3789
3824
  value = snowpark_args[1]
3790
- start_pos = unwrap_literal(exp.unresolved_function.arguments[2])
3825
+ if len(exp.unresolved_function.arguments) == 3:
3826
+ start_pos = unwrap_literal(exp.unresolved_function.arguments[2])
3827
+ else:
3828
+ # start_pos is an optional argument and if not provided we should default to 1.
3829
+ # This path will only be reached by spark connect scala clients.
3830
+ start_pos = 1
3831
+ spark_function_name = f"locate({', '.join(snowpark_arg_names)}, 1)"
3791
3832
 
3792
3833
  if start_pos > 0:
3793
3834
  result_exp = snowpark_fn.locate(substr, value, start_pos)
@@ -5496,9 +5537,27 @@ def map_unresolved_function(
5496
5537
  ):
5497
5538
  result_exp = snowpark_fn.lit(None)
5498
5539
  else:
5540
+ right_expr = snowpark_fn.right(*snowpark_args)
5541
+ if isinstance(snowpark_typed_args[0].typ, TimestampType):
5542
+ # Spark format is always displayed as YYY-MM-DD HH:mm:ss.FF6
5543
+ # When microseconds are equal to 0 .FF6 part is removed
5544
+ # When microseconds are equal to 0 at the end, they are removed i.e. .123000 -> .123 when displayed
5545
+
5546
+ formated_timestamp = snowpark_fn.to_varchar(
5547
+ snowpark_args[0], "YYYY-MM-DD HH:MI:SS.FF6"
5548
+ )
5549
+ right_expr = snowpark_fn.right(
5550
+ snowpark_fn.regexp_replace(
5551
+ snowpark_fn.regexp_replace(formated_timestamp, "0+$", ""),
5552
+ "\\.$",
5553
+ "",
5554
+ ),
5555
+ snowpark_args[1],
5556
+ )
5557
+
5499
5558
  result_exp = snowpark_fn.when(
5500
5559
  snowpark_args[1] <= 0, snowpark_fn.lit("")
5501
- ).otherwise(snowpark_fn.right(*snowpark_args))
5560
+ ).otherwise(right_expr)
5502
5561
  result_type = StringType()
5503
5562
  case "rint":
5504
5563
  result_exp = snowpark_fn.cast(
@@ -6729,6 +6788,7 @@ def map_unresolved_function(
6729
6788
  if value == "" or any(
6730
6789
  c in value for c in [",", "\n", "\r", '"', "'"]
6731
6790
  ):
6791
+ value = value.replace("\\", "\\\\").replace('"', '\\"')
6732
6792
  result.append(f'"{value}"')
6733
6793
  else:
6734
6794
  result.append(value)
@@ -16,7 +16,7 @@ from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_rel
16
16
  from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2
17
17
 
18
18
 
19
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1csnowflake_relation_ext.proto\x12\rsnowflake.ext\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\"\xe3\x02\n\tExtension\x12(\n\x07rdd_map\x18\x01 \x01(\x0b\x32\x15.snowflake.ext.RddMapH\x00\x12.\n\nrdd_reduce\x18\x02 \x01(\x0b\x32\x18.snowflake.ext.RddReduceH\x00\x12G\n\x17subquery_column_aliases\x18\x03 \x01(\x0b\x32$.snowflake.ext.SubqueryColumnAliasesH\x00\x12\x32\n\x0clateral_join\x18\x04 \x01(\x0b\x32\x1a.snowflake.ext.LateralJoinH\x00\x12J\n\x19udtf_with_table_arguments\x18\x05 \x01(\x0b\x32%.snowflake.ext.UDTFWithTableArgumentsH\x00\x12-\n\taggregate\x18\x06 \x01(\x0b\x32\x18.snowflake.ext.AggregateH\x00\x42\x04\n\x02op\">\n\x06RddMap\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"A\n\tRddReduce\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"P\n\x15SubqueryColumnAliases\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0f\n\x07\x61liases\x18\x02 \x03(\t\"\\\n\x0bLateralJoin\x12%\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12&\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.Relation\"\x98\x01\n\x16UDTFWithTableArguments\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12,\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x39\n\x0ftable_arguments\x18\x03 \x03(\x0b\x32 .snowflake.ext.TableArgumentInfo\"`\n\x11TableArgumentInfo\x12/\n\x0etable_argument\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x1a\n\x12table_argument_idx\x18\x02 \x01(\x05\"\x92\x05\n\tAggregate\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x36\n\ngroup_type\x18\x02 \x01(\x0e\x32\".snowflake.ext.Aggregate.GroupType\x12\x37\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x38\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.Expression\x12-\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.snowflake.ext.Aggregate.Pivot\x12<\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.snowflake.ext.Aggregate.GroupingSets\x1a\x62\n\x05Pivot\x12&\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.Expression\x12\x31\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.Literal\x1a?\n\x0cGroupingSets\x12/\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.Expression\"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05\x62\x06proto3')
19
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1csnowflake_relation_ext.proto\x12\rsnowflake.ext\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\"\xe3\x02\n\tExtension\x12(\n\x07rdd_map\x18\x01 \x01(\x0b\x32\x15.snowflake.ext.RddMapH\x00\x12.\n\nrdd_reduce\x18\x02 \x01(\x0b\x32\x18.snowflake.ext.RddReduceH\x00\x12G\n\x17subquery_column_aliases\x18\x03 \x01(\x0b\x32$.snowflake.ext.SubqueryColumnAliasesH\x00\x12\x32\n\x0clateral_join\x18\x04 \x01(\x0b\x32\x1a.snowflake.ext.LateralJoinH\x00\x12J\n\x19udtf_with_table_arguments\x18\x05 \x01(\x0b\x32%.snowflake.ext.UDTFWithTableArgumentsH\x00\x12-\n\taggregate\x18\x06 \x01(\x0b\x32\x18.snowflake.ext.AggregateH\x00\x42\x04\n\x02op\">\n\x06RddMap\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"A\n\tRddReduce\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"P\n\x15SubqueryColumnAliases\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0f\n\x07\x61liases\x18\x02 \x03(\t\"\\\n\x0bLateralJoin\x12%\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12&\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.Relation\"\x98\x01\n\x16UDTFWithTableArguments\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12,\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x39\n\x0ftable_arguments\x18\x03 \x03(\x0b\x32 .snowflake.ext.TableArgumentInfo\"`\n\x11TableArgumentInfo\x12/\n\x0etable_argument\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x1a\n\x12table_argument_idx\x18\x02 \x01(\x05\"\xc7\x05\n\tAggregate\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x36\n\ngroup_type\x18\x02 \x01(\x0e\x32\".snowflake.ext.Aggregate.GroupType\x12\x37\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x38\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.Expression\x12-\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.snowflake.ext.Aggregate.Pivot\x12<\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.snowflake.ext.Aggregate.GroupingSets\x12\x33\n\x10having_condition\x18\x07 \x01(\x0b\x32\x19.spark.connect.Expression\x1a\x62\n\x05Pivot\x12&\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.Expression\x12\x31\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.Literal\x1a?\n\x0cGroupingSets\x12/\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.Expression\"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05\x62\x06proto3')
20
20
 
21
21
  _globals = globals()
22
22
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -38,11 +38,11 @@ if _descriptor._USE_C_DESCRIPTORS == False:
38
38
  _globals['_TABLEARGUMENTINFO']._serialized_start=931
39
39
  _globals['_TABLEARGUMENTINFO']._serialized_end=1027
40
40
  _globals['_AGGREGATE']._serialized_start=1030
41
- _globals['_AGGREGATE']._serialized_end=1688
42
- _globals['_AGGREGATE_PIVOT']._serialized_start=1363
43
- _globals['_AGGREGATE_PIVOT']._serialized_end=1461
44
- _globals['_AGGREGATE_GROUPINGSETS']._serialized_start=1463
45
- _globals['_AGGREGATE_GROUPINGSETS']._serialized_end=1526
46
- _globals['_AGGREGATE_GROUPTYPE']._serialized_start=1529
47
- _globals['_AGGREGATE_GROUPTYPE']._serialized_end=1688
41
+ _globals['_AGGREGATE']._serialized_end=1741
42
+ _globals['_AGGREGATE_PIVOT']._serialized_start=1416
43
+ _globals['_AGGREGATE_PIVOT']._serialized_end=1514
44
+ _globals['_AGGREGATE_GROUPINGSETS']._serialized_start=1516
45
+ _globals['_AGGREGATE_GROUPINGSETS']._serialized_end=1579
46
+ _globals['_AGGREGATE_GROUPTYPE']._serialized_start=1582
47
+ _globals['_AGGREGATE_GROUPTYPE']._serialized_end=1741
48
48
  # @@protoc_insertion_point(module_scope)
@@ -75,7 +75,7 @@ class TableArgumentInfo(_message.Message):
75
75
  def __init__(self, table_argument: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., table_argument_idx: _Optional[int] = ...) -> None: ...
76
76
 
77
77
  class Aggregate(_message.Message):
78
- __slots__ = ("input", "group_type", "grouping_expressions", "aggregate_expressions", "pivot", "grouping_sets")
78
+ __slots__ = ("input", "group_type", "grouping_expressions", "aggregate_expressions", "pivot", "grouping_sets", "having_condition")
79
79
  class GroupType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
80
80
  __slots__ = ()
81
81
  GROUP_TYPE_UNSPECIFIED: _ClassVar[Aggregate.GroupType]
@@ -108,10 +108,12 @@ class Aggregate(_message.Message):
108
108
  AGGREGATE_EXPRESSIONS_FIELD_NUMBER: _ClassVar[int]
109
109
  PIVOT_FIELD_NUMBER: _ClassVar[int]
110
110
  GROUPING_SETS_FIELD_NUMBER: _ClassVar[int]
111
+ HAVING_CONDITION_FIELD_NUMBER: _ClassVar[int]
111
112
  input: _relations_pb2.Relation
112
113
  group_type: Aggregate.GroupType
113
114
  grouping_expressions: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
114
115
  aggregate_expressions: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
115
116
  pivot: Aggregate.Pivot
116
117
  grouping_sets: _containers.RepeatedCompositeFieldContainer[Aggregate.GroupingSets]
117
- def __init__(self, input: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., group_type: _Optional[_Union[Aggregate.GroupType, str]] = ..., grouping_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., aggregate_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., pivot: _Optional[_Union[Aggregate.Pivot, _Mapping]] = ..., grouping_sets: _Optional[_Iterable[_Union[Aggregate.GroupingSets, _Mapping]]] = ...) -> None: ...
118
+ having_condition: _expressions_pb2.Expression
119
+ def __init__(self, input: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., group_type: _Optional[_Union[Aggregate.GroupType, str]] = ..., grouping_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., aggregate_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., pivot: _Optional[_Union[Aggregate.Pivot, _Mapping]] = ..., grouping_sets: _Optional[_Iterable[_Union[Aggregate.GroupingSets, _Mapping]]] = ..., having_condition: _Optional[_Union[_expressions_pb2.Expression, _Mapping]] = ...) -> None: ...