snowpark-connect 0.27.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  3. snowflake/snowpark_connect/client/__init__.py +15 -0
  4. snowflake/snowpark_connect/client/error_utils.py +30 -0
  5. snowflake/snowpark_connect/client/exceptions.py +36 -0
  6. snowflake/snowpark_connect/client/query_results.py +90 -0
  7. snowflake/snowpark_connect/client/server.py +680 -0
  8. snowflake/snowpark_connect/client/utils/__init__.py +10 -0
  9. snowflake/snowpark_connect/client/utils/session.py +85 -0
  10. snowflake/snowpark_connect/column_name_handler.py +404 -243
  11. snowflake/snowpark_connect/column_qualifier.py +43 -0
  12. snowflake/snowpark_connect/config.py +237 -23
  13. snowflake/snowpark_connect/constants.py +2 -0
  14. snowflake/snowpark_connect/dataframe_container.py +102 -8
  15. snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
  16. snowflake/snowpark_connect/error/error_codes.py +50 -0
  17. snowflake/snowpark_connect/error/error_utils.py +172 -23
  18. snowflake/snowpark_connect/error/exceptions.py +13 -4
  19. snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
  20. snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
  21. snowflake/snowpark_connect/execute_plan/utils.py +5 -1
  22. snowflake/snowpark_connect/expression/function_defaults.py +9 -2
  23. snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
  24. snowflake/snowpark_connect/expression/literal.py +37 -13
  25. snowflake/snowpark_connect/expression/map_cast.py +123 -5
  26. snowflake/snowpark_connect/expression/map_expression.py +80 -27
  27. snowflake/snowpark_connect/expression/map_extension.py +322 -12
  28. snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
  29. snowflake/snowpark_connect/expression/map_udf.py +85 -20
  30. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
  31. snowflake/snowpark_connect/expression/map_unresolved_function.py +2748 -746
  32. snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
  33. snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
  34. snowflake/snowpark_connect/expression/map_window_function.py +18 -3
  35. snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
  36. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
  37. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
  38. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  39. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  40. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
  41. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
  42. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
  43. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
  44. snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
  45. snowflake/snowpark_connect/relation/io_utils.py +110 -10
  46. snowflake/snowpark_connect/relation/map_aggregate.py +196 -255
  47. snowflake/snowpark_connect/relation/map_catalog.py +5 -1
  48. snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
  49. snowflake/snowpark_connect/relation/map_extension.py +263 -29
  50. snowflake/snowpark_connect/relation/map_join.py +683 -442
  51. snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
  52. snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
  53. snowflake/snowpark_connect/relation/map_relation.py +48 -19
  54. snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
  55. snowflake/snowpark_connect/relation/map_show_string.py +13 -6
  56. snowflake/snowpark_connect/relation/map_sql.py +1233 -222
  57. snowflake/snowpark_connect/relation/map_stats.py +48 -9
  58. snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
  59. snowflake/snowpark_connect/relation/map_udtf.py +14 -4
  60. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
  61. snowflake/snowpark_connect/relation/read/map_read.py +134 -43
  62. snowflake/snowpark_connect/relation/read/map_read_csv.py +255 -45
  63. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
  64. snowflake/snowpark_connect/relation/read/map_read_json.py +320 -85
  65. snowflake/snowpark_connect/relation/read/map_read_parquet.py +142 -27
  66. snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
  67. snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
  68. snowflake/snowpark_connect/relation/read/map_read_table.py +82 -5
  69. snowflake/snowpark_connect/relation/read/map_read_text.py +18 -3
  70. snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
  71. snowflake/snowpark_connect/relation/read/reader_config.py +36 -3
  72. snowflake/snowpark_connect/relation/read/utils.py +50 -5
  73. snowflake/snowpark_connect/relation/stage_locator.py +91 -55
  74. snowflake/snowpark_connect/relation/utils.py +128 -5
  75. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
  76. snowflake/snowpark_connect/relation/write/map_write.py +929 -319
  77. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
  78. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  79. snowflake/snowpark_connect/resources_initializer.py +110 -48
  80. snowflake/snowpark_connect/server.py +546 -456
  81. snowflake/snowpark_connect/server_common/__init__.py +500 -0
  82. snowflake/snowpark_connect/snowflake_session.py +65 -0
  83. snowflake/snowpark_connect/start_server.py +53 -5
  84. snowflake/snowpark_connect/type_mapping.py +349 -27
  85. snowflake/snowpark_connect/typed_column.py +9 -7
  86. snowflake/snowpark_connect/utils/artifacts.py +9 -8
  87. snowflake/snowpark_connect/utils/cache.py +49 -27
  88. snowflake/snowpark_connect/utils/concurrent.py +36 -1
  89. snowflake/snowpark_connect/utils/context.py +187 -37
  90. snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
  91. snowflake/snowpark_connect/utils/env_utils.py +5 -1
  92. snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
  93. snowflake/snowpark_connect/utils/identifiers.py +137 -3
  94. snowflake/snowpark_connect/utils/io_utils.py +57 -1
  95. snowflake/snowpark_connect/utils/java_stored_procedure.py +125 -0
  96. snowflake/snowpark_connect/utils/java_udaf_utils.py +303 -0
  97. snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
  98. snowflake/snowpark_connect/utils/jvm_udf_utils.py +248 -0
  99. snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
  100. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
  101. snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
  102. snowflake/snowpark_connect/utils/profiling.py +25 -8
  103. snowflake/snowpark_connect/utils/scala_udf_utils.py +101 -332
  104. snowflake/snowpark_connect/utils/sequence.py +21 -0
  105. snowflake/snowpark_connect/utils/session.py +64 -28
  106. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
  107. snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
  108. snowflake/snowpark_connect/utils/telemetry.py +163 -22
  109. snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
  110. snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
  111. snowflake/snowpark_connect/utils/udf_cache.py +117 -41
  112. snowflake/snowpark_connect/utils/udf_helper.py +39 -37
  113. snowflake/snowpark_connect/utils/udf_utils.py +133 -14
  114. snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
  115. snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
  116. snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
  117. snowflake/snowpark_connect/version.py +1 -1
  118. snowflake/snowpark_decoder/dp_session.py +6 -2
  119. snowflake/snowpark_decoder/spark_decoder.py +12 -0
  120. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-submit +2 -2
  121. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/METADATA +14 -7
  122. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/RECORD +129 -167
  123. snowflake/snowpark_connect/hidden_column.py +0 -39
  124. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  125. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  126. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  127. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  128. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  129. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  130. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  131. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  132. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  133. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  134. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  135. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  136. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  137. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  138. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  139. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  140. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  141. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  142. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  143. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  144. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  145. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  146. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  147. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  148. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  149. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  150. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  151. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  152. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  153. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  154. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  155. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  156. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  157. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  158. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  159. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  160. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  161. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  162. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  163. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  164. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  165. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  166. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  167. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  168. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
  169. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  170. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  171. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  172. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  173. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  174. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  175. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  176. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  177. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  178. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  179. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  180. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  181. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  182. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  183. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  184. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  185. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  186. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-connect +0 -0
  187. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-session +0 -0
  188. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/WHEEL +0 -0
  189. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE-binary +0 -0
  190. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
  191. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/NOTICE-binary +0 -0
  192. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/top_level.txt +0 -0
@@ -1,35 +1,50 @@
1
1
  #
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
- from collections import Counter
4
+ import dataclasses
5
+ from collections.abc import Callable
6
+ from copy import copy
7
+ from enum import Enum
5
8
  from functools import reduce
