snowpark-connect 0.32.0__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (98) hide show
  1. snowflake/snowpark_connect/column_name_handler.py +92 -27
  2. snowflake/snowpark_connect/column_qualifier.py +0 -4
  3. snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
  4. snowflake/snowpark_connect/expression/map_sql_expression.py +12 -4
  5. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +58 -21
  6. snowflake/snowpark_connect/expression/map_unresolved_function.py +62 -27
  7. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
  8. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
  9. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
  10. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
  11. snowflake/snowpark_connect/relation/map_aggregate.py +2 -4
  12. snowflake/snowpark_connect/relation/map_column_ops.py +5 -0
  13. snowflake/snowpark_connect/relation/map_join.py +218 -146
  14. snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
  15. snowflake/snowpark_connect/relation/map_sql.py +102 -16
  16. snowflake/snowpark_connect/relation/read/map_read_json.py +87 -2
  17. snowflake/snowpark_connect/relation/utils.py +46 -0
  18. snowflake/snowpark_connect/relation/write/map_write.py +186 -275
  19. snowflake/snowpark_connect/resources_initializer.py +25 -13
  20. snowflake/snowpark_connect/server.py +9 -24
  21. snowflake/snowpark_connect/type_mapping.py +2 -0
  22. snowflake/snowpark_connect/typed_column.py +2 -2
  23. snowflake/snowpark_connect/utils/sequence.py +21 -0
  24. snowflake/snowpark_connect/utils/session.py +8 -1
  25. snowflake/snowpark_connect/version.py +1 -1
  26. {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/METADATA +3 -1
  27. {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/RECORD +35 -93
  28. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  29. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  30. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  31. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  32. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  33. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  34. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  35. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  36. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  37. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  39. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  46. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  47. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  48. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  52. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  53. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  54. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  55. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  56. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  57. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  58. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  59. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  60. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  61. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  62. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  63. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  64. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  65. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  66. snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
  67. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  68. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  69. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  70. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  71. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  72. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  73. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
  74. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  75. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  76. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  77. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  78. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  79. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  80. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  81. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  82. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  83. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  84. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  85. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  86. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  87. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  88. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  89. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  90. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  91. {snowpark_connect-0.32.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-connect +0 -0
  92. {snowpark_connect-0.32.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-session +0 -0
  93. {snowpark_connect-0.32.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-submit +0 -0
  94. {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/WHEEL +0 -0
  95. {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE-binary +0 -0
  96. {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE.txt +0 -0
  97. {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/NOTICE-binary +0 -0
  98. {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,21 @@
1
1
  #
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
-
4
+ import dataclasses
5
+ from enum import Enum
5
6
  from functools import reduce
7
+ from typing import Optional
6
8
 
7
9
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
8
10
  from pyspark.errors import AnalysisException
9
11
 
10
12
  import snowflake.snowpark.functions as snowpark_fn
11
13
  from snowflake import snowpark
12
- from snowflake.snowpark_connect.column_name_handler import JoinColumnNameMap
14
+ from snowflake.snowpark.types import StructField, StructType
15
+ from snowflake.snowpark_connect.column_name_handler import (
16
+ JoinColumnNameMap,
17
+ make_unique_snowpark_name,
18
+ )
13
19
  from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
14
20
  from snowflake.snowpark_connect.config import global_config
15
21
  from snowflake.snowpark_connect.constants import COLUMN_METADATA_COLLISION_KEY
@@ -43,6 +49,25 @@ from snowflake.snowpark_connect.utils.telemetry import (
43
49
  USING_COLUMN_NOT_FOUND_ERROR = "[UNRESOLVED_USING_COLUMN_FOR_JOIN] USING column `{0}` not found on the {1} side of the join. The {1}-side columns: {2}"
44
50
 
45
51
 
52
+ class ConditionType(Enum):
53
+ USING_COLUMNS = 1
54
+ JOIN_CONDITION = 2
55
+ NO_CONDITION = 3
56
+
57
+
58
+ @dataclasses.dataclass
59
+ class JoinInfo:
60
+ join_type: str
61
+ condition_type: ConditionType
62
+ join_columns: Optional[list[str]]
63
+
64
+ def has_join_condition(self) -> bool:
65
+ return self.condition_type == ConditionType.JOIN_CONDITION
66
+
67
+ def is_using_columns(self):
68
+ return self.condition_type == ConditionType.USING_COLUMNS
69
+
70
+
46
71
  def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
47
72
  left_container: DataFrameContainer = map_relation(rel.join.left)
48
73
  right_container: DataFrameContainer = map_relation(rel.join.right)
@@ -54,48 +79,11 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
54
79
 
55
80
  left_input: snowpark.DataFrame = left_container.dataframe
56
81
  right_input: snowpark.DataFrame = right_container.dataframe
57
- is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
58
- using_columns = rel.join.using_columns
59
- if is_natural_join:
60
- rel.join.join_type -= NATURAL_JOIN_TYPE_BASE
61
- left_spark_columns = left_container.column_map.get_spark_columns()
62
- right_spark_columns = right_container.column_map.get_spark_columns()
63
- common_spark_columns = [
64
- x for x in left_spark_columns if x in right_spark_columns
65
- ]
66
- using_columns = common_spark_columns
67
-
68
- match rel.join.join_type:
69
- case relation_proto.Join.JOIN_TYPE_UNSPECIFIED:
70
- # TODO: Understand what UNSPECIFIED Join type is
71
- exception = SnowparkConnectNotImplementedError("Unspecified Join Type")
72
- attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
73
- raise exception
74
- case relation_proto.Join.JOIN_TYPE_INNER:
75
- join_type = "inner"
76
- case relation_proto.Join.JOIN_TYPE_FULL_OUTER:
77
- join_type = "full_outer"
78
- case relation_proto.Join.JOIN_TYPE_LEFT_OUTER:
79
- join_type = "left"
80
- case relation_proto.Join.JOIN_TYPE_RIGHT_OUTER:
81
- join_type = "right"
82
- case relation_proto.Join.JOIN_TYPE_LEFT_ANTI:
83
- join_type = "leftanti"
84
- case relation_proto.Join.JOIN_TYPE_LEFT_SEMI:
85
- join_type = "leftsemi"
86
- case relation_proto.Join.JOIN_TYPE_CROSS:
87
- join_type = "cross"
88
- case other:
89
- exception = SnowparkConnectNotImplementedError(f"Other Join Type: {other}")
90
- attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
91
- raise exception
92
-
93
- # This handles case sensitivity for using_columns
94
- case_corrected_right_columns: list[str] = []
95
82
 
96
- if rel.join.HasField("join_condition"):
97
- assert not using_columns
83
+ join_info = _get_join_info(rel, left_container, right_container)
84
+ join_type = join_info.join_type
98
85
 
86
+ if join_info.has_join_condition():
99
87
  left_columns = list(left_container.column_map.spark_to_col.keys())
100
88
  right_columns = list(right_container.column_map.spark_to_col.keys())
101
89
 
@@ -122,72 +110,42 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
122
110
  result: snowpark.DataFrame = left_input.join(
123
111
  right=right_input,
124
112
  on=join_expression.col,
125
- how=join_type,
113
+ how="inner" if join_info.join_type == "cross" else join_info.join_type,
126
114
  lsuffix="_left",
127
115
  rsuffix="_right",
128
116
  )
129
- elif using_columns:
130
- if any(
131
- left_container.column_map.get_snowpark_column_name_from_spark_column_name(
132
- c, allow_non_exists=True, return_first=True
133
- )
134
- is None
135
- for c in using_columns
136
- ):
137
- exception = AnalysisException(
138
- USING_COLUMN_NOT_FOUND_ERROR.format(
139
- next(
140
- c
141
- for c in using_columns
142
- if left_container.column_map.get_snowpark_column_name_from_spark_column_name(
143
- c, allow_non_exists=True, return_first=True
144
- )
145
- is None
146
- ),
147
- "left",
148
- left_container.column_map.get_spark_columns(),
117
+ elif join_info.is_using_columns():
118
+ # TODO: disambiguate snowpark columns for all join condition types
119
+ # disambiguation temporarily done only for using_columns/natural joins to reduce changes
120
+ left_container, right_container = _disambiguate_snowpark_columns(
121
+ left_container, right_container
122
+ )
123
+ left_input = left_container.dataframe
124
+ right_input = right_container.dataframe
125
+
126
+ join_columns = join_info.join_columns
127
+
128
+ def _validate_using_column(
129
+ column: str, container: DataFrameContainer, side: str
130
+ ) -> None:
131
+ if (
132
+ container.column_map.get_snowpark_column_name_from_spark_column_name(
133
+ column, allow_non_exists=True, return_first=True
149
134
  )
150
- )
151
- attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
152
- raise exception
153
- if any(
154
- right_container.column_map.get_snowpark_column_name_from_spark_column_name(
155
- c, allow_non_exists=True, return_first=True
156
- )
157
- is None
158
- for c in using_columns
159
- ):
160
- exception = AnalysisException(
161
- USING_COLUMN_NOT_FOUND_ERROR.format(
162
- next(
163
- c
164
- for c in using_columns
165
- if right_container.column_map.get_snowpark_column_name_from_spark_column_name(
166
- c, allow_non_exists=True, return_first=True
167
- )
168
- is None
169
- ),
170
- "right",
171
- right_container.column_map.get_spark_columns(),
135
+ is None
136
+ ):
137
+ exception = AnalysisException(
138
+ USING_COLUMN_NOT_FOUND_ERROR.format(
139
+ column, side, container.column_map.get_spark_columns()
140
+ )
172
141
  )
173
- )
174
- attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
175
- raise exception
142
+ attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
143
+ raise exception
144
+
145
+ for col in join_columns:
146
+ _validate_using_column(col, left_container, "left")
147
+ _validate_using_column(col, right_container, "right")
176
148
 
177
- # Round trip the using columns through the column map to get the correct names
178
- # in order to support case sensitivity.
179
- # TODO: case_corrected_left_columns / case_corrected_right_columns may no longer be required as Snowpark dataframe preserves the column casing now.
180
- case_corrected_left_columns = left_container.column_map.get_spark_column_names_from_snowpark_column_names(
181
- left_container.column_map.get_snowpark_column_names_from_spark_column_names(
182
- list(using_columns), return_first=True
183
- )
184
- )
185
- case_corrected_right_columns = right_container.column_map.get_spark_column_names_from_snowpark_column_names(
186
- right_container.column_map.get_snowpark_column_names_from_spark_column_names(
187
- list(using_columns), return_first=True
188
- )
189
- )
190
- using_columns = zip(case_corrected_left_columns, case_corrected_right_columns)
191
149
  # We cannot assume that Snowpark will have the same names for left and right columns,
192
150
  # so we convert ["a", "b"] into (left["a"] == right["a"] & left["b"] == right["b"]),
193
151
  # then drop right["a"] and right["b"].
@@ -195,16 +153,16 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
195
153
  (
196
154
  left_input[
197
155
  left_container.column_map.get_snowpark_column_name_from_spark_column_name(
198
- lft, return_first=True
156
+ spark_name, return_first=True
199
157
  )
200
158
  ],
201
159
  right_input[
202
160
  right_container.column_map.get_snowpark_column_name_from_spark_column_name(
203
- r, return_first=True
161
+ spark_name, return_first=True
204
162
  )
205
163
  ],
206
164
  )
207
- for lft, r in using_columns
165
+ for spark_name in join_columns
208
166
  ]
209
167
  joined_df = left_input.join(
210
168
  right=right_input,
@@ -240,10 +198,19 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
240
198
  exception = SparkException.implicit_cartesian_product("inner")
241
199
  attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
242
200
  raise exception
243
- result: snowpark.DataFrame = left_input.join(
244
- right=right_input,
245
- how=join_type,
246
- )
201
+ # For outer joins without a condition, we need to use a TRUE condition
202
+ # to match Spark's behavior.
203
+ if join_type in ["left", "right", "full_outer"]:
204
+ result: snowpark.DataFrame = left_input.join(
205
+ right=right_input,
206
+ on=snowpark_fn.lit(True),
207
+ how=join_type,
208
+ )
209
+ else:
210
+ result: snowpark.DataFrame = left_input.join(
211
+ right=right_input,
212
+ how=join_type,
213
+ )
247
214
 
248
215
  if join_type in ["leftanti", "leftsemi"]:
249
216
  # Join types that only return columns from the left side:
@@ -253,39 +220,26 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
253
220
  spark_cols_after_join: list[str] = left_container.column_map.get_spark_columns()
254
221
  qualifiers = left_container.column_map.get_qualifiers()
255
222
  else:
256
- # Add Spark columns and plan_ids from left DF
257
- spark_cols_after_join: list[str] = list(
258
- left_container.column_map.get_spark_columns()
259
- ) + [
260
- spark_col
261
- for i, spark_col in enumerate(
262
- right_container.column_map.get_spark_columns()
223
+ if not join_info.is_using_columns():
224
+ spark_cols_after_join: list[str] = (
225
+ left_container.column_map.get_spark_columns()
226
+ + right_container.column_map.get_spark_columns()
263
227
  )
264
- if spark_col not in case_corrected_right_columns
265
- or spark_col
266
- in right_container.column_map.get_spark_columns()[
267
- :i
268
- ] # this is to make sure we only remove the column once
269
- ]
270
-
271
- qualifiers: list[set[ColumnQualifier]] = list(
272
- left_container.column_map.get_qualifiers()
273
- ) + [
274
- {right_container.column_map.get_qualifier_for_spark_column(spark_col)}
275
- for i, spark_col in enumerate(
276
- right_container.column_map.get_spark_columns()
228
+ qualifiers: list[set[ColumnQualifier]] = (
229
+ left_container.column_map.get_qualifiers()
230
+ + right_container.column_map.get_qualifiers()
277
231
  )
278
- if spark_col not in case_corrected_right_columns
279
- or spark_col
280
- in right_container.column_map.get_spark_columns()[
281
- :i
282
- ] # this is to make sure we only remove the column once]
283
- ]
284
-
285
- column_metadata = {}
286
- if left_container.column_map.column_metadata:
287
- column_metadata.update(left_container.column_map.column_metadata)
232
+ else:
233
+ # get columns after join
234
+ joined_columns = left_container.column_map.get_columns_after_join(
235
+ right_container.column_map, join_info.join_columns
236
+ )
237
+ spark_cols_after_join: list[str] = [c.spark_name for c in joined_columns]
238
+ qualifiers: list[set[ColumnQualifier]] = [
239
+ c.qualifiers for c in joined_columns
240
+ ]
288
241
 
242
+ column_metadata = dict(left_container.column_map.column_metadata or {})
289
243
  if right_container.column_map.column_metadata:
290
244
  for key, value in right_container.column_map.column_metadata.items():
291
245
  if key not in column_metadata:
@@ -318,7 +272,7 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
318
272
  # After a USING join, references to the right dataframe's columns should resolve
319
273
  # to the result dataframe that contains the merged columns
320
274
  if (
321
- using_columns
275
+ join_info.is_using_columns()
322
276
  and rel.join.right.HasField("common")
323
277
  and rel.join.right.common.HasField("plan_id")
324
278
  ):
@@ -328,7 +282,7 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
328
282
  # For FULL OUTER joins, we also need to map the left dataframe's plan_id
329
283
  # since both columns are replaced with a coalesced column
330
284
  if (
331
- using_columns
285
+ join_info.is_using_columns()
332
286
  and join_type == "full_outer"
333
287
  and rel.join.left.HasField("common")
334
288
  and rel.join.left.common.HasField("plan_id")
@@ -336,12 +290,12 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
336
290
  left_plan_id = rel.join.left.common.plan_id
337
291
  set_plan_id_map(left_plan_id, result_container)
338
292
 
339
- if rel.join.using_columns:
293
+ if join_info.is_using_columns():
340
294
  # When join 'using_columns', the 'join columns' should go first in result DF.
341
- idxs_to_shift = [
342
- spark_cols_after_join.index(left_col_name)
343
- for left_col_name in case_corrected_left_columns
344
- ]
295
+ # we're only shifting left side columns, since we dropped the right-side ones
296
+ idxs_to_shift = left_container.column_map.get_column_indexes(
297
+ join_info.join_columns
298
+ )
345
299
 
346
300
  def reorder(lst: list) -> list:
347
301
  to_move = [lst[i] for i in idxs_to_shift]
@@ -370,3 +324,121 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
370
324
  )
371
325
 
372
326
  return result_container
327
+
328
+
329
+ def _get_join_info(
330
+ rel: relation_proto.Relation, left: DataFrameContainer, right: DataFrameContainer
331
+ ) -> JoinInfo:
332
+ """
333
+ Gathers basic information about the join, and performs basic assertions
334
+ """
335
+
336
+ is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
337
+ join_columns = rel.join.using_columns
338
+ if is_natural_join:
339
+ rel.join.join_type -= NATURAL_JOIN_TYPE_BASE
340
+ left_spark_columns = left.column_map.get_spark_columns()
341
+ right_spark_columns = right.column_map.get_spark_columns()
342
+ common_spark_columns = [
343
+ x for x in left_spark_columns if x in right_spark_columns
344
+ ]
345
+ join_columns = common_spark_columns
346
+
347
+ match rel.join.join_type:
348
+ case relation_proto.Join.JOIN_TYPE_UNSPECIFIED:
349
+ # TODO: Understand what UNSPECIFIED Join type is
350
+ exception = SnowparkConnectNotImplementedError("Unspecified Join Type")
351
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
352
+ raise exception
353
+ case relation_proto.Join.JOIN_TYPE_INNER:
354
+ join_type = "inner"
355
+ case relation_proto.Join.JOIN_TYPE_FULL_OUTER:
356
+ join_type = "full_outer"
357
+ case relation_proto.Join.JOIN_TYPE_LEFT_OUTER:
358
+ join_type = "left"
359
+ case relation_proto.Join.JOIN_TYPE_RIGHT_OUTER:
360
+ join_type = "right"
361
+ case relation_proto.Join.JOIN_TYPE_LEFT_ANTI:
362
+ join_type = "leftanti"
363
+ case relation_proto.Join.JOIN_TYPE_LEFT_SEMI:
364
+ join_type = "leftsemi"
365
+ case relation_proto.Join.JOIN_TYPE_CROSS:
366
+ join_type = "cross"
367
+ case other:
368
+ exception = SnowparkConnectNotImplementedError(f"Other Join Type: {other}")
369
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
370
+ raise exception
371
+
372
+ has_join_condition = rel.join.HasField("join_condition")
373
+ is_using_columns = bool(join_columns)
374
+
375
+ if has_join_condition:
376
+ assert not is_using_columns
377
+
378
+ condition_type = ConditionType.NO_CONDITION
379
+ if has_join_condition:
380
+ condition_type = ConditionType.JOIN_CONDITION
381
+ elif is_using_columns:
382
+ condition_type = ConditionType.USING_COLUMNS
383
+
384
+ return JoinInfo(join_type, condition_type, join_columns)
385
+
386
+
387
+ def _disambiguate_snowpark_columns(
388
+ left: DataFrameContainer, right: DataFrameContainer
389
+ ) -> tuple[DataFrameContainer, DataFrameContainer]:
390
+ conflicting_snowpark_columns = left.column_map.get_conflicting_snowpark_columns(
391
+ right.column_map
392
+ )
393
+
394
+ if not conflicting_snowpark_columns:
395
+ return left, right
396
+
397
+ # rename and create new containers
398
+ return _disambiguate_container(
399
+ left, conflicting_snowpark_columns
400
+ ), _disambiguate_container(right, conflicting_snowpark_columns)
401
+
402
+
403
+ def _disambiguate_container(
404
+ container: DataFrameContainer, conflicting_snowpark_columns: set[str]
405
+ ) -> DataFrameContainer:
406
+ column_map = container.column_map
407
+ disambiguated_columns = []
408
+ disambiguated_snowpark_names = []
409
+ for c in column_map.columns:
410
+ if c.snowpark_name in conflicting_snowpark_columns:
411
+ # alias snowpark column with a new unique name
412
+ new_name = make_unique_snowpark_name(c.spark_name)
413
+ disambiguated_snowpark_names.append(new_name)
414
+ disambiguated_columns.append(
415
+ snowpark_fn.col(c.snowpark_name).alias(new_name)
416
+ )
417
+ else:
418
+ disambiguated_snowpark_names.append(c.snowpark_name)
419
+ disambiguated_columns.append(snowpark_fn.col(c.snowpark_name))
420
+
421
+ disambiguated_df = container.dataframe.select(*disambiguated_columns)
422
+
423
+ def _get_new_schema():
424
+ old_schema = container.dataframe.schema
425
+ if not old_schema.fields:
426
+ return StructType([])
427
+
428
+ new_fields = []
429
+ for i, name in enumerate(disambiguated_snowpark_names):
430
+ f = old_schema.fields[i]
431
+ new_fields.append(
432
+ StructField(name, f.datatype, nullable=f.nullable, _is_column=True)
433
+ )
434
+ return StructType(new_fields)
435
+
436
+ return DataFrameContainer.create_with_column_mapping(
437
+ dataframe=disambiguated_df,
438
+ spark_column_names=column_map.get_spark_columns(),
439
+ snowpark_column_names=disambiguated_snowpark_names,
440
+ column_metadata=column_map.column_metadata,
441
+ column_qualifiers=column_map.get_qualifiers(),
442
+ table_name=container.table_name,
443
+ cached_schema_getter=_get_new_schema,
444
+ )
@@ -45,6 +45,61 @@ from snowflake.snowpark_connect.utils.telemetry import (
45
45
  )
46
46
 
47
47
 
48
+ def cast_columns(
49
+ df_container: DataFrameContainer,
50
+ df_dtypes: list[snowpark.types.DataType],
51
+ target_dtypes: list[snowpark.types.DataType],
52
+ column_map: ColumnNameMap,
53
+ ):
54
+ df: snowpark.DataFrame = df_container.dataframe
55
+ if df_dtypes == target_dtypes:
56
+ return df_container
57
+ # Use cached schema if available to avoid triggering extra queries
58
+ if (
59
+ hasattr(df_container, "cached_schema_getter")
60
+ and df_container.cached_schema_getter is not None
61
+ ):
62
+ df_schema = df_container.cached_schema_getter()
63
+ else:
64
+ df_schema = df.schema # Get current schema
65
+ new_columns = []
66
+
67
+ for i, field in enumerate(df_schema.fields):
68
+ col_name = field.name
69
+ current_type = field.datatype
70
+ target_type = target_dtypes[i]
71
+
72
+ if current_type != target_type:
73
+ new_columns.append(df[col_name].cast(target_type).alias(col_name))
74
+ else:
75
+ new_columns.append(df[col_name])
76
+
77
+ new_df = df.select(new_columns)
78
+ return DataFrameContainer.create_with_column_mapping(
79
+ dataframe=new_df,
80
+ spark_column_names=column_map.get_spark_columns(),
81
+ snowpark_column_names=column_map.get_snowpark_columns(),
82
+ snowpark_column_types=target_dtypes,
83
+ column_metadata=column_map.column_metadata,
84
+ parent_column_name_map=column_map,
85
+ )
86
+
87
+
88
+ def get_schema_from_result(
89
+ result: DataFrameContainer,
90
+ ) -> StructType:
91
+ """
92
+ Get schema from a DataFrameContainer, using cached schema if available to avoid extra queries.
93
+ """
94
+ if (
95
+ hasattr(result, "cached_schema_getter")
96
+ and result.cached_schema_getter is not None
97
+ ):
98
+ return result.cached_schema_getter()
99
+ else:
100
+ return result.dataframe.schema
101
+
102
+
48
103
  def map_deduplicate(
49
104
  rel: relation_proto.Relation,
50
105
  ) -> DataFrameContainer:
@@ -205,21 +260,8 @@ def map_union(
205
260
 
206
261
  # workaround for unstructured type vs structured type
207
262
  # Use cached schema if available to avoid triggering extra queries
208
- if (
209
- hasattr(left_result, "cached_schema_getter")
210
- and left_result.cached_schema_getter is not None
211
- ):
212
- left_schema = left_result.cached_schema_getter()
213
- else:
214
- left_schema = left_df.schema
215
-
216
- if (
217
- hasattr(right_result, "cached_schema_getter")
218
- and right_result.cached_schema_getter is not None
219
- ):
220
- right_schema = right_result.cached_schema_getter()
221
- else:
222
- right_schema = right_df.schema
263
+ left_schema = get_schema_from_result(left_result)
264
+ right_schema = get_schema_from_result(right_result)
223
265
 
224
266
  left_dtypes = [field.datatype for field in left_schema.fields]
225
267
  right_dtypes = [field.datatype for field in right_schema.fields]
@@ -257,6 +299,29 @@ def map_union(
257
299
  # Union of any type with null type is of the other type
258
300
  target_left_dtypes.append(other_t)
259
301
  target_right_dtypes.append(other_t)
302
+ case (snowpark.types.DecimalType(), snowpark.types.DecimalType()):
303
+ # Widen decimal types to accommodate both sides
304
+ # Calculate the maximum scale and maximum integer digits
305
+ left_integer_digits = left_type.precision - left_type.scale
306
+ right_integer_digits = right_type.precision - right_type.scale
307
+
308
+ # The common type needs to accommodate:
309
+ # - The maximum number of digits after the decimal point (scale)
310
+ # - The maximum number of digits before the decimal point (integer digits)
311
+ common_scale = max(left_type.scale, right_type.scale)
312
+ common_integer_digits = max(
313
+ left_integer_digits, right_integer_digits
314
+ )
315
+ common_precision = min(38, common_scale + common_integer_digits)
316
+
317
+ # Ensure scale doesn't exceed precision
318
+ common_scale = min(common_scale, common_precision)
319
+
320
+ common_type = snowpark.types.DecimalType(
321
+ common_precision, common_scale
322
+ )
323
+ target_left_dtypes.append(common_type)
324
+ target_right_dtypes.append(common_type)
260
325
  case (snowpark.types.BooleanType(), _) | (
261
326
  _,
262
327
  snowpark.types.BooleanType(),
@@ -272,49 +337,24 @@ def map_union(
272
337
  raise exception
273
338
  target_left_dtypes.append(left_type)
274
339
  target_right_dtypes.append(right_type)
340
+ case (
341
+ snowpark.types.TimestampType()
342
+ | snowpark.types.DateType()
343
+ | snowpark.types._NumericType(),
344
+ snowpark.types.StringType(),
345
+ ) | (
346
+ snowpark.types.StringType(),
347
+ snowpark.types.TimestampType()
348
+ | snowpark.types.DateType()
349
+ | snowpark.types._NumericType(),
350
+ ) if not spark_sql_ansi_enabled:
351
+ common_type = snowpark.types.StringType()
352
+ target_left_dtypes.append(common_type)
353
+ target_right_dtypes.append(common_type)
275
354
  case _:
276
355
  target_left_dtypes.append(left_type)
277
356
  target_right_dtypes.append(right_type)
278
357
 
279
- def cast_columns(
280
- df_container: DataFrameContainer,
281
- df_dtypes: list[snowpark.types.DataType],
282
- target_dtypes: list[snowpark.types.DataType],
283
- column_map: ColumnNameMap,
284
- ):
285
- df: snowpark.DataFrame = df_container.dataframe
286
- if df_dtypes == target_dtypes:
287
- return df_container
288
- # Use cached schema if available to avoid triggering extra queries
289
- if (
290
- hasattr(df_container, "cached_schema_getter")
291
- and df_container.cached_schema_getter is not None
292
- ):
293
- df_schema = df_container.cached_schema_getter()
294
- else:
295
- df_schema = df.schema # Get current schema
296
- new_columns = []
297
-
298
- for i, field in enumerate(df_schema.fields):
299
- col_name = field.name
300
- current_type = field.datatype
301
- target_type = target_dtypes[i]
302
-
303
- if current_type != target_type:
304
- new_columns.append(df[col_name].cast(target_type).alias(col_name))
305
- else:
306
- new_columns.append(df[col_name])
307
-
308
- new_df = df.select(new_columns)
309
- return DataFrameContainer.create_with_column_mapping(
310
- dataframe=new_df,
311
- spark_column_names=column_map.get_spark_columns(),
312
- snowpark_column_names=column_map.get_snowpark_columns(),
313
- snowpark_column_types=target_dtypes,
314
- column_metadata=column_map.column_metadata,
315
- parent_column_name_map=column_map,
316
- )
317
-
318
358
  left_result = cast_columns(
319
359
  left_result,
320
360
  left_dtypes,
@@ -527,6 +567,48 @@ def map_except(
527
567
  left_df = left_result.dataframe
528
568
  right_df = right_result.dataframe
529
569
 
570
+ # workaround for unstructured type vs structured type
571
+ # Use cached schema if available to avoid triggering extra queries
572
+ left_schema = get_schema_from_result(left_result)
573
+ right_schema = get_schema_from_result(right_result)
574
+
575
+ left_dtypes = [field.datatype for field in left_schema.fields]
576
+ right_dtypes = [field.datatype for field in right_schema.fields]
577
+
578
+ if left_dtypes != right_dtypes and not rel.set_op.by_name:
579
+ if len(left_dtypes) != len(right_dtypes):
580
+ exception = AnalysisException("UNION: the number of columns must match")
581
+ attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
582
+ raise exception
583
+ target_left_dtypes, target_right_dtypes = [], []
584
+ for left_type, right_type in zip(left_dtypes, right_dtypes):
585
+ match (left_type, right_type):
586
+ case (snowpark.types._NumericType(), snowpark.types.StringType()) | (
587
+ snowpark.types.StringType(),
588
+ snowpark.types._NumericType(),
589
+ ):
590
+ common_type = snowpark.types.StringType()
591
+ target_left_dtypes.append(common_type)
592
+ target_right_dtypes.append(common_type)
593
+ case _:
594
+ target_left_dtypes.append(left_type)
595
+ target_right_dtypes.append(right_type)
596
+
597
+ left_result = cast_columns(
598
+ left_result,
599
+ left_dtypes,
600
+ target_left_dtypes,
601
+ left_result.column_map,
602
+ )
603
+ right_result = cast_columns(
604
+ right_result,
605
+ right_dtypes,
606
+ target_right_dtypes,
607
+ right_result.column_map,
608
+ )
609
+ left_df = left_result.dataframe
610
+ right_df = right_result.dataframe
611
+
530
612
  if rel.set_op.is_all:
531
613
  # Snowflake except removes all duplicated rows. In order to handle the case,
532
614
  # we add a partition row number column to the df to make duplicated rows unique to