snowpark-connect 0.32.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (106) hide show
  1. snowflake/snowpark_connect/column_name_handler.py +91 -40
  2. snowflake/snowpark_connect/column_qualifier.py +0 -4
  3. snowflake/snowpark_connect/config.py +9 -0
  4. snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
  5. snowflake/snowpark_connect/expression/literal.py +12 -12
  6. snowflake/snowpark_connect/expression/map_sql_expression.py +18 -4
  7. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +150 -29
  8. snowflake/snowpark_connect/expression/map_unresolved_function.py +93 -55
  9. snowflake/snowpark_connect/relation/map_aggregate.py +156 -257
  10. snowflake/snowpark_connect/relation/map_column_ops.py +19 -0
  11. snowflake/snowpark_connect/relation/map_join.py +454 -252
  12. snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
  13. snowflake/snowpark_connect/relation/map_sql.py +335 -90
  14. snowflake/snowpark_connect/relation/read/map_read.py +9 -1
  15. snowflake/snowpark_connect/relation/read/map_read_csv.py +19 -2
  16. snowflake/snowpark_connect/relation/read/map_read_json.py +90 -2
  17. snowflake/snowpark_connect/relation/read/map_read_parquet.py +3 -0
  18. snowflake/snowpark_connect/relation/read/map_read_text.py +4 -0
  19. snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
  20. snowflake/snowpark_connect/relation/read/utils.py +41 -0
  21. snowflake/snowpark_connect/relation/utils.py +50 -2
  22. snowflake/snowpark_connect/relation/write/map_write.py +251 -292
  23. snowflake/snowpark_connect/resources_initializer.py +25 -13
  24. snowflake/snowpark_connect/server.py +9 -24
  25. snowflake/snowpark_connect/type_mapping.py +2 -0
  26. snowflake/snowpark_connect/typed_column.py +2 -2
  27. snowflake/snowpark_connect/utils/context.py +0 -14
  28. snowflake/snowpark_connect/utils/expression_transformer.py +163 -0
  29. snowflake/snowpark_connect/utils/sequence.py +21 -0
  30. snowflake/snowpark_connect/utils/session.py +4 -1
  31. snowflake/snowpark_connect/utils/udf_helper.py +1 -0
  32. snowflake/snowpark_connect/utils/udtf_helper.py +3 -0
  33. snowflake/snowpark_connect/version.py +1 -1
  34. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/METADATA +4 -2
  35. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/RECORD +43 -104
  36. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  37. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  39. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  46. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  47. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  48. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  52. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  53. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  54. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  55. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  56. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  57. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  58. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  59. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  60. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  61. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  62. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  63. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  64. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  65. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  66. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  67. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  68. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  69. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  70. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  71. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  72. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  73. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  74. snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
  75. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  76. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  77. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  78. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  79. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  80. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  81. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
  82. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  83. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  84. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  85. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  86. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  87. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  88. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  89. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  90. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  91. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  92. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  93. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  94. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  95. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  96. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  97. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  98. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  99. {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-connect +0 -0
  100. {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-session +0 -0
  101. {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-submit +0 -0
  102. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/WHEEL +0 -0
  103. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE-binary +0 -0
  104. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE.txt +0 -0
  105. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/NOTICE-binary +0 -0
  106. {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,23 @@
1
1
  #
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
-
4
+ import dataclasses
5
+ from collections.abc import Callable
6
+ from enum import Enum
5
7
  from functools import reduce
8
+ from typing import Optional
6
9
 
7
10
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
8
11
  from pyspark.errors import AnalysisException
9
12
 
10
13
  import snowflake.snowpark.functions as snowpark_fn
11
14
  from snowflake import snowpark
12
- from snowflake.snowpark_connect.column_name_handler import JoinColumnNameMap
13
- from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
15
+ from snowflake.snowpark import DataFrame
16
+ from snowflake.snowpark.types import StructField, StructType
17
+ from snowflake.snowpark_connect.column_name_handler import (
18
+ JoinColumnNameMap,
19
+ make_unique_snowpark_name,
20
+ )
14
21
  from snowflake.snowpark_connect.config import global_config
15
22
  from snowflake.snowpark_connect.constants import COLUMN_METADATA_COLLISION_KEY
16
23
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
@@ -43,6 +50,23 @@ from snowflake.snowpark_connect.utils.telemetry import (
43
50
  USING_COLUMN_NOT_FOUND_ERROR = "[UNRESOLVED_USING_COLUMN_FOR_JOIN] USING column `{0}` not found on the {1} side of the join. The {1}-side columns: {2}"
44
51
 
45
52
 
53
+ class ConditionType(Enum):
54
+ USING_COLUMNS = 1
55
+ JOIN_CONDITION = 2
56
+ NO_CONDITION = 3
57
+
58
+
59
+ @dataclasses.dataclass
60
+ class JoinInfo:
61
+ join_type: str
62
+ condition_type: ConditionType
63
+ join_columns: Optional[list[str]]
64
+ just_left_columns: bool
65
+
66
+ def is_using_columns(self):
67
+ return self.condition_type == ConditionType.USING_COLUMNS
68
+
69
+
46
70
  def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
47
71
  left_container: DataFrameContainer = map_relation(rel.join.left)
48
72
  right_container: DataFrameContainer = map_relation(rel.join.right)
@@ -52,18 +76,321 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
52
76
  left_container = filter_metadata_columns(left_container)
53
77
  right_container = filter_metadata_columns(right_container)
54
78
 
55
- left_input: snowpark.DataFrame = left_container.dataframe
56
- right_input: snowpark.DataFrame = right_container.dataframe
79
+ # if there are any conflicting snowpark columns, this is the time to rename them
80
+ left_container, right_container = _disambiguate_snowpark_columns(
81
+ left_container, right_container, rel
82
+ )
83
+
84
+ join_info = _get_join_info(rel, left_container, right_container)
85
+
86
+ match join_info.condition_type:
87
+ case ConditionType.JOIN_CONDITION:
88
+ result_container = _join_using_condition(
89
+ left_container, right_container, join_info, rel
90
+ )
91
+ case ConditionType.USING_COLUMNS:
92
+ result_container = _join_using_columns(
93
+ left_container, right_container, join_info
94
+ )
95
+ case _:
96
+ result_container = _join_unconditionally(
97
+ left_container, right_container, join_info
98
+ )
99
+
100
+ # Fix for USING join column references with different plan IDs
101
+ # After a USING join, references to the right dataframe's columns should resolve
102
+ # to the result dataframe that contains the merged columns
103
+ if (
104
+ join_info.is_using_columns()
105
+ and rel.join.right.HasField("common")
106
+ and rel.join.right.common.HasField("plan_id")
107
+ ):
108
+ right_plan_id = rel.join.right.common.plan_id
109
+ set_plan_id_map(right_plan_id, result_container)
110
+
111
+ # For FULL OUTER joins, we also need to map the left dataframe's plan_id
112
+ # since both columns are replaced with a coalesced column
113
+ if (
114
+ join_info.is_using_columns()
115
+ and join_info.join_type == "full_outer"
116
+ and rel.join.left.HasField("common")
117
+ and rel.join.left.common.HasField("plan_id")
118
+ ):
119
+ left_plan_id = rel.join.left.common.plan_id
120
+ set_plan_id_map(left_plan_id, result_container)
121
+
122
+ return result_container
123
+
124
+
125
+ def _join_unconditionally(
126
+ left_container: DataFrameContainer,
127
+ right_container: DataFrameContainer,
128
+ info: JoinInfo,
129
+ ) -> DataFrameContainer:
130
+ if info.join_type != "cross" and not global_config.spark_sql_crossJoin_enabled:
131
+ exception = SparkException.implicit_cartesian_product("inner")
132
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
133
+ raise exception
134
+
135
+ left_input = left_container.dataframe
136
+ right_input = right_container.dataframe
137
+ join_type = info.join_type
138
+
139
+ # For outer joins without a condition, we need to use a TRUE condition
140
+ # to match Spark's behavior.
141
+ result: snowpark.DataFrame = left_input.join(
142
+ right=right_input,
143
+ on=snowpark_fn.lit(True)
144
+ if join_type in ["left", "right", "full_outer"]
145
+ else None,
146
+ how=join_type,
147
+ )
148
+
149
+ columns = left_container.column_map.columns + right_container.column_map.columns
150
+ column_metadata = _combine_metadata(left_container, right_container)
151
+
152
+ if info.just_left_columns:
153
+ columns = left_container.column_map.columns
154
+ column_metadata = left_container.column_map.column_metadata
155
+ result = result.select(*left_container.column_map.get_snowpark_columns())
156
+
157
+ snowpark_columns = [c.snowpark_name for c in columns]
158
+
159
+ return DataFrameContainer.create_with_column_mapping(
160
+ dataframe=result,
161
+ spark_column_names=[c.spark_name for c in columns],
162
+ snowpark_column_names=snowpark_columns,
163
+ column_metadata=column_metadata,
164
+ column_qualifiers=[c.qualifiers for c in columns],
165
+ cached_schema_getter=_build_joined_schema(
166
+ snowpark_columns, left_input, right_input
167
+ ),
168
+ )
169
+
170
+
171
+ def _join_using_columns(
172
+ left_container: DataFrameContainer,
173
+ right_container: DataFrameContainer,
174
+ info: JoinInfo,
175
+ ) -> DataFrameContainer:
176
+ join_columns = info.join_columns
177
+
178
+ def _validate_using_column(
179
+ column: str, container: DataFrameContainer, side: str
180
+ ) -> None:
181
+ if (
182
+ container.column_map.get_snowpark_column_name_from_spark_column_name(
183
+ column, allow_non_exists=True, return_first=True
184
+ )
185
+ is None
186
+ ):
187
+ exception = AnalysisException(
188
+ USING_COLUMN_NOT_FOUND_ERROR.format(
189
+ column, side, container.column_map.get_spark_columns()
190
+ )
191
+ )
192
+ attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
193
+ raise exception
194
+
195
+ for col in join_columns:
196
+ _validate_using_column(col, left_container, "left")
197
+ _validate_using_column(col, right_container, "right")
198
+
199
+ left_input = left_container.dataframe
200
+ right_input = right_container.dataframe
201
+
202
+ # The inputs will have different snowpark names for the same spark name,
203
+ # so we convert ["a", "b"] into (left["a"] == right["a"] & left["b"] == right["b"]),
204
+ # then drop right["a"] and right["b"].
205
+ snowpark_using_columns = [
206
+ (
207
+ snowpark_fn.col(
208
+ left_container.column_map.get_snowpark_column_name_from_spark_column_name(
209
+ spark_name, return_first=True
210
+ )
211
+ ),
212
+ snowpark_fn.col(
213
+ right_container.column_map.get_snowpark_column_name_from_spark_column_name(
214
+ spark_name, return_first=True
215
+ )
216
+ ),
217
+ )
218
+ for spark_name in join_columns
219
+ ]
220
+
221
+ # this is a condition join, so it will contain left + right columns
222
+ # we need to postprocess this later to have a correct projection
223
+ joined_df = left_input.join(
224
+ right=right_input,
225
+ on=reduce(
226
+ snowpark.Column.__and__,
227
+ (left == right for left, right in snowpark_using_columns),
228
+ ),
229
+ how=info.join_type,
230
+ )
231
+
232
+ # figure out default column ordering after the join
233
+ columns = left_container.column_map.get_columns_after_join(
234
+ right_container.column_map, join_columns, info.join_type
235
+ )
236
+
237
+ # For outer joins, we need to preserve join keys from both sides using COALESCE
238
+ if info.join_type == "full_outer":
239
+ coalesced_columns = []
240
+ coalesced_column_names = []
241
+ for i, (left_col, right_col) in enumerate(snowpark_using_columns):
242
+ # spark uses the left side spark name
243
+ spark_name = columns[i].spark_name
244
+ new_snowpark_name = make_unique_snowpark_name(spark_name)
245
+ coalesced_col = snowpark_fn.coalesce(left_col, right_col).alias(
246
+ new_snowpark_name
247
+ )
248
+ coalesced_columns.append(coalesced_col)
249
+ coalesced_column_names.append((spark_name, new_snowpark_name))
250
+
251
+ # join columns need to be replaced, so we need the original names for schema lookup later
252
+ snowpark_names_for_schema_lookup = [c.snowpark_name for c in columns]
253
+
254
+ # we need to use the coalesced columns instead of the left-side join columns
255
+ columns = columns[len(join_columns) :]
256
+
257
+ non_join_columns = [snowpark_fn.col(c.snowpark_name) for c in columns]
258
+ result = joined_df.select(coalesced_columns + non_join_columns)
259
+
260
+ spark_names = [spark_name for spark_name, _ in coalesced_column_names] + [
261
+ c.spark_name for c in columns
262
+ ]
263
+ snowpark_names = [
264
+ snowpark_name for _, snowpark_name in coalesced_column_names
265
+ ] + [c.snowpark_name for c in columns]
266
+ qualifiers = ([set()] * len(join_columns)) + [c.qualifiers for c in columns]
267
+
268
+ return DataFrameContainer.create_with_column_mapping(
269
+ dataframe=result,
270
+ spark_column_names=spark_names,
271
+ snowpark_column_names=snowpark_names,
272
+ column_metadata=_combine_metadata(left_container, right_container),
273
+ column_qualifiers=qualifiers,
274
+ cached_schema_getter=_build_joined_schema(
275
+ snowpark_names_for_schema_lookup,
276
+ left_input,
277
+ right_input,
278
+ snowpark_names,
279
+ ),
280
+ )
281
+
282
+ if info.just_left_columns:
283
+ # we just need the left columns
284
+ columns = columns[: len(left_container.column_map.columns)]
285
+ snowpark_columns = [c.snowpark_name for c in columns]
286
+ result = joined_df.select(*snowpark_columns)
287
+
288
+ return DataFrameContainer.create_with_column_mapping(
289
+ dataframe=result,
290
+ spark_column_names=[c.spark_name for c in columns],
291
+ snowpark_column_names=snowpark_columns,
292
+ column_metadata=left_container.column_map.column_metadata,
293
+ column_qualifiers=[c.qualifiers for c in columns],
294
+ cached_schema_getter=_build_joined_schema(
295
+ snowpark_columns, left_input, right_input
296
+ ),
297
+ )
298
+
299
+ snowpark_columns = [c.snowpark_name for c in columns]
300
+ result = joined_df.select(*snowpark_columns)
301
+ return DataFrameContainer.create_with_column_mapping(
302
+ dataframe=result,
303
+ spark_column_names=[c.spark_name for c in columns],
304
+ snowpark_column_names=snowpark_columns,
305
+ column_metadata=_combine_metadata(left_container, right_container),
306
+ column_qualifiers=[c.qualifiers for c in columns],
307
+ cached_schema_getter=_build_joined_schema(
308
+ snowpark_columns, left_input, right_input
309
+ ),
310
+ )
311
+
312
+
313
+ def _join_using_condition(
314
+ left_container: DataFrameContainer,
315
+ right_container: DataFrameContainer,
316
+ info: JoinInfo,
317
+ rel: relation_proto.Relation,
318
+ ) -> DataFrameContainer:
319
+ left_columns = left_container.column_map.get_spark_columns()
320
+ right_columns = right_container.column_map.get_spark_columns()
321
+
322
+ left_input = left_container.dataframe
323
+ right_input = right_container.dataframe
324
+
325
+ # All PySpark join types are in the format of JOIN_TYPE_XXX.
326
+ # We remove the first 10 characters (JOIN_TYPE_) and replace all underscores with spaces to match the exception.
327
+ pyspark_join_type = relation_proto.Join.JoinType.Name(rel.join.join_type)[
328
+ 10:
329
+ ].replace("_", " ")
330
+ with push_sql_scope(), push_evaluating_join_condition(
331
+ pyspark_join_type, left_columns, right_columns
332
+ ):
333
+ if left_container.alias is not None:
334
+ set_sql_plan_name(left_container.alias, rel.join.left.common.plan_id)
335
+ if right_container.alias is not None:
336
+ set_sql_plan_name(right_container.alias, rel.join.right.common.plan_id)
337
+ # resolve join condition expression
338
+ _, join_expression = map_single_column_expression(
339
+ rel.join.join_condition,
340
+ column_mapping=JoinColumnNameMap(
341
+ left_container.column_map,
342
+ right_container.column_map,
343
+ ),
344
+ typer=JoinExpressionTyper(left_input, right_input),
345
+ )
346
+
347
+ result: snowpark.DataFrame = left_input.join(
348
+ right=right_input,
349
+ on=join_expression.col,
350
+ how=info.join_type,
351
+ )
352
+
353
+ # column order is already correct, so we just take the left + right side list
354
+ columns = left_container.column_map.columns + right_container.column_map.columns
355
+ column_metadata = _combine_metadata(left_container, right_container)
356
+
357
+ if info.just_left_columns:
358
+ # we just need left-side columns
359
+ columns = left_container.column_map.columns
360
+ result = result.select(*[c.snowpark_name for c in columns])
361
+ column_metadata = left_container.column_map.column_metadata
362
+
363
+ snowpark_columns = [c.snowpark_name for c in columns]
364
+
365
+ return DataFrameContainer.create_with_column_mapping(
366
+ dataframe=result,
367
+ spark_column_names=[c.spark_name for c in columns],
368
+ snowpark_column_names=snowpark_columns,
369
+ column_metadata=column_metadata,
370
+ column_qualifiers=[c.qualifiers for c in columns],
371
+ cached_schema_getter=_build_joined_schema(
372
+ snowpark_columns, left_input, right_input
373
+ ),
374
+ )
375
+
376
+
377
+ def _get_join_info(
378
+ rel: relation_proto.Relation, left: DataFrameContainer, right: DataFrameContainer
379
+ ) -> JoinInfo:
380
+ """
381
+ Gathers basic information about the join, and performs basic assertions
382
+ """
383
+
57
384
  is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
58
- using_columns = rel.join.using_columns
385
+ join_columns = rel.join.using_columns
59
386
  if is_natural_join:
60
387
  rel.join.join_type -= NATURAL_JOIN_TYPE_BASE
61
- left_spark_columns = left_container.column_map.get_spark_columns()
62
- right_spark_columns = right_container.column_map.get_spark_columns()
388
+ left_spark_columns = left.column_map.get_spark_columns()
389
+ right_spark_columns = right.column_map.get_spark_columns()
63
390
  common_spark_columns = [
64
391
  x for x in left_spark_columns if x in right_spark_columns
65
392
  ]
66
- using_columns = common_spark_columns
393
+ join_columns = common_spark_columns
67
394
 
68
395
  match rel.join.join_type:
69
396
  case relation_proto.Join.JOIN_TYPE_UNSPECIFIED:
@@ -90,202 +417,108 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
90
417
  attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
91
418
  raise exception
92
419
 
93
- # This handles case sensitivity for using_columns
94
- case_corrected_right_columns: list[str] = []
420
+ has_join_condition = rel.join.HasField("join_condition")
421
+ is_using_columns = bool(join_columns)
95
422
 
96
- if rel.join.HasField("join_condition"):
97
- assert not using_columns
423
+ if join_type == "cross" and has_join_condition:
424
+ # if the user provided any condition, it's no longer a cross join
425
+ join_type = "inner"
98
426
 
99
- left_columns = list(left_container.column_map.spark_to_col.keys())
100
- right_columns = list(right_container.column_map.spark_to_col.keys())
427
+ if has_join_condition:
428
+ assert not is_using_columns
101
429
 
102
- # All PySpark join types are in the format of JOIN_TYPE_XXX.
103
- # We remove the first 10 characters (JOIN_TYPE_) and replace all underscores with spaces to match the exception.
104
- pyspark_join_type = relation_proto.Join.JoinType.Name(rel.join.join_type)[
105
- 10:
106
- ].replace("_", " ")
107
- with push_sql_scope(), push_evaluating_join_condition(
108
- pyspark_join_type, left_columns, right_columns
109
- ):
110
- if left_container.alias is not None:
111
- set_sql_plan_name(left_container.alias, rel.join.left.common.plan_id)
112
- if right_container.alias is not None:
113
- set_sql_plan_name(right_container.alias, rel.join.right.common.plan_id)
114
- _, join_expression = map_single_column_expression(
115
- rel.join.join_condition,
116
- column_mapping=JoinColumnNameMap(
117
- left_container.column_map,
118
- right_container.column_map,
119
- ),
120
- typer=JoinExpressionTyper(left_input, right_input),
121
- )
122
- result: snowpark.DataFrame = left_input.join(
123
- right=right_input,
124
- on=join_expression.col,
125
- how=join_type,
126
- lsuffix="_left",
127
- rsuffix="_right",
128
- )
129
- elif using_columns:
130
- if any(
131
- left_container.column_map.get_snowpark_column_name_from_spark_column_name(
132
- c, allow_non_exists=True, return_first=True
133
- )
134
- is None
135
- for c in using_columns
136
- ):
137
- exception = AnalysisException(
138
- USING_COLUMN_NOT_FOUND_ERROR.format(
139
- next(
140
- c
141
- for c in using_columns
142
- if left_container.column_map.get_snowpark_column_name_from_spark_column_name(
143
- c, allow_non_exists=True, return_first=True
144
- )
145
- is None
146
- ),
147
- "left",
148
- left_container.column_map.get_spark_columns(),
149
- )
150
- )
151
- attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
152
- raise exception
153
- if any(
154
- right_container.column_map.get_snowpark_column_name_from_spark_column_name(
155
- c, allow_non_exists=True, return_first=True
156
- )
157
- is None
158
- for c in using_columns
159
- ):
160
- exception = AnalysisException(
161
- USING_COLUMN_NOT_FOUND_ERROR.format(
162
- next(
163
- c
164
- for c in using_columns
165
- if right_container.column_map.get_snowpark_column_name_from_spark_column_name(
166
- c, allow_non_exists=True, return_first=True
167
- )
168
- is None
169
- ),
170
- "right",
171
- right_container.column_map.get_spark_columns(),
172
- )
173
- )
174
- attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
175
- raise exception
430
+ condition_type = ConditionType.NO_CONDITION
431
+ if has_join_condition:
432
+ condition_type = ConditionType.JOIN_CONDITION
433
+ elif is_using_columns:
434
+ condition_type = ConditionType.USING_COLUMNS
176
435
 
177
- # Round trip the using columns through the column map to get the correct names
178
- # in order to support case sensitivity.
179
- # TODO: case_corrected_left_columns / case_corrected_right_columns may no longer be required as Snowpark dataframe preserves the column casing now.
180
- case_corrected_left_columns = left_container.column_map.get_spark_column_names_from_snowpark_column_names(
181
- left_container.column_map.get_snowpark_column_names_from_spark_column_names(
182
- list(using_columns), return_first=True
183
- )
184
- )
185
- case_corrected_right_columns = right_container.column_map.get_spark_column_names_from_snowpark_column_names(
186
- right_container.column_map.get_snowpark_column_names_from_spark_column_names(
187
- list(using_columns), return_first=True
188
- )
189
- )
190
- using_columns = zip(case_corrected_left_columns, case_corrected_right_columns)
191
- # We cannot assume that Snowpark will have the same names for left and right columns,
192
- # so we convert ["a", "b"] into (left["a"] == right["a"] & left["b"] == right["b"]),
193
- # then drop right["a"] and right["b"].
194
- snowpark_using_columns = [
195
- (
196
- left_input[
197
- left_container.column_map.get_snowpark_column_name_from_spark_column_name(
198
- lft, return_first=True
199
- )
200
- ],
201
- right_input[
202
- right_container.column_map.get_snowpark_column_name_from_spark_column_name(
203
- r, return_first=True
204
- )
205
- ],
436
+ # Join types that only return columns from the left side:
437
+ # - LEFT SEMI JOIN: Returns left rows that have matches in right table (no right columns)
438
+ # - LEFT ANTI JOIN: Returns left rows that have NO matches in right table (no right columns)
439
+ # Both preserve only the columns from the left DataFrame without adding any columns from the right.
440
+ just_left_columns = join_type in ["leftanti", "leftsemi"]
441
+
442
+ return JoinInfo(join_type, condition_type, join_columns, just_left_columns)
443
+
444
+
445
+ def _disambiguate_snowpark_columns(
446
+ left: DataFrameContainer, right: DataFrameContainer, rel: relation_proto.Relation
447
+ ) -> tuple[DataFrameContainer, DataFrameContainer]:
448
+ conflicting_snowpark_columns = left.column_map.get_conflicting_snowpark_columns(
449
+ right.column_map
450
+ )
451
+
452
+ if not conflicting_snowpark_columns:
453
+ return left, right
454
+
455
+ left_plan = rel.join.left.common.plan_id
456
+ right_plan = rel.join.right.common.plan_id
457
+
458
+ if left_plan == right_plan:
459
+ # don't overwrite plan_id map for self joins
460
+ right_plan = None
461
+
462
+ # rename and create new right container
463
+ # TODO: rename both sides after SNOW-2382499
464
+ return left, _disambiguate_container(
465
+ right, conflicting_snowpark_columns, right_plan
466
+ )
467
+
468
+
469
+ def _disambiguate_container(
470
+ container: DataFrameContainer,
471
+ conflicting_snowpark_columns: set[str],
472
+ plan_id: Optional[int],
473
+ ) -> DataFrameContainer:
474
+ column_map = container.column_map
475
+ disambiguated_columns = []
476
+ disambiguated_snowpark_names = []
477
+ for c in column_map.columns:
478
+ if c.snowpark_name in conflicting_snowpark_columns:
479
+ # alias snowpark column with a new unique name
480
+ new_name = make_unique_snowpark_name(c.spark_name)
481
+ disambiguated_snowpark_names.append(new_name)
482
+ disambiguated_columns.append(
483
+ snowpark_fn.col(c.snowpark_name).alias(new_name)
206
484
  )
207
- for lft, r in using_columns
208
- ]
209
- joined_df = left_input.join(
210
- right=right_input,
211
- on=reduce(
212
- snowpark.Column.__and__,
213
- (left == right for left, right in snowpark_using_columns),
214
- ),
215
- how=join_type,
216
- )
217
- # For outer joins, we need to preserve join keys from both sides using COALESCE
218
- if join_type == "full_outer":
219
- coalesced_columns = []
220
- columns_to_drop = []
221
- for i, (left_col, right_col) in enumerate(snowpark_using_columns):
222
- # Use the original user-specified column name to preserve case sensitivity
223
- original_column_name = rel.join.using_columns[i]
224
- coalesced_col = snowpark_fn.coalesce(left_col, right_col).alias(
225
- original_column_name
226
- )
227
- coalesced_columns.append(coalesced_col)
228
- columns_to_drop.extend([left_col, right_col])
485
+ else:
486
+ disambiguated_snowpark_names.append(c.snowpark_name)
487
+ disambiguated_columns.append(snowpark_fn.col(c.snowpark_name))
488
+
489
+ disambiguated_df = container.dataframe.select(*disambiguated_columns)
229
490
 
230
- other_columns = [
231
- snowpark_fn.col(col_name)
232
- for col_name in joined_df.columns
233
- if col_name not in [col.getName() for col in columns_to_drop]
491
+ def _schema_getter() -> StructType:
492
+ fields = container.dataframe.schema.fields
493
+ return StructType(
494
+ [
495
+ StructField(name, fields[i].datatype, fields[i].nullable)
496
+ for i, name in enumerate(disambiguated_snowpark_names)
234
497
  ]
235
- result = joined_df.select(coalesced_columns + other_columns)
236
- else:
237
- result = joined_df.drop(*(right for _, right in snowpark_using_columns))
238
- else:
239
- if join_type != "cross" and not global_config.spark_sql_crossJoin_enabled:
240
- exception = SparkException.implicit_cartesian_product("inner")
241
- attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
242
- raise exception
243
- result: snowpark.DataFrame = left_input.join(
244
- right=right_input,
245
- how=join_type,
246
498
  )
247
499
 
248
- if join_type in ["leftanti", "leftsemi"]:
249
- # Join types that only return columns from the left side:
250
- # - LEFT SEMI JOIN: Returns left rows that have matches in right table (no right columns)
251
- # - LEFT ANTI JOIN: Returns left rows that have NO matches in right table (no right columns)
252
- # Both preserve only the columns from the left DataFrame without adding any columns from the right.
253
- spark_cols_after_join: list[str] = left_container.column_map.get_spark_columns()
254
- qualifiers = left_container.column_map.get_qualifiers()
255
- else:
256
- # Add Spark columns and plan_ids from left DF
257
- spark_cols_after_join: list[str] = list(
258
- left_container.column_map.get_spark_columns()
259
- ) + [
260
- spark_col
261
- for i, spark_col in enumerate(
262
- right_container.column_map.get_spark_columns()
263
- )
264
- if spark_col not in case_corrected_right_columns
265
- or spark_col
266
- in right_container.column_map.get_spark_columns()[
267
- :i
268
- ] # this is to make sure we only remove the column once
269
- ]
500
+ disambiguated_container = DataFrameContainer.create_with_column_mapping(
501
+ dataframe=disambiguated_df,
502
+ spark_column_names=column_map.get_spark_columns(),
503
+ snowpark_column_names=disambiguated_snowpark_names,
504
+ column_metadata=column_map.column_metadata,
505
+ column_qualifiers=column_map.get_qualifiers(),
506
+ table_name=container.table_name,
507
+ cached_schema_getter=_schema_getter,
508
+ )
270
509
 
271
- qualifiers: list[set[ColumnQualifier]] = list(
272
- left_container.column_map.get_qualifiers()
273
- ) + [
274
- {right_container.column_map.get_qualifier_for_spark_column(spark_col)}
275
- for i, spark_col in enumerate(
276
- right_container.column_map.get_spark_columns()
277
- )
278
- if spark_col not in case_corrected_right_columns
279
- or spark_col
280
- in right_container.column_map.get_spark_columns()[
281
- :i
282
- ] # this is to make sure we only remove the column once]
283
- ]
510
+ # since we just renamed some snowpark columns, we need to update the dataframe container for the given plan_id
511
+ # TODO: is there a better way to do this?
512
+ if plan_id is not None:
513
+ set_plan_id_map(plan_id, disambiguated_container)
514
+
515
+ return disambiguated_container
284
516
 
285
- column_metadata = {}
286
- if left_container.column_map.column_metadata:
287
- column_metadata.update(left_container.column_map.column_metadata)
288
517
 
518
+ def _combine_metadata(
519
+ left_container: DataFrameContainer, right_container: DataFrameContainer
520
+ ) -> dict:
521
+ column_metadata = dict(left_container.column_map.column_metadata or {})
289
522
  if right_container.column_map.column_metadata:
290
523
  for key, value in right_container.column_map.column_metadata.items():
291
524
  if key not in column_metadata:
@@ -297,7 +530,9 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
297
530
  snowpark_name = right_container.column_map.get_snowpark_column_name_from_spark_column_name(
298
531
  key
299
532
  )
300
- expr_id = right_input[snowpark_name]._expression.expr_id
533
+ expr_id = right_container.dataframe[
534
+ snowpark_name
535
+ ]._expression.expr_id
301
536
  updated_key = COLUMN_METADATA_COLLISION_KEY.format(
302
537
  expr_id=expr_id, key=snowpark_name
303
538
  )
@@ -305,68 +540,35 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
305
540
  except Exception:
306
541
  # ignore any errors that happens while fetching the metadata
307
542
  pass
543
+ return column_metadata
308
544
 
309
- result_container = DataFrameContainer.create_with_column_mapping(
310
- dataframe=result,
311
- spark_column_names=spark_cols_after_join,
312
- snowpark_column_names=result.columns,
313
- column_metadata=column_metadata,
314
- column_qualifiers=qualifiers,
315
- )
316
545
 
317
- # Fix for USING join column references with different plan IDs
318
- # After a USING join, references to the right dataframe's columns should resolve
319
- # to the result dataframe that contains the merged columns
320
- if (
321
- using_columns
322
- and rel.join.right.HasField("common")
323
- and rel.join.right.common.HasField("plan_id")
324
- ):
325
- right_plan_id = rel.join.right.common.plan_id
326
- set_plan_id_map(right_plan_id, result_container)
546
+ def _build_joined_schema(
547
+ snowpark_columns: list[str],
548
+ left_input: DataFrame,
549
+ right_input: DataFrame,
550
+ target_snowpark_columns: Optional[list[str]] = None,
551
+ ) -> Callable[[], StructType]:
552
+ """
553
+ Builds a lazy schema for the joined dataframe, based on the given snowpark_columns and input dataframes.
554
+ In case of full outer joins, we need a separate target_snowpark_columns, since join columns will have different
555
+ names in the output than in any input.
556
+ """
327
557
 
328
- # For FULL OUTER joins, we also need to map the left dataframe's plan_id
329
- # since both columns are replaced with a coalesced column
330
- if (
331
- using_columns
332
- and join_type == "full_outer"
333
- and rel.join.left.HasField("common")
334
- and rel.join.left.common.HasField("plan_id")
335
- ):
336
- left_plan_id = rel.join.left.common.plan_id
337
- set_plan_id_map(left_plan_id, result_container)
338
-
339
- if rel.join.using_columns:
340
- # When join 'using_columns', the 'join columns' should go first in result DF.
341
- idxs_to_shift = [
342
- spark_cols_after_join.index(left_col_name)
343
- for left_col_name in case_corrected_left_columns
344
- ]
558
+ def _schema_getter() -> StructType:
559
+ all_fields = left_input.schema.fields + right_input.schema.fields
560
+ fields: dict[str, StructField] = {f.name: f for f in all_fields}
561
+ target_names = target_snowpark_columns or snowpark_columns
345
562
 
346
- def reorder(lst: list) -> list:
347
- to_move = [lst[i] for i in idxs_to_shift]
348
- remaining = [el for i, el in enumerate(lst) if i not in idxs_to_shift]
349
- return to_move + remaining
563
+ assert len(snowpark_columns) == len(target_names)
350
564
 
351
- # Create reordered DataFrame
352
- reordered_df = result_container.dataframe.select(
353
- [snowpark_fn.col(c) for c in reorder(result_container.dataframe.columns)]
354
- )
355
-
356
- # Create new container with reordered metadata
357
- original_df = result_container.dataframe
358
- return DataFrameContainer.create_with_column_mapping(
359
- dataframe=reordered_df,
360
- spark_column_names=reorder(result_container.column_map.get_spark_columns()),
361
- snowpark_column_names=reorder(
362
- result_container.column_map.get_snowpark_columns()
363
- ),
364
- column_metadata=column_metadata,
365
- column_qualifiers=reorder(qualifiers),
366
- table_name=result_container.table_name,
367
- cached_schema_getter=lambda: snowpark.types.StructType(
368
- reorder(original_df.schema.fields)
369
- ),
565
+ return StructType(
566
+ [
567
+ StructField(
568
+ target_names[i], fields[name].datatype, fields[name].nullable
569
+ )
570
+ for i, name in enumerate(snowpark_columns)
571
+ ]
370
572
  )
371
573
 
372
- return result_container
574
+ return _schema_getter