snowpark-connect 0.31.0__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/column_name_handler.py +143 -105
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/dataframe_container.py +3 -2
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +4 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
- snowflake/snowpark_connect/expression/map_expression.py +5 -4
- snowflake/snowpark_connect/expression/map_extension.py +12 -6
- snowflake/snowpark_connect/expression/map_sql_expression.py +50 -7
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +62 -25
- snowflake/snowpark_connect/expression/map_unresolved_function.py +924 -127
- snowflake/snowpark_connect/expression/map_unresolved_star.py +9 -7
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +4 -1
- snowflake/snowpark_connect/relation/map_aggregate.py +6 -5
- snowflake/snowpark_connect/relation/map_column_ops.py +9 -3
- snowflake/snowpark_connect/relation/map_extension.py +10 -9
- snowflake/snowpark_connect/relation/map_join.py +219 -144
- snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
- snowflake/snowpark_connect/relation/map_sql.py +134 -16
- snowflake/snowpark_connect/relation/map_subquery_alias.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_json.py +87 -2
- snowflake/snowpark_connect/relation/read/map_read_table.py +6 -3
- snowflake/snowpark_connect/relation/utils.py +46 -0
- snowflake/snowpark_connect/relation/write/map_write.py +215 -289
- snowflake/snowpark_connect/resources_initializer.py +25 -13
- snowflake/snowpark_connect/server.py +10 -26
- snowflake/snowpark_connect/type_mapping.py +38 -3
- snowflake/snowpark_connect/typed_column.py +8 -6
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +27 -4
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +1 -1
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/METADATA +7 -2
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/RECORD +46 -105
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
- {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
3
|
+
# contributor license agreements. See the NOTICE file distributed with
|
|
4
|
+
# this work for additional information regarding copyright ownership.
|
|
5
|
+
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
6
|
+
# (the "License"); you may not use this file except in compliance with
|
|
7
|
+
# the License. You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
"""
|
|
18
|
+
Additional Spark functions used in pandas-on-Spark.
|
|
19
|
+
"""
|
|
20
|
+
from typing import Union
|
|
21
|
+
|
|
22
|
+
from pyspark import SparkContext
|
|
23
|
+
import pyspark.sql.functions as F
|
|
24
|
+
from pyspark.sql.column import Column
|
|
25
|
+
|
|
26
|
+
# For supporting Spark Connect
|
|
27
|
+
from pyspark.sql.utils import is_remote
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def product(col: Column, dropna: bool) -> Column:
|
|
31
|
+
if is_remote():
|
|
32
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
33
|
+
|
|
34
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
35
|
+
"pandas_product",
|
|
36
|
+
col, # type: ignore[arg-type]
|
|
37
|
+
lit(dropna),
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
else:
|
|
41
|
+
sc = SparkContext._active_spark_context
|
|
42
|
+
return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def stddev(col: Column, ddof: int) -> Column:
|
|
46
|
+
if is_remote():
|
|
47
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
48
|
+
|
|
49
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
50
|
+
"pandas_stddev",
|
|
51
|
+
col, # type: ignore[arg-type]
|
|
52
|
+
lit(ddof),
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
else:
|
|
56
|
+
|
|
57
|
+
sc = SparkContext._active_spark_context
|
|
58
|
+
return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def var(col: Column, ddof: int) -> Column:
|
|
62
|
+
if is_remote():
|
|
63
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
64
|
+
|
|
65
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
66
|
+
"pandas_var",
|
|
67
|
+
col, # type: ignore[arg-type]
|
|
68
|
+
lit(ddof),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
else:
|
|
72
|
+
|
|
73
|
+
sc = SparkContext._active_spark_context
|
|
74
|
+
return Column(sc._jvm.PythonSQLUtils.pandasVariance(col._jc, ddof))
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def skew(col: Column) -> Column:
|
|
78
|
+
if is_remote():
|
|
79
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns
|
|
80
|
+
|
|
81
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
82
|
+
"pandas_skew",
|
|
83
|
+
col, # type: ignore[arg-type]
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
else:
|
|
87
|
+
|
|
88
|
+
sc = SparkContext._active_spark_context
|
|
89
|
+
return Column(sc._jvm.PythonSQLUtils.pandasSkewness(col._jc))
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def kurt(col: Column) -> Column:
|
|
93
|
+
if is_remote():
|
|
94
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns
|
|
95
|
+
|
|
96
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
97
|
+
"pandas_kurt",
|
|
98
|
+
col, # type: ignore[arg-type]
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
else:
|
|
102
|
+
|
|
103
|
+
sc = SparkContext._active_spark_context
|
|
104
|
+
return Column(sc._jvm.PythonSQLUtils.pandasKurtosis(col._jc))
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def mode(col: Column, dropna: bool) -> Column:
|
|
108
|
+
if is_remote():
|
|
109
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
110
|
+
|
|
111
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
112
|
+
"pandas_mode",
|
|
113
|
+
col, # type: ignore[arg-type]
|
|
114
|
+
lit(dropna),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
else:
|
|
118
|
+
sc = SparkContext._active_spark_context
|
|
119
|
+
return Column(sc._jvm.PythonSQLUtils.pandasMode(col._jc, dropna))
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def covar(col1: Column, col2: Column, ddof: int) -> Column:
|
|
123
|
+
if is_remote():
|
|
124
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
125
|
+
|
|
126
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
127
|
+
"pandas_covar",
|
|
128
|
+
col1, # type: ignore[arg-type]
|
|
129
|
+
col2, # type: ignore[arg-type]
|
|
130
|
+
lit(ddof),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
else:
|
|
134
|
+
sc = SparkContext._active_spark_context
|
|
135
|
+
return Column(sc._jvm.PythonSQLUtils.pandasCovar(col1._jc, col2._jc, ddof))
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def repeat(col: Column, n: Union[int, Column]) -> Column:
|
|
139
|
+
"""
|
|
140
|
+
Repeats a string column n times, and returns it as a new string column.
|
|
141
|
+
"""
|
|
142
|
+
_n = F.lit(n) if isinstance(n, int) else n
|
|
143
|
+
return F.call_udf("repeat", col, _n)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def ewm(col: Column, alpha: float, ignore_na: bool) -> Column:
|
|
147
|
+
if is_remote():
|
|
148
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
149
|
+
|
|
150
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
151
|
+
"ewm",
|
|
152
|
+
col, # type: ignore[arg-type]
|
|
153
|
+
lit(alpha),
|
|
154
|
+
lit(ignore_na),
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
else:
|
|
158
|
+
sc = SparkContext._active_spark_context
|
|
159
|
+
return Column(sc._jvm.PythonSQLUtils.ewm(col._jc, alpha, ignore_na))
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def last_non_null(col: Column) -> Column:
|
|
163
|
+
if is_remote():
|
|
164
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns
|
|
165
|
+
|
|
166
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
167
|
+
"last_non_null",
|
|
168
|
+
col, # type: ignore[arg-type]
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
else:
|
|
172
|
+
sc = SparkContext._active_spark_context
|
|
173
|
+
return Column(sc._jvm.PythonSQLUtils.lastNonNull(col._jc))
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def null_index(col: Column) -> Column:
|
|
177
|
+
if is_remote():
|
|
178
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns
|
|
179
|
+
|
|
180
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
181
|
+
"null_index",
|
|
182
|
+
col, # type: ignore[arg-type]
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
else:
|
|
186
|
+
sc = SparkContext._active_spark_context
|
|
187
|
+
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def timestampdiff(unit: str, start: Column, end: Column) -> Column:
|
|
191
|
+
if is_remote():
|
|
192
|
+
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
|
|
193
|
+
|
|
194
|
+
return _invoke_function_over_columns( # type: ignore[return-value]
|
|
195
|
+
"timestampdiff",
|
|
196
|
+
lit(unit),
|
|
197
|
+
start, # type: ignore[arg-type]
|
|
198
|
+
end, # type: ignore[arg-type]
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
else:
|
|
202
|
+
sc = SparkContext._active_spark_context
|
|
203
|
+
return Column(sc._jvm.PythonSQLUtils.timestampDiff(unit, start._jc, end._jc))
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
3
|
+
# contributor license agreements. See the NOTICE file distributed with
|
|
4
|
+
# this work for additional information regarding copyright ownership.
|
|
5
|
+
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
6
|
+
# (the "License"); you may not use this file except in compliance with
|
|
7
|
+
# the License. You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
"""
|
|
18
|
+
Helpers and utilities to deal with PySpark instances
|
|
19
|
+
"""
|
|
20
|
+
from typing import overload
|
|
21
|
+
|
|
22
|
+
from pyspark.sql.types import DecimalType, StructType, MapType, ArrayType, StructField, DataType
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@overload
|
|
26
|
+
def as_nullable_spark_type(dt: StructType) -> StructType:
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@overload
|
|
31
|
+
def as_nullable_spark_type(dt: ArrayType) -> ArrayType:
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@overload
|
|
36
|
+
def as_nullable_spark_type(dt: MapType) -> MapType:
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@overload
|
|
41
|
+
def as_nullable_spark_type(dt: DataType) -> DataType:
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def as_nullable_spark_type(dt: DataType) -> DataType:
|
|
46
|
+
"""
|
|
47
|
+
Returns a nullable schema or data types.
|
|
48
|
+
|
|
49
|
+
Examples
|
|
50
|
+
--------
|
|
51
|
+
>>> from pyspark.sql.types import *
|
|
52
|
+
>>> as_nullable_spark_type(StructType([
|
|
53
|
+
... StructField("A", IntegerType(), True),
|
|
54
|
+
... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
|
|
55
|
+
StructType([StructField('A', IntegerType(), True), StructField('B', FloatType(), True)])
|
|
56
|
+
|
|
57
|
+
>>> as_nullable_spark_type(StructType([
|
|
58
|
+
... StructField("A",
|
|
59
|
+
... StructType([
|
|
60
|
+
... StructField('a',
|
|
61
|
+
... MapType(IntegerType(),
|
|
62
|
+
... ArrayType(IntegerType(), False), False), False),
|
|
63
|
+
... StructField('b', StringType(), True)])),
|
|
64
|
+
... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
|
|
65
|
+
StructType([StructField('A',
|
|
66
|
+
StructType([StructField('a',
|
|
67
|
+
MapType(IntegerType(),
|
|
68
|
+
ArrayType(IntegerType(), True), True), True),
|
|
69
|
+
StructField('b', StringType(), True)]), True),
|
|
70
|
+
StructField('B', FloatType(), True)])
|
|
71
|
+
"""
|
|
72
|
+
if isinstance(dt, StructType):
|
|
73
|
+
new_fields = []
|
|
74
|
+
for field in dt.fields:
|
|
75
|
+
new_fields.append(
|
|
76
|
+
StructField(
|
|
77
|
+
field.name,
|
|
78
|
+
as_nullable_spark_type(field.dataType),
|
|
79
|
+
nullable=True,
|
|
80
|
+
metadata=field.metadata,
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
return StructType(new_fields)
|
|
84
|
+
elif isinstance(dt, ArrayType):
|
|
85
|
+
return ArrayType(as_nullable_spark_type(dt.elementType), containsNull=True)
|
|
86
|
+
elif isinstance(dt, MapType):
|
|
87
|
+
return MapType(
|
|
88
|
+
as_nullable_spark_type(dt.keyType),
|
|
89
|
+
as_nullable_spark_type(dt.valueType),
|
|
90
|
+
valueContainsNull=True,
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
return dt
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@overload
|
|
97
|
+
def force_decimal_precision_scale(
|
|
98
|
+
dt: StructType, *, precision: int = ..., scale: int = ...
|
|
99
|
+
) -> StructType:
|
|
100
|
+
...
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@overload
|
|
104
|
+
def force_decimal_precision_scale(
|
|
105
|
+
dt: ArrayType, *, precision: int = ..., scale: int = ...
|
|
106
|
+
) -> ArrayType:
|
|
107
|
+
...
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@overload
|
|
111
|
+
def force_decimal_precision_scale(
|
|
112
|
+
dt: MapType, *, precision: int = ..., scale: int = ...
|
|
113
|
+
) -> MapType:
|
|
114
|
+
...
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@overload
|
|
118
|
+
def force_decimal_precision_scale(
|
|
119
|
+
dt: DataType, *, precision: int = ..., scale: int = ...
|
|
120
|
+
) -> DataType:
|
|
121
|
+
...
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def force_decimal_precision_scale(
|
|
125
|
+
dt: DataType, *, precision: int = 38, scale: int = 18
|
|
126
|
+
) -> DataType:
|
|
127
|
+
"""
|
|
128
|
+
Returns a data type with a fixed decimal type.
|
|
129
|
+
|
|
130
|
+
The precision and scale of the decimal type are fixed with the given values.
|
|
131
|
+
|
|
132
|
+
Examples
|
|
133
|
+
--------
|
|
134
|
+
>>> from pyspark.sql.types import *
|
|
135
|
+
>>> force_decimal_precision_scale(StructType([
|
|
136
|
+
... StructField("A", DecimalType(10, 0), True),
|
|
137
|
+
... StructField("B", DecimalType(14, 7), False)])) # doctest: +NORMALIZE_WHITESPACE
|
|
138
|
+
StructType([StructField('A', DecimalType(38,18), True),
|
|
139
|
+
StructField('B', DecimalType(38,18), False)])
|
|
140
|
+
|
|
141
|
+
>>> force_decimal_precision_scale(StructType([
|
|
142
|
+
... StructField("A",
|
|
143
|
+
... StructType([
|
|
144
|
+
... StructField('a',
|
|
145
|
+
... MapType(DecimalType(5, 0),
|
|
146
|
+
... ArrayType(DecimalType(20, 0), False), False), False),
|
|
147
|
+
... StructField('b', StringType(), True)])),
|
|
148
|
+
... StructField("B", DecimalType(30, 15), False)]),
|
|
149
|
+
... precision=30, scale=15) # doctest: +NORMALIZE_WHITESPACE
|
|
150
|
+
StructType([StructField('A',
|
|
151
|
+
StructType([StructField('a',
|
|
152
|
+
MapType(DecimalType(30,15),
|
|
153
|
+
ArrayType(DecimalType(30,15), False), False), False),
|
|
154
|
+
StructField('b', StringType(), True)]), True),
|
|
155
|
+
StructField('B', DecimalType(30,15), False)])
|
|
156
|
+
"""
|
|
157
|
+
if isinstance(dt, StructType):
|
|
158
|
+
new_fields = []
|
|
159
|
+
for field in dt.fields:
|
|
160
|
+
new_fields.append(
|
|
161
|
+
StructField(
|
|
162
|
+
field.name,
|
|
163
|
+
force_decimal_precision_scale(field.dataType, precision=precision, scale=scale),
|
|
164
|
+
nullable=field.nullable,
|
|
165
|
+
metadata=field.metadata,
|
|
166
|
+
)
|
|
167
|
+
)
|
|
168
|
+
return StructType(new_fields)
|
|
169
|
+
elif isinstance(dt, ArrayType):
|
|
170
|
+
return ArrayType(
|
|
171
|
+
force_decimal_precision_scale(dt.elementType, precision=precision, scale=scale),
|
|
172
|
+
containsNull=dt.containsNull,
|
|
173
|
+
)
|
|
174
|
+
elif isinstance(dt, MapType):
|
|
175
|
+
return MapType(
|
|
176
|
+
force_decimal_precision_scale(dt.keyType, precision=precision, scale=scale),
|
|
177
|
+
force_decimal_precision_scale(dt.valueType, precision=precision, scale=scale),
|
|
178
|
+
valueContainsNull=dt.valueContainsNull,
|
|
179
|
+
)
|
|
180
|
+
elif isinstance(dt, DecimalType):
|
|
181
|
+
return DecimalType(precision=precision, scale=scale)
|
|
182
|
+
else:
|
|
183
|
+
return dt
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _test() -> None:
|
|
187
|
+
import doctest
|
|
188
|
+
import sys
|
|
189
|
+
import pyspark.pandas.spark.utils
|
|
190
|
+
|
|
191
|
+
globs = pyspark.pandas.spark.utils.__dict__.copy()
|
|
192
|
+
(failure_count, test_count) = doctest.testmod(
|
|
193
|
+
pyspark.pandas.spark.utils,
|
|
194
|
+
globs=globs,
|
|
195
|
+
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
|
|
196
|
+
)
|
|
197
|
+
if failure_count:
|
|
198
|
+
sys.exit(-1)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
if __name__ == "__main__":
|
|
202
|
+
_test()
|
|
@@ -19,6 +19,7 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
|
19
19
|
)
|
|
20
20
|
from snowflake.snowpark.functions import lit
|
|
21
21
|
from snowflake.snowpark.types import BooleanType, StringType
|
|
22
|
+
from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
|
|
22
23
|
from snowflake.snowpark_connect.config import (
|
|
23
24
|
auto_uppercase_non_column_identifiers,
|
|
24
25
|
global_config,
|
|
@@ -743,7 +744,9 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
743
744
|
sp_schema = proto_to_snowpark_type(schema)
|
|
744
745
|
columns = [c.name for c in schema.struct.fields]
|
|
745
746
|
table_name_parts = split_fully_qualified_spark_name(tableName)
|
|
746
|
-
qualifiers = [
|
|
747
|
+
qualifiers: list[set[ColumnQualifier]] = [
|
|
748
|
+
{ColumnQualifier(tuple(table_name_parts))} for _ in columns
|
|
749
|
+
]
|
|
747
750
|
column_types = [f.datatype for f in sp_schema.fields]
|
|
748
751
|
return DataFrameContainer.create_with_column_mapping(
|
|
749
752
|
dataframe=session.createDataFrame([], sp_schema),
|
|
@@ -16,6 +16,7 @@ from snowflake.snowpark.types import DataType
|
|
|
16
16
|
from snowflake.snowpark_connect.column_name_handler import (
|
|
17
17
|
make_column_names_snowpark_compatible,
|
|
18
18
|
)
|
|
19
|
+
from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
|
|
19
20
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
20
21
|
from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
|
|
21
22
|
from snowflake.snowpark_connect.expression.map_expression import (
|
|
@@ -200,7 +201,7 @@ def map_pivot_aggregate(
|
|
|
200
201
|
dataframe=result.select(*column_selectors),
|
|
201
202
|
spark_column_names=reordered_spark_names,
|
|
202
203
|
snowpark_column_names=reordered_snowpark_names,
|
|
203
|
-
column_qualifiers=[
|
|
204
|
+
column_qualifiers=[set() for _ in reordered_spark_names],
|
|
204
205
|
parent_column_name_map=input_container.column_map,
|
|
205
206
|
snowpark_column_types=reordered_types,
|
|
206
207
|
)
|
|
@@ -349,7 +350,7 @@ class _ColumnMetadata:
|
|
|
349
350
|
spark_name: str
|
|
350
351
|
snowpark_name: str
|
|
351
352
|
data_type: DataType
|
|
352
|
-
qualifiers:
|
|
353
|
+
qualifiers: set[ColumnQualifier]
|
|
353
354
|
|
|
354
355
|
|
|
355
356
|
@dataclass(frozen=True)
|
|
@@ -385,7 +386,7 @@ class _Columns:
|
|
|
385
386
|
col.spark_name for col in self.grouping_columns + self.aggregation_columns
|
|
386
387
|
]
|
|
387
388
|
|
|
388
|
-
def get_qualifiers(self) -> list[
|
|
389
|
+
def get_qualifiers(self) -> list[set[ColumnQualifier]]:
|
|
389
390
|
return [
|
|
390
391
|
col.qualifiers for col in self.grouping_columns + self.aggregation_columns
|
|
391
392
|
]
|
|
@@ -429,7 +430,7 @@ def map_aggregate_helper(
|
|
|
429
430
|
new_name,
|
|
430
431
|
None if skip_alias else alias,
|
|
431
432
|
None if pivot else snowpark_column.typ,
|
|
432
|
-
snowpark_column.get_qualifiers(),
|
|
433
|
+
qualifiers=snowpark_column.get_qualifiers(),
|
|
433
434
|
)
|
|
434
435
|
)
|
|
435
436
|
|
|
@@ -469,7 +470,7 @@ def map_aggregate_helper(
|
|
|
469
470
|
new_name,
|
|
470
471
|
None if skip_alias else alias,
|
|
471
472
|
agg_col_typ,
|
|
472
|
-
|
|
473
|
+
qualifiers=set(),
|
|
473
474
|
)
|
|
474
475
|
)
|
|
475
476
|
|
|
@@ -29,6 +29,7 @@ from snowflake.snowpark.column import Column
|
|
|
29
29
|
from snowflake.snowpark.table_function import _ExplodeFunctionCall
|
|
30
30
|
from snowflake.snowpark.types import DataType, StructField, StructType, _NumericType
|
|
31
31
|
from snowflake.snowpark_connect.column_name_handler import (
|
|
32
|
+
ColumnQualifier,
|
|
32
33
|
make_column_names_snowpark_compatible,
|
|
33
34
|
)
|
|
34
35
|
from snowflake.snowpark_connect.config import global_config
|
|
@@ -315,6 +316,11 @@ def map_project(
|
|
|
315
316
|
final_snowpark_columns = make_column_names_snowpark_compatible(
|
|
316
317
|
new_spark_columns, rel.common.plan_id
|
|
317
318
|
)
|
|
319
|
+
# if there are duplicate snowpark column names, we need to disambiguate them by their index
|
|
320
|
+
if len(new_spark_columns) != len(set(new_spark_columns)):
|
|
321
|
+
result = result.select(
|
|
322
|
+
[f"${i}" for i in range(1, len(new_spark_columns) + 1)]
|
|
323
|
+
)
|
|
318
324
|
result = result.toDF(*final_snowpark_columns)
|
|
319
325
|
new_snowpark_columns = final_snowpark_columns
|
|
320
326
|
|
|
@@ -1014,7 +1020,7 @@ def map_unpivot(
|
|
|
1014
1020
|
column_project = []
|
|
1015
1021
|
column_reverse_project = []
|
|
1016
1022
|
snowpark_columns = []
|
|
1017
|
-
qualifiers = []
|
|
1023
|
+
qualifiers: list[set[ColumnQualifier]] = []
|
|
1018
1024
|
for c in input_container.column_map.get_snowpark_columns():
|
|
1019
1025
|
c_name = snowpark_functions_col(c, input_container.column_map).get_name()
|
|
1020
1026
|
if c_name in unpivot_col_names:
|
|
@@ -1042,7 +1048,7 @@ def map_unpivot(
|
|
|
1042
1048
|
)
|
|
1043
1049
|
snowpark_columns.append(c)
|
|
1044
1050
|
qualifiers.append(
|
|
1045
|
-
input_container.column_map.
|
|
1051
|
+
input_container.column_map.get_qualifiers_for_spark_column(c)
|
|
1046
1052
|
)
|
|
1047
1053
|
|
|
1048
1054
|
# Without the case when postprocessing, the result Spark dataframe is:
|
|
@@ -1087,7 +1093,7 @@ def map_unpivot(
|
|
|
1087
1093
|
snowpark_functions_col(snowpark_value_column_name, input_container.column_map)
|
|
1088
1094
|
)
|
|
1089
1095
|
snowpark_columns.append(snowpark_value_column_name)
|
|
1090
|
-
qualifiers.extend([
|
|
1096
|
+
qualifiers.extend([set() for _ in range(2)])
|
|
1091
1097
|
|
|
1092
1098
|
result = (
|
|
1093
1099
|
input_df.select(*column_project)
|
|
@@ -15,6 +15,7 @@ from snowflake.snowpark_connect.column_name_handler import (
|
|
|
15
15
|
ColumnNameMap,
|
|
16
16
|
make_column_names_snowpark_compatible,
|
|
17
17
|
)
|
|
18
|
+
from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
|
|
18
19
|
from snowflake.snowpark_connect.config import get_boolean_session_config_param
|
|
19
20
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
20
21
|
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
@@ -178,7 +179,7 @@ def get_udtf_project(relation: relation_proto.Relation) -> bool:
|
|
|
178
179
|
|
|
179
180
|
def handle_udtf_with_table_arguments(
|
|
180
181
|
udtf_info: snowflake_proto.UDTFWithTableArguments,
|
|
181
|
-
) ->
|
|
182
|
+
) -> DataFrameContainer:
|
|
182
183
|
"""
|
|
183
184
|
Handle UDTF with one or more table arguments using Snowpark's join_table_function.
|
|
184
185
|
For multiple table arguments, this creates a Cartesian product of all input tables.
|
|
@@ -286,7 +287,7 @@ def handle_lateral_join_with_udtf(
|
|
|
286
287
|
left_result: DataFrameContainer,
|
|
287
288
|
udtf_relation: relation_proto.Relation,
|
|
288
289
|
udtf_info: tuple[snowpark.udtf.UserDefinedTableFunction, list],
|
|
289
|
-
) ->
|
|
290
|
+
) -> DataFrameContainer:
|
|
290
291
|
"""
|
|
291
292
|
Handle lateral join with UDTF on the right side using join_table_function.
|
|
292
293
|
"""
|
|
@@ -319,7 +320,7 @@ def handle_lateral_join_with_udtf(
|
|
|
319
320
|
|
|
320
321
|
def map_aggregate(
|
|
321
322
|
aggregate: snowflake_proto.Aggregate, plan_id: int
|
|
322
|
-
) ->
|
|
323
|
+
) -> DataFrameContainer:
|
|
323
324
|
input_container = map_relation(aggregate.input)
|
|
324
325
|
input_df: snowpark.DataFrame = input_container.dataframe
|
|
325
326
|
|
|
@@ -363,7 +364,7 @@ def map_aggregate(
|
|
|
363
364
|
return new_names[0], snowpark_column
|
|
364
365
|
|
|
365
366
|
raw_groupings: list[tuple[str, TypedColumn]] = []
|
|
366
|
-
raw_aggregations: list[tuple[str, TypedColumn,
|
|
367
|
+
raw_aggregations: list[tuple[str, TypedColumn, set[ColumnQualifier]]] = []
|
|
367
368
|
|
|
368
369
|
if not is_group_by_all:
|
|
369
370
|
raw_groupings = [_map_column(exp) for exp in aggregate.grouping_expressions]
|
|
@@ -401,11 +402,11 @@ def map_aggregate(
|
|
|
401
402
|
col = _map_column(exp)
|
|
402
403
|
if exp.WhichOneof("expr_type") == "unresolved_attribute":
|
|
403
404
|
spark_name = col[0]
|
|
404
|
-
qualifiers
|
|
405
|
-
|
|
406
|
-
)
|
|
405
|
+
qualifiers: set[
|
|
406
|
+
ColumnQualifier
|
|
407
|
+
] = input_container.column_map.get_qualifiers_for_spark_column(spark_name)
|
|
407
408
|
else:
|
|
408
|
-
qualifiers =
|
|
409
|
+
qualifiers = set()
|
|
409
410
|
|
|
410
411
|
raw_aggregations.append((col[0], col[1], qualifiers))
|
|
411
412
|
|
|
@@ -438,7 +439,7 @@ def map_aggregate(
|
|
|
438
439
|
spark_columns: list[str] = []
|
|
439
440
|
snowpark_columns: list[str] = []
|
|
440
441
|
snowpark_column_types: list[snowpark_types.DataType] = []
|
|
441
|
-
all_qualifiers: list[
|
|
442
|
+
all_qualifiers: list[set[ColumnQualifier]] = []
|
|
442
443
|
|
|
443
444
|
# Use grouping columns directly without aliases
|
|
444
445
|
groupings = [col.col for _, col in raw_groupings]
|