snowpark-connect 0.22.1__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (46) hide show
  1. snowflake/snowpark_connect/config.py +0 -11
  2. snowflake/snowpark_connect/error/error_utils.py +7 -0
  3. snowflake/snowpark_connect/error/exceptions.py +4 -0
  4. snowflake/snowpark_connect/expression/function_defaults.py +207 -0
  5. snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
  6. snowflake/snowpark_connect/expression/literal.py +14 -12
  7. snowflake/snowpark_connect/expression/map_cast.py +20 -4
  8. snowflake/snowpark_connect/expression/map_expression.py +18 -2
  9. snowflake/snowpark_connect/expression/map_extension.py +12 -2
  10. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
  11. snowflake/snowpark_connect/expression/map_unresolved_function.py +69 -10
  12. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
  13. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
  14. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
  15. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
  16. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
  17. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
  18. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
  19. snowflake/snowpark_connect/relation/map_aggregate.py +57 -5
  20. snowflake/snowpark_connect/relation/map_column_ops.py +6 -5
  21. snowflake/snowpark_connect/relation/map_extension.py +65 -31
  22. snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
  23. snowflake/snowpark_connect/relation/map_row_ops.py +2 -0
  24. snowflake/snowpark_connect/relation/map_sql.py +22 -5
  25. snowflake/snowpark_connect/relation/read/map_read.py +2 -1
  26. snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
  27. snowflake/snowpark_connect/relation/read/reader_config.py +9 -0
  28. snowflake/snowpark_connect/relation/write/map_write.py +243 -68
  29. snowflake/snowpark_connect/server.py +25 -5
  30. snowflake/snowpark_connect/type_mapping.py +2 -2
  31. snowflake/snowpark_connect/utils/env_utils.py +55 -0
  32. snowflake/snowpark_connect/utils/session.py +21 -0
  33. snowflake/snowpark_connect/version.py +1 -1
  34. snowflake/snowpark_decoder/spark_decoder.py +1 -1
  35. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/METADATA +2 -2
  36. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/RECORD +44 -39
  37. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
  38. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
  39. {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-connect +0 -0
  40. {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-session +0 -0
  41. {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-submit +0 -0
  42. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/WHEEL +0 -0
  43. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE-binary +0 -0
  44. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE.txt +0 -0
  45. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/NOTICE-binary +0 -0
  46. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,203 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ """
18
+ Additional Spark functions used in pandas-on-Spark.
19
+ """
20
+ from typing import Union
21
+
22
+ from pyspark import SparkContext
23
+ import pyspark.sql.functions as F
24
+ from pyspark.sql.column import Column
25
+
26
+ # For supporting Spark Connect
27
+ from pyspark.sql.utils import is_remote
28
+
29
+
30
+ def product(col: Column, dropna: bool) -> Column:
31
+ if is_remote():
32
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
33
+
34
+ return _invoke_function_over_columns( # type: ignore[return-value]
35
+ "pandas_product",
36
+ col, # type: ignore[arg-type]
37
+ lit(dropna),
38
+ )
39
+
40
+ else:
41
+ sc = SparkContext._active_spark_context
42
+ return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna))
43
+
44
+
45
+ def stddev(col: Column, ddof: int) -> Column:
46
+ if is_remote():
47
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
48
+
49
+ return _invoke_function_over_columns( # type: ignore[return-value]
50
+ "pandas_stddev",
51
+ col, # type: ignore[arg-type]
52
+ lit(ddof),
53
+ )
54
+
55
+ else:
56
+
57
+ sc = SparkContext._active_spark_context
58
+ return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))
59
+
60
+
61
+ def var(col: Column, ddof: int) -> Column:
62
+ if is_remote():
63
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
64
+
65
+ return _invoke_function_over_columns( # type: ignore[return-value]
66
+ "pandas_var",
67
+ col, # type: ignore[arg-type]
68
+ lit(ddof),
69
+ )
70
+
71
+ else:
72
+
73
+ sc = SparkContext._active_spark_context
74
+ return Column(sc._jvm.PythonSQLUtils.pandasVariance(col._jc, ddof))
75
+
76
+
77
+ def skew(col: Column) -> Column:
78
+ if is_remote():
79
+ from pyspark.sql.connect.functions import _invoke_function_over_columns
80
+
81
+ return _invoke_function_over_columns( # type: ignore[return-value]
82
+ "pandas_skew",
83
+ col, # type: ignore[arg-type]
84
+ )
85
+
86
+ else:
87
+
88
+ sc = SparkContext._active_spark_context
89
+ return Column(sc._jvm.PythonSQLUtils.pandasSkewness(col._jc))
90
+
91
+
92
+ def kurt(col: Column) -> Column:
93
+ if is_remote():
94
+ from pyspark.sql.connect.functions import _invoke_function_over_columns
95
+
96
+ return _invoke_function_over_columns( # type: ignore[return-value]
97
+ "pandas_kurt",
98
+ col, # type: ignore[arg-type]
99
+ )
100
+
101
+ else:
102
+
103
+ sc = SparkContext._active_spark_context
104
+ return Column(sc._jvm.PythonSQLUtils.pandasKurtosis(col._jc))
105
+
106
+
107
+ def mode(col: Column, dropna: bool) -> Column:
108
+ if is_remote():
109
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
110
+
111
+ return _invoke_function_over_columns( # type: ignore[return-value]
112
+ "pandas_mode",
113
+ col, # type: ignore[arg-type]
114
+ lit(dropna),
115
+ )
116
+
117
+ else:
118
+ sc = SparkContext._active_spark_context
119
+ return Column(sc._jvm.PythonSQLUtils.pandasMode(col._jc, dropna))
120
+
121
+
122
+ def covar(col1: Column, col2: Column, ddof: int) -> Column:
123
+ if is_remote():
124
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
125
+
126
+ return _invoke_function_over_columns( # type: ignore[return-value]
127
+ "pandas_covar",
128
+ col1, # type: ignore[arg-type]
129
+ col2, # type: ignore[arg-type]
130
+ lit(ddof),
131
+ )
132
+
133
+ else:
134
+ sc = SparkContext._active_spark_context
135
+ return Column(sc._jvm.PythonSQLUtils.pandasCovar(col1._jc, col2._jc, ddof))
136
+
137
+
138
+ def repeat(col: Column, n: Union[int, Column]) -> Column:
139
+ """
140
+ Repeats a string column n times, and returns it as a new string column.
141
+ """
142
+ _n = F.lit(n) if isinstance(n, int) else n
143
+ return F.call_udf("repeat", col, _n)
144
+
145
+
146
+ def ewm(col: Column, alpha: float, ignore_na: bool) -> Column:
147
+ if is_remote():
148
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
149
+
150
+ return _invoke_function_over_columns( # type: ignore[return-value]
151
+ "ewm",
152
+ col, # type: ignore[arg-type]
153
+ lit(alpha),
154
+ lit(ignore_na),
155
+ )
156
+
157
+ else:
158
+ sc = SparkContext._active_spark_context
159
+ return Column(sc._jvm.PythonSQLUtils.ewm(col._jc, alpha, ignore_na))
160
+
161
+
162
+ def last_non_null(col: Column) -> Column:
163
+ if is_remote():
164
+ from pyspark.sql.connect.functions import _invoke_function_over_columns
165
+
166
+ return _invoke_function_over_columns( # type: ignore[return-value]
167
+ "last_non_null",
168
+ col, # type: ignore[arg-type]
169
+ )
170
+
171
+ else:
172
+ sc = SparkContext._active_spark_context
173
+ return Column(sc._jvm.PythonSQLUtils.lastNonNull(col._jc))
174
+
175
+
176
+ def null_index(col: Column) -> Column:
177
+ if is_remote():
178
+ from pyspark.sql.connect.functions import _invoke_function_over_columns
179
+
180
+ return _invoke_function_over_columns( # type: ignore[return-value]
181
+ "null_index",
182
+ col, # type: ignore[arg-type]
183
+ )
184
+
185
+ else:
186
+ sc = SparkContext._active_spark_context
187
+ return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
188
+
189
+
190
+ def timestampdiff(unit: str, start: Column, end: Column) -> Column:
191
+ if is_remote():
192
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
193
+
194
+ return _invoke_function_over_columns( # type: ignore[return-value]
195
+ "timestampdiff",
196
+ lit(unit),
197
+ start, # type: ignore[arg-type]
198
+ end, # type: ignore[arg-type]
199
+ )
200
+
201
+ else:
202
+ sc = SparkContext._active_spark_context
203
+ return Column(sc._jvm.PythonSQLUtils.timestampDiff(unit, start._jc, end._jc))
@@ -0,0 +1,202 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ """
18
+ Helpers and utilities to deal with PySpark instances
19
+ """
20
+ from typing import overload
21
+
22
+ from pyspark.sql.types import DecimalType, StructType, MapType, ArrayType, StructField, DataType
23
+
24
+
25
+ @overload
26
+ def as_nullable_spark_type(dt: StructType) -> StructType:
27
+ ...
28
+
29
+
30
+ @overload
31
+ def as_nullable_spark_type(dt: ArrayType) -> ArrayType:
32
+ ...
33
+
34
+
35
+ @overload
36
+ def as_nullable_spark_type(dt: MapType) -> MapType:
37
+ ...
38
+
39
+
40
+ @overload
41
+ def as_nullable_spark_type(dt: DataType) -> DataType:
42
+ ...
43
+
44
+
45
+ def as_nullable_spark_type(dt: DataType) -> DataType:
46
+ """
47
+ Returns a nullable schema or data types.
48
+
49
+ Examples
50
+ --------
51
+ >>> from pyspark.sql.types import *
52
+ >>> as_nullable_spark_type(StructType([
53
+ ... StructField("A", IntegerType(), True),
54
+ ... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
55
+ StructType([StructField('A', IntegerType(), True), StructField('B', FloatType(), True)])
56
+
57
+ >>> as_nullable_spark_type(StructType([
58
+ ... StructField("A",
59
+ ... StructType([
60
+ ... StructField('a',
61
+ ... MapType(IntegerType(),
62
+ ... ArrayType(IntegerType(), False), False), False),
63
+ ... StructField('b', StringType(), True)])),
64
+ ... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
65
+ StructType([StructField('A',
66
+ StructType([StructField('a',
67
+ MapType(IntegerType(),
68
+ ArrayType(IntegerType(), True), True), True),
69
+ StructField('b', StringType(), True)]), True),
70
+ StructField('B', FloatType(), True)])
71
+ """
72
+ if isinstance(dt, StructType):
73
+ new_fields = []
74
+ for field in dt.fields:
75
+ new_fields.append(
76
+ StructField(
77
+ field.name,
78
+ as_nullable_spark_type(field.dataType),
79
+ nullable=True,
80
+ metadata=field.metadata,
81
+ )
82
+ )
83
+ return StructType(new_fields)
84
+ elif isinstance(dt, ArrayType):
85
+ return ArrayType(as_nullable_spark_type(dt.elementType), containsNull=True)
86
+ elif isinstance(dt, MapType):
87
+ return MapType(
88
+ as_nullable_spark_type(dt.keyType),
89
+ as_nullable_spark_type(dt.valueType),
90
+ valueContainsNull=True,
91
+ )
92
+ else:
93
+ return dt
94
+
95
+
96
+ @overload
97
+ def force_decimal_precision_scale(
98
+ dt: StructType, *, precision: int = ..., scale: int = ...
99
+ ) -> StructType:
100
+ ...
101
+
102
+
103
+ @overload
104
+ def force_decimal_precision_scale(
105
+ dt: ArrayType, *, precision: int = ..., scale: int = ...
106
+ ) -> ArrayType:
107
+ ...
108
+
109
+
110
+ @overload
111
+ def force_decimal_precision_scale(
112
+ dt: MapType, *, precision: int = ..., scale: int = ...
113
+ ) -> MapType:
114
+ ...
115
+
116
+
117
+ @overload
118
+ def force_decimal_precision_scale(
119
+ dt: DataType, *, precision: int = ..., scale: int = ...
120
+ ) -> DataType:
121
+ ...
122
+
123
+
124
+ def force_decimal_precision_scale(
125
+ dt: DataType, *, precision: int = 38, scale: int = 18
126
+ ) -> DataType:
127
+ """
128
+ Returns a data type with a fixed decimal type.
129
+
130
+ The precision and scale of the decimal type are fixed with the given values.
131
+
132
+ Examples
133
+ --------
134
+ >>> from pyspark.sql.types import *
135
+ >>> force_decimal_precision_scale(StructType([
136
+ ... StructField("A", DecimalType(10, 0), True),
137
+ ... StructField("B", DecimalType(14, 7), False)])) # doctest: +NORMALIZE_WHITESPACE
138
+ StructType([StructField('A', DecimalType(38,18), True),
139
+ StructField('B', DecimalType(38,18), False)])
140
+
141
+ >>> force_decimal_precision_scale(StructType([
142
+ ... StructField("A",
143
+ ... StructType([
144
+ ... StructField('a',
145
+ ... MapType(DecimalType(5, 0),
146
+ ... ArrayType(DecimalType(20, 0), False), False), False),
147
+ ... StructField('b', StringType(), True)])),
148
+ ... StructField("B", DecimalType(30, 15), False)]),
149
+ ... precision=30, scale=15) # doctest: +NORMALIZE_WHITESPACE
150
+ StructType([StructField('A',
151
+ StructType([StructField('a',
152
+ MapType(DecimalType(30,15),
153
+ ArrayType(DecimalType(30,15), False), False), False),
154
+ StructField('b', StringType(), True)]), True),
155
+ StructField('B', DecimalType(30,15), False)])
156
+ """
157
+ if isinstance(dt, StructType):
158
+ new_fields = []
159
+ for field in dt.fields:
160
+ new_fields.append(
161
+ StructField(
162
+ field.name,
163
+ force_decimal_precision_scale(field.dataType, precision=precision, scale=scale),
164
+ nullable=field.nullable,
165
+ metadata=field.metadata,
166
+ )
167
+ )
168
+ return StructType(new_fields)
169
+ elif isinstance(dt, ArrayType):
170
+ return ArrayType(
171
+ force_decimal_precision_scale(dt.elementType, precision=precision, scale=scale),
172
+ containsNull=dt.containsNull,
173
+ )
174
+ elif isinstance(dt, MapType):
175
+ return MapType(
176
+ force_decimal_precision_scale(dt.keyType, precision=precision, scale=scale),
177
+ force_decimal_precision_scale(dt.valueType, precision=precision, scale=scale),
178
+ valueContainsNull=dt.valueContainsNull,
179
+ )
180
+ elif isinstance(dt, DecimalType):
181
+ return DecimalType(precision=precision, scale=scale)
182
+ else:
183
+ return dt
184
+
185
+
186
+ def _test() -> None:
187
+ import doctest
188
+ import sys
189
+ import pyspark.pandas.spark.utils
190
+
191
+ globs = pyspark.pandas.spark.utils.__dict__.copy()
192
+ (failure_count, test_count) = doctest.testmod(
193
+ pyspark.pandas.spark.utils,
194
+ globs=globs,
195
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
196
+ )
197
+ if failure_count:
198
+ sys.exit(-1)
199
+
200
+
201
+ if __name__ == "__main__":
202
+ _test()
@@ -16,7 +16,7 @@ from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_rel
16
16
  from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2
