snowpark-connect 0.31.0__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (111) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/column_name_handler.py +143 -105
  3. snowflake/snowpark_connect/column_qualifier.py +43 -0
  4. snowflake/snowpark_connect/dataframe_container.py +3 -2
  5. snowflake/snowpark_connect/execute_plan/map_execution_command.py +4 -2
  6. snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
  7. snowflake/snowpark_connect/expression/map_expression.py +5 -4
  8. snowflake/snowpark_connect/expression/map_extension.py +12 -6
  9. snowflake/snowpark_connect/expression/map_sql_expression.py +50 -7
  10. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +62 -25
  11. snowflake/snowpark_connect/expression/map_unresolved_function.py +924 -127
  12. snowflake/snowpark_connect/expression/map_unresolved_star.py +9 -7
  13. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
  14. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
  15. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
  16. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
  17. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +4 -1
  18. snowflake/snowpark_connect/relation/map_aggregate.py +6 -5
  19. snowflake/snowpark_connect/relation/map_column_ops.py +9 -3
  20. snowflake/snowpark_connect/relation/map_extension.py +10 -9
  21. snowflake/snowpark_connect/relation/map_join.py +219 -144
  22. snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
  23. snowflake/snowpark_connect/relation/map_sql.py +134 -16
  24. snowflake/snowpark_connect/relation/map_subquery_alias.py +4 -1
  25. snowflake/snowpark_connect/relation/read/map_read_json.py +87 -2
  26. snowflake/snowpark_connect/relation/read/map_read_table.py +6 -3
  27. snowflake/snowpark_connect/relation/utils.py +46 -0
  28. snowflake/snowpark_connect/relation/write/map_write.py +215 -289
  29. snowflake/snowpark_connect/resources_initializer.py +25 -13
  30. snowflake/snowpark_connect/server.py +10 -26
  31. snowflake/snowpark_connect/type_mapping.py +38 -3
  32. snowflake/snowpark_connect/typed_column.py +8 -6
  33. snowflake/snowpark_connect/utils/sequence.py +21 -0
  34. snowflake/snowpark_connect/utils/session.py +27 -4
  35. snowflake/snowpark_connect/version.py +1 -1
  36. snowflake/snowpark_decoder/dp_session.py +1 -1
  37. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/METADATA +7 -2
  38. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/RECORD +46 -105
  39. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  46. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  47. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  48. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  52. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  53. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  54. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  55. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  56. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  57. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  58. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  59. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  60. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  61. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  62. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  63. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  64. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  65. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  66. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  67. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  68. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  69. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  70. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  71. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  72. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  73. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  74. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  75. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  76. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  77. snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
  78. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  79. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  80. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  81. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  82. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  83. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  84. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
  85. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  86. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  87. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  88. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  89. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  90. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  91. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  92. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  93. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  94. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  95. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  96. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  97. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  98. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  99. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  100. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  101. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  102. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
  103. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
  104. {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-connect +0 -0
  105. {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-session +0 -0
  106. {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-submit +0 -0
  107. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/WHEEL +0 -0
  108. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE-binary +0 -0
  109. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE.txt +0 -0
  110. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/NOTICE-binary +0 -0
  111. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,203 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ """
18
+ Additional Spark functions used in pandas-on-Spark.
19
+ """
20
+ from typing import Union
21
+
22
+ from pyspark import SparkContext
23
+ import pyspark.sql.functions as F
24
+ from pyspark.sql.column import Column
25
+
26
+ # For supporting Spark Connect
27
+ from pyspark.sql.utils import is_remote
28
+
29
+
30
+ def product(col: Column, dropna: bool) -> Column:
31
+ if is_remote():
32
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
33
+
34
+ return _invoke_function_over_columns( # type: ignore[return-value]
35
+ "pandas_product",
36
+ col, # type: ignore[arg-type]
37
+ lit(dropna),
38
+ )
39
+
40
+ else:
41
+ sc = SparkContext._active_spark_context
42
+ return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna))
43
+
44
+
45
+ def stddev(col: Column, ddof: int) -> Column:
46
+ if is_remote():
47
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
48
+
49
+ return _invoke_function_over_columns( # type: ignore[return-value]
50
+ "pandas_stddev",
51
+ col, # type: ignore[arg-type]
52
+ lit(ddof),
53
+ )
54
+
55
+ else:
56
+
57
+ sc = SparkContext._active_spark_context
58
+ return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))
59
+
60
+
61
+ def var(col: Column, ddof: int) -> Column:
62
+ if is_remote():
63
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
64
+
65
+ return _invoke_function_over_columns( # type: ignore[return-value]
66
+ "pandas_var",
67
+ col, # type: ignore[arg-type]
68
+ lit(ddof),
69
+ )
70
+
71
+ else:
72
+
73
+ sc = SparkContext._active_spark_context
74
+ return Column(sc._jvm.PythonSQLUtils.pandasVariance(col._jc, ddof))
75
+
76
+
77
+ def skew(col: Column) -> Column:
78
+ if is_remote():
79
+ from pyspark.sql.connect.functions import _invoke_function_over_columns
80
+
81
+ return _invoke_function_over_columns( # type: ignore[return-value]
82
+ "pandas_skew",
83
+ col, # type: ignore[arg-type]
84
+ )
85
+
86
+ else:
87
+
88
+ sc = SparkContext._active_spark_context
89
+ return Column(sc._jvm.PythonSQLUtils.pandasSkewness(col._jc))
90
+
91
+
92
+ def kurt(col: Column) -> Column:
93
+ if is_remote():
94
+ from pyspark.sql.connect.functions import _invoke_function_over_columns
95
+
96
+ return _invoke_function_over_columns( # type: ignore[return-value]
97
+ "pandas_kurt",
98
+ col, # type: ignore[arg-type]
99
+ )
100
+
101
+ else:
102
+
103
+ sc = SparkContext._active_spark_context
104
+ return Column(sc._jvm.PythonSQLUtils.pandasKurtosis(col._jc))
105
+
106
+
107
+ def mode(col: Column, dropna: bool) -> Column:
108
+ if is_remote():
109
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
110
+
111
+ return _invoke_function_over_columns( # type: ignore[return-value]
112
+ "pandas_mode",
113
+ col, # type: ignore[arg-type]
114
+ lit(dropna),
115
+ )
116
+
117
+ else:
118
+ sc = SparkContext._active_spark_context
119
+ return Column(sc._jvm.PythonSQLUtils.pandasMode(col._jc, dropna))
120
+
121
+
122
+ def covar(col1: Column, col2: Column, ddof: int) -> Column:
123
+ if is_remote():
124
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
125
+
126
+ return _invoke_function_over_columns( # type: ignore[return-value]
127
+ "pandas_covar",
128
+ col1, # type: ignore[arg-type]
129
+ col2, # type: ignore[arg-type]
130
+ lit(ddof),
131
+ )
132
+
133
+ else:
134
+ sc = SparkContext._active_spark_context
135
+ return Column(sc._jvm.PythonSQLUtils.pandasCovar(col1._jc, col2._jc, ddof))
136
+
137
+
138
+ def repeat(col: Column, n: Union[int, Column]) -> Column:
139
+ """
140
+ Repeats a string column n times, and returns it as a new string column.
141
+ """
142
+ _n = F.lit(n) if isinstance(n, int) else n
143
+ return F.call_udf("repeat", col, _n)
144
+
145
+
146
+ def ewm(col: Column, alpha: float, ignore_na: bool) -> Column:
147
+ if is_remote():
148
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
149
+
150
+ return _invoke_function_over_columns( # type: ignore[return-value]
151
+ "ewm",
152
+ col, # type: ignore[arg-type]
153
+ lit(alpha),
154
+ lit(ignore_na),
155
+ )
156
+
157
+ else:
158
+ sc = SparkContext._active_spark_context
159
+ return Column(sc._jvm.PythonSQLUtils.ewm(col._jc, alpha, ignore_na))
160
+
161
+
162
+ def last_non_null(col: Column) -> Column:
163
+ if is_remote():
164
+ from pyspark.sql.connect.functions import _invoke_function_over_columns
165
+
166
+ return _invoke_function_over_columns( # type: ignore[return-value]
167
+ "last_non_null",
168
+ col, # type: ignore[arg-type]
169
+ )
170
+
171
+ else:
172
+ sc = SparkContext._active_spark_context
173
+ return Column(sc._jvm.PythonSQLUtils.lastNonNull(col._jc))
174
+
175
+
176
+ def null_index(col: Column) -> Column:
177
+ if is_remote():
178
+ from pyspark.sql.connect.functions import _invoke_function_over_columns
179
+
180
+ return _invoke_function_over_columns( # type: ignore[return-value]
181
+ "null_index",
182
+ col, # type: ignore[arg-type]
183
+ )
184
+
185
+ else:
186
+ sc = SparkContext._active_spark_context
187
+ return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
188
+
189
+
190
+ def timestampdiff(unit: str, start: Column, end: Column) -> Column:
191
+ if is_remote():
192
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
193
+
194
+ return _invoke_function_over_columns( # type: ignore[return-value]
195
+ "timestampdiff",
196
+ lit(unit),
197
+ start, # type: ignore[arg-type]
198
+ end, # type: ignore[arg-type]
199
+ )
200
+
201
+ else:
202
+ sc = SparkContext._active_spark_context
203
+ return Column(sc._jvm.PythonSQLUtils.timestampDiff(unit, start._jc, end._jc))
@@ -0,0 +1,202 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ """
18
+ Helpers and utilities to deal with PySpark instances
19
+ """
20
+ from typing import overload
21
+
22
+ from pyspark.sql.types import DecimalType, StructType, MapType, ArrayType, StructField, DataType
23
+
24
+
25
+ @overload
26
+ def as_nullable_spark_type(dt: StructType) -> StructType:
27
+ ...
28
+
29
+
30
+ @overload
31
+ def as_nullable_spark_type(dt: ArrayType) -> ArrayType:
32
+ ...
33
+
34
+
35
+ @overload
36
+ def as_nullable_spark_type(dt: MapType) -> MapType:
37
+ ...
38
+
39
+
40
+ @overload
41
+ def as_nullable_spark_type(dt: DataType) -> DataType:
42
+ ...
43
+
44
+
45
+ def as_nullable_spark_type(dt: DataType) -> DataType:
46
+ """
47
+ Returns a nullable schema or data types.
48
+
49
+ Examples
50
+ --------
51
+ >>> from pyspark.sql.types import *
52
+ >>> as_nullable_spark_type(StructType([
53
+ ... StructField("A", IntegerType(), True),
54
+ ... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
55
+ StructType([StructField('A', IntegerType(), True), StructField('B', FloatType(), True)])
56
+
57
+ >>> as_nullable_spark_type(StructType([
58
+ ... StructField("A",
59
+ ... StructType([
60
+ ... StructField('a',
61
+ ... MapType(IntegerType(),
62
+ ... ArrayType(IntegerType(), False), False), False),
63
+ ... StructField('b', StringType(), True)])),
64
+ ... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
65
+ StructType([StructField('A',
66
+ StructType([StructField('a',
67
+ MapType(IntegerType(),
68
+ ArrayType(IntegerType(), True), True), True),
69
+ StructField('b', StringType(), True)]), True),
70
+ StructField('B', FloatType(), True)])
71
+ """
72
+ if isinstance(dt, StructType):
73
+ new_fields = []
74
+ for field in dt.fields:
75
+ new_fields.append(
76
+ StructField(
77
+ field.name,
78
+ as_nullable_spark_type(field.dataType),
79
+ nullable=True,
80
+ metadata=field.metadata,
81
+ )
82
+ )
83
+ return StructType(new_fields)
84
+ elif isinstance(dt, ArrayType):
85
+ return ArrayType(as_nullable_spark_type(dt.elementType), containsNull=True)
86
+ elif isinstance(dt, MapType):
87
+ return MapType(
88
+ as_nullable_spark_type(dt.keyType),
89
+ as_nullable_spark_type(dt.valueType),
90
+ valueContainsNull=True,
91
+ )
92
+ else:
93
+ return dt
94
+
95
+
96
+ @overload
97
+ def force_decimal_precision_scale(
98
+ dt: StructType, *, precision: int = ..., scale: int = ...
99
+ ) -> StructType:
100
+ ...
101
+
102
+
103
+ @overload
104
+ def force_decimal_precision_scale(
105
+ dt: ArrayType, *, precision: int = ..., scale: int = ...
106
+ ) -> ArrayType:
107
+ ...
108
+
109
+
110
+ @overload
111
+ def force_decimal_precision_scale(
112
+ dt: MapType, *, precision: int = ..., scale: int = ...
113
+ ) -> MapType:
114
+ ...
115
+
116
+
117
+ @overload
118
+ def force_decimal_precision_scale(
119
+ dt: DataType, *, precision: int = ..., scale: int = ...
120
+ ) -> DataType:
121
+ ...
122
+
123
+
124
+ def force_decimal_precision_scale(
125
+ dt: DataType, *, precision: int = 38, scale: int = 18
126
+ ) -> DataType:
127
+ """
128
+ Returns a data type with a fixed decimal type.
129
+
130
+ The precision and scale of the decimal type are fixed with the given values.
131
+
132
+ Examples
133
+ --------
134
+ >>> from pyspark.sql.types import *
135
+ >>> force_decimal_precision_scale(StructType([
136
+ ... StructField("A", DecimalType(10, 0), True),
137
+ ... StructField("B", DecimalType(14, 7), False)])) # doctest: +NORMALIZE_WHITESPACE
138
+ StructType([StructField('A', DecimalType(38,18), True),
139
+ StructField('B', DecimalType(38,18), False)])
140
+
141
+ >>> force_decimal_precision_scale(StructType([
142
+ ... StructField("A",
143
+ ... StructType([
144
+ ... StructField('a',
145
+ ... MapType(DecimalType(5, 0),
146
+ ... ArrayType(DecimalType(20, 0), False), False), False),
147
+ ... StructField('b', StringType(), True)])),
148
+ ... StructField("B", DecimalType(30, 15), False)]),
149
+ ... precision=30, scale=15) # doctest: +NORMALIZE_WHITESPACE
150
+ StructType([StructField('A',
151
+ StructType([StructField('a',
152
+ MapType(DecimalType(30,15),
153
+ ArrayType(DecimalType(30,15), False), False), False),
154
+ StructField('b', StringType(), True)]), True),
155
+ StructField('B', DecimalType(30,15), False)])
156
+ """
157
+ if isinstance(dt, StructType):
158
+ new_fields = []
159
+ for field in dt.fields:
160
+ new_fields.append(
161
+ StructField(
162
+ field.name,
163
+ force_decimal_precision_scale(field.dataType, precision=precision, scale=scale),
164
+ nullable=field.nullable,
165
+ metadata=field.metadata,
166
+ )
167
+ )
168
+ return StructType(new_fields)
169
+ elif isinstance(dt, ArrayType):
170
+ return ArrayType(
171
+ force_decimal_precision_scale(dt.elementType, precision=precision, scale=scale),
172
+ containsNull=dt.containsNull,
173
+ )
174
+ elif isinstance(dt, MapType):
175
+ return MapType(
176
+ force_decimal_precision_scale(dt.keyType, precision=precision, scale=scale),
177
+ force_decimal_precision_scale(dt.valueType, precision=precision, scale=scale),
178
+ valueContainsNull=dt.valueContainsNull,
179
+ )
180
+ elif isinstance(dt, DecimalType):
181
+ return DecimalType(precision=precision, scale=scale)
182
+ else:
183
+ return dt
184
+
185
+
186
+ def _test() -> None:
187
+ import doctest
188
+ import sys
189
+ import pyspark.pandas.spark.utils
190
+
191
+ globs = pyspark.pandas.spark.utils.__dict__.copy()
192
+ (failure_count, test_count) = doctest.testmod(
193
+ pyspark.pandas.spark.utils,
194
+ globs=globs,
195
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
196
+ )
197
+ if failure_count:
198
+ sys.exit(-1)
199
+
200
+
201
+ if __name__ == "__main__":
202
+ _test()
@@ -19,6 +19,7 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
19
19
  )
