snowpark-connect 0.31.0__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (111) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/column_name_handler.py +143 -105
  3. snowflake/snowpark_connect/column_qualifier.py +43 -0
  4. snowflake/snowpark_connect/dataframe_container.py +3 -2
  5. snowflake/snowpark_connect/execute_plan/map_execution_command.py +4 -2
  6. snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
  7. snowflake/snowpark_connect/expression/map_expression.py +5 -4
  8. snowflake/snowpark_connect/expression/map_extension.py +12 -6
  9. snowflake/snowpark_connect/expression/map_sql_expression.py +50 -7
  10. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +62 -25
  11. snowflake/snowpark_connect/expression/map_unresolved_function.py +924 -127
  12. snowflake/snowpark_connect/expression/map_unresolved_star.py +9 -7
  13. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
  14. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
  15. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
  16. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
  17. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +4 -1
  18. snowflake/snowpark_connect/relation/map_aggregate.py +6 -5
  19. snowflake/snowpark_connect/relation/map_column_ops.py +9 -3
  20. snowflake/snowpark_connect/relation/map_extension.py +10 -9
  21. snowflake/snowpark_connect/relation/map_join.py +219 -144
  22. snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
  23. snowflake/snowpark_connect/relation/map_sql.py +134 -16
  24. snowflake/snowpark_connect/relation/map_subquery_alias.py +4 -1
  25. snowflake/snowpark_connect/relation/read/map_read_json.py +87 -2
  26. snowflake/snowpark_connect/relation/read/map_read_table.py +6 -3
  27. snowflake/snowpark_connect/relation/utils.py +46 -0
  28. snowflake/snowpark_connect/relation/write/map_write.py +215 -289
  29. snowflake/snowpark_connect/resources_initializer.py +25 -13
  30. snowflake/snowpark_connect/server.py +10 -26
  31. snowflake/snowpark_connect/type_mapping.py +38 -3
  32. snowflake/snowpark_connect/typed_column.py +8 -6
  33. snowflake/snowpark_connect/utils/sequence.py +21 -0
  34. snowflake/snowpark_connect/utils/session.py +27 -4
  35. snowflake/snowpark_connect/version.py +1 -1
  36. snowflake/snowpark_decoder/dp_session.py +1 -1
  37. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/METADATA +7 -2
  38. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/RECORD +46 -105
  39. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  46. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  47. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  48. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  52. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  53. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  54. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  55. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  56. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  57. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  58. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  59. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  60. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  61. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  62. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  63. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  64. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  65. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  66. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  67. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  68. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  69. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  70. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  71. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  72. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  73. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  74. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  75. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  76. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  77. snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
  78. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  79. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  80. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  81. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  82. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  83. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  84. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
  85. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  86. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  87. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  88. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  89. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  90. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  91. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  92. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  93. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  94. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  95. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  96. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  97. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  98. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  99. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  100. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  101. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  102. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
  103. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
  104. {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-connect +0 -0
  105. {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-session +0 -0
  106. {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-submit +0 -0
  107. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/WHEEL +0 -0
  108. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE-binary +0 -0
  109. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE.txt +0 -0
  110. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/NOTICE-binary +0 -0
  111. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,22 @@
1
1
  #
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
-
4
+ import dataclasses
5
+ from enum import Enum
5
6
  from functools import reduce
7
+ from typing import Optional
6
8
 
7
9
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
8
10
  from pyspark.errors import AnalysisException
9
11
 
10
12
  import snowflake.snowpark.functions as snowpark_fn
11
13
  from snowflake import snowpark
12
- from snowflake.snowpark_connect.column_name_handler import JoinColumnNameMap
14
+ from snowflake.snowpark.types import StructField, StructType
15
+ from snowflake.snowpark_connect.column_name_handler import (
16
+ JoinColumnNameMap,
17
+ make_unique_snowpark_name,
18
+ )
19
+ from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
13
20
  from snowflake.snowpark_connect.config import global_config
14
21
  from snowflake.snowpark_connect.constants import COLUMN_METADATA_COLLISION_KEY
15
22
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
@@ -42,6 +49,25 @@ from snowflake.snowpark_connect.utils.telemetry import (
42
49
  USING_COLUMN_NOT_FOUND_ERROR = "[UNRESOLVED_USING_COLUMN_FOR_JOIN] USING column `{0}` not found on the {1} side of the join. The {1}-side columns: {2}"
43
50
 
44
51
 
52
+ class ConditionType(Enum):
53
+ USING_COLUMNS = 1
54
+ JOIN_CONDITION = 2
55
+ NO_CONDITION = 3
56
+
57
+
58
+ @dataclasses.dataclass
59
+ class JoinInfo:
60
+ join_type: str
61
+ condition_type: ConditionType
62
+ join_columns: Optional[list[str]]
63
+
64
+ def has_join_condition(self) -> bool:
65
+ return self.condition_type == ConditionType.JOIN_CONDITION
66
+
67
+ def is_using_columns(self):
68
+ return self.condition_type == ConditionType.USING_COLUMNS
69
+
70
+
45
71
  def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
46
72
  left_container: DataFrameContainer = map_relation(rel.join.left)
47
73
  right_container: DataFrameContainer = map_relation(rel.join.right)
@@ -53,48 +79,11 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
53
79
 
54
80
  left_input: snowpark.DataFrame = left_container.dataframe
55
81
  right_input: snowpark.DataFrame = right_container.dataframe
56
- is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
57
- using_columns = rel.join.using_columns
58
- if is_natural_join:
59
- rel.join.join_type -= NATURAL_JOIN_TYPE_BASE
60
- left_spark_columns = left_container.column_map.get_spark_columns()
61
- right_spark_columns = right_container.column_map.get_spark_columns()
62
- common_spark_columns = [
63
- x for x in left_spark_columns if x in right_spark_columns
64
- ]
65
- using_columns = common_spark_columns
66
-
67
- match rel.join.join_type:
68
- case relation_proto.Join.JOIN_TYPE_UNSPECIFIED:
69
- # TODO: Understand what UNSPECIFIED Join type is
70
- exception = SnowparkConnectNotImplementedError("Unspecified Join Type")
71
- attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
72
- raise exception
73
- case relation_proto.Join.JOIN_TYPE_INNER:
74
- join_type = "inner"
75
- case relation_proto.Join.JOIN_TYPE_FULL_OUTER:
76
- join_type = "full_outer"
77
- case relation_proto.Join.JOIN_TYPE_LEFT_OUTER:
78
- join_type = "left"
79
- case relation_proto.Join.JOIN_TYPE_RIGHT_OUTER:
80
- join_type = "right"
81
- case relation_proto.Join.JOIN_TYPE_LEFT_ANTI:
82
- join_type = "leftanti"
83
- case relation_proto.Join.JOIN_TYPE_LEFT_SEMI:
84
- join_type = "leftsemi"
85
- case relation_proto.Join.JOIN_TYPE_CROSS:
86
- join_type = "cross"
87
- case other:
88
- exception = SnowparkConnectNotImplementedError(f"Other Join Type: {other}")
89
- attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
90
- raise exception
91
-
92
- # This handles case sensitivity for using_columns
93
- case_corrected_right_columns: list[str] = []
94
82
 
95
- if rel.join.HasField("join_condition"):
96
- assert not using_columns
83
+ join_info = _get_join_info(rel, left_container, right_container)
84
+ join_type = join_info.join_type
97
85
 
86
+ if join_info.has_join_condition():
98
87
  left_columns = list(left_container.column_map.spark_to_col.keys())
99
88
  right_columns = list(right_container.column_map.spark_to_col.keys())
100
89
 
@@ -121,72 +110,42 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
121
110
  result: snowpark.DataFrame = left_input.join(
122
111
  right=right_input,
123
112
  on=join_expression.col,
124
- how=join_type,
113
+ how="inner" if join_info.join_type == "cross" else join_info.join_type,
125
114
  lsuffix="_left",
126
115
  rsuffix="_right",
127
116
  )
128
- elif using_columns:
129
- if any(
130
- left_container.column_map.get_snowpark_column_name_from_spark_column_name(
131
- c, allow_non_exists=True, return_first=True
132
- )
133
- is None
134
- for c in using_columns
135
- ):
136
- exception = AnalysisException(
137
- USING_COLUMN_NOT_FOUND_ERROR.format(
138
- next(
139
- c
140
- for c in using_columns
141
- if left_container.column_map.get_snowpark_column_name_from_spark_column_name(
142
- c, allow_non_exists=True, return_first=True
143
- )
144
- is None
145
- ),
146
- "left",
147
- left_container.column_map.get_spark_columns(),
117
+ elif join_info.is_using_columns():
118
+ # TODO: disambiguate snowpark columns for all join condition types
119
+ # disambiguation temporarily done only for using_columns/natural joins to reduce changes
120
+ left_container, right_container = _disambiguate_snowpark_columns(
121
+ left_container, right_container
122
+ )
123
+ left_input = left_container.dataframe
124
+ right_input = right_container.dataframe
125
+
126
+ join_columns = join_info.join_columns
127
+
128
+ def _validate_using_column(
129
+ column: str, container: DataFrameContainer, side: str
130
+ ) -> None:
131
+ if (
132
+ container.column_map.get_snowpark_column_name_from_spark_column_name(
133
+ column, allow_non_exists=True, return_first=True
148
134
  )
149
- )
150
- attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
151
- raise exception
152
- if any(
153
- right_container.column_map.get_snowpark_column_name_from_spark_column_name(
154
- c, allow_non_exists=True, return_first=True
155
- )
156
- is None
157
- for c in using_columns
158
- ):
159
- exception = AnalysisException(
160
- USING_COLUMN_NOT_FOUND_ERROR.format(
161
- next(
162
- c
163
- for c in using_columns
164
- if right_container.column_map.get_snowpark_column_name_from_spark_column_name(
165
- c, allow_non_exists=True, return_first=True
166
- )
167
- is None
168
- ),
169
- "right",
170
- right_container.column_map.get_spark_columns(),
135
+ is None
136
+ ):
137
+ exception = AnalysisException(
138
+ USING_COLUMN_NOT_FOUND_ERROR.format(
139
+ column, side, container.column_map.get_spark_columns()
140
+ )
171
141
  )
172
- )
173
- attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
174
- raise exception
142
+ attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
143
+ raise exception
144
+
145
+ for col in join_columns:
146
+ _validate_using_column(col, left_container, "left")
147
+ _validate_using_column(col, right_container, "right")
175
148
 
176
- # Round trip the using columns through the column map to get the correct names
177
- # in order to support case sensitivity.
178
- # TODO: case_corrected_left_columns / case_corrected_right_columns may no longer be required as Snowpark dataframe preserves the column casing now.
179
- case_corrected_left_columns = left_container.column_map.get_spark_column_names_from_snowpark_column_names(
180
- left_container.column_map.get_snowpark_column_names_from_spark_column_names(
181
- list(using_columns), return_first=True
182
- )
183
- )
184
- case_corrected_right_columns = right_container.column_map.get_spark_column_names_from_snowpark_column_names(
185
- right_container.column_map.get_snowpark_column_names_from_spark_column_names(
186
- list(using_columns), return_first=True
187
- )
188
- )
189
- using_columns = zip(case_corrected_left_columns, case_corrected_right_columns)
190
149
  # We cannot assume that Snowpark will have the same names for left and right columns,
191
150
  # so we convert ["a", "b"] into (left["a"] == right["a"] & left["b"] == right["b"]),
192
151
  # then drop right["a"] and right["b"].
@@ -194,16 +153,16 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
194
153
  (
195
154
  left_input[
196
155
  left_container.column_map.get_snowpark_column_name_from_spark_column_name(
197
- lft, return_first=True
156
+ spark_name, return_first=True
198
157
  )
199
158
  ],
200
159
  right_input[
201
160
  right_container.column_map.get_snowpark_column_name_from_spark_column_name(
202
- r, return_first=True
161
+ spark_name, return_first=True
203
162
  )
204
163
  ],
205
164
  )
206
- for lft, r in using_columns
165
+ for spark_name in join_columns
207
166
  ]
208
167
  joined_df = left_input.join(
209
168
  right=right_input,
@@ -239,10 +198,19 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
239
198
  exception = SparkException.implicit_cartesian_product("inner")
240
199
  attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
241
200
  raise exception
242
- result: snowpark.DataFrame = left_input.join(
243
- right=right_input,
244
- how=join_type,
245
- )
201
+ # For outer joins without a condition, we need to use a TRUE condition
202
+ # to match Spark's behavior.
203
+ if join_type in ["left", "right", "full_outer"]:
204
+ result: snowpark.DataFrame = left_input.join(
205
+ right=right_input,
206
+ on=snowpark_fn.lit(True),
207
+ how=join_type,
208
+ )
209
+ else:
210
+ result: snowpark.DataFrame = left_input.join(
211
+ right=right_input,
212
+ how=join_type,
213
+ )
246
214
 
247
215
  if join_type in ["leftanti", "leftsemi"]:
248
216
  # Join types that only return columns from the left side:
@@ -252,37 +220,26 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
252
220
  spark_cols_after_join: list[str] = left_container.column_map.get_spark_columns()
253
221
  qualifiers = left_container.column_map.get_qualifiers()
254
222
  else:
255
- # Add Spark columns and plan_ids from left DF
256
- spark_cols_after_join: list[str] = list(
257
- left_container.column_map.get_spark_columns()
258
- ) + [
259
- spark_col
260
- for i, spark_col in enumerate(
261
- right_container.column_map.get_spark_columns()
223
+ if not join_info.is_using_columns():
224
+ spark_cols_after_join: list[str] = (
225
+ left_container.column_map.get_spark_columns()
226
+ + right_container.column_map.get_spark_columns()
262
227
  )
263
- if spark_col not in case_corrected_right_columns
264
- or spark_col
265
- in right_container.column_map.get_spark_columns()[
266
- :i
267
- ] # this is to make sure we only remove the column once
268
- ]
269
-
270
- qualifiers = list(left_container.column_map.get_qualifiers()) + [
271
- right_container.column_map.get_qualifier_for_spark_column(spark_col)
272
- for i, spark_col in enumerate(
273
- right_container.column_map.get_spark_columns()
228
+ qualifiers: list[set[ColumnQualifier]] = (
229
+ left_container.column_map.get_qualifiers()
230
+ + right_container.column_map.get_qualifiers()
274
231
  )
275
- if spark_col not in case_corrected_right_columns
276
- or spark_col
277
- in right_container.column_map.get_spark_columns()[
278
- :i
279
- ] # this is to make sure we only remove the column once]
280
- ]
281
-
282
- column_metadata = {}
283
- if left_container.column_map.column_metadata:
284
- column_metadata.update(left_container.column_map.column_metadata)
232
+ else:
233
+ # get columns after join
234
+ joined_columns = left_container.column_map.get_columns_after_join(
235
+ right_container.column_map, join_info.join_columns
236
+ )
237
+ spark_cols_after_join: list[str] = [c.spark_name for c in joined_columns]
238
+ qualifiers: list[set[ColumnQualifier]] = [
239
+ c.qualifiers for c in joined_columns
240
+ ]
285
241
 
242
+ column_metadata = dict(left_container.column_map.column_metadata or {})
286
243
  if right_container.column_map.column_metadata:
287
244
  for key, value in right_container.column_map.column_metadata.items():
288
245
  if key not in column_metadata:
@@ -315,7 +272,7 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
315
272
  # After a USING join, references to the right dataframe's columns should resolve
316
273
  # to the result dataframe that contains the merged columns
317
274
  if (
318
- using_columns
275
+ join_info.is_using_columns()
319
276
  and rel.join.right.HasField("common")
320
277
  and rel.join.right.common.HasField("plan_id")
321
278
  ):
@@ -325,7 +282,7 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
325
282
  # For FULL OUTER joins, we also need to map the left dataframe's plan_id
326
283
  # since both columns are replaced with a coalesced column
327
284
  if (
328
- using_columns
285
+ join_info.is_using_columns()
329
286
  and join_type == "full_outer"
330
287
  and rel.join.left.HasField("common")
331
288
  and rel.join.left.common.HasField("plan_id")
@@ -333,12 +290,12 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
333
290
  left_plan_id = rel.join.left.common.plan_id
334
291
  set_plan_id_map(left_plan_id, result_container)
335
292
 
336
- if rel.join.using_columns:
293
+ if join_info.is_using_columns():
337
294
  # When join 'using_columns', the 'join columns' should go first in result DF.
338
- idxs_to_shift = [
339
- spark_cols_after_join.index(left_col_name)
340
- for left_col_name in case_corrected_left_columns
341
- ]
295
+ # we're only shifting left side columns, since we dropped the right-side ones
296
+ idxs_to_shift = left_container.column_map.get_column_indexes(
297
+ join_info.join_columns
298
+ )
342
299
 
343
300
  def reorder(lst: list) -> list:
344
301
  to_move = [lst[i] for i in idxs_to_shift]
@@ -367,3 +324,121 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
367
324
  )
368
325
 
369
326
  return result_container
327
+
328
+
329
+ def _get_join_info(
330
+ rel: relation_proto.Relation, left: DataFrameContainer, right: DataFrameContainer
331
+ ) -> JoinInfo:
332
+ """
333
+ Gathers basic information about the join, and performs basic assertions
334
+ """
335
+
336
+ is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
337
+ join_columns = rel.join.using_columns
338
+ if is_natural_join:
339
+ rel.join.join_type -= NATURAL_JOIN_TYPE_BASE
340
+ left_spark_columns = left.column_map.get_spark_columns()
341
+ right_spark_columns = right.column_map.get_spark_columns()
342
+ common_spark_columns = [
343
+ x for x in left_spark_columns if x in right_spark_columns
344
+ ]
345
+ join_columns = common_spark_columns
346
+
347
+ match rel.join.join_type:
348
+ case relation_proto.Join.JOIN_TYPE_UNSPECIFIED:
349
+ # TODO: Understand what UNSPECIFIED Join type is
350
+ exception = SnowparkConnectNotImplementedError("Unspecified Join Type")
351
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
352
+ raise exception
353
+ case relation_proto.Join.JOIN_TYPE_INNER:
354
+ join_type = "inner"
355
+ case relation_proto.Join.JOIN_TYPE_FULL_OUTER:
356
+ join_type = "full_outer"
357
+ case relation_proto.Join.JOIN_TYPE_LEFT_OUTER:
358
+ join_type = "left"
359
+ case relation_proto.Join.JOIN_TYPE_RIGHT_OUTER:
360
+ join_type = "right"
361
+ case relation_proto.Join.JOIN_TYPE_LEFT_ANTI:
362
+ join_type = "leftanti"
363
+ case relation_proto.Join.JOIN_TYPE_LEFT_SEMI:
364
+ join_type = "leftsemi"
365
+ case relation_proto.Join.JOIN_TYPE_CROSS:
366
+ join_type = "cross"
367
+ case other:
368
+ exception = SnowparkConnectNotImplementedError(f"Other Join Type: {other}")
369
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
370
+ raise exception
371
+
372
+ has_join_condition = rel.join.HasField("join_condition")
373
+ is_using_columns = bool(join_columns)
374
+
375
+ if has_join_condition:
376
+ assert not is_using_columns
377
+
378
+ condition_type = ConditionType.NO_CONDITION
379
+ if has_join_condition:
380
+ condition_type = ConditionType.JOIN_CONDITION
381
+ elif is_using_columns:
382
+ condition_type = ConditionType.USING_COLUMNS
383
+
384
+ return JoinInfo(join_type, condition_type, join_columns)
385
+
386
+
387
+ def _disambiguate_snowpark_columns(
388
+ left: DataFrameContainer, right: DataFrameContainer
389
+ ) -> tuple[DataFrameContainer, DataFrameContainer]:
390
+ conflicting_snowpark_columns = left.column_map.get_conflicting_snowpark_columns(
391
+ right.column_map
392
+ )
393
+
394
+ if not conflicting_snowpark_columns:
395
+ return left, right
396
+
397
+ # rename and create new containers
398
+ return _disambiguate_container(
399
+ left, conflicting_snowpark_columns
400
+ ), _disambiguate_container(right, conflicting_snowpark_columns)
401
+
402
+
403
+ def _disambiguate_container(
404
+ container: DataFrameContainer, conflicting_snowpark_columns: set[str]
405
+ ) -> DataFrameContainer:
406
+ column_map = container.column_map
407
+ disambiguated_columns = []
408
+ disambiguated_snowpark_names = []
409
+ for c in column_map.columns:
410
+ if c.snowpark_name in conflicting_snowpark_columns:
411
+ # alias snowpark column with a new unique name
412
+ new_name = make_unique_snowpark_name(c.spark_name)
413
+ disambiguated_snowpark_names.append(new_name)
414
+ disambiguated_columns.append(
415
+ snowpark_fn.col(c.snowpark_name).alias(new_name)
416
+ )
417
+ else:
418
+ disambiguated_snowpark_names.append(c.snowpark_name)
419
+ disambiguated_columns.append(snowpark_fn.col(c.snowpark_name))
420
+
421
+ disambiguated_df = container.dataframe.select(*disambiguated_columns)
422
+
423
+ def _get_new_schema():
424
+ old_schema = container.dataframe.schema
425
+ if not old_schema.fields:
426
+ return StructType([])
427
+
428
+ new_fields = []
429
+ for i, name in enumerate(disambiguated_snowpark_names):
430
+ f = old_schema.fields[i]
431
+ new_fields.append(
432
+ StructField(name, f.datatype, nullable=f.nullable, _is_column=True)
433
+ )
434
+ return StructType(new_fields)
435
+
436
+ return DataFrameContainer.create_with_column_mapping(
437
+ dataframe=disambiguated_df,
438
+ spark_column_names=column_map.get_spark_columns(),
439
+ snowpark_column_names=disambiguated_snowpark_names,
440
+ column_metadata=column_map.column_metadata,
441
+ column_qualifiers=column_map.get_qualifiers(),
442
+ table_name=container.table_name,
443
+ cached_schema_getter=_get_new_schema,
444
+ )
@@ -45,6 +45,61 @@ from snowflake.snowpark_connect.utils.telemetry import (
45
45
  )
46
46
 
47
47
 
48
+ def cast_columns(
49
+ df_container: DataFrameContainer,
50
+ df_dtypes: list[snowpark.types.DataType],
51
+ target_dtypes: list[snowpark.types.DataType],
52
+ column_map: ColumnNameMap,
53
+ ):
54
+ df: snowpark.DataFrame = df_container.dataframe
55
+ if df_dtypes == target_dtypes:
56
+ return df_container
57
+ # Use cached schema if available to avoid triggering extra queries
58
+ if (
59
+ hasattr(df_container, "cached_schema_getter")
60
+ and df_container.cached_schema_getter is not None
61
+ ):
62
+ df_schema = df_container.cached_schema_getter()
63
+ else:
64
+ df_schema = df.schema # Get current schema
65
+ new_columns = []
66
+
67
+ for i, field in enumerate(df_schema.fields):
68
+ col_name = field.name
69
+ current_type = field.datatype
70
+ target_type = target_dtypes[i]
71
+
72
+ if current_type != target_type:
73
+ new_columns.append(df[col_name].cast(target_type).alias(col_name))
74
+ else:
75
+ new_columns.append(df[col_name])
76
+
77
+ new_df = df.select(new_columns)
78
+ return DataFrameContainer.create_with_column_mapping(
79
+ dataframe=new_df,
80
+ spark_column_names=column_map.get_spark_columns(),
81
+ snowpark_column_names=column_map.get_snowpark_columns(),
82
+ snowpark_column_types=target_dtypes,
83
+ column_metadata=column_map.column_metadata,
84
+ parent_column_name_map=column_map,
85
+ )
86
+
87
+
88
+ def get_schema_from_result(
89
+ result: DataFrameContainer,
90
+ ) -> StructType:
91
+ """
92
+ Get schema from a DataFrameContainer, using cached schema if available to avoid extra queries.
93
+ """
94
+ if (
95
+ hasattr(result, "cached_schema_getter")
96
+ and result.cached_schema_getter is not None
97
+ ):
98
+ return result.cached_schema_getter()
99
+ else:
100
+ return result.dataframe.schema
101
+
102
+
48
103
  def map_deduplicate(
49
104
  rel: relation_proto.Relation,
50
105
  ) -> DataFrameContainer:
@@ -205,21 +260,8 @@ def map_union(
205
260
 
206
261
  # workaround for unstructured type vs structured type
207
262
  # Use cached schema if available to avoid triggering extra queries
208
- if (
209
- hasattr(left_result, "cached_schema_getter")
210
- and left_result.cached_schema_getter is not None
211
- ):
212
- left_schema = left_result.cached_schema_getter()
213
- else:
214
- left_schema = left_df.schema
215
-
216
- if (
217
- hasattr(right_result, "cached_schema_getter")
218
- and right_result.cached_schema_getter is not None
219
- ):
220
- right_schema = right_result.cached_schema_getter()
221
- else:
222
- right_schema = right_df.schema
263
+ left_schema = get_schema_from_result(left_result)
264
+ right_schema = get_schema_from_result(right_result)
223
265
 
224
266
  left_dtypes = [field.datatype for field in left_schema.fields]
225
267
  right_dtypes = [field.datatype for field in right_schema.fields]
@@ -257,6 +299,29 @@ def map_union(
257
299
  # Union of any type with null type is of the other type
258
300
  target_left_dtypes.append(other_t)
259
301
  target_right_dtypes.append(other_t)
302
+ case (snowpark.types.DecimalType(), snowpark.types.DecimalType()):
303
+ # Widen decimal types to accommodate both sides
304
+ # Calculate the maximum scale and maximum integer digits
305
+ left_integer_digits = left_type.precision - left_type.scale
306
+ right_integer_digits = right_type.precision - right_type.scale
307
+
308
+ # The common type needs to accommodate:
309
+ # - The maximum number of digits after the decimal point (scale)
310
+ # - The maximum number of digits before the decimal point (integer digits)
311
+ common_scale = max(left_type.scale, right_type.scale)
312
+ common_integer_digits = max(
313
+ left_integer_digits, right_integer_digits
314
+ )
315
+ common_precision = min(38, common_scale + common_integer_digits)
316
+
317
+ # Ensure scale doesn't exceed precision
318
+ common_scale = min(common_scale, common_precision)
319
+
320
+ common_type = snowpark.types.DecimalType(
321
+ common_precision, common_scale
322
+ )
323
+ target_left_dtypes.append(common_type)
324
+ target_right_dtypes.append(common_type)
260
325
  case (snowpark.types.BooleanType(), _) | (
261
326
  _,
262
327
  snowpark.types.BooleanType(),
@@ -272,49 +337,24 @@ def map_union(
272
337
  raise exception
273
338
  target_left_dtypes.append(left_type)
274
339
  target_right_dtypes.append(right_type)
340
+ case (
341
+ snowpark.types.TimestampType()
342
+ | snowpark.types.DateType()
343
+ | snowpark.types._NumericType(),
344
+ snowpark.types.StringType(),
345
+ ) | (
346
+ snowpark.types.StringType(),
347
+ snowpark.types.TimestampType()
348
+ | snowpark.types.DateType()
349
+ | snowpark.types._NumericType(),
350
+ ) if not spark_sql_ansi_enabled:
351
+ common_type = snowpark.types.StringType()
352
+ target_left_dtypes.append(common_type)
353
+ target_right_dtypes.append(common_type)
275
354
  case _:
276
355
  target_left_dtypes.append(left_type)
277
356
  target_right_dtypes.append(right_type)
278
357
 
279
- def cast_columns(
280
- df_container: DataFrameContainer,
281
- df_dtypes: list[snowpark.types.DataType],
282
- target_dtypes: list[snowpark.types.DataType],
283
- column_map: ColumnNameMap,
284
- ):
285
- df: snowpark.DataFrame = df_container.dataframe
286
- if df_dtypes == target_dtypes:
287
- return df_container
288
- # Use cached schema if available to avoid triggering extra queries
289
- if (
290
- hasattr(df_container, "cached_schema_getter")
291
- and df_container.cached_schema_getter is not None
292
- ):
293
- df_schema = df_container.cached_schema_getter()
294
- else:
295
- df_schema = df.schema # Get current schema
296
- new_columns = []
297
-
298
- for i, field in enumerate(df_schema.fields):
299
- col_name = field.name
300
- current_type = field.datatype
301
- target_type = target_dtypes[i]
302
-
303
- if current_type != target_type:
304
- new_columns.append(df[col_name].cast(target_type).alias(col_name))
305
- else:
306
- new_columns.append(df[col_name])
307
-
308
- new_df = df.select(new_columns)
309
- return DataFrameContainer.create_with_column_mapping(
310
- dataframe=new_df,
311
- spark_column_names=column_map.get_spark_columns(),
312
- snowpark_column_names=column_map.get_snowpark_columns(),
313
- snowpark_column_types=target_dtypes,
314
- column_metadata=column_map.column_metadata,
315
- parent_column_name_map=column_map,
316
- )
317
-
318
358
  left_result = cast_columns(
319
359
  left_result,
320
360
  left_dtypes,
@@ -527,6 +567,48 @@ def map_except(
527
567
  left_df = left_result.dataframe
528
568
  right_df = right_result.dataframe
529
569
 
570
+ # workaround for unstructured type vs structured type
571
+ # Use cached schema if available to avoid triggering extra queries
572
+ left_schema = get_schema_from_result(left_result)
573
+ right_schema = get_schema_from_result(right_result)
574
+
575
+ left_dtypes = [field.datatype for field in left_schema.fields]
576
+ right_dtypes = [field.datatype for field in right_schema.fields]
577
+
578
+ if left_dtypes != right_dtypes and not rel.set_op.by_name:
579
+ if len(left_dtypes) != len(right_dtypes):
580
+ exception = AnalysisException("UNION: the number of columns must match")
581
+ attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
582
+ raise exception
583
+ target_left_dtypes, target_right_dtypes = [], []
584
+ for left_type, right_type in zip(left_dtypes, right_dtypes):
585
+ match (left_type, right_type):
586
+ case (snowpark.types._NumericType(), snowpark.types.StringType()) | (
587
+ snowpark.types.StringType(),
588
+ snowpark.types._NumericType(),
589
+ ):
590
+ common_type = snowpark.types.StringType()
591
+ target_left_dtypes.append(common_type)
592
+ target_right_dtypes.append(common_type)
593
+ case _:
594
+ target_left_dtypes.append(left_type)
595
+ target_right_dtypes.append(right_type)
596
+
597
+ left_result = cast_columns(
598
+ left_result,
599
+ left_dtypes,
600
+ target_left_dtypes,
601
+ left_result.column_map,
602
+ )
603
+ right_result = cast_columns(
604
+ right_result,
605
+ right_dtypes,
606
+ target_right_dtypes,
607
+ right_result.column_map,
608
+ )
609
+ left_df = left_result.dataframe
610
+ right_df = right_result.dataframe
611
+
530
612
  if rel.set_op.is_all:
531
613
  # Snowflake except removes all duplicated rows. In order to handle the case,
532
614
  # we add a partition row number column to the df to make duplicated rows unique to