17
17
 
18
18
 
19
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1csnowflake_relation_ext.proto\x12\rsnowflake.ext\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\"\xe3\x02\n\tExtension\x12(\n\x07rdd_map\x18\x01 \x01(\x0b\x32\x15.snowflake.ext.RddMapH\x00\x12.\n\nrdd_reduce\x18\x02 \x01(\x0b\x32\x18.snowflake.ext.RddReduceH\x00\x12G\n\x17subquery_column_aliases\x18\x03 \x01(\x0b\x32$.snowflake.ext.SubqueryColumnAliasesH\x00\x12\x32\n\x0clateral_join\x18\x04 \x01(\x0b\x32\x1a.snowflake.ext.LateralJoinH\x00\x12J\n\x19udtf_with_table_arguments\x18\x05 \x01(\x0b\x32%.snowflake.ext.UDTFWithTableArgumentsH\x00\x12-\n\taggregate\x18\x06 \x01(\x0b\x32\x18.snowflake.ext.AggregateH\x00\x42\x04\n\x02op\">\n\x06RddMap\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"A\n\tRddReduce\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"P\n\x15SubqueryColumnAliases\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0f\n\x07\x61liases\x18\x02 \x03(\t\"\\\n\x0bLateralJoin\x12%\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12&\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.Relation\"\x98\x01\n\x16UDTFWithTableArguments\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12,\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x39\n\x0ftable_arguments\x18\x03 \x03(\x0b\x32 .snowflake.ext.TableArgumentInfo\"`\n\x11TableArgumentInfo\x12/\n\x0etable_argument\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x1a\n\x12table_argument_idx\x18\x02 \x01(\x05\"\x92\x05\n\tAggregate\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x36\n\ngroup_type\x18\x02 \x01(\x0e\x32\".snowflake.ext.Aggregate.GroupType\x12\x37\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x38\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.Expression\x12-\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.snowflake.ext.Aggregate.Pivot\x12<\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.snowflake.ext.Aggregate.GroupingSets\x1a\x62\n\x05Pivot\x12&\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.Expression\x12\x31\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.Literal\x1a?\n\x0cGroupingSets\x12/\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.Expression\"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05\x62\x06proto3')
19
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1csnowflake_relation_ext.proto\x12\rsnowflake.ext\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\"\xe3\x02\n\tExtension\x12(\n\x07rdd_map\x18\x01 \x01(\x0b\x32\x15.snowflake.ext.RddMapH\x00\x12.\n\nrdd_reduce\x18\x02 \x01(\x0b\x32\x18.snowflake.ext.RddReduceH\x00\x12G\n\x17subquery_column_aliases\x18\x03 \x01(\x0b\x32$.snowflake.ext.SubqueryColumnAliasesH\x00\x12\x32\n\x0clateral_join\x18\x04 \x01(\x0b\x32\x1a.snowflake.ext.LateralJoinH\x00\x12J\n\x19udtf_with_table_arguments\x18\x05 \x01(\x0b\x32%.snowflake.ext.UDTFWithTableArgumentsH\x00\x12-\n\taggregate\x18\x06 \x01(\x0b\x32\x18.snowflake.ext.AggregateH\x00\x42\x04\n\x02op\">\n\x06RddMap\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"A\n\tRddReduce\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"P\n\x15SubqueryColumnAliases\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0f\n\x07\x61liases\x18\x02 \x03(\t\"\\\n\x0bLateralJoin\x12%\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12&\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.Relation\"\x98\x01\n\x16UDTFWithTableArguments\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12,\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x39\n\x0ftable_arguments\x18\x03 \x03(\x0b\x32 .snowflake.ext.TableArgumentInfo\"`\n\x11TableArgumentInfo\x12/\n\x0etable_argument\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x1a\n\x12table_argument_idx\x18\x02 \x01(\x05\"\xc7\x05\n\tAggregate\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x36\n\ngroup_type\x18\x02 \x01(\x0e\x32\".snowflake.ext.Aggregate.GroupType\x12\x37\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x38\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.Expression\x12-\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.snowflake.ext.Aggregate.Pivot\x12<\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.snowflake.ext.Aggregate.GroupingSets\x12\x33\n\x10having_condition\x18\x07 \x01(\x0b\x32\x19.spark.connect.Expression\x1a\x62\n\x05Pivot\x12&\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.Expression\x12\x31\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.Literal\x1a?\n\x0cGroupingSets\x12/\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.Expression\"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05\x62\x06proto3')
20
20
 
