snowpark-connect 0.23.0__py3-none-any.whl → 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/expression/function_defaults.py +207 -0
- snowflake/snowpark_connect/expression/literal.py +5 -0
- snowflake/snowpark_connect/expression/map_expression.py +10 -1
- snowflake/snowpark_connect/expression/map_extension.py +12 -2
- snowflake/snowpark_connect/expression/map_unresolved_function.py +11 -12
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
- snowflake/snowpark_connect/relation/map_column_ops.py +1 -32
- snowflake/snowpark_connect/relation/map_extension.py +7 -7
- snowflake/snowpark_connect/relation/map_row_ops.py +2 -29
- snowflake/snowpark_connect/relation/read/utils.py +6 -7
- snowflake/snowpark_connect/relation/utils.py +1 -170
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/METADATA +1 -1
- {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/RECORD +25 -20
- {snowpark_connect-0.23.0.data → snowpark_connect-0.24.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.23.0.data → snowpark_connect-0.24.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.23.0.data → snowpark_connect-0.24.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
3
|
+
# contributor license agreements. See the NOTICE file distributed with
|
|
4
|
+
# this work for additional information regarding copyright ownership.
|
|
5
|
+
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
6
|
+
# (the "License"); you may not use this file except in compliance with
|
|
7
|
+
# the License. You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
"""
|
|
18
|
+
Additional Spark functions used in pandas-on-Spark.
|
|
19
|
+
"""
|
|
20
|
+
from typing import Union
|
|
21
|
+
|
|
22
|
+
from pyspark import SparkContext
|
|
23
|
+
import pyspark.sql.functions as F
|
|
24
|
+
from pyspark.sql.column import Column
|
|
25
|
+
|
|
26
|
+
# For supporting Spark Connect
|
|
27
|
+
from pyspark.sql.utils import is_remote
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def product(col: Column, dropna: bool) -> Column:
|
|
31
|
+
if is_remote():
|
|
32
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
33
|
+
|
|
34
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
35
|
+
"pandas_product",
|
|
36
|
+
col, # type: ignore[arg-type]
|
|
37
|
+
lit(dropna),
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
else:
|
|
41
|
+
sc = SparkContext._active_spark_context
|
|
42
|
+
return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def stddev(col: Column, ddof: int) -> Column:
|
|
46
|
+
if is_remote():
|
|
47
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
48
|
+
|
|
49
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
50
|
+
"pandas_stddev",
|
|
51
|
+
col, # type: ignore[arg-type]
|
|
52
|
+
lit(ddof),
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
else:
|
|
56
|
+
|
|
57
|
+
sc = SparkContext._active_spark_context
|
|
58
|
+
return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def var(col: Column, ddof: int) -> Column:
|
|
62
|
+
if is_remote():
|
|
63
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
64
|
+
|
|
65
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
66
|
+
"pandas_var",
|
|
67
|
+
col, # type: ignore[arg-type]
|
|
68
|
+
lit(ddof),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
else:
|
|
72
|
+
|
|
73
|
+
sc = SparkContext._active_spark_context
|
|
74
|
+
return Column(sc._jvm.PythonSQLUtils.pandasVariance(col._jc, ddof))
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def skew(col: Column) -> Column:
|
|
78
|
+
if is_remote():
|
|
79
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns
|
|
80
|
+
|
|
81
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
82
|
+
"pandas_skew",
|
|
83
|
+
col, # type: ignore[arg-type]
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
else:
|
|
87
|
+
|
|
88
|
+
sc = SparkContext._active_spark_context
|
|
89
|
+
return Column(sc._jvm.PythonSQLUtils.pandasSkewness(col._jc))
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def kurt(col: Column) -> Column:
|
|
93
|
+
if is_remote():
|
|
94
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns
|
|
95
|
+
|
|
96
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
97
|
+
"pandas_kurt",
|
|
98
|
+
col, # type: ignore[arg-type]
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
else:
|
|
102
|
+
|
|
103
|
+
sc = SparkContext._active_spark_context
|
|
104
|
+
return Column(sc._jvm.PythonSQLUtils.pandasKurtosis(col._jc))
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def mode(col: Column, dropna: bool) -> Column:
|
|
108
|
+
if is_remote():
|
|
109
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
110
|
+
|
|
111
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
112
|
+
"pandas_mode",
|
|
113
|
+
col, # type: ignore[arg-type]
|
|
114
|
+
lit(dropna),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
else:
|
|
118
|
+
sc = SparkContext._active_spark_context
|
|
119
|
+
return Column(sc._jvm.PythonSQLUtils.pandasMode(col._jc, dropna))
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def covar(col1: Column, col2: Column, ddof: int) -> Column:
|
|
123
|
+
if is_remote():
|
|
124
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
125
|
+
|
|
126
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
127
|
+
"pandas_covar",
|
|
128
|
+
col1, # type: ignore[arg-type]
|
|
129
|
+
col2, # type: ignore[arg-type]
|
|
130
|
+
lit(ddof),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
else:
|
|
134
|
+
sc = SparkContext._active_spark_context
|
|
135
|
+
return Column(sc._jvm.PythonSQLUtils.pandasCovar(col1._jc, col2._jc, ddof))
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def repeat(col: Column, n: Union[int, Column]) -> Column:
|
|
139
|
+
"""
|
|
140
|
+
Repeats a string column n times, and returns it as a new string column.
|
|
141
|
+
"""
|
|
142
|
+
_n = F.lit(n) if isinstance(n, int) else n
|
|
143
|
+
return F.call_udf("repeat", col, _n)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def ewm(col: Column, alpha: float, ignore_na: bool) -> Column:
|
|
147
|
+
if is_remote():
|
|
148
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
149
|
+
|
|
150
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
151
|
+
"ewm",
|
|
152
|
+
col, # type: ignore[arg-type]
|
|
153
|
+
lit(alpha),
|
|
154
|
+
lit(ignore_na),
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
else:
|
|
158
|
+
sc = SparkContext._active_spark_context
|
|
159
|
+
return Column(sc._jvm.PythonSQLUtils.ewm(col._jc, alpha, ignore_na))
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def last_non_null(col: Column) -> Column:
|
|
163
|
+
if is_remote():
|
|
164
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns
|
|
165
|
+
|
|
166
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
167
|
+
"last_non_null",
|
|
168
|
+
col, # type: ignore[arg-type]
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
else:
|
|
172
|
+
sc = SparkContext._active_spark_context
|
|
173
|
+
return Column(sc._jvm.PythonSQLUtils.lastNonNull(col._jc))
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def null_index(col: Column) -> Column:
|
|
177
|
+
if is_remote():
|
|
178
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns
|
|
179
|
+
|
|
180
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
181
|
+
"null_index",
|
|
182
|
+
col, # type: ignore[arg-type]
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
else:
|
|
186
|
+
sc = SparkContext._active_spark_context
|
|
187
|
+
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def timestampdiff(unit: str, start: Column, end: Column) -> Column:
|
|
191
|
+
if is_remote():
|
|
192
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
193
|
+
|
|
194
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
195
|
+
"timestampdiff",
|
|
196
|
+
lit(unit),
|
|
197
|
+
start, # type: ignore[arg-type]
|
|
198
|
+
end, # type: ignore[arg-type]
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
else:
|
|
202
|
+
sc = SparkContext._active_spark_context
|
|
203
|
+
return Column(sc._jvm.PythonSQLUtils.timestampDiff(unit, start._jc, end._jc))
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
3
|
+
# contributor license agreements. See the NOTICE file distributed with
|
|
4
|
+
# this work for additional information regarding copyright ownership.
|
|
5
|
+
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
6
|
+
# (the "License"); you may not use this file except in compliance with
|
|
7
|
+
# the License. You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
"""
|
|
18
|
+
Helpers and utilities to deal with PySpark instances
|
|
19
|
+
"""
|
|
20
|
+
from typing import overload
|
|
21
|
+
|
|
22
|
+
from pyspark.sql.types import DecimalType, StructType, MapType, ArrayType, StructField, DataType
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@overload
|
|
26
|
+
def as_nullable_spark_type(dt: StructType) -> StructType:
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@overload
|
|
31
|
+
def as_nullable_spark_type(dt: ArrayType) -> ArrayType:
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@overload
|
|
36
|
+
def as_nullable_spark_type(dt: MapType) -> MapType:
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@overload
|
|
41
|
+
def as_nullable_spark_type(dt: DataType) -> DataType:
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def as_nullable_spark_type(dt: DataType) -> DataType:
|
|
46
|
+
"""
|
|
47
|
+
Returns a nullable schema or data types.
|
|
48
|
+
|
|
49
|
+
Examples
|
|
50
|
+
--------
|
|
51
|
+
>>> from pyspark.sql.types import *
|
|
52
|
+
>>> as_nullable_spark_type(StructType([
|
|
53
|
+
... StructField("A", IntegerType(), True),
|
|
54
|
+
... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
|
|
55
|
+
StructType([StructField('A', IntegerType(), True), StructField('B', FloatType(), True)])
|
|
56
|
+
|
|
57
|
+
>>> as_nullable_spark_type(StructType([
|
|
58
|
+
... StructField("A",
|
|
59
|
+
... StructType([
|
|
60
|
+
... StructField('a',
|
|
61
|
+
... MapType(IntegerType(),
|
|
62
|
+
... ArrayType(IntegerType(), False), False), False),
|
|
63
|
+
... StructField('b', StringType(), True)])),
|
|
64
|
+
... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
|
|
65
|
+
StructType([StructField('A',
|
|
66
|
+
StructType([StructField('a',
|
|
67
|
+
MapType(IntegerType(),
|
|
68
|
+
ArrayType(IntegerType(), True), True), True),
|
|
69
|
+
StructField('b', StringType(), True)]), True),
|
|
70
|
+
StructField('B', FloatType(), True)])
|
|
71
|
+
"""
|
|
72
|
+
if isinstance(dt, StructType):
|
|
73
|
+
new_fields = []
|
|
74
|
+
for field in dt.fields:
|
|
75
|
+
new_fields.append(
|
|
76
|
+
StructField(
|
|
77
|
+
field.name,
|
|
78
|
+
as_nullable_spark_type(field.dataType),
|
|
79
|
+
nullable=True,
|
|
80
|
+
metadata=field.metadata,
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
return StructType(new_fields)
|
|
84
|
+
elif isinstance(dt, ArrayType):
|
|
85
|
+
return ArrayType(as_nullable_spark_type(dt.elementType), containsNull=True)
|
|
86
|
+
elif isinstance(dt, MapType):
|
|
87
|
+
return MapType(
|
|
88
|
+
as_nullable_spark_type(dt.keyType),
|
|
89
|
+
as_nullable_spark_type(dt.valueType),
|
|
90
|
+
valueContainsNull=True,
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
return dt
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@overload
|
|
97
|
+
def force_decimal_precision_scale(
|
|
98
|
+
dt: StructType, *, precision: int = ..., scale: int = ...
|
|
99
|
+
) -> StructType:
|
|
100
|
+
...
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@overload
|
|
104
|
+
def force_decimal_precision_scale(
|
|
105
|
+
dt: ArrayType, *, precision: int = ..., scale: int = ...
|
|
106
|
+
) -> ArrayType:
|
|
107
|
+
...
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@overload
|
|
111
|
+
def force_decimal_precision_scale(
|
|
112
|
+
dt: MapType, *, precision: int = ..., scale: int = ...
|
|
113
|
+
) -> MapType:
|
|
114
|
+
...
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@overload
|
|
118
|
+
def force_decimal_precision_scale(
|
|
119
|
+
dt: DataType, *, precision: int = ..., scale: int = ...
|
|
120
|
+
) -> DataType:
|
|
121
|
+
...
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def force_decimal_precision_scale(
|
|
125
|
+
dt: DataType, *, precision: int = 38, scale: int = 18
|
|
126
|
+
) -> DataType:
|
|
127
|
+
"""
|
|
128
|
+
Returns a data type with a fixed decimal type.
|
|
129
|
+
|
|
130
|
+
The precision and scale of the decimal type are fixed with the given values.
|
|
131
|
+
|
|
132
|
+
Examples
|
|
133
|
+
--------
|
|
134
|
+
>>> from pyspark.sql.types import *
|
|
135
|
+
>>> force_decimal_precision_scale(StructType([
|
|
136
|
+
... StructField("A", DecimalType(10, 0), True),
|
|
137
|
+
... StructField("B", DecimalType(14, 7), False)])) # doctest: +NORMALIZE_WHITESPACE
|
|
138
|
+
StructType([StructField('A', DecimalType(38,18), True),
|
|
139
|
+
StructField('B', DecimalType(38,18), False)])
|
|
140
|
+
|
|
141
|
+
>>> force_decimal_precision_scale(StructType([
|
|
142
|
+
... StructField("A",
|
|
143
|
+
... StructType([
|
|
144
|
+
... StructField('a',
|
|
145
|
+
... MapType(DecimalType(5, 0),
|
|
146
|
+
... ArrayType(DecimalType(20, 0), False), False), False),
|
|
147
|
+
... StructField('b', StringType(), True)])),
|
|
148
|
+
... StructField("B", DecimalType(30, 15), False)]),
|
|
149
|
+
... precision=30, scale=15) # doctest: +NORMALIZE_WHITESPACE
|
|
150
|
+
StructType([StructField('A',
|
|
151
|
+
StructType([StructField('a',
|
|
152
|
+
MapType(DecimalType(30,15),
|
|
153
|
+
ArrayType(DecimalType(30,15), False), False), False),
|
|
154
|
+
StructField('b', StringType(), True)]), True),
|
|
155
|
+
StructField('B', DecimalType(30,15), False)])
|
|
156
|
+
"""
|
|
157
|
+
if isinstance(dt, StructType):
|
|
158
|
+
new_fields = []
|
|
159
|
+
for field in dt.fields:
|
|
160
|
+
new_fields.append(
|
|
161
|
+
StructField(
|
|
162
|
+
field.name,
|
|
163
|
+
force_decimal_precision_scale(field.dataType, precision=precision, scale=scale),
|
|
164
|
+
nullable=field.nullable,
|
|
165
|
+
metadata=field.metadata,
|
|
166
|
+
)
|
|
167
|
+
)
|
|
168
|
+
return StructType(new_fields)
|
|
169
|
+
elif isinstance(dt, ArrayType):
|
|
170
|
+
return ArrayType(
|
|
171
|
+
force_decimal_precision_scale(dt.elementType, precision=precision, scale=scale),
|
|
172
|
+
containsNull=dt.containsNull,
|
|
173
|
+
)
|
|
174
|
+
elif isinstance(dt, MapType):
|
|
175
|
+
return MapType(
|
|
176
|
+
force_decimal_precision_scale(dt.keyType, precision=precision, scale=scale),
|
|
177
|
+
force_decimal_precision_scale(dt.valueType, precision=precision, scale=scale),
|
|
178
|
+
valueContainsNull=dt.valueContainsNull,
|
|
179
|
+
)
|
|
180
|
+
elif isinstance(dt, DecimalType):
|
|
181
|
+
return DecimalType(precision=precision, scale=scale)
|
|
182
|
+
else:
|
|
183
|
+
return dt
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _test() -> None:
|
|
187
|
+
import doctest
|
|
188
|
+
import sys
|
|
189
|
+
import pyspark.pandas.spark.utils
|
|
190
|
+
|
|
191
|
+
globs = pyspark.pandas.spark.utils.__dict__.copy()
|
|
192
|
+
(failure_count, test_count) = doctest.testmod(
|
|
193
|
+
pyspark.pandas.spark.utils,
|
|
194
|
+
globs=globs,
|
|
195
|
+
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
|
|
196
|
+
)
|
|
197
|
+
if failure_count:
|
|
198
|
+
sys.exit(-1)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
if __name__ == "__main__":
|
|
202
|
+
_test()
|
|
@@ -6,12 +6,10 @@ import ast
|
|
|
6
6
|
import json
|
|
7
7
|
import sys
|
|
8
8
|
from collections import defaultdict
|
|
9
|
-
from copy import copy
|
|
10
9
|
|
|
11
10
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
12
11
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
13
12
|
import pyspark.sql.connect.proto.types_pb2 as types_proto
|
|
14
|
-
from pyspark.errors import PySparkValueError
|
|
15
13
|
from pyspark.errors.exceptions.base import AnalysisException
|
|
16
14
|
from pyspark.serializers import CloudPickleSerializer
|
|
17
15
|
|
|
@@ -46,7 +44,6 @@ from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
|
46
44
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
47
45
|
from snowflake.snowpark_connect.relation.utils import (
|
|
48
46
|
TYPE_MAP_FOR_TO_SCHEMA,
|
|
49
|
-
can_sort_be_flattened,
|
|
50
47
|
snowpark_functions_col,
|
|
51
48
|
)
|
|
52
49
|
from snowflake.snowpark_connect.type_mapping import (
|
|
@@ -346,12 +343,6 @@ def map_sort(
|
|
|
346
343
|
|
|
347
344
|
sort_order = sort.order
|
|
348
345
|
|
|
349
|
-
if not sort_order:
|
|
350
|
-
raise PySparkValueError(
|
|
351
|
-
error_class="CANNOT_BE_EMPTY",
|
|
352
|
-
message="At least one column must be specified.",
|
|
353
|
-
)
|
|
354
|
-
|
|
355
346
|
if len(sort_order) == 1:
|
|
356
347
|
parsed_col_name = split_fully_qualified_spark_name(
|
|
357
348
|
sort_order[0].child.unresolved_attribute.unparsed_identifier
|
|
@@ -433,29 +424,7 @@ def map_sort(
|
|
|
433
424
|
if not order_specified:
|
|
434
425
|
ascending = None
|
|
435
426
|
|
|
436
|
-
|
|
437
|
-
sort_expressions = [c._expression for c in cols]
|
|
438
|
-
if (
|
|
439
|
-
can_sort_be_flattened(select_statement, *sort_expressions)
|
|
440
|
-
and input_df._ops_after_agg is None
|
|
441
|
-
):
|
|
442
|
-
# "flattened" order by that will allow using dropped columns
|
|
443
|
-
new = copy(select_statement)
|
|
444
|
-
new.from_ = select_statement.from_.to_subqueryable()
|
|
445
|
-
new.pre_actions = new.from_.pre_actions
|
|
446
|
-
new.post_actions = new.from_.post_actions
|
|
447
|
-
new.order_by = sort_expressions + (select_statement.order_by or [])
|
|
448
|
-
new.column_states = select_statement.column_states
|
|
449
|
-
new._merge_projection_complexity_with_subquery = False
|
|
450
|
-
new.df_ast_ids = (
|
|
451
|
-
select_statement.df_ast_ids.copy()
|
|
452
|
-
if select_statement.df_ast_ids is not None
|
|
453
|
-
else None
|
|
454
|
-
)
|
|
455
|
-
new.attributes = select_statement.attributes
|
|
456
|
-
result = input_df._with_plan(new)
|
|
457
|
-
else:
|
|
458
|
-
result = input_df.sort(cols, ascending=ascending)
|
|
427
|
+
result = input_df.sort(cols, ascending=ascending)
|
|
459
428
|
|
|
460
429
|
return DataFrameContainer(
|
|
461
430
|
result,
|
|
@@ -347,6 +347,13 @@ def map_aggregate(
|
|
|
347
347
|
raw_groupings: list[tuple[str, TypedColumn]] = []
|
|
348
348
|
raw_aggregations: list[tuple[str, TypedColumn]] = []
|
|
349
349
|
|
|
350
|
+
if not is_group_by_all:
|
|
351
|
+
raw_groupings = [_map_column(exp) for exp in aggregate.grouping_expressions]
|
|
352
|
+
|
|
353
|
+
# Set the current grouping columns in context for grouping_id() function
|
|
354
|
+
grouping_spark_columns = [spark_name for spark_name, _ in raw_groupings]
|
|
355
|
+
set_current_grouping_columns(grouping_spark_columns)
|
|
356
|
+
|
|
350
357
|
agg_count = get_sql_aggregate_function_count()
|
|
351
358
|
for exp in aggregate.aggregate_expressions:
|
|
352
359
|
col = _map_column(exp)
|
|
@@ -359,13 +366,6 @@ def map_aggregate(
|
|
|
359
366
|
else:
|
|
360
367
|
agg_count = new_agg_count
|
|
361
368
|
|
|
362
|
-
if not is_group_by_all:
|
|
363
|
-
raw_groupings = [_map_column(exp) for exp in aggregate.grouping_expressions]
|
|
364
|
-
|
|
365
|
-
# Set the current grouping columns in context for grouping_id() function
|
|
366
|
-
grouping_spark_columns = [spark_name for spark_name, _ in raw_groupings]
|
|
367
|
-
set_current_grouping_columns(grouping_spark_columns)
|
|
368
|
-
|
|
369
369
|
# Now create column name lists and assign aliases.
|
|
370
370
|
# In case of GROUP BY ALL, even though groupings are a subset of aggregations,
|
|
371
371
|
# they will have their own aliases so we can drop them later.
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
#
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
|
|
4
|
+
|
|
5
5
|
|
|
6
6
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
7
7
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
@@ -9,7 +9,6 @@ from pyspark.errors.exceptions.base import AnalysisException, IllegalArgumentExc
|
|
|
9
9
|
|
|
10
10
|
import snowflake.snowpark_connect.relation.utils as utils
|
|
11
11
|
from snowflake import snowpark
|
|
12
|
-
from snowflake.snowpark._internal.analyzer.binary_expression import And
|
|
13
12
|
from snowflake.snowpark.functions import col, expr as snowpark_expr
|
|
14
13
|
from snowflake.snowpark.types import (
|
|
15
14
|
BooleanType,
|
|
@@ -31,7 +30,6 @@ from snowflake.snowpark_connect.expression.map_expression import (
|
|
|
31
30
|
)
|
|
32
31
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
33
32
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
34
|
-
from snowflake.snowpark_connect.relation.utils import can_filter_be_flattened
|
|
35
33
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
36
34
|
SnowparkConnectNotImplementedError,
|
|
37
35
|
)
|
|
@@ -555,32 +553,7 @@ def map_filter(
|
|
|
555
553
|
rel.filter.condition, input_container.column_map, typer
|
|
556
554
|
)
|
|
557
555
|
|
|
558
|
-
|
|
559
|
-
condition_exp = condition.col._expression
|
|
560
|
-
if (
|
|
561
|
-
can_filter_be_flattened(select_statement, condition_exp)
|
|
562
|
-
and input_df._ops_after_agg is None
|
|
563
|
-
):
|
|
564
|
-
new = copy(select_statement)
|
|
565
|
-
new.from_ = select_statement.from_.to_subqueryable()
|
|
566
|
-
new.pre_actions = new.from_.pre_actions
|
|
567
|
-
new.post_actions = new.from_.post_actions
|
|
568
|
-
new.column_states = select_statement.column_states
|
|
569
|
-
new.where = (
|
|
570
|
-
And(select_statement.where, condition_exp)
|
|
571
|
-
if select_statement.where is not None
|
|
572
|
-
else condition_exp
|
|
573
|
-
)
|
|
574
|
-
new._merge_projection_complexity_with_subquery = False
|
|
575
|
-
new.df_ast_ids = (
|
|
576
|
-
select_statement.df_ast_ids.copy()
|
|
577
|
-
if select_statement.df_ast_ids is not None
|
|
578
|
-
else None
|
|
579
|
-
)
|
|
580
|
-
new.attributes = select_statement.attributes
|
|
581
|
-
result = input_df._with_plan(new)
|
|
582
|
-
else:
|
|
583
|
-
result = input_df.filter(condition.col)
|
|
556
|
+
result = input_df.filter(condition.col)
|
|
584
557
|
|
|
585
558
|
return DataFrameContainer(
|
|
586
559
|
result,
|
|
@@ -73,13 +73,12 @@ def rename_columns_as_snowflake_standard(
|
|
|
73
73
|
return df, []
|
|
74
74
|
|
|
75
75
|
new_columns = make_column_names_snowpark_compatible(df.columns, plan_id)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
return (result, new_columns)
|
|
76
|
+
return (
|
|
77
|
+
df.select(
|
|
78
|
+
*(df.col(orig).alias(alias) for orig, alias in zip(df.columns, new_columns))
|
|
79
|
+
),
|
|
80
|
+
new_columns,
|
|
81
|
+
)
|
|
83
82
|
|
|
84
83
|
|
|
85
84
|
class Connection(Protocol):
|