9
+ from typing import Optional
6
10
 
7
11
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
8
- from pyspark.errors.exceptions.base import AnalysisException
12
+ from pyspark.errors import AnalysisException
13
+ from pyspark.errors.exceptions.connect import IllegalArgumentException
9
14
 
10
15
  import snowflake.snowpark.functions as snowpark_fn
11
16
  from snowflake import snowpark
12
- from snowflake.snowpark._internal.analyzer.analyzer_utils import (
13
- quote_name_without_upper_casing,
14
- unquote_if_quoted,
17
+ from snowflake.snowpark import Column, DataFrame
18
+ from snowflake.snowpark.types import StructField, StructType
19
+ from snowflake.snowpark_connect.column_name_handler import (
20
+ ColumnNames,
21
+ ColumnQualifier,
22
+ JoinColumnNameMap,
23
+ make_unique_snowpark_name,
15
24
  )
16
- from snowflake.snowpark_connect.column_name_handler import JoinColumnNameMap
17
25
  from snowflake.snowpark_connect.config import global_config
18
26
  from snowflake.snowpark_connect.constants import COLUMN_METADATA_COLLISION_KEY
19
27
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
20
- from snowflake.snowpark_connect.error.error_utils import SparkException
28
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
29
+ from snowflake.snowpark_connect.error.error_utils import (
30
+ SparkException,
31
+ attach_custom_error_code,
32
+ )
21
33
  from snowflake.snowpark_connect.expression.map_expression import (
22
34
  map_single_column_expression,
23
35
  )
24
36
  from snowflake.snowpark_connect.expression.typer import JoinExpressionTyper
25
- from snowflake.snowpark_connect.hidden_column import HiddenColumn
26
37
  from snowflake.snowpark_connect.relation.map_relation import (
27
38
  NATURAL_JOIN_TYPE_BASE,
28
39
  map_relation,
29
40
  )
41
+ from snowflake.snowpark_connect.relation.read.metadata_utils import (
42
+ without_internal_columns,
43
+ )
30
44
  from snowflake.snowpark_connect.utils.context import (
31
45
  push_evaluating_join_condition,
32
46
  push_sql_scope,
47
+ set_plan_id_map,
33
48
  set_sql_plan_name,
34
49
  )
35
50
  from snowflake.snowpark_connect.utils.telemetry import (
@@ -38,447 +53,583 @@ from snowflake.snowpark_connect.utils.telemetry import (
38
53
 
39
54
  USING_COLUMN_NOT_FOUND_ERROR = "[UNRESOLVED_USING_COLUMN_FOR_JOIN] USING column `{0}` not found on the {1} side of the join. The {1}-side columns: {2}"
40
55
 
41
- DUPLICATED_JOIN_COL_LSUFFIX = "_left"
42
- DUPLICATED_JOIN_COL_RSUFFIX = "_right"
56
+
57
+ class ConditionType(Enum):
58
+ USING_COLUMNS = 1
59
+ JOIN_CONDITION = 2
60
+ NO_CONDITION = 3
61
+
62
+
63
+ @dataclasses.dataclass
64
+ class JoinInfo:
65
+ join_type: str
66
+ condition_type: ConditionType
67
+ join_columns: Optional[list[str]]
68
+ just_left_columns: bool
69
+ is_join_with: bool
70
+ is_left_struct: bool
71
+ is_right_struct: bool
72
+
73
+ def is_using_columns(self):
74
+ return self.condition_type == ConditionType.USING_COLUMNS
43
75
 
44
76
 
45
77
  def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
46
78
  left_container: DataFrameContainer = map_relation(rel.join.left)
47
79
  right_container: DataFrameContainer = map_relation(rel.join.right)
48
80
 
49
- left_input: snowpark.DataFrame = left_container.dataframe
50
- right_input: snowpark.DataFrame = right_container.dataframe
51
- is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
52
- using_columns = rel.join.using_columns
53
- if is_natural_join:
54
- rel.join.join_type -= NATURAL_JOIN_TYPE_BASE
55
- left_spark_columns = left_container.column_map.get_spark_columns()
56
- right_spark_columns = right_container.column_map.get_spark_columns()
57
- common_spark_columns = [
58
- x for x in left_spark_columns if x in right_spark_columns
59
- ]
60
- using_columns = common_spark_columns
81
+ # Remove any metadata columns(like metada$filename) present in the dataframes.
82
+ # We cannot support inputfilename for multisources as each dataframe has it's own source.
83
+ left_container = without_internal_columns(left_container)
84
+ right_container = without_internal_columns(right_container)
61
85
 
62
- match rel.join.join_type:
63
- case relation_proto.Join.JOIN_TYPE_UNSPECIFIED:
64
- # TODO: Understand what UNSPECIFIED Join type is
65
- raise SnowparkConnectNotImplementedError("Unspecified Join Type")
66
- case relation_proto.Join.JOIN_TYPE_INNER:
67
- join_type = "inner"
68
- case relation_proto.Join.JOIN_TYPE_FULL_OUTER:
69
- join_type = "full_outer"
70
- case relation_proto.Join.JOIN_TYPE_LEFT_OUTER:
71
- join_type = "left"
72
- case relation_proto.Join.JOIN_TYPE_RIGHT_OUTER:
73
- join_type = "right"
74
- case relation_proto.Join.JOIN_TYPE_LEFT_ANTI:
75
- join_type = "leftanti"
76
- case relation_proto.Join.JOIN_TYPE_LEFT_SEMI:
77
- join_type = "leftsemi"
78
- case relation_proto.Join.JOIN_TYPE_CROSS:
79
- join_type = "cross"
80
- case other:
81
- raise SnowparkConnectNotImplementedError(f"Other Join Type: {other}")
82
-
83
- # This handles case sensitivity for using_columns
84
- case_corrected_right_columns: list[str] = []
85
- hidden_columns = set()
86
- # Propagate the hidden columns from left/right inputs to the result in case of chained joins
87
- if left_container.column_map.hidden_columns:
88
- hidden_columns.update(left_container.column_map.hidden_columns)
89
-
90
- if right_container.column_map.hidden_columns:
91
- hidden_columns.update(right_container.column_map.hidden_columns)
92
-
93
- if rel.join.HasField("join_condition"):
94
- assert not using_columns
95
-
96
- left_columns = list(left_container.column_map.spark_to_col.keys())
97
- right_columns = list(right_container.column_map.spark_to_col.keys())
98
-
99
- # All PySpark join types are in the format of JOIN_TYPE_XXX.
100
- # We remove the first 10 characters (JOIN_TYPE_) and replace all underscores with spaces to match the exception.
101
- pyspark_join_type = relation_proto.Join.JoinType.Name(rel.join.join_type)[
102
- 10:
103
- ].replace("_", " ")
104
- with push_sql_scope(), push_evaluating_join_condition(
105
- pyspark_join_type, left_columns, right_columns
106
- ):
107
- if left_container.alias is not None:
108
- set_sql_plan_name(left_container.alias, rel.join.left.common.plan_id)
109
- if right_container.alias is not None:
110
- set_sql_plan_name(right_container.alias, rel.join.right.common.plan_id)
111
- _, join_expression = map_single_column_expression(
112
- rel.join.join_condition,
113
- column_mapping=JoinColumnNameMap(
114
- left_container.column_map,
115
- right_container.column_map,
116
- ),
117
- typer=JoinExpressionTyper(left_input, right_input),
86
+ left_plan = rel.join.left.common.plan_id
87
+ right_plan = rel.join.right.common.plan_id
88
+
89
+ # if there are any conflicting snowpark columns, this is the time to rename them
90
+ disambiguated_right_container = _disambiguate_snowpark_columns(
91
+ left_container, right_container, right_plan if left_plan != right_plan else None
92
+ )
93
+
94
+ join_info = _get_join_info(rel, left_container, disambiguated_right_container)
95
+
96
+ match join_info.condition_type:
97
+ case ConditionType.JOIN_CONDITION:
98
+ result_container = _join_using_condition(
99
+ left_container,
100
+ disambiguated_right_container,
101
+ join_info,
102
+ rel,
103
+ right_container if left_plan == right_plan else None,
118
104
  )
119
- result: snowpark.DataFrame = left_input.join(
120
- right=right_input,
121
- on=join_expression.col,
122
- how=join_type,
123
- lsuffix=DUPLICATED_JOIN_COL_LSUFFIX,
124
- rsuffix=DUPLICATED_JOIN_COL_RSUFFIX,
125
- )
126
- elif using_columns:
127
- if any(
128
- left_container.column_map.get_snowpark_column_name_from_spark_column_name(
129
- c, allow_non_exists=True, return_first=True
105
+ case ConditionType.USING_COLUMNS:
106
+ result_container = _join_using_columns(
107
+ left_container,
108
+ disambiguated_right_container,
109
+ join_info,
130
110
  )
131
- is None
132
- for c in using_columns
133
- ):
134
- import pyspark
135
-
136
- raise pyspark.errors.AnalysisException(
137
- USING_COLUMN_NOT_FOUND_ERROR.format(
138
- next(
139
- c
140
- for c in using_columns
141
- if left_container.column_map.get_snowpark_column_name_from_spark_column_name(
142
- c, allow_non_exists=True, return_first=True
143
- )
144
- is None
145
- ),
146
- "left",
147
- left_container.column_map.get_spark_columns(),
148
- )
111
+ case _:
112
+ result_container = _join_unconditionally(
113
+ left_container, disambiguated_right_container, join_info
149
114
  )
150
- if any(
151
- right_container.column_map.get_snowpark_column_name_from_spark_column_name(
152
- c, allow_non_exists=True, return_first=True
115
+
116
+ return result_container
117
+
118
+
119
+ def _join_unconditionally(
120
+ left_container: DataFrameContainer,
121
+ right_container: DataFrameContainer,
122
+ info: JoinInfo,
123
+ ) -> DataFrameContainer:
124
+ if info.join_type != "cross" and not global_config.spark_sql_crossJoin_enabled:
125
+ exception = SparkException.implicit_cartesian_product("inner")
126
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
127
+ raise exception
128
+
129
+ left_input = left_container.dataframe
130
+ right_input = right_container.dataframe
131
+ join_type = info.join_type
132
+
133
+ # For outer joins without a condition, we need to use a TRUE condition
134
+ # to match Spark's behavior.
135
+ result: snowpark.DataFrame = left_input.join(
136
+ right=right_input,
137
+ on=snowpark_fn.lit(True)
138
+ if join_type in ["left", "right", "full_outer"]
139
+ else None,
140
+ how=join_type,
141
+ )
142
+
143
+ columns = left_container.column_map.columns + right_container.column_map.columns
144
+ column_metadata = _combine_metadata(left_container, right_container)
145
+
146
+ if info.just_left_columns:
147
+ columns = left_container.column_map.columns
148
+ column_metadata = left_container.column_map.column_metadata
149
+ result = result.select(*left_container.column_map.get_snowpark_columns())
150
+
151
+ snowpark_columns = [c.snowpark_name for c in columns]
152
+
153
+ return DataFrameContainer.create_with_column_mapping(
154
+ dataframe=result,
155
+ spark_column_names=[c.spark_name for c in columns],
156
+ snowpark_column_names=snowpark_columns,
157
+ column_metadata=column_metadata,
158
+ column_qualifiers=[c.qualifiers for c in columns],
159
+ cached_schema_getter=_build_joined_schema(
160
+ snowpark_columns, left_input, right_input
161
+ ),
162
+ equivalent_snowpark_names=[c.equivalent_snowpark_names for c in columns],
163
+ )
164
+
165
+
166
+ def _join_using_columns(
167
+ left_container: DataFrameContainer,
168
+ right_container: DataFrameContainer,
169
+ info: JoinInfo,
170
+ ) -> DataFrameContainer:
171
+ join_columns = info.join_columns
172
+
173
+ def _validate_using_column(
174
+ column: str, container: DataFrameContainer, side: str
175
+ ) -> None:
176
+ if (
177
+ container.column_map.get_snowpark_column_name_from_spark_column_name(
178
+ column, allow_non_exists=True, return_first=True
153
179
  )
154
180
  is None
155
- for c in using_columns
156
181
  ):
157
- import pyspark
158
-
159
- raise pyspark.errors.AnalysisException(
182
+ exception = AnalysisException(
160
183
  USING_COLUMN_NOT_FOUND_ERROR.format(
161
- next(
162
- c
163
- for c in using_columns
164
- if right_container.column_map.get_snowpark_column_name_from_spark_column_name(
165
- c, allow_non_exists=True, return_first=True
166
- )
167
- is None
168
- ),
169
- "right",
170
- right_container.column_map.get_spark_columns(),
184
+ column, side, container.column_map.get_spark_columns()
171
185
  )
172
186
  )
173
-
174
- using_columns_snowpark_names = (
175
- left_container.column_map.get_snowpark_column_names_from_spark_column_names(
176
- list(using_columns), return_first=True
177
- )
187
+ attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
188
+ raise exception
189
+
190
+ for col in join_columns:
191
+ _validate_using_column(col, left_container, "left")
192
+ _validate_using_column(col, right_container, "right")
193
+
194
+ left_input = left_container.dataframe
195
+ right_input = right_container.dataframe
196
+
197
+ # The inputs will have different snowpark names for the same spark name,
198
+ # so we convert ["a", "b"] into (left["a"] == right["a"] & left["b"] == right["b"]),
199
+ # then drop right["a"] and right["b"].
200
+ snowpark_using_columns = [
201
+ (
202
+ snowpark_fn.col(
203
+ left_container.column_map.get_snowpark_column_name_from_spark_column_name(
204
+ spark_name, return_first=True
205
+ )
206
+ ),
207
+ snowpark_fn.col(
208
+ right_container.column_map.get_snowpark_column_name_from_spark_column_name(
209
+ spark_name, return_first=True
210
+ )
211
+ ),
178
212
  )
213
+ for spark_name in join_columns
214
+ ]
215
+
216
+ # this is a condition join, so it will contain left + right columns
217
+ # we need to postprocess this later to have a correct projection
218
+ joined_df = left_input.join(
219
+ right=right_input,
220
+ on=reduce(
221
+ snowpark.Column.__and__,
222
+ (left == right for left, right in snowpark_using_columns),
223
+ ),
224
+ how=info.join_type,
225
+ )
179
226
 
180
- using_columns_snowpark_types = [
181
- left_container.dataframe.schema.fields[idx].datatype
182
- for idx, col in enumerate(left_container.column_map.get_snowpark_columns())
183
- if col in using_columns_snowpark_names
184
- ]
227
+ # figure out default column ordering after the join
228
+ columns = left_container.column_map.get_columns_after_join(
229
+ right_container.column_map, join_columns, info.join_type
230
+ )
231
+
232
+ if info.join_type in ["full_outer", "left", "right"]:
233
+ all_columns_for_select = []
234
+ all_column_names = []
185
235
 
186
- # Round trip the using columns through the column map to get the correct names
187
- # in order to support case sensitivity.
188
- # TODO: case_corrected_left_columns / case_corrected_right_columns may no longer be required as Snowpark dataframe preserves the column casing now.
189
- case_corrected_left_columns = (
190
- left_container.column_map.get_spark_column_names_from_snowpark_column_names(
191
- using_columns_snowpark_names
236
+ for column_info in columns[: len(join_columns)]:
237
+ spark_name = column_info.spark_name
238
+ left_sp_name = left_container.column_map.get_snowpark_column_name_from_spark_column_name(
239
+ spark_name, return_first=True
192
240
  )
193
- )
194
- case_corrected_right_columns = right_container.column_map.get_spark_column_names_from_snowpark_column_names(
195
- right_container.column_map.get_snowpark_column_names_from_spark_column_names(
196
- list(using_columns), return_first=True
241
+ right_sp_name = right_container.column_map.get_snowpark_column_name_from_spark_column_name(
242
+ spark_name, return_first=True
197
243
  )
198
- )
199
- using_columns = zip(case_corrected_left_columns, case_corrected_right_columns)
200
- # We cannot assume that Snowpark will have the same names for left and right columns,
201
- # so we convert ["a", "b"] into (left["a"] == right["a"] & left["b"] == right["b"]),
202
- # then drop right["a"] and right["b"].
203
- snowpark_using_columns = [
204
- (
205
- left_input[
206
- left_container.column_map.get_snowpark_column_name_from_spark_column_name(
207
- lft, return_first=True
208
- )
209
- ],
210
- right_input[
211
- right_container.column_map.get_snowpark_column_name_from_spark_column_name(
212
- r, return_first=True
244
+
245
+ if info.join_type == "full_outer":
246
+ new_sp_name = make_unique_snowpark_name(spark_name)
247
+ all_columns_for_select.append(
248
+ snowpark_fn.coalesce(
249
+ snowpark_fn.col(left_sp_name), snowpark_fn.col(right_sp_name)
250
+ ).alias(new_sp_name)
251
+ )
252
+ all_column_names.append(
253
+ ColumnNames(
254
+ spark_name,
255
+ new_sp_name,
256
+ set(),
257
+ equivalent_snowpark_names=set(),
258
+ is_hidden=False,
213
259
  )
214
- ],
215
- )
216
- for lft, r in using_columns
217
- ]
218
- joined_df = left_input.join(
219
- right=right_input,
220
- on=reduce(
221
- snowpark.Column.__and__,
222
- (left == right for left, right in snowpark_using_columns),
223
- ),
224
- how=join_type,
225
- rsuffix=DUPLICATED_JOIN_COL_RSUFFIX,
226
- )
227
- # If we disambiguated the snowpark_using_columns during the join, we need to update 'snowpark_using_columns' to
228
- # use the disambiguated names.
229
- disambiguated_snowpark_using_columns = []
260
+ )
230
261
 
231
- # Ignore disambiguation for LEFT SEMI JOIN and LEFT ANTI JOIN because they drop the right columns, so it'll never disambiguate.
232
- if join_type in ["leftsemi", "leftanti"]:
233
- disambiguated_snowpark_using_columns = snowpark_using_columns
234
- else:
235
- normalized_joined_columns = [
236
- unquote_if_quoted(col) for col in joined_df.columns
237
- ]
238
- # snowpark_using_columns is a list of tuples of snowpark columns, joined_df.columns is a list of strings of column names
239
- for (left, right) in snowpark_using_columns:
240
- normalized_left_name = unquote_if_quoted(left.getName())
241
- normalized_right_name = unquote_if_quoted(right.getName())
242
-
243
- # are both left and right in joined_df? if not, it's been disambiguated
244
- if (
245
- normalized_left_name in normalized_joined_columns
246
- and normalized_right_name in normalized_joined_columns
247
- ):
248
- # we want to just add this
249
- disambiguated_snowpark_using_columns.append((left, right))
250
- else:
251
- # we need to figure out the disambiguated names and add those - it only disambiguates if left == right
252
- disambiguated_left: snowpark.Column | None = None
253
- disambiguated_right: snowpark.Column | None = None
254
-
255
- for col in normalized_joined_columns:
256
- quoted_col = f'"{col}"'
257
- # get the column name and cross check it to see if it ends with the og name
258
- if col.endswith(normalized_left_name) and col.startswith("l_"):
259
- disambiguated_left = joined_df[quoted_col]
260
- elif col.endswith(normalized_right_name) and col.startswith(
261
- "r_"
262
- ):
263
- disambiguated_right = joined_df[quoted_col]
264
-
265
- # If we have both disambiguated columns, we can break out of the loop to save processing time
266
- if (
267
- disambiguated_left is not None
268
- and disambiguated_right is not None
269
- ):
270
- break
271
- if disambiguated_left is None or disambiguated_right is None:
272
- raise AnalysisException(
273
- f"Disambiguated columns not found for {normalized_left_name} and {normalized_right_name}."
262
+ for sp_name, container in [
263
+ (left_sp_name, left_container),
264
+ (right_sp_name, right_container),
265
+ ]:
266
+ all_columns_for_select.append(snowpark_fn.col(sp_name))
267
+ all_column_names.append(
268
+ ColumnNames(
269
+ spark_name,
270
+ sp_name,
271
+ container.column_map.get_qualifiers_for_snowpark_column(
272
+ sp_name
273
+ ),
274
+ equivalent_snowpark_names=container.column_map.get_equivalent_snowpark_names_for_snowpark_name(
275
+ sp_name
276
+ ),
277
+ is_hidden=True,
278
+ )
279
+ )
280
+ else:
281
+ for sp_name, container, side in [
282
+ (left_sp_name, left_container, "left"),
283
+ (right_sp_name, right_container, "right"),
284
+ ]:
285
+ all_columns_for_select.append(snowpark_fn.col(sp_name))
286
+ qualifiers = (
287
+ container.column_map.get_qualifiers_for_snowpark_column(sp_name)
288
+ )
289
+ equivalent_snowpark_names = set()
290
+ equivalent_snowpark_names.update(
291
+ container.column_map.get_equivalent_snowpark_names_for_snowpark_name(
292
+ sp_name
293
+ )
294
+ )
295
+ is_visible = info.join_type == side
296
+ if is_visible:
297
+ qualifiers = qualifiers | {ColumnQualifier(())}
298
+ all_column_names.append(
299
+ ColumnNames(
300
+ spark_name,
301
+ sp_name,
302
+ qualifiers,
303
+ equivalent_snowpark_names=equivalent_snowpark_names,
304
+ is_hidden=not is_visible,
274
305
  )
275
- disambiguated_snowpark_using_columns.append(
276
- (disambiguated_left, disambiguated_right)
277
306
  )
278
307
 
279
- # For outer joins, we need to preserve join keys from both sides using COALESCE
280
- """
281
- CHANGES:
282
- - IF CASE
283
- - Need to drop the using columns
284
- - Need to create the hidden_columns DF with the using columns from right and left
285
- - ELSE CASE
286
- - Need to drop the right side using columns
287
- - Need to create the hidden_columns DF with the using columns from right
288
- """
289
- if join_type == "full_outer":
290
- coalesced_columns = []
291
- for i, (left_col, _right_col) in enumerate(snowpark_using_columns):
292
- # Use the original user-specified column name to preserve case sensitivity
293
- # Use the disambiguated columns for coalescing
294
- disambiguated_left_col = disambiguated_snowpark_using_columns[i][0]
295
- disambiguated_right_col = disambiguated_snowpark_using_columns[i][1]
296
-
297
- coalesced_col = snowpark_fn.coalesce(
298
- disambiguated_left_col, disambiguated_right_col
299
- ).alias(left_col.get_name())
300
- coalesced_columns.append(coalesced_col)
301
-
302
- # Create HiddenColumn objects for each hidden column
303
- hidden_left = HiddenColumn(
304
- hidden_snowpark_name=disambiguated_left_col.getName(),
305
- spark_name=case_corrected_left_columns[i],
306
- visible_snowpark_name=left_col.get_name(),
307
- qualifiers=left_container.column_map.get_qualifier_for_spark_column(
308
- case_corrected_left_columns[i]
309
- ),
310
- original_position=left_container.column_map.get_spark_columns().index(
311
- case_corrected_left_columns[i]
312
- ),
313
- )
314
-
315
- hidden_right = HiddenColumn(
316
- hidden_snowpark_name=disambiguated_right_col.getName(),
317
- spark_name=case_corrected_right_columns[i],
318
- visible_snowpark_name=left_col.get_name(),
319
- qualifiers=right_container.column_map.get_qualifier_for_spark_column(
320
- case_corrected_right_columns[i]
321
- ),
322
- original_position=right_container.column_map.get_spark_columns().index(
323
- case_corrected_right_columns[i]
324
- ),
325
- )
326
- hidden_columns.update(
327
- [
328
- hidden_left,
329
- hidden_right,
330
- ]
331
- )
308
+ for c in columns[len(join_columns) :]:
309
+ all_columns_for_select.append(snowpark_fn.col(c.snowpark_name))
310
+ all_column_names.append(c)
332
311
 
333
- # All non-hidden columns (not including the coalesced columns)
334
- other_columns = [
335
- snowpark_fn.col(col_name)
336
- for col_name in joined_df.columns
337
- if col_name not in [col.hidden_snowpark_name for col in hidden_columns]
338
- ]
339
- result = joined_df.select(coalesced_columns + other_columns)
312
+ result = joined_df.select(all_columns_for_select)
313
+ snowpark_names_for_schema = [c.snowpark_name for c in columns]
340
314
 
341
- else:
342
- result = joined_df.drop(*(right for _, right in snowpark_using_columns))
343
- # We never run into the disambiguation case unless it's a full outer join.
344
- for i, (left_col, right_col) in enumerate(
345
- disambiguated_snowpark_using_columns
346
- ):
347
- # Only right side columns are hidden
348
- hidden_col = HiddenColumn(
349
- hidden_snowpark_name=right_col.getName(),
350
- spark_name=case_corrected_right_columns[i],
351
- visible_snowpark_name=left_col.getName(),
352
- qualifiers=right_container.column_map.get_qualifier_for_spark_column(
353
- case_corrected_right_columns[i]
354
- ),
355
- original_position=right_container.column_map.get_spark_columns().index(
356
- case_corrected_right_columns[i]
357
- ),
358
- )
359
- hidden_columns.add(hidden_col)
360
- else:
361
- if join_type != "cross" and not global_config.spark_sql_crossJoin_enabled:
362
- raise SparkException.implicit_cartesian_product("inner")
363
- result: snowpark.DataFrame = left_input.join(
364
- right=right_input,
365
- how=join_type,
315
+ return DataFrameContainer.create_with_column_mapping(
316
+ dataframe=result,
317
+ spark_column_names=[c.spark_name for c in all_column_names],
318
+ snowpark_column_names=[c.snowpark_name for c in all_column_names],
319
+ column_metadata=_combine_metadata(left_container, right_container),
320
+ column_qualifiers=[c.qualifiers for c in all_column_names],
321
+ column_is_hidden=[c.is_hidden for c in all_column_names],
322
+ cached_schema_getter=_build_joined_schema(
323
+ snowpark_names_for_schema,
324
+ left_input,
325
+ right_input,
326
+ all_column_names,
327
+ ),
328
+ equivalent_snowpark_names=[
329
+ c.equivalent_snowpark_names for c in all_column_names
330
+ ],
366
331
  )
367
332
 
368
- if join_type in ["leftanti", "leftsemi"]:
369
- # Join types that only return columns from the left side:
370
- # - LEFT SEMI JOIN: Returns left rows that have matches in right table (no right columns)
371
- # - LEFT ANTI JOIN: Returns left rows that have NO matches in right table (no right columns)
372
- # Both preserve only the columns from the left DataFrame without adding any columns from the right.
373
- spark_cols_after_join = left_container.column_map.get_spark_columns()
374
- snowpark_cols_after_join = left_container.column_map.get_snowpark_columns()
375
- snowpark_col_types = [
376
- f.datatype for f in left_container.dataframe.schema.fields
377
- ]
378
- qualifiers = left_container.column_map.get_qualifiers()
379
- elif join_type == "full_outer" and using_columns:
380
- # We want the coalesced columns to be first, followed by all the left and right columns (excluding using columns)
381
- spark_cols_after_join: list[str] = []
382
- snowpark_cols_after_join: list[str] = []
383
- snowpark_col_types: list[str] = []
384
-
385
- left_container_snowpark_columns = (
386
- left_container.column_map.get_snowpark_columns()
333
+ if info.just_left_columns:
334
+ # we just need the left columns
335
+ columns = columns[: len(left_container.column_map.columns)]
336
+ snowpark_columns = [c.snowpark_name for c in columns]
337
+ result = joined_df.select(*snowpark_columns)
338
+
339
+ return DataFrameContainer.create_with_column_mapping(
340
+ dataframe=result,
341
+ spark_column_names=[c.spark_name for c in columns],
342
+ snowpark_column_names=snowpark_columns,
343
+ column_metadata=left_container.column_map.column_metadata,
344
+ column_qualifiers=[c.qualifiers for c in columns],
345
+ cached_schema_getter=_build_joined_schema(
346
+ snowpark_columns, left_input, right_input
347
+ ),
348
+ equivalent_snowpark_names=[c.equivalent_snowpark_names for c in columns],
387
349
  )
388
- right_container_snowpark_columns = (
389
- right_container.column_map.get_snowpark_columns()
350
+
351
+ snowpark_columns = [c.snowpark_name for c in columns]
352
+ result = joined_df.select(*snowpark_columns)
353
+ return DataFrameContainer.create_with_column_mapping(
354
+ dataframe=result,
355
+ spark_column_names=[c.spark_name for c in columns],
356
+ snowpark_column_names=snowpark_columns,
357
+ column_metadata=_combine_metadata(left_container, right_container),
358
+ column_qualifiers=[c.qualifiers for c in columns],
359
+ cached_schema_getter=_build_joined_schema(
360
+ snowpark_columns, left_input, right_input
361
+ ),
362
+ equivalent_snowpark_names=[c.equivalent_snowpark_names for c in columns],
363
+ )
364
+
365
+
366
+ def _join_using_condition(
367
+ left_container: DataFrameContainer,
368
+ right_container: DataFrameContainer,
369
+ info: JoinInfo,
370
+ rel: relation_proto.Relation,
371
+ original_right_container: Optional[DataFrameContainer],
372
+ ) -> DataFrameContainer:
373
+ left_columns = left_container.column_map.get_spark_columns()
374
+ right_columns = right_container.column_map.get_spark_columns()
375
+
376
+ left_input = left_container.dataframe
377
+ right_input = right_container.dataframe
378
+
379
+ # All PySpark join types are in the format of JOIN_TYPE_XXX.
380
+ # We remove the first 10 characters (JOIN_TYPE_) and replace all underscores with spaces to match the exception.
381
+ pyspark_join_type = relation_proto.Join.JoinType.Name(rel.join.join_type)[
382
+ 10:
383
+ ].replace("_", " ")
384
+ with push_sql_scope(), push_evaluating_join_condition(
385
+ pyspark_join_type, left_columns, right_columns
386
+ ):
387
+ if left_container.alias is not None:
388
+ set_sql_plan_name(left_container.alias, rel.join.left.common.plan_id)
389
+ if right_container.alias is not None:
390
+ set_sql_plan_name(right_container.alias, rel.join.right.common.plan_id)
391
+ # resolve join condition expression
392
+ _, join_expression = map_single_column_expression(
393
+ rel.join.join_condition,
394
+ column_mapping=JoinColumnNameMap(
395
+ left_container.column_map,
396
+ # using the original (not disambiguated) right container is intended to break
397
+ # self join cases like a.join(a, a.id == a.id), since SAS can't handle them correctly
398
+ # and they fail in Spark Connect
399
+ (
400
+ original_right_container
401
+ if original_right_container
402
+ else right_container
403
+ ).column_map,
404
+ ),
405
+ typer=JoinExpressionTyper(left_input, right_input),
390
406
  )
391
407
 
392
- qualifiers = []
393
- for i in range(len(case_corrected_left_columns)):
394
- spark_cols_after_join.append(case_corrected_left_columns[i])
395
- snowpark_cols_after_join.append(using_columns_snowpark_names[i])
396
- snowpark_col_types.append(using_columns_snowpark_types[i])
397
- qualifiers.append([])
398
-
399
- # Handle adding left and right columns, excluding the using columns
400
- for i, spark_col in enumerate(left_container.column_map.get_spark_columns()):
401
- if (
402
- spark_col not in case_corrected_left_columns
403
- or spark_col in left_container.column_map.get_spark_columns()[:i]
404
- ):
405
- spark_cols_after_join.append(spark_col)
406
- snowpark_cols_after_join.append(left_container_snowpark_columns[i])
407
- qualifiers.append(
408
- left_container.column_map.get_qualifier_for_spark_column(spark_col)
409
- )
408
+ result: snowpark.DataFrame = left_input.join(
409
+ right=right_input,
410
+ on=join_expression.col,
411
+ how=info.join_type,
412
+ )
410
413
 
411
- snowpark_col_types.append(
412
- left_container.dataframe.schema.fields[i].datatype
413
- )
414
+ # early return for joinWith
415
+ if info.is_join_with:
416
+ return _join_with(left_container, right_container, result, info)
414
417
 
415
- for i, spark_col in enumerate(right_container.column_map.get_spark_columns()):
416
- if (
417
- spark_col not in case_corrected_right_columns
418
- or spark_col in right_container.column_map.get_spark_columns()[:i]
419
- ):
420
- spark_cols_after_join.append(spark_col)
421
- snowpark_cols_after_join.append(right_container_snowpark_columns[i])
422
- qualifiers.append(
423
- right_container.column_map.get_qualifier_for_spark_column(spark_col)
424
- )
418
+ # column order is already correct, so we just take the left + right side list
419
+ columns = left_container.column_map.columns + right_container.column_map.columns
420
+ column_metadata = _combine_metadata(left_container, right_container)
425
421
 
426
- snowpark_col_types.append(
427
- right_container.dataframe.schema.fields[i].datatype
428
- )
422
+ if info.just_left_columns:
423
+ # we just need left-side columns
424
+ columns = left_container.column_map.columns
425
+ result = result.select(*[c.snowpark_name for c in columns])
426
+ column_metadata = left_container.column_map.column_metadata
429
427
 
430
- else:
431
- spark_cols_after_join = left_container.column_map.get_spark_columns()
432
- snowpark_cols_after_join = left_container.column_map.get_snowpark_columns()
433
- snowpark_col_types = [
434
- f.datatype for f in left_container.dataframe.schema.fields
435
- ]
428
+ snowpark_columns = [c.snowpark_name for c in columns]
436
429
 
437
- qualifiers = left_container.column_map.get_qualifiers()
430
+ return DataFrameContainer.create_with_column_mapping(
431
+ dataframe=result,
432
+ spark_column_names=[c.spark_name for c in columns],
433
+ snowpark_column_names=snowpark_columns,
434
+ column_metadata=column_metadata,
435
+ column_qualifiers=[c.qualifiers for c in columns],
436
+ cached_schema_getter=_build_joined_schema(
437
+ snowpark_columns, left_input, right_input
438
+ ),
439
+ equivalent_snowpark_names=[c.equivalent_snowpark_names for c in columns],
440
+ )
438
441
 
439
- right_df_snowpark_columns = right_container.column_map.get_snowpark_columns()
440
442
 
441
- for i, spark_col in enumerate(right_container.column_map.get_spark_columns()):
442
- if (
443
- spark_col not in case_corrected_right_columns
444
- or spark_col in right_container.column_map.get_spark_columns()[:i]
445
- ):
446
- spark_cols_after_join.append(spark_col)
447
- snowpark_cols_after_join.append(right_df_snowpark_columns[i])
448
- snowpark_col_types.append(
449
- right_container.dataframe.schema.fields[i].datatype
450
- )
443
+ def _join_with(
444
+ left_container: DataFrameContainer,
445
+ right_container: DataFrameContainer,
446
+ joined_df: DataFrame,
447
+ info: JoinInfo,
448
+ ) -> DataFrameContainer:
449
+ # joinWith always returns 2 columns
450
+ left_column = "_1"
451
+ right_column = "_2"
452
+ left_snowpark_name: str = make_unique_snowpark_name(left_column)
453
+ right_snowpark_name: str = make_unique_snowpark_name(right_column)
451
454
 
452
- qualifiers.append(
453
- right_container.column_map.get_qualifier_for_spark_column(spark_col)
454
- )
455
+ left_nullable, right_nullable = _join_with_nullability(info.join_type)
455
456
 
456
- snowpark_cols_after_join_deduplicated = []
457
- snowpark_cols_after_join_counter = Counter(snowpark_cols_after_join)
458
- seen_duplicated_columns = set()
457
+ left_col, left_col_type = _construct_join_with_column(
458
+ left_container, left_snowpark_name, info.is_left_struct
459
+ )
460
+ right_col, right_col_type = _construct_join_with_column(
461
+ right_container, right_snowpark_name, info.is_right_struct
462
+ )
459
463
 
460
- for col in snowpark_cols_after_join:
461
- if snowpark_cols_after_join_counter[col] == 2:
462
- # This means that the same column exists twice in the joined df, likely due to a self-join and
463
- # we need to lsuffix and rsuffix to the names of both columns, similar to what Snowpark did under the hood.
464
+ result = joined_df.select(left_col, right_col)
464
465
 
465
- suffix = (
466
- DUPLICATED_JOIN_COL_RSUFFIX
467
- if col in seen_duplicated_columns
468
- else DUPLICATED_JOIN_COL_LSUFFIX
469
- )
470
- unquoted_col = unquote_if_quoted(col)
471
- quoted = quote_name_without_upper_casing(unquoted_col + suffix)
472
- snowpark_cols_after_join_deduplicated.append(quoted)
466
+ def _schema_getter() -> StructType:
467
+ return StructType(
468
+ [
469
+ StructField(left_snowpark_name, left_col_type, left_nullable),
470
+ StructField(right_snowpark_name, right_col_type, right_nullable),
471
+ ]
472
+ )
473
+
474
+ return DataFrameContainer.create_with_column_mapping(
475
+ dataframe=result,
476
+ spark_column_names=[left_column, right_column],
477
+ snowpark_column_names=[left_snowpark_name, right_snowpark_name],
478
+ cached_schema_getter=_schema_getter,
479
+ column_metadata={}, # no top-level metadata for struct columns
480
+ # no qualifiers or equivalent snowpark names
481
+ )
482
+
483
+
484
+ def _get_join_info(
485
+ rel: relation_proto.Relation, left: DataFrameContainer, right: DataFrameContainer
486
+ ) -> JoinInfo:
487
+ """
488
+ Gathers basic information about the join, and performs basic assertions
489
+ """
490
+
491
+ is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
492
+ join_columns = rel.join.using_columns
493
+ if is_natural_join:
494
+ rel.join.join_type -= NATURAL_JOIN_TYPE_BASE
495
+ left_spark_columns = left.column_map.get_spark_columns()
496
+ right_spark_columns = right.column_map.get_spark_columns()
497
+ common_spark_columns = [
498
+ x for x in left_spark_columns if x in right_spark_columns
499
+ ]
500
+ join_columns = common_spark_columns
473
501
 
474
- seen_duplicated_columns.add(col)
502
+ match rel.join.join_type:
503
+ case relation_proto.Join.JOIN_TYPE_UNSPECIFIED:
504
+ # TODO: Understand what UNSPECIFIED Join type is
505
+ exception = SnowparkConnectNotImplementedError("Unspecified Join Type")
506
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
507
+ raise exception
508
+ case relation_proto.Join.JOIN_TYPE_INNER:
509
+ join_type = "inner"
510
+ case relation_proto.Join.JOIN_TYPE_FULL_OUTER:
511
+ join_type = "full_outer"
512
+ case relation_proto.Join.JOIN_TYPE_LEFT_OUTER:
513
+ join_type = "left"
514
+ case relation_proto.Join.JOIN_TYPE_RIGHT_OUTER:
515
+ join_type = "right"
516
+ case relation_proto.Join.JOIN_TYPE_LEFT_ANTI:
517
+ join_type = "leftanti"
518
+ case relation_proto.Join.JOIN_TYPE_LEFT_SEMI:
519
+ join_type = "leftsemi"
520
+ case relation_proto.Join.JOIN_TYPE_CROSS:
521
+ join_type = "cross"
522
+ case other:
523
+ exception = SnowparkConnectNotImplementedError(f"Other Join Type: {other}")
524
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
525
+ raise exception
526
+
527
+ has_join_condition = rel.join.HasField("join_condition")
528
+ is_using_columns = bool(join_columns)
529
+
530
+ if join_type == "cross" and has_join_condition:
531
+ # if the user provided any condition, it's no longer a cross join
532
+ join_type = "inner"
533
+
534
+ if has_join_condition:
535
+ assert not is_using_columns
536
+
537
+ condition_type = ConditionType.NO_CONDITION
538
+ if has_join_condition:
539
+ condition_type = ConditionType.JOIN_CONDITION
540
+ elif is_using_columns:
541
+ condition_type = ConditionType.USING_COLUMNS
542
+
543
+ # Join types that only return columns from the left side:
544
+ # - LEFT SEMI JOIN: Returns left rows that have matches in right table (no right columns)
545
+ # - LEFT ANTI JOIN: Returns left rows that have NO matches in right table (no right columns)
546
+ # Both preserve only the columns from the left DataFrame without adding any columns from the right.
547
+ just_left_columns = join_type in ["leftanti", "leftsemi"]
548
+
549
+ # joinWith
550
+ is_join_with = rel.join.HasField("join_data_type")
551
+ is_left_struct = False
552
+ is_right_struct = False
553
+ if is_join_with:
554
+ is_left_struct = rel.join.join_data_type.is_left_struct
555
+ is_right_struct = rel.join.join_data_type.is_right_struct
556
+
557
+ return JoinInfo(
558
+ join_type,
559
+ condition_type,
560
+ join_columns,
561
+ just_left_columns,
562
+ is_join_with,
563
+ is_left_struct,
564
+ is_right_struct,
565
+ )
566
+
567
+
568
+ def _disambiguate_snowpark_columns(
569
+ left: DataFrameContainer, right: DataFrameContainer, right_plan: int
570
+ ) -> DataFrameContainer:
571
+ conflicting_snowpark_columns = left.column_map.get_conflicting_snowpark_columns(
572
+ right.column_map
573
+ )
574
+
575
+ if not conflicting_snowpark_columns:
576
+ return right
577
+
578
+ # rename and create new right container
579
+ column_map = right.column_map
580
+ disambiguated_columns: list[Column] = []
581
+ disambiguated_snowpark_names: list[str] = []
582
+ # retain old snowpark names in column map
583
+ equivalent_snowpark_names: list[set[str]] = []
584
+ for c in column_map.columns:
585
+ col_equivalent_snowpark_names = copy(c.equivalent_snowpark_names)
586
+ if c.snowpark_name in conflicting_snowpark_columns:
587
+ # alias snowpark column with a new unique name
588
+ new_name = make_unique_snowpark_name(c.spark_name)
589
+ disambiguated_snowpark_names.append(new_name)
590
+ disambiguated_columns.append(
591
+ snowpark_fn.col(c.snowpark_name).alias(new_name)
592
+ )
475
593
  else:
476
- snowpark_cols_after_join_deduplicated.append(col)
594
+ disambiguated_snowpark_names.append(c.snowpark_name)
595
+ disambiguated_columns.append(snowpark_fn.col(c.snowpark_name))
596
+
597
+ equivalent_snowpark_names.append(col_equivalent_snowpark_names)
598
+
599
+ disambiguated_df = right.dataframe.select(*disambiguated_columns)
600
+
601
+ def _schema_getter() -> StructType:
602
+ fields = right.dataframe.schema.fields
603
+ return StructType(
604
+ [
605
+ StructField(name, fields[i].datatype, fields[i].nullable)
606
+ for i, name in enumerate(disambiguated_snowpark_names)
607
+ ]
608
+ )
609
+
610
+ disambiguated_right = DataFrameContainer.create_with_column_mapping(
611
+ dataframe=disambiguated_df,
612
+ spark_column_names=column_map.get_spark_columns(),
613
+ snowpark_column_names=disambiguated_snowpark_names,
614
+ column_metadata=column_map.column_metadata,
615
+ column_qualifiers=column_map.get_qualifiers(),
616
+ table_name=right.table_name,
617
+ cached_schema_getter=_schema_getter,
618
+ equivalent_snowpark_names=equivalent_snowpark_names,
619
+ )
477
620
 
478
- column_metadata = {}
479
- if left_container.column_map.column_metadata:
480
- column_metadata.update(left_container.column_map.column_metadata)
621
+ # since we just renamed some snowpark columns, we need to update the dataframe container for the given plan_id
622
+ # TODO: is there a better way to do this?
623
+ if right_plan:
624
+ set_plan_id_map(right_plan, disambiguated_right)
481
625
 
626
+ return disambiguated_right
627
+
628
+
629
+ def _combine_metadata(
630
+ left_container: DataFrameContainer, right_container: DataFrameContainer
631
+ ) -> dict:
632
+ column_metadata = dict(left_container.column_map.column_metadata or {})
482
633
  if right_container.column_map.column_metadata:
483
634
  for key, value in right_container.column_map.column_metadata.items():
484
635
  if key not in column_metadata:
@@ -490,7 +641,9 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
490
641
  snowpark_name = right_container.column_map.get_snowpark_column_name_from_spark_column_name(
491
642
  key
492
643
  )
493
- expr_id = right_input[snowpark_name]._expression.expr_id
644
+ expr_id = right_container.dataframe[
645
+ snowpark_name
646
+ ]._expression.expr_id
494
647
  updated_key = COLUMN_METADATA_COLLISION_KEY.format(
495
648
  expr_id=expr_id, key=snowpark_name
496
649
  )
@@ -498,49 +651,137 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
498
651
  except Exception:
499
652
  # ignore any errors that happens while fetching the metadata
500
653
  pass
654
+ return column_metadata
655
+
656
+
657
+ def _build_joined_schema(
658
+ snowpark_columns: list[str],
659
+ left_input: DataFrame,
660
+ right_input: DataFrame,
661
+ outer_join_columns: Optional[list[ColumnNames]] = None,
662
+ ) -> Callable[[], StructType]:
663
+ """
664
+ Builds a lazy schema for the joined dataframe, based on the given snowpark_columns and input dataframes.
665
+ In case of full outer joins, we need a separate target_snowpark_columns, since join columns will have different
666
+ names in the output than in any input.
667
+ """
668
+
669
+ def _schema_getter() -> StructType:
670
+ all_fields = left_input.schema.fields + right_input.schema.fields
671
+ fields: dict[str, StructField] = {f.name: f for f in all_fields}
672
+
673
+ if outer_join_columns:
674
+ visible_columns = [c for c in outer_join_columns if not c.is_hidden]
675
+ assert len(snowpark_columns) == len(visible_columns)
676
+
677
+ result_fields = []
678
+ visible_idx = 0
679
+ for col in outer_join_columns:
680
+ if col.is_hidden:
681
+ source_field = fields[col.snowpark_name]
682
+ result_fields.append(
683
+ StructField(
684
+ col.snowpark_name,
685
+ source_field.datatype,
686
+ source_field.nullable,
687
+ )
688
+ )
689
+ else:
690
+ source_field = fields[snowpark_columns[visible_idx]]
691
+ result_fields.append(
692
+ StructField(
693
+ col.snowpark_name,
694
+ source_field.datatype,
695
+ source_field.nullable,
696
+ )
697
+ )
698
+ visible_idx += 1
501
699
 
502
- result_container = DataFrameContainer.create_with_column_mapping(
503
- dataframe=result,
504
- spark_column_names=spark_cols_after_join,
505
- snowpark_column_names=snowpark_cols_after_join_deduplicated,
506
- column_metadata=column_metadata,
507
- column_qualifiers=qualifiers,
508
- hidden_columns=hidden_columns,
509
- snowpark_column_types=snowpark_col_types,
510
- )
511
-
512
- if rel.join.using_columns:
513
- # When join 'using_columns', the 'join columns' should go first in result DF.
514
- idxs_to_shift = [
515
- spark_cols_after_join.index(left_col_name)
516
- for left_col_name in case_corrected_left_columns
517
- ]
518
-
519
- def reorder(lst: list) -> list:
520
- to_move = [lst[i] for i in idxs_to_shift]
521
- remaining = [el for i, el in enumerate(lst) if i not in idxs_to_shift]
522
- return to_move + remaining
700
+ return StructType(result_fields)
523
701
 
524
- # Create reordered DataFrame
525
- reordered_df = result_container.dataframe.select(
526
- [snowpark_fn.col(c) for c in reorder(result_container.dataframe.columns)]
702
+ return StructType(
703
+ [
704
+ StructField(name, fields[name].datatype, fields[name].nullable)
705
+ for name in snowpark_columns
706
+ ]
527
707
  )
528
708
 
529
- # Create new container with reordered metadata
530
- original_df = result_container.dataframe
531
- return DataFrameContainer.create_with_column_mapping(
532
- dataframe=reordered_df,
533
- spark_column_names=reorder(result_container.column_map.get_spark_columns()),
534
- snowpark_column_names=reorder(
535
- result_container.column_map.get_snowpark_columns()
536
- ),
537
- column_metadata=column_metadata,
538
- column_qualifiers=reorder(qualifiers),
539
- table_name=result_container.table_name,
540
- cached_schema_getter=lambda: snowpark.types.StructType(
541
- reorder(original_df.schema.fields)
542
- ),
543
- hidden_columns=hidden_columns,
709
+ return _schema_getter
710
+
711
+
712
+ def _make_struct_column(
713
+ container: DataFrameContainer, snowpark_name: str
714
+ ) -> tuple[snowpark.Column, StructType]:
715
+ column_metadata: dict = {}
716
+ for c in container.column_map.columns:
717
+ column_metadata[c.snowpark_name] = c
718
+
719
+ args: list[Column] = []
720
+ struct_fields: list[StructField] = []
721
+ for f in container.dataframe.schema.fields:
722
+ c = column_metadata[f.name]
723
+ if c.is_hidden:
724
+ continue
725
+ args.append(snowpark_fn.lit(c.spark_name))
726
+ args.append(snowpark_fn.col(c.snowpark_name))
727
+ struct_fields.append(
728
+ StructField(c.spark_name, f.datatype, f.nullable, _is_column=False)
544
729
  )
545
730
 
546
- return result_container
731
+ struct_type = StructType(struct_fields, structured=True)
732
+ struct_col: snowpark.Column = (
733
+ snowpark_fn.object_construct_keep_null(*args)
734
+ .cast(struct_type)
735
+ .alias(snowpark_name)
736
+ )
737
+ return struct_col, struct_type
738
+
739
+
740
+ def _construct_join_with_column(
741
+ container: DataFrameContainer, snowpark_name: str, is_struct: bool
742
+ ) -> tuple[Column, StructType]:
743
+ if is_struct:
744
+ return _make_struct_column(container, snowpark_name)
745
+ else:
746
+ # the dataframe must have a single field
747
+ cols = [
748
+ c.snowpark_name for c in container.column_map.columns if not c.is_hidden
749
+ ]
750
+ assert (
751
+ len(cols) == 1
752
+ ), "A non-struct dataframe must have a single column in joinWith"
753
+ field = None
754
+ for f in container.dataframe.schema.fields:
755
+ if f.name == cols[0]:
756
+ field = f
757
+ break
758
+ assert field is not None
759
+ col = snowpark_fn.col(field.name).alias(snowpark_name)
760
+ col_type = field.datatype
761
+ return col, col_type
762
+
763
+
764
+ def _join_with_nullability(join_type: str) -> tuple[bool, bool]:
765
+ """
766
+ Returns the nullability for the left and right result columns of a joinWith operation.
767
+
768
+ The tuple corresponds to (left_nullable, right_nullable) and depends on the join type:
769
+ - "inner" or "cross": both columns are non-nullable
770
+ - "left": left is non-nullable, right is nullable
771
+ - "right": left is nullable, right is non-nullable
772
+ - "full_outer": both columns are nullable
773
+
774
+ Raises:
775
+ IllegalArgumentException: If the provided join type is unsupported.
776
+ """
777
+ match join_type:
778
+ case "inner" | "cross":
779
+ return False, False
780
+ case "left":
781
+ return False, True
782
+ case "right":
783
+ return True, False
784
+ case "full_outer":
785
+ return True, True
786
+ case _:
787
+ raise IllegalArgumentException(f"Unsupported join type '{join_type}'.")