21
21
  _globals = globals()
22
22
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -38,11 +38,11 @@ if _descriptor._USE_C_DESCRIPTORS == False:
38
38
  _globals['_TABLEARGUMENTINFO']._serialized_start=931
39
39
  _globals['_TABLEARGUMENTINFO']._serialized_end=1027
40
40
  _globals['_AGGREGATE']._serialized_start=1030
41
- _globals['_AGGREGATE']._serialized_end=1688
42
- _globals['_AGGREGATE_PIVOT']._serialized_start=1363
43
- _globals['_AGGREGATE_PIVOT']._serialized_end=1461
44
- _globals['_AGGREGATE_GROUPINGSETS']._serialized_start=1463
45
- _globals['_AGGREGATE_GROUPINGSETS']._serialized_end=1526
46
- _globals['_AGGREGATE_GROUPTYPE']._serialized_start=1529
47
- _globals['_AGGREGATE_GROUPTYPE']._serialized_end=1688
41
+ _globals['_AGGREGATE']._serialized_end=1741
42
+ _globals['_AGGREGATE_PIVOT']._serialized_start=1416
43
+ _globals['_AGGREGATE_PIVOT']._serialized_end=1514
44
+ _globals['_AGGREGATE_GROUPINGSETS']._serialized_start=1516
45
+ _globals['_AGGREGATE_GROUPINGSETS']._serialized_end=1579
46
+ _globals['_AGGREGATE_GROUPTYPE']._serialized_start=1582
47
+ _globals['_AGGREGATE_GROUPTYPE']._serialized_end=1741
48
48
  # @@protoc_insertion_point(module_scope)