20
20
  from snowflake.snowpark.functions import lit
21
21
  from snowflake.snowpark.types import BooleanType, StringType
22
+ from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
22
23
  from snowflake.snowpark_connect.config import (
23
24
  auto_uppercase_non_column_identifiers,
24
25
  global_config,
@@ -743,7 +744,9 @@ class SnowflakeCatalog(AbstractSparkCatalog):
743
744
  sp_schema = proto_to_snowpark_type(schema)
744
745
  columns = [c.name for c in schema.struct.fields]
745
746
  table_name_parts = split_fully_qualified_spark_name(tableName)
746
- qualifiers = [table_name_parts for _ in columns]
747
+ qualifiers: list[set[ColumnQualifier]] = [
748
+ {ColumnQualifier(tuple(table_name_parts))} for _ in columns
749
+ ]
747
750
  column_types = [f.datatype for f in sp_schema.fields]
748
751
  return DataFrameContainer.create_with_column_mapping(
749
752
  dataframe=session.createDataFrame([], sp_schema),
@@ -16,6 +16,7 @@ from snowflake.snowpark.types import DataType
16
16
  from snowflake.snowpark_connect.column_name_handler import (
17
17
  make_column_names_snowpark_compatible,
18
18
  )
19
+ from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
19
20
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
20
21
  from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
21
22
  from snowflake.snowpark_connect.expression.map_expression import (
@@ -200,7 +201,7 @@ def map_pivot_aggregate(
200
201
  dataframe=result.select(*column_selectors),
201
202
  spark_column_names=reordered_spark_names,
202
203
  snowpark_column_names=reordered_snowpark_names,
203
- column_qualifiers=[[]] * len(reordered_spark_names),
204
+ column_qualifiers=[set() for _ in reordered_spark_names],
204
205
  parent_column_name_map=input_container.column_map,
205
206
  snowpark_column_types=reordered_types,
206
207
  )
@@ -349,7 +350,7 @@ class _ColumnMetadata:
349
350
  spark_name: str
350
351
  snowpark_name: str
351
352
  data_type: DataType
352
- qualifiers: list[str]
353
+ qualifiers: set[ColumnQualifier]
353
354
 
354
355
 
355
356
  @dataclass(frozen=True)
@@ -385,7 +386,7 @@ class _Columns:
385
386
  col.spark_name for col in self.grouping_columns + self.aggregation_columns
386
387
  ]
387
388
 
388
- def get_qualifiers(self) -> list[list[str]]:
389
+ def get_qualifiers(self) -> list[set[ColumnQualifier]]:
389
390
  return [
390
391
  col.qualifiers for col in self.grouping_columns + self.aggregation_columns
391
392
  ]
@@ -429,7 +430,7 @@ def map_aggregate_helper(
429
430
  new_name,
430
431
  None if skip_alias else alias,
431
432
  None if pivot else snowpark_column.typ,
432
- snowpark_column.get_qualifiers(),
433
+ qualifiers=snowpark_column.get_qualifiers(),
433
434
  )
434
435
  )
435
436
 
@@ -469,7 +470,7 @@ def map_aggregate_helper(
469
470
  new_name,
470
471
  None if skip_alias else alias,
471
472
  agg_col_typ,
472
- [],
473
+ qualifiers=set(),
473
474
  )
474
475
  )
