snowpark-connect 0.22.1__py3-none-any.whl → 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +0 -11
- snowflake/snowpark_connect/error/error_utils.py +7 -0
- snowflake/snowpark_connect/error/exceptions.py +4 -0
- snowflake/snowpark_connect/expression/function_defaults.py +207 -0
- snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
- snowflake/snowpark_connect/expression/literal.py +14 -12
- snowflake/snowpark_connect/expression/map_cast.py +20 -4
- snowflake/snowpark_connect/expression/map_expression.py +18 -2
- snowflake/snowpark_connect/expression/map_extension.py +12 -2
- snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
- snowflake/snowpark_connect/expression/map_unresolved_function.py +69 -10
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
- snowflake/snowpark_connect/relation/map_aggregate.py +57 -5
- snowflake/snowpark_connect/relation/map_column_ops.py +6 -5
- snowflake/snowpark_connect/relation/map_extension.py +65 -31
- snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
- snowflake/snowpark_connect/relation/map_row_ops.py +2 -0
- snowflake/snowpark_connect/relation/map_sql.py +22 -5
- snowflake/snowpark_connect/relation/read/map_read.py +2 -1
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
- snowflake/snowpark_connect/relation/read/reader_config.py +9 -0
- snowflake/snowpark_connect/relation/write/map_write.py +243 -68
- snowflake/snowpark_connect/server.py +25 -5
- snowflake/snowpark_connect/type_mapping.py +2 -2
- snowflake/snowpark_connect/utils/env_utils.py +55 -0
- snowflake/snowpark_connect/utils/session.py +21 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/spark_decoder.py +1 -1
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/METADATA +2 -2
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/RECORD +44 -39
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
- {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
3
|
+
# contributor license agreements. See the NOTICE file distributed with
|
|
4
|
+
# this work for additional information regarding copyright ownership.
|
|
5
|
+
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
6
|
+
# (the "License"); you may not use this file except in compliance with
|
|
7
|
+
# the License. You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
"""
|
|
18
|
+
Additional Spark functions used in pandas-on-Spark.
|
|
19
|
+
"""
|
|
20
|
+
from typing import Union
|
|
21
|
+
|
|
22
|
+
from pyspark import SparkContext
|
|
23
|
+
import pyspark.sql.functions as F
|
|
24
|
+
from pyspark.sql.column import Column
|
|
25
|
+
|
|
26
|
+
# For supporting Spark Connect
|
|
27
|
+
from pyspark.sql.utils import is_remote
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def product(col: Column, dropna: bool) -> Column:
|
|
31
|
+
if is_remote():
|
|
32
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
33
|
+
|
|
34
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
35
|
+
"pandas_product",
|
|
36
|
+
col, # type: ignore[arg-type]
|
|
37
|
+
lit(dropna),
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
else:
|
|
41
|
+
sc = SparkContext._active_spark_context
|
|
42
|
+
return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def stddev(col: Column, ddof: int) -> Column:
|
|
46
|
+
if is_remote():
|
|
47
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
48
|
+
|
|
49
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
50
|
+
"pandas_stddev",
|
|
51
|
+
col, # type: ignore[arg-type]
|
|
52
|
+
lit(ddof),
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
else:
|
|
56
|
+
|
|
57
|
+
sc = SparkContext._active_spark_context
|
|
58
|
+
return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def var(col: Column, ddof: int) -> Column:
|
|
62
|
+
if is_remote():
|
|
63
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
64
|
+
|
|
65
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
66
|
+
"pandas_var",
|
|
67
|
+
col, # type: ignore[arg-type]
|
|
68
|
+
lit(ddof),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
else:
|
|
72
|
+
|
|
73
|
+
sc = SparkContext._active_spark_context
|
|
74
|
+
return Column(sc._jvm.PythonSQLUtils.pandasVariance(col._jc, ddof))
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def skew(col: Column) -> Column:
|
|
78
|
+
if is_remote():
|
|
79
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns
|
|
80
|
+
|
|
81
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
82
|
+
"pandas_skew",
|
|
83
|
+
col, # type: ignore[arg-type]
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
else:
|
|
87
|
+
|
|
88
|
+
sc = SparkContext._active_spark_context
|
|
89
|
+
return Column(sc._jvm.PythonSQLUtils.pandasSkewness(col._jc))
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def kurt(col: Column) -> Column:
|
|
93
|
+
if is_remote():
|
|
94
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns
|
|
95
|
+
|
|
96
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
97
|
+
"pandas_kurt",
|
|
98
|
+
col, # type: ignore[arg-type]
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
else:
|
|
102
|
+
|
|
103
|
+
sc = SparkContext._active_spark_context
|
|
104
|
+
return Column(sc._jvm.PythonSQLUtils.pandasKurtosis(col._jc))
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def mode(col: Column, dropna: bool) -> Column:
|
|
108
|
+
if is_remote():
|
|
109
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
110
|
+
|
|
111
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
112
|
+
"pandas_mode",
|
|
113
|
+
col, # type: ignore[arg-type]
|
|
114
|
+
lit(dropna),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
else:
|
|
118
|
+
sc = SparkContext._active_spark_context
|
|
119
|
+
return Column(sc._jvm.PythonSQLUtils.pandasMode(col._jc, dropna))
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def covar(col1: Column, col2: Column, ddof: int) -> Column:
|
|
123
|
+
if is_remote():
|
|
124
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
125
|
+
|
|
126
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
127
|
+
"pandas_covar",
|
|
128
|
+
col1, # type: ignore[arg-type]
|
|
129
|
+
col2, # type: ignore[arg-type]
|
|
130
|
+
lit(ddof),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
else:
|
|
134
|
+
sc = SparkContext._active_spark_context
|
|
135
|
+
return Column(sc._jvm.PythonSQLUtils.pandasCovar(col1._jc, col2._jc, ddof))
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def repeat(col: Column, n: Union[int, Column]) -> Column:
|
|
139
|
+
"""
|
|
140
|
+
Repeats a string column n times, and returns it as a new string column.
|
|
141
|
+
"""
|
|
142
|
+
_n = F.lit(n) if isinstance(n, int) else n
|
|
143
|
+
return F.call_udf("repeat", col, _n)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def ewm(col: Column, alpha: float, ignore_na: bool) -> Column:
|
|
147
|
+
if is_remote():
|
|
148
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
149
|
+
|
|
150
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
151
|
+
"ewm",
|
|
152
|
+
col, # type: ignore[arg-type]
|
|
153
|
+
lit(alpha),
|
|
154
|
+
lit(ignore_na),
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
else:
|
|
158
|
+
sc = SparkContext._active_spark_context
|
|
159
|
+
return Column(sc._jvm.PythonSQLUtils.ewm(col._jc, alpha, ignore_na))
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def last_non_null(col: Column) -> Column:
|
|
163
|
+
if is_remote():
|
|
164
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns
|
|
165
|
+
|
|
166
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
167
|
+
"last_non_null",
|
|
168
|
+
col, # type: ignore[arg-type]
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
else:
|
|
172
|
+
sc = SparkContext._active_spark_context
|
|
173
|
+
return Column(sc._jvm.PythonSQLUtils.lastNonNull(col._jc))
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def null_index(col: Column) -> Column:
|
|
177
|
+
if is_remote():
|
|
178
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns
|
|
179
|
+
|
|
180
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
181
|
+
"null_index",
|
|
182
|
+
col, # type: ignore[arg-type]
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
else:
|
|
186
|
+
sc = SparkContext._active_spark_context
|
|
187
|
+
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def timestampdiff(unit: str, start: Column, end: Column) -> Column:
|
|
191
|
+
if is_remote():
|
|
192
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
193
|
+
|
|
194
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
195
|
+
"timestampdiff",
|
|
196
|
+
lit(unit),
|
|
197
|
+
start, # type: ignore[arg-type]
|
|
198
|
+
end, # type: ignore[arg-type]
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
else:
|
|
202
|
+
sc = SparkContext._active_spark_context
|
|
203
|
+
return Column(sc._jvm.PythonSQLUtils.timestampDiff(unit, start._jc, end._jc))
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
3
|
+
# contributor license agreements. See the NOTICE file distributed with
|
|
4
|
+
# this work for additional information regarding copyright ownership.
|
|
5
|
+
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
6
|
+
# (the "License"); you may not use this file except in compliance with
|
|
7
|
+
# the License. You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
"""
|
|
18
|
+
Helpers and utilities to deal with PySpark instances
|
|
19
|
+
"""
|
|
20
|
+
from typing import overload
|
|
21
|
+
|
|
22
|
+
from pyspark.sql.types import DecimalType, StructType, MapType, ArrayType, StructField, DataType
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@overload
|
|
26
|
+
def as_nullable_spark_type(dt: StructType) -> StructType:
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@overload
|
|
31
|
+
def as_nullable_spark_type(dt: ArrayType) -> ArrayType:
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@overload
|
|
36
|
+
def as_nullable_spark_type(dt: MapType) -> MapType:
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@overload
|
|
41
|
+
def as_nullable_spark_type(dt: DataType) -> DataType:
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def as_nullable_spark_type(dt: DataType) -> DataType:
|
|
46
|
+
"""
|
|
47
|
+
Returns a nullable schema or data types.
|
|
48
|
+
|
|
49
|
+
Examples
|
|
50
|
+
--------
|
|
51
|
+
>>> from pyspark.sql.types import *
|
|
52
|
+
>>> as_nullable_spark_type(StructType([
|
|
53
|
+
... StructField("A", IntegerType(), True),
|
|
54
|
+
... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
|
|
55
|
+
StructType([StructField('A', IntegerType(), True), StructField('B', FloatType(), True)])
|
|
56
|
+
|
|
57
|
+
>>> as_nullable_spark_type(StructType([
|
|
58
|
+
... StructField("A",
|
|
59
|
+
... StructType([
|
|
60
|
+
... StructField('a',
|
|
61
|
+
... MapType(IntegerType(),
|
|
62
|
+
... ArrayType(IntegerType(), False), False), False),
|
|
63
|
+
... StructField('b', StringType(), True)])),
|
|
64
|
+
... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
|
|
65
|
+
StructType([StructField('A',
|
|
66
|
+
StructType([StructField('a',
|
|
67
|
+
MapType(IntegerType(),
|
|
68
|
+
ArrayType(IntegerType(), True), True), True),
|
|
69
|
+
StructField('b', StringType(), True)]), True),
|
|
70
|
+
StructField('B', FloatType(), True)])
|
|
71
|
+
"""
|
|
72
|
+
if isinstance(dt, StructType):
|
|
73
|
+
new_fields = []
|
|
74
|
+
for field in dt.fields:
|
|
75
|
+
new_fields.append(
|
|
76
|
+
StructField(
|
|
77
|
+
field.name,
|
|
78
|
+
as_nullable_spark_type(field.dataType),
|
|
79
|
+
nullable=True,
|
|
80
|
+
metadata=field.metadata,
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
return StructType(new_fields)
|
|
84
|
+
elif isinstance(dt, ArrayType):
|
|
85
|
+
return ArrayType(as_nullable_spark_type(dt.elementType), containsNull=True)
|
|
86
|
+
elif isinstance(dt, MapType):
|
|
87
|
+
return MapType(
|
|
88
|
+
as_nullable_spark_type(dt.keyType),
|
|
89
|
+
as_nullable_spark_type(dt.valueType),
|
|
90
|
+
valueContainsNull=True,
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
return dt
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@overload
|
|
97
|
+
def force_decimal_precision_scale(
|
|
98
|
+
dt: StructType, *, precision: int = ..., scale: int = ...
|
|
99
|
+
) -> StructType:
|
|
100
|
+
...
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@overload
|
|
104
|
+
def force_decimal_precision_scale(
|
|
105
|
+
dt: ArrayType, *, precision: int = ..., scale: int = ...
|
|
106
|
+
) -> ArrayType:
|
|
107
|
+
...
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@overload
|
|
111
|
+
def force_decimal_precision_scale(
|
|
112
|
+
dt: MapType, *, precision: int = ..., scale: int = ...
|
|
113
|
+
) -> MapType:
|
|
114
|
+
...
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@overload
|
|
118
|
+
def force_decimal_precision_scale(
|
|
119
|
+
dt: DataType, *, precision: int = ..., scale: int = ...
|
|
120
|
+
) -> DataType:
|
|
121
|
+
...
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def force_decimal_precision_scale(
|
|
125
|
+
dt: DataType, *, precision: int = 38, scale: int = 18
|
|
126
|
+
) -> DataType:
|
|
127
|
+
"""
|
|
128
|
+
Returns a data type with a fixed decimal type.
|
|
129
|
+
|
|
130
|
+
The precision and scale of the decimal type are fixed with the given values.
|
|
131
|
+
|
|
132
|
+
Examples
|
|
133
|
+
--------
|
|
134
|
+
>>> from pyspark.sql.types import *
|
|
135
|
+
>>> force_decimal_precision_scale(StructType([
|
|
136
|
+
... StructField("A", DecimalType(10, 0), True),
|
|
137
|
+
... StructField("B", DecimalType(14, 7), False)])) # doctest: +NORMALIZE_WHITESPACE
|
|
138
|
+
StructType([StructField('A', DecimalType(38,18), True),
|
|
139
|
+
StructField('B', DecimalType(38,18), False)])
|
|
140
|
+
|
|
141
|
+
>>> force_decimal_precision_scale(StructType([
|
|
142
|
+
... StructField("A",
|
|
143
|
+
... StructType([
|
|
144
|
+
... StructField('a',
|
|
145
|
+
... MapType(DecimalType(5, 0),
|
|
146
|
+
... ArrayType(DecimalType(20, 0), False), False), False),
|
|
147
|
+
... StructField('b', StringType(), True)])),
|
|
148
|
+
... StructField("B", DecimalType(30, 15), False)]),
|
|
149
|
+
... precision=30, scale=15) # doctest: +NORMALIZE_WHITESPACE
|
|
150
|
+
StructType([StructField('A',
|
|
151
|
+
StructType([StructField('a',
|
|
152
|
+
MapType(DecimalType(30,15),
|
|
153
|
+
ArrayType(DecimalType(30,15), False), False), False),
|
|
154
|
+
StructField('b', StringType(), True)]), True),
|
|
155
|
+
StructField('B', DecimalType(30,15), False)])
|
|
156
|
+
"""
|
|
157
|
+
if isinstance(dt, StructType):
|
|
158
|
+
new_fields = []
|
|
159
|
+
for field in dt.fields:
|
|
160
|
+
new_fields.append(
|
|
161
|
+
StructField(
|
|
162
|
+
field.name,
|
|
163
|
+
force_decimal_precision_scale(field.dataType, precision=precision, scale=scale),
|
|
164
|
+
nullable=field.nullable,
|
|
165
|
+
metadata=field.metadata,
|
|
166
|
+
)
|
|
167
|
+
)
|
|
168
|
+
return StructType(new_fields)
|
|
169
|
+
elif isinstance(dt, ArrayType):
|
|
170
|
+
return ArrayType(
|
|
171
|
+
force_decimal_precision_scale(dt.elementType, precision=precision, scale=scale),
|
|
172
|
+
containsNull=dt.containsNull,
|
|
173
|
+
)
|
|
174
|
+
elif isinstance(dt, MapType):
|
|
175
|
+
return MapType(
|
|
176
|
+
force_decimal_precision_scale(dt.keyType, precision=precision, scale=scale),
|
|
177
|
+
force_decimal_precision_scale(dt.valueType, precision=precision, scale=scale),
|
|
178
|
+
valueContainsNull=dt.valueContainsNull,
|
|
179
|
+
)
|
|
180
|
+
elif isinstance(dt, DecimalType):
|
|
181
|
+
return DecimalType(precision=precision, scale=scale)
|
|
182
|
+
else:
|
|
183
|
+
return dt
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _test() -> None:
|
|
187
|
+
import doctest
|
|
188
|
+
import sys
|
|
189
|
+
import pyspark.pandas.spark.utils
|
|
190
|
+
|
|
191
|
+
globs = pyspark.pandas.spark.utils.__dict__.copy()
|
|
192
|
+
(failure_count, test_count) = doctest.testmod(
|
|
193
|
+
pyspark.pandas.spark.utils,
|
|
194
|
+
globs=globs,
|
|
195
|
+
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
|
|
196
|
+
)
|
|
197
|
+
if failure_count:
|
|
198
|
+
sys.exit(-1)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
if __name__ == "__main__":
|
|
202
|
+
_test()
|
|
@@ -16,7 +16,7 @@ from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_rel
|
|
|
16
16
|
from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1csnowflake_relation_ext.proto\x12\rsnowflake.ext\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\"\xe3\x02\n\tExtension\x12(\n\x07rdd_map\x18\x01 \x01(\x0b\x32\x15.snowflake.ext.RddMapH\x00\x12.\n\nrdd_reduce\x18\x02 \x01(\x0b\x32\x18.snowflake.ext.RddReduceH\x00\x12G\n\x17subquery_column_aliases\x18\x03 \x01(\x0b\x32$.snowflake.ext.SubqueryColumnAliasesH\x00\x12\x32\n\x0clateral_join\x18\x04 \x01(\x0b\x32\x1a.snowflake.ext.LateralJoinH\x00\x12J\n\x19udtf_with_table_arguments\x18\x05 \x01(\x0b\x32%.snowflake.ext.UDTFWithTableArgumentsH\x00\x12-\n\taggregate\x18\x06 \x01(\x0b\x32\x18.snowflake.ext.AggregateH\x00\x42\x04\n\x02op\">\n\x06RddMap\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"A\n\tRddReduce\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"P\n\x15SubqueryColumnAliases\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0f\n\x07\x61liases\x18\x02 \x03(\t\"\\\n\x0bLateralJoin\x12%\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12&\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.Relation\"\x98\x01\n\x16UDTFWithTableArguments\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12,\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x39\n\x0ftable_arguments\x18\x03 \x03(\x0b\x32 .snowflake.ext.TableArgumentInfo\"`\n\x11TableArgumentInfo\x12/\n\x0etable_argument\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x1a\n\x12table_argument_idx\x18\x02 \x01(\x05\"\
|
|
19
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1csnowflake_relation_ext.proto\x12\rsnowflake.ext\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\"\xe3\x02\n\tExtension\x12(\n\x07rdd_map\x18\x01 \x01(\x0b\x32\x15.snowflake.ext.RddMapH\x00\x12.\n\nrdd_reduce\x18\x02 \x01(\x0b\x32\x18.snowflake.ext.RddReduceH\x00\x12G\n\x17subquery_column_aliases\x18\x03 \x01(\x0b\x32$.snowflake.ext.SubqueryColumnAliasesH\x00\x12\x32\n\x0clateral_join\x18\x04 \x01(\x0b\x32\x1a.snowflake.ext.LateralJoinH\x00\x12J\n\x19udtf_with_table_arguments\x18\x05 \x01(\x0b\x32%.snowflake.ext.UDTFWithTableArgumentsH\x00\x12-\n\taggregate\x18\x06 \x01(\x0b\x32\x18.snowflake.ext.AggregateH\x00\x42\x04\n\x02op\">\n\x06RddMap\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"A\n\tRddReduce\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"P\n\x15SubqueryColumnAliases\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0f\n\x07\x61liases\x18\x02 \x03(\t\"\\\n\x0bLateralJoin\x12%\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12&\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.Relation\"\x98\x01\n\x16UDTFWithTableArguments\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12,\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x39\n\x0ftable_arguments\x18\x03 \x03(\x0b\x32 .snowflake.ext.TableArgumentInfo\"`\n\x11TableArgumentInfo\x12/\n\x0etable_argument\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x1a\n\x12table_argument_idx\x18\x02 \x01(\x05\"\xc7\x05\n\tAggregate\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x36\n\ngroup_type\x18\x02 \x01(\x0e\x32\".snowflake.ext.Aggregate.GroupType\x12\x37\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x38\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.Expression\x12-\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.snowflake.ext.Aggregate.Pivot\x12<\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.snowflake.ext.Aggregate.GroupingSets\x12\x33\n\x10having_condition\x18\x07 \x01(\x0b\x32\x19.spark.connect.Expression\x1a\x62\n\x05Pivot\x12&\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.Expression\x12\x31\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.Literal\x1a?\n\x0cGroupingSets\x12/\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.Expression\"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05\x62\x06proto3')
|
|
20
20
|
|
|
21
21
|
_globals = globals()
|
|
22
22
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -38,11 +38,11 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
38
38
|
_globals['_TABLEARGUMENTINFO']._serialized_start=931
|
|
39
39
|
_globals['_TABLEARGUMENTINFO']._serialized_end=1027
|
|
40
40
|
_globals['_AGGREGATE']._serialized_start=1030
|
|
41
|
-
_globals['_AGGREGATE']._serialized_end=
|
|
42
|
-
_globals['_AGGREGATE_PIVOT']._serialized_start=
|
|
43
|
-
_globals['_AGGREGATE_PIVOT']._serialized_end=
|
|
44
|
-
_globals['_AGGREGATE_GROUPINGSETS']._serialized_start=
|
|
45
|
-
_globals['_AGGREGATE_GROUPINGSETS']._serialized_end=
|
|
46
|
-
_globals['_AGGREGATE_GROUPTYPE']._serialized_start=
|
|
47
|
-
_globals['_AGGREGATE_GROUPTYPE']._serialized_end=
|
|
41
|
+
_globals['_AGGREGATE']._serialized_end=1741
|
|
42
|
+
_globals['_AGGREGATE_PIVOT']._serialized_start=1416
|
|
43
|
+
_globals['_AGGREGATE_PIVOT']._serialized_end=1514
|
|
44
|
+
_globals['_AGGREGATE_GROUPINGSETS']._serialized_start=1516
|
|
45
|
+
_globals['_AGGREGATE_GROUPINGSETS']._serialized_end=1579
|
|
46
|
+
_globals['_AGGREGATE_GROUPTYPE']._serialized_start=1582
|
|
47
|
+
_globals['_AGGREGATE_GROUPTYPE']._serialized_end=1741
|
|
48
48
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -75,7 +75,7 @@ class TableArgumentInfo(_message.Message):
|
|
|
75
75
|
def __init__(self, table_argument: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., table_argument_idx: _Optional[int] = ...) -> None: ...
|
|
76
76
|
|
|
77
77
|
class Aggregate(_message.Message):
|
|
78
|
-
__slots__ = ("input", "group_type", "grouping_expressions", "aggregate_expressions", "pivot", "grouping_sets")
|
|
78
|
+
__slots__ = ("input", "group_type", "grouping_expressions", "aggregate_expressions", "pivot", "grouping_sets", "having_condition")
|
|
79
79
|
class GroupType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
|
80
80
|
__slots__ = ()
|
|
81
81
|
GROUP_TYPE_UNSPECIFIED: _ClassVar[Aggregate.GroupType]
|
|
@@ -108,10 +108,12 @@ class Aggregate(_message.Message):
|
|
|
108
108
|
AGGREGATE_EXPRESSIONS_FIELD_NUMBER: _ClassVar[int]
|
|
109
109
|
PIVOT_FIELD_NUMBER: _ClassVar[int]
|
|
110
110
|
GROUPING_SETS_FIELD_NUMBER: _ClassVar[int]
|
|
111
|
+
HAVING_CONDITION_FIELD_NUMBER: _ClassVar[int]
|
|
111
112
|
input: _relations_pb2.Relation
|
|
112
113
|
group_type: Aggregate.GroupType
|
|
113
114
|
grouping_expressions: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
|
|
114
115
|
aggregate_expressions: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
|
|
115
116
|
pivot: Aggregate.Pivot
|
|
116
117
|
grouping_sets: _containers.RepeatedCompositeFieldContainer[Aggregate.GroupingSets]
|
|
117
|
-
|
|
118
|
+
having_condition: _expressions_pb2.Expression
|
|
119
|
+
def __init__(self, input: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., group_type: _Optional[_Union[Aggregate.GroupType, str]] = ..., grouping_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., aggregate_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., pivot: _Optional[_Union[Aggregate.Pivot, _Mapping]] = ..., grouping_sets: _Optional[_Iterable[_Union[Aggregate.GroupingSets, _Mapping]]] = ..., having_condition: _Optional[_Union[_expressions_pb2.Expression, _Mapping]] = ...) -> None: ...
|
|
@@ -8,7 +8,10 @@ import typing
|
|
|
8
8
|
import pandas
|
|
9
9
|
import pyspark.sql.connect.proto.common_pb2 as common_proto
|
|
10
10
|
import pyspark.sql.connect.proto.types_pb2 as types_proto
|
|
11
|
-
from
|
|
11
|
+
from pyspark.sql.connect.client.core import Retrying
|
|
12
|
+
from snowflake.core.exceptions import APIError, NotFoundError
|
|
13
|
+
from snowflake.core.schema import Schema
|
|
14
|
+
from snowflake.core.table import Table, TableColumn
|
|
12
15
|
|
|
13
16
|
from snowflake.snowpark import functions
|
|
14
17
|
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
@@ -22,6 +25,7 @@ from snowflake.snowpark_connect.config import (
|
|
|
22
25
|
global_config,
|
|
23
26
|
)
|
|
24
27
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
28
|
+
from snowflake.snowpark_connect.error.exceptions import MaxRetryExceeded
|
|
25
29
|
from snowflake.snowpark_connect.relation.catalogs.abstract_spark_catalog import (
|
|
26
30
|
AbstractSparkCatalog,
|
|
27
31
|
_get_current_snowflake_schema,
|
|
@@ -39,6 +43,37 @@ from snowflake.snowpark_connect.utils.telemetry import (
|
|
|
39
43
|
from snowflake.snowpark_connect.utils.udf_cache import cached_udf
|
|
40
44
|
|
|
41
45
|
|
|
46
|
+
def _is_retryable_api_error(e: Exception) -> bool:
|
|
47
|
+
"""
|
|
48
|
+
Determine if an APIError should be retried.
|
|
49
|
+
|
|
50
|
+
Only retry on server errors, rate limiting, and transient network issues.
|
|
51
|
+
Don't retry on client errors like authentication, authorization, or validation failures.
|
|
52
|
+
"""
|
|
53
|
+
if not isinstance(e, APIError):
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
# Check if the error has a status_code attribute
|
|
57
|
+
if hasattr(e, "status_code"):
|
|
58
|
+
# Retry on server errors (5xx), rate limiting (429), and some client errors (400)
|
|
59
|
+
# 400 can be transient in some cases (like the original error trace shows)
|
|
60
|
+
return e.status_code in [400, 429, 500, 502, 503, 504]
|
|
61
|
+
|
|
62
|
+
# For APIErrors without explicit status codes, check the message
|
|
63
|
+
error_msg = str(e).lower()
|
|
64
|
+
retryable_patterns = [
|
|
65
|
+
"timeout",
|
|
66
|
+
"connection",
|
|
67
|
+
"network",
|
|
68
|
+
"unavailable",
|
|
69
|
+
"temporary",
|
|
70
|
+
"rate limit",
|
|
71
|
+
"throttle",
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
return any(pattern in error_msg for pattern in retryable_patterns)
|
|
75
|
+
|
|
76
|
+
|
|
42
77
|
def _normalize_identifier(identifier: str | None) -> str | None:
|
|
43
78
|
if identifier is None:
|
|
44
79
|
return None
|
|
@@ -73,10 +108,25 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
73
108
|
)
|
|
74
109
|
sp_catalog = get_or_create_snowpark_session().catalog
|
|
75
110
|
|
|
76
|
-
dbs =
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
111
|
+
dbs: list[Schema] | None = None
|
|
112
|
+
for attempt in Retrying(
|
|
113
|
+
max_retries=5,
|
|
114
|
+
initial_backoff=100, # 100ms
|
|
115
|
+
max_backoff=5000, # 5 s
|
|
116
|
+
backoff_multiplier=2.0,
|
|
117
|
+
jitter=100,
|
|
118
|
+
min_jitter_threshold=200,
|
|
119
|
+
can_retry=_is_retryable_api_error,
|
|
120
|
+
):
|
|
121
|
+
with attempt:
|
|
122
|
+
dbs = sp_catalog.list_schemas(
|
|
123
|
+
database=sf_quote(sf_database),
|
|
124
|
+
pattern=_normalize_identifier(sf_schema),
|
|
125
|
+
)
|
|
126
|
+
if dbs is None:
|
|
127
|
+
raise MaxRetryExceeded(
|
|
128
|
+
f"Failed to fetch databases {f'with pattern {pattern} ' if pattern is not None else ''}after all retry attempts"
|
|
129
|
+
)
|
|
80
130
|
names: list[str] = list()
|
|
81
131
|
catalogs: list[str] = list()
|
|
82
132
|
descriptions: list[str | None] = list()
|
|
@@ -112,9 +162,24 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
112
162
|
)
|
|
113
163
|
sp_catalog = get_or_create_snowpark_session().catalog
|
|
114
164
|
|
|
115
|
-
db =
|
|
116
|
-
|
|
117
|
-
|
|
165
|
+
db: Schema | None = None
|
|
166
|
+
for attempt in Retrying(
|
|
167
|
+
max_retries=5,
|
|
168
|
+
initial_backoff=100, # 100ms
|
|
169
|
+
max_backoff=5000, # 5 s
|
|
170
|
+
backoff_multiplier=2.0,
|
|
171
|
+
jitter=100,
|
|
172
|
+
min_jitter_threshold=200,
|
|
173
|
+
can_retry=_is_retryable_api_error,
|
|
174
|
+
):
|
|
175
|
+
with attempt:
|
|
176
|
+
db = sp_catalog.get_schema(
|
|
177
|
+
schema=sf_quote(sf_schema), database=sf_quote(sf_database)
|
|
178
|
+
)
|
|
179
|
+
if db is None:
|
|
180
|
+
raise MaxRetryExceeded(
|
|
181
|
+
f"Failed to fetch database {spark_dbName} after all retry attempts"
|
|
182
|
+
)
|
|
118
183
|
|
|
119
184
|
name = unquote_if_quoted(db.name)
|
|
120
185
|
return pandas.DataFrame(
|
|
@@ -241,11 +306,27 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
241
306
|
"Calling into another catalog is not currently supported"
|
|
242
307
|
)
|
|
243
308
|
|
|
244
|
-
table =
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
309
|
+
table: Table | None = None
|
|
310
|
+
for attempt in Retrying(
|
|
311
|
+
max_retries=5,
|
|
312
|
+
initial_backoff=100, # 100ms
|
|
313
|
+
max_backoff=5000, # 5 s
|
|
314
|
+
backoff_multiplier=2.0,
|
|
315
|
+
jitter=100,
|
|
316
|
+
min_jitter_threshold=200,
|
|
317
|
+
can_retry=_is_retryable_api_error,
|
|
318
|
+
):
|
|
319
|
+
with attempt:
|
|
320
|
+
table = sp_catalog.get_table(
|
|
321
|
+
database=sf_quote(sf_database),
|
|
322
|
+
schema=sf_quote(sf_schema),
|
|
323
|
+
table_name=sf_quote(table_name),
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
if table is None:
|
|
327
|
+
raise MaxRetryExceeded(
|
|
328
|
+
f"Failed to fetch table {spark_tableName} after all retry attempts"
|
|
329
|
+
)
|
|
249
330
|
|
|
250
331
|
return pandas.DataFrame(
|
|
251
332
|
{
|
|
@@ -286,6 +367,7 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
286
367
|
) -> pandas.DataFrame:
|
|
287
368
|
"""List all columns in a table/view, optionally database name filter can be provided."""
|
|
288
369
|
sp_catalog = get_or_create_snowpark_session().catalog
|
|
370
|
+
columns: list[TableColumn] | None = None
|
|
289
371
|
if spark_dbName is None:
|
|
290
372
|
catalog, sf_database, sf_schema, sf_table = _process_multi_layer_identifier(
|
|
291
373
|
spark_tableName
|
|
@@ -294,15 +376,39 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
294
376
|
raise SnowparkConnectNotImplementedError(
|
|
295
377
|
"Calling into another catalog is not currently supported"
|
|
296
378
|
)
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
379
|
+
for attempt in Retrying(
|
|
380
|
+
max_retries=5,
|
|
381
|
+
initial_backoff=100, # 100ms
|
|
382
|
+
max_backoff=5000, # 5 s
|
|
383
|
+
backoff_multiplier=2.0,
|
|
384
|
+
jitter=100,
|
|
385
|
+
min_jitter_threshold=200,
|
|
386
|
+
can_retry=_is_retryable_api_error,
|
|
387
|
+
):
|
|
388
|
+
with attempt:
|
|
389
|
+
columns = sp_catalog.list_columns(
|
|
390
|
+
database=sf_quote(sf_database),
|
|
391
|
+
schema=sf_quote(sf_schema),
|
|
392
|
+
table_name=sf_quote(sf_table),
|
|
393
|
+
)
|
|
302
394
|
else:
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
395
|
+
for attempt in Retrying(
|
|
396
|
+
max_retries=5,
|
|
397
|
+
initial_backoff=100, # 100ms
|
|
398
|
+
max_backoff=5000, # 5 s
|
|
399
|
+
backoff_multiplier=2.0,
|
|
400
|
+
jitter=100,
|
|
401
|
+
min_jitter_threshold=200,
|
|
402
|
+
can_retry=_is_retryable_api_error,
|
|
403
|
+
):
|
|
404
|
+
with attempt:
|
|
405
|
+
columns = sp_catalog.list_columns(
|
|
406
|
+
schema=sf_quote(spark_dbName),
|
|
407
|
+
table_name=sf_quote(spark_tableName),
|
|
408
|
+
)
|
|
409
|
+
if columns is None:
|
|
410
|
+
raise MaxRetryExceeded(
|
|
411
|
+
f"Failed to fetch columns of {spark_tableName} after all retry attempts"
|
|
306
412
|
)
|
|
307
413
|
names: list[str] = list()
|
|
308
414
|
descriptions: list[str | None] = list()
|