@@ -75,7 +75,7 @@ class TableArgumentInfo(_message.Message):
75
75
  def __init__(self, table_argument: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., table_argument_idx: _Optional[int] = ...) -> None: ...
76
76
 
77
77
  class Aggregate(_message.Message):
78
- __slots__ = ("input", "group_type", "grouping_expressions", "aggregate_expressions", "pivot", "grouping_sets")
78
+ __slots__ = ("input", "group_type", "grouping_expressions", "aggregate_expressions", "pivot", "grouping_sets", "having_condition")
79
79
  class GroupType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
80
80
  __slots__ = ()
81
81
  GROUP_TYPE_UNSPECIFIED: _ClassVar[Aggregate.GroupType]
@@ -108,10 +108,12 @@ class Aggregate(_message.Message):
108
108
  AGGREGATE_EXPRESSIONS_FIELD_NUMBER: _ClassVar[int]
109
109
  PIVOT_FIELD_NUMBER: _ClassVar[int]
110
110
  GROUPING_SETS_FIELD_NUMBER: _ClassVar[int]
111
+ HAVING_CONDITION_FIELD_NUMBER: _ClassVar[int]
111
112
  input: _relations_pb2.Relation
112
113
  group_type: Aggregate.GroupType
113
114
  grouping_expressions: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