475
476
 
@@ -29,6 +29,7 @@ from snowflake.snowpark.column import Column
29
29
  from snowflake.snowpark.table_function import _ExplodeFunctionCall
30
30
  from snowflake.snowpark.types import DataType, StructField, StructType, _NumericType
31
31
  from snowflake.snowpark_connect.column_name_handler import (
32
+ ColumnQualifier,
32
33
  make_column_names_snowpark_compatible,
33
34
  )
34
35
  from snowflake.snowpark_connect.config import global_config
@@ -315,6 +316,11 @@ def map_project(
315
316
  final_snowpark_columns = make_column_names_snowpark_compatible(
316
317
  new_spark_columns, rel.common.plan_id
317
318
  )
319
+ # if there are duplicate snowpark column names, we need to disambiguate them by their index
320
+ if len(new_spark_columns) != len(set(new_spark_columns)):
321
+ result = result.select(
322
+ [f"${i}" for i in range(1, len(new_spark_columns) + 1)]
323
+ )
318
324
  result = result.toDF(*final_snowpark_columns)
319
325
  new_snowpark_columns = final_snowpark_columns
320
326
 
@@ -1014,7 +1020,7 @@ def map_unpivot(
1014
1020
  column_project = []
1015
1021
  column_reverse_project = []
1016
1022
  snowpark_columns = []
1017
- qualifiers = []
1023
+ qualifiers: list[set[ColumnQualifier]] = []
1018
1024
  for c in input_container.column_map.get_snowpark_columns():
1019
1025
  c_name = snowpark_functions_col(c, input_container.column_map).get_name()
1020
1026
  if c_name in unpivot_col_names:
@@ -1042,7 +1048,7 @@ def map_unpivot(
1042
1048
  )
1043
1049
  snowpark_columns.append(c)
1044
1050
  qualifiers.append(
1045
- input_container.column_map.get_qualifier_for_spark_column(c)
1051
+ input_container.column_map.get_qualifiers_for_spark_column(c)
1046
1052
  )
1047
1053
 
1048
1054
  # Without the case when postprocessing, the result Spark dataframe is:
@@ -1087,7 +1093,7 @@ def map_unpivot(
1087
1093
  snowpark_functions_col(snowpark_value_column_name, input_container.column_map)
1088
1094
  )
1089
1095
  snowpark_columns.append(snowpark_value_column_name)
1090
- qualifiers.extend([[]] * 2)
1096
+ qualifiers.extend([set() for _ in range(2)])
1091
1097
 
1092
1098
  result = (
1093
1099
  input_df.select(*column_project)
@@ -15,6 +15,7 @@ from snowflake.snowpark_connect.column_name_handler import (
15
15
  ColumnNameMap,
16
16
  make_column_names_snowpark_compatible,
17
17
  )
18
+ from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
18
19
  from snowflake.snowpark_connect.config import get_boolean_session_config_param
19
20
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
20
21
  from snowflake.snowpark_connect.error.error_codes import ErrorCodes
@@ -178,7 +179,7 @@ def get_udtf_project(relation: relation_proto.Relation) -> bool:
178
179
 
179
180
  def handle_udtf_with_table_arguments(
180
181
  udtf_info: snowflake_proto.UDTFWithTableArguments,
181
- ) -> snowpark.DataFrame:
182
+ ) -> DataFrameContainer:
182
183
  """
183
184
  Handle UDTF with one or more table arguments using Snowpark's join_table_function.
184
185
  For multiple table arguments, this creates a Cartesian product of all input tables.
@@ -286,7 +287,7 @@ def handle_lateral_join_with_udtf(
286
287
  left_result: DataFrameContainer,
287
288
  udtf_relation: relation_proto.Relation,
288
289
  udtf_info: tuple[snowpark.udtf.UserDefinedTableFunction, list],
289
- ) -> snowpark.DataFrame:
290
+ ) -> DataFrameContainer:
290
291
  """
291
292
  Handle lateral join with UDTF on the right side using join_table_function.
292
293
  """
@@ -319,7 +320,7 @@ def handle_lateral_join_with_udtf(
319
320
 
320
321
  def map_aggregate(
321
322
  aggregate: snowflake_proto.Aggregate, plan_id: int
322
- ) -> snowpark.DataFrame:
323
+ ) -> DataFrameContainer:
323
324
  input_container = map_relation(aggregate.input)
324
325
  input_df: snowpark.DataFrame = input_container.dataframe
325
326
 
@@ -363,7 +364,7 @@ def map_aggregate(
363
364
  return new_names[0], snowpark_column
364
365
 
365
366
  raw_groupings: list[tuple[str, TypedColumn]] = []
366
- raw_aggregations: list[tuple[str, TypedColumn, list[str]]] = []
367
+ raw_aggregations: list[tuple[str, TypedColumn, set[ColumnQualifier]]] = []
367
368
 
368
369
  if not is_group_by_all:
369
370
  raw_groupings = [_map_column(exp) for exp in aggregate.grouping_expressions]
@@ -401,11 +402,11 @@ def map_aggregate(
401
402
  col = _map_column(exp)
402
403
  if exp.WhichOneof("expr_type") == "unresolved_attribute":
403
404
  spark_name = col[0]
404
- qualifiers = input_container.column_map.get_qualifier_for_spark_column(
405
- spark_name
406
- )
405
+ qualifiers: set[
406
+ ColumnQualifier
407
+ ] = input_container.column_map.get_qualifiers_for_spark_column(spark_name)
407
408
  else:
408
- qualifiers = []
409
+ qualifiers = set()
409
410
 
410
411
  raw_aggregations.append((col[0], col[1], qualifiers))
411
412
 
@@ -438,7 +439,7 @@ def map_aggregate(
438
439
  spark_columns: list[str] = []
439
440
  snowpark_columns: list[str] = []
440
441
  snowpark_column_types: list[snowpark_types.DataType] = []
441
- all_qualifiers: list[list[str]] = []
442
+ all_qualifiers: list[set[ColumnQualifier]] = []
442
443
 
443
444
  # Use grouping columns directly without aliases
444
445
  groupings = [col.col for _, col in raw_groupings]