114
115
  aggregate_expressions: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
115
116
  pivot: Aggregate.Pivot
116
117
  grouping_sets: _containers.RepeatedCompositeFieldContainer[Aggregate.GroupingSets]
117
- def __init__(self, input: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., group_type: _Optional[_Union[Aggregate.GroupType, str]] = ..., grouping_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., aggregate_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., pivot: _Optional[_Union[Aggregate.Pivot, _Mapping]] = ..., grouping_sets: _Optional[_Iterable[_Union[Aggregate.GroupingSets, _Mapping]]] = ...) -> None: ...
118
+ having_condition: _expressions_pb2.Expression
119
+ def __init__(self, input: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., group_type: _Optional[_Union[Aggregate.GroupType, str]] = ..., grouping_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., aggregate_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., pivot: _Optional[_Union[Aggregate.Pivot, _Mapping]] = ..., grouping_sets: _Optional[_Iterable[_Union[Aggregate.GroupingSets, _Mapping]]] = ..., having_condition: _Optional[_Union[_expressions_pb2.Expression, _Mapping]] = ...) -> None: ...
@@ -8,7 +8,10 @@ import typing
8
8
  import pandas
9
9
  import pyspark.sql.connect.proto.common_pb2 as common_proto
10
10
  import pyspark.sql.connect.proto.types_pb2 as types_proto
11
- from snowflake.core.exceptions import NotFoundError
11
+ from pyspark.sql.connect.client.core import Retrying
12
+ from snowflake.core.exceptions import APIError, NotFoundError
13
+ from snowflake.core.schema import Schema
14
+ from snowflake.core.table import Table, TableColumn
12
15
 
13
16
  from snowflake.snowpark import functions
14
17
  from snowflake.snowpark._internal.analyzer.analyzer_utils import (
@@ -22,6 +25,7 @@ from snowflake.snowpark_connect.config import (
22
25
  global_config,
23
26
  )
24
27
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
28
+ from snowflake.snowpark_connect.error.exceptions import MaxRetryExceeded
25
29
  from snowflake.snowpark_connect.relation.catalogs.abstract_spark_catalog import (
26
30
  AbstractSparkCatalog,
27
31
  _get_current_snowflake_schema,
@@ -39,6 +43,37 @@ from snowflake.snowpark_connect.utils.telemetry import (
39
43
  from snowflake.snowpark_connect.utils.udf_cache import cached_udf
40
44
 
41
45
 
46
+ def _is_retryable_api_error(e: Exception) -> bool:
47
+ """
48
+ Determine if an APIError should be retried.
49
+
50
+ Only retry on server errors, rate limiting, and transient network issues.
51
+ Don't retry on client errors like authentication, authorization, or validation failures.
52
+ """
53
+ if not isinstance(e, APIError):
54
+ return False
55
+
56
+ # Check if the error has a status_code attribute
57
+ if hasattr(e, "status_code"):
58
+ # Retry on server errors (5xx), rate limiting (429), and some client errors (400)
59
+ # 400 can be transient in some cases (like the original error trace shows)
60
+ return e.status_code in [400, 429, 500, 502, 503, 504]
61
+
62
+ # For APIErrors without explicit status codes, check the message
63
+ error_msg = str(e).lower()
64
+ retryable_patterns = [
65
+ "timeout",
66
+ "connection",
67
+ "network",
68
+ "unavailable",
69
+ "temporary",
70
+ "rate limit",
71
+ "throttle",
72
+ ]
73
+
74
+ return any(pattern in error_msg for pattern in retryable_patterns)
75
+
76
+
42
77
  def _normalize_identifier(identifier: str | None) -> str | None:
43
78
  if identifier is None:
44
79
  return None
@@ -73,10 +108,25 @@ class SnowflakeCatalog(AbstractSparkCatalog):
73
108
  )
74
109
  sp_catalog = get_or_create_snowpark_session().catalog
75
110
 
76
- dbs = sp_catalog.list_schemas(
77
- database=sf_quote(sf_database),
78
- pattern=_normalize_identifier(sf_schema),
79
- )
111
+ dbs: list[Schema] | None = None
112
+ for attempt in Retrying(
113
+ max_retries=5,
114
+ initial_backoff=100, # 100ms
115
+ max_backoff=5000, # 5 s
116
+ backoff_multiplier=2.0,
117
+ jitter=100,
118
+ min_jitter_threshold=200,
119
+ can_retry=_is_retryable_api_error,
120
+ ):
121
+ with attempt:
122
+ dbs = sp_catalog.list_schemas(
123
+ database=sf_quote(sf_database),
124
+ pattern=_normalize_identifier(sf_schema),
125
+ )
126
+ if dbs is None:
127
+ raise MaxRetryExceeded(
128
+ f"Failed to fetch databases {f'with pattern {pattern} ' if pattern is not None else ''}after all retry attempts"
129
+ )
80
130
  names: list[str] = list()
81
131
  catalogs: list[str] = list()
82
132
  descriptions: list[str | None] = list()
@@ -112,9 +162,24 @@ class SnowflakeCatalog(AbstractSparkCatalog):
112
162
  )
113
163
  sp_catalog = get_or_create_snowpark_session().catalog
114
164
 
115
- db = sp_catalog.get_schema(
116
- schema=sf_quote(sf_schema), database=sf_quote(sf_database)
117
- )
165
+ db: Schema | None = None
166
+ for attempt in Retrying(
167
+ max_retries=5,
168
+ initial_backoff=100, # 100ms
169
+ max_backoff=5000, # 5 s
170
+ backoff_multiplier=2.0,
171
+ jitter=100,
172
+ min_jitter_threshold=200,
173
+ can_retry=_is_retryable_api_error,
174
+ ):
175
+ with attempt:
176
+ db = sp_catalog.get_schema(
177
+ schema=sf_quote(sf_schema), database=sf_quote(sf_database)
178
+ )
179
+ if db is None:
180
+ raise MaxRetryExceeded(
181
+ f"Failed to fetch database {spark_dbName} after all retry attempts"
182
+ )
118
183
 
119
184
  name = unquote_if_quoted(db.name)
120
185
  return pandas.DataFrame(
@@ -241,11 +306,27 @@ class SnowflakeCatalog(AbstractSparkCatalog):
241
306
  "Calling into another catalog is not currently supported"
242
307
  )
243
308
 
244
- table = sp_catalog.get_table(
245
- database=sf_quote(sf_database),
246
- schema=sf_quote(sf_schema),
247
- table_name=sf_quote(table_name),
248
- )
309
+ table: Table | None = None
310
+ for attempt in Retrying(
311
+ max_retries=5,
312
+ initial_backoff=100, # 100ms
313
+ max_backoff=5000, # 5 s
314
+ backoff_multiplier=2.0,
315
+ jitter=100,
316
+ min_jitter_threshold=200,
317
+ can_retry=_is_retryable_api_error,
318
+ ):
319
+ with attempt:
320
+ table = sp_catalog.get_table(
321
+ database=sf_quote(sf_database),
322
+ schema=sf_quote(sf_schema),
323
+ table_name=sf_quote(table_name),
324
+ )
325
+
326
+ if table is None:
327
+ raise MaxRetryExceeded(
328
+ f"Failed to fetch table {spark_tableName} after all retry attempts"
329
+ )
249
330
 
250
331
  return pandas.DataFrame(
251
332
  {
@@ -286,6 +367,7 @@ class SnowflakeCatalog(AbstractSparkCatalog):
286
367
  ) -> pandas.DataFrame:
287
368
  """List all columns in a table/view, optionally database name filter can be provided."""
288
369
  sp_catalog = get_or_create_snowpark_session().catalog
370
+ columns: list[TableColumn] | None = None
289
371
  if spark_dbName is None:
290
372
  catalog, sf_database, sf_schema, sf_table = _process_multi_layer_identifier(
291
373
  spark_tableName
@@ -294,15 +376,39 @@ class SnowflakeCatalog(AbstractSparkCatalog):
294
376
  raise SnowparkConnectNotImplementedError(
295
377
  "Calling into another catalog is not currently supported"
296
378
  )
297
- columns = sp_catalog.list_columns(
298
- database=sf_quote(sf_database),
299
- schema=sf_quote(sf_schema),
300
- table_name=sf_quote(sf_table),
301
- )
379
+ for attempt in Retrying(
380
+ max_retries=5,
381
+ initial_backoff=100, # 100ms
382
+ max_backoff=5000, # 5 s
383
+ backoff_multiplier=2.0,
384
+ jitter=100,
385
+ min_jitter_threshold=200,
386
+ can_retry=_is_retryable_api_error,
387
+ ):
388
+ with attempt:
389
+ columns = sp_catalog.list_columns(
390
+ database=sf_quote(sf_database),
391
+ schema=sf_quote(sf_schema),
392
+ table_name=sf_quote(sf_table),
393
+ )
302
394
  else:
303
- columns = sp_catalog.list_columns(
304
- schema=sf_quote(spark_dbName),
305
- table_name=sf_quote(spark_tableName),
395
+ for attempt in Retrying(
396
+ max_retries=5,
397
+ initial_backoff=100, # 100ms
398
+ max_backoff=5000, # 5 s
399
+ backoff_multiplier=2.0,
400
+ jitter=100,
401
+ min_jitter_threshold=200,
402
+ can_retry=_is_retryable_api_error,
403
+ ):
404
+ with attempt:
405
+ columns = sp_catalog.list_columns(
406
+ schema=sf_quote(spark_dbName),
407
+ table_name=sf_quote(spark_tableName),
408
+ )
409
+ if columns is None:
410
+ raise MaxRetryExceeded(
411
+ f"Failed to fetch columns of {spark_tableName} after all retry attempts"
306
412
  )
307
413
  names: list[str] = list()
308
414
  descriptions: list[str | None] = list()