snowpark-connect 0.27.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +717 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +309 -26
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/error_utils.py +28 -0
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +224 -15
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +86 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2964 -829
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-connect-client-jvm_2.12-3.5.6.jar → spark-connect-client-jvm_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/jars/{spark-sql_2.12-3.5.6.jar → spark-sql_2.13-3.5.6.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +239 -256
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +326 -47
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +21 -6
- snowflake/snowpark_connect/relation/read/map_read_json.py +324 -86
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +146 -28
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +15 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +86 -6
- snowflake/snowpark_connect/relation/read/map_read_text.py +22 -4
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +42 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +171 -48
- snowflake/snowpark_connect/server.py +528 -473
- snowflake/snowpark_connect/server_common/__init__.py +503 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/type_support.py +130 -0
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +195 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +151 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +321 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +281 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +185 -340
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +192 -40
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +14 -4
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +16 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +139 -168
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,35 +1,50 @@
|
|
|
1
1
|
#
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
|
|
4
|
+
import dataclasses
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from copy import copy
|
|
7
|
+
from enum import Enum
|
|
5
8
|
from functools import reduce
|
|
9
|
+
from typing import Optional
|
|
6
10
|
|
|
7
11
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
8
|
-
from pyspark.errors
|
|
12
|
+
from pyspark.errors import AnalysisException
|
|
13
|
+
from pyspark.errors.exceptions.connect import IllegalArgumentException
|
|
9
14
|
|
|
10
15
|
import snowflake.snowpark.functions as snowpark_fn
|
|
11
16
|
from snowflake import snowpark
|
|
12
|
-
from snowflake.snowpark
|
|
13
|
-
|
|
14
|
-
|
|
17
|
+
from snowflake.snowpark import Column, DataFrame
|
|
18
|
+
from snowflake.snowpark.types import StructField, StructType
|
|
19
|
+
from snowflake.snowpark_connect.column_name_handler import (
|
|
20
|
+
ColumnNames,
|
|
21
|
+
ColumnQualifier,
|
|
22
|
+
JoinColumnNameMap,
|
|
23
|
+
make_unique_snowpark_name,
|
|
15
24
|
)
|
|
16
|
-
from snowflake.snowpark_connect.column_name_handler import JoinColumnNameMap
|
|
17
25
|
from snowflake.snowpark_connect.config import global_config
|
|
18
26
|
from snowflake.snowpark_connect.constants import COLUMN_METADATA_COLLISION_KEY
|
|
19
27
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
20
|
-
from snowflake.snowpark_connect.error.
|
|
28
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
29
|
+
from snowflake.snowpark_connect.error.error_utils import (
|
|
30
|
+
SparkException,
|
|
31
|
+
attach_custom_error_code,
|
|
32
|
+
)
|
|
21
33
|
from snowflake.snowpark_connect.expression.map_expression import (
|
|
22
34
|
map_single_column_expression,
|
|
23
35
|
)
|
|
24
36
|
from snowflake.snowpark_connect.expression.typer import JoinExpressionTyper
|
|
25
|
-
from snowflake.snowpark_connect.hidden_column import HiddenColumn
|
|
26
37
|
from snowflake.snowpark_connect.relation.map_relation import (
|
|
27
38
|
NATURAL_JOIN_TYPE_BASE,
|
|
28
39
|
map_relation,
|
|
29
40
|
)
|
|
41
|
+
from snowflake.snowpark_connect.relation.read.metadata_utils import (
|
|
42
|
+
without_internal_columns,
|
|
43
|
+
)
|
|
30
44
|
from snowflake.snowpark_connect.utils.context import (
|
|
31
45
|
push_evaluating_join_condition,
|
|
32
46
|
push_sql_scope,
|
|
47
|
+
set_plan_id_map,
|
|
33
48
|
set_sql_plan_name,
|
|
34
49
|
)
|
|
35
50
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
@@ -38,447 +53,583 @@ from snowflake.snowpark_connect.utils.telemetry import (
|
|
|
38
53
|
|
|
39
54
|
USING_COLUMN_NOT_FOUND_ERROR = "[UNRESOLVED_USING_COLUMN_FOR_JOIN] USING column `{0}` not found on the {1} side of the join. The {1}-side columns: {2}"
|
|
40
55
|
|
|
41
|
-
|
|
42
|
-
|
|
56
|
+
|
|
57
|
+
class ConditionType(Enum):
|
|
58
|
+
USING_COLUMNS = 1
|
|
59
|
+
JOIN_CONDITION = 2
|
|
60
|
+
NO_CONDITION = 3
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclasses.dataclass
|
|
64
|
+
class JoinInfo:
|
|
65
|
+
join_type: str
|
|
66
|
+
condition_type: ConditionType
|
|
67
|
+
join_columns: Optional[list[str]]
|
|
68
|
+
just_left_columns: bool
|
|
69
|
+
is_join_with: bool
|
|
70
|
+
is_left_struct: bool
|
|
71
|
+
is_right_struct: bool
|
|
72
|
+
|
|
73
|
+
def is_using_columns(self):
|
|
74
|
+
return self.condition_type == ConditionType.USING_COLUMNS
|
|
43
75
|
|
|
44
76
|
|
|
45
77
|
def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
46
78
|
left_container: DataFrameContainer = map_relation(rel.join.left)
|
|
47
79
|
right_container: DataFrameContainer = map_relation(rel.join.right)
|
|
48
80
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
if is_natural_join:
|
|
54
|
-
rel.join.join_type -= NATURAL_JOIN_TYPE_BASE
|
|
55
|
-
left_spark_columns = left_container.column_map.get_spark_columns()
|
|
56
|
-
right_spark_columns = right_container.column_map.get_spark_columns()
|
|
57
|
-
common_spark_columns = [
|
|
58
|
-
x for x in left_spark_columns if x in right_spark_columns
|
|
59
|
-
]
|
|
60
|
-
using_columns = common_spark_columns
|
|
81
|
+
# Remove any metadata columns(like metada$filename) present in the dataframes.
|
|
82
|
+
# We cannot support inputfilename for multisources as each dataframe has it's own source.
|
|
83
|
+
left_container = without_internal_columns(left_container)
|
|
84
|
+
right_container = without_internal_columns(right_container)
|
|
61
85
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
case other:
|
|
81
|
-
raise SnowparkConnectNotImplementedError(f"Other Join Type: {other}")
|
|
82
|
-
|
|
83
|
-
# This handles case sensitivity for using_columns
|
|
84
|
-
case_corrected_right_columns: list[str] = []
|
|
85
|
-
hidden_columns = set()
|
|
86
|
-
# Propagate the hidden columns from left/right inputs to the result in case of chained joins
|
|
87
|
-
if left_container.column_map.hidden_columns:
|
|
88
|
-
hidden_columns.update(left_container.column_map.hidden_columns)
|
|
89
|
-
|
|
90
|
-
if right_container.column_map.hidden_columns:
|
|
91
|
-
hidden_columns.update(right_container.column_map.hidden_columns)
|
|
92
|
-
|
|
93
|
-
if rel.join.HasField("join_condition"):
|
|
94
|
-
assert not using_columns
|
|
95
|
-
|
|
96
|
-
left_columns = list(left_container.column_map.spark_to_col.keys())
|
|
97
|
-
right_columns = list(right_container.column_map.spark_to_col.keys())
|
|
98
|
-
|
|
99
|
-
# All PySpark join types are in the format of JOIN_TYPE_XXX.
|
|
100
|
-
# We remove the first 10 characters (JOIN_TYPE_) and replace all underscores with spaces to match the exception.
|
|
101
|
-
pyspark_join_type = relation_proto.Join.JoinType.Name(rel.join.join_type)[
|
|
102
|
-
10:
|
|
103
|
-
].replace("_", " ")
|
|
104
|
-
with push_sql_scope(), push_evaluating_join_condition(
|
|
105
|
-
pyspark_join_type, left_columns, right_columns
|
|
106
|
-
):
|
|
107
|
-
if left_container.alias is not None:
|
|
108
|
-
set_sql_plan_name(left_container.alias, rel.join.left.common.plan_id)
|
|
109
|
-
if right_container.alias is not None:
|
|
110
|
-
set_sql_plan_name(right_container.alias, rel.join.right.common.plan_id)
|
|
111
|
-
_, join_expression = map_single_column_expression(
|
|
112
|
-
rel.join.join_condition,
|
|
113
|
-
column_mapping=JoinColumnNameMap(
|
|
114
|
-
left_container.column_map,
|
|
115
|
-
right_container.column_map,
|
|
116
|
-
),
|
|
117
|
-
typer=JoinExpressionTyper(left_input, right_input),
|
|
86
|
+
left_plan = rel.join.left.common.plan_id
|
|
87
|
+
right_plan = rel.join.right.common.plan_id
|
|
88
|
+
|
|
89
|
+
# if there are any conflicting snowpark columns, this is the time to rename them
|
|
90
|
+
disambiguated_right_container = _disambiguate_snowpark_columns(
|
|
91
|
+
left_container, right_container, right_plan if left_plan != right_plan else None
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
join_info = _get_join_info(rel, left_container, disambiguated_right_container)
|
|
95
|
+
|
|
96
|
+
match join_info.condition_type:
|
|
97
|
+
case ConditionType.JOIN_CONDITION:
|
|
98
|
+
result_container = _join_using_condition(
|
|
99
|
+
left_container,
|
|
100
|
+
disambiguated_right_container,
|
|
101
|
+
join_info,
|
|
102
|
+
rel,
|
|
103
|
+
right_container if left_plan == right_plan else None,
|
|
118
104
|
)
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
rsuffix=DUPLICATED_JOIN_COL_RSUFFIX,
|
|
125
|
-
)
|
|
126
|
-
elif using_columns:
|
|
127
|
-
if any(
|
|
128
|
-
left_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
129
|
-
c, allow_non_exists=True, return_first=True
|
|
105
|
+
case ConditionType.USING_COLUMNS:
|
|
106
|
+
result_container = _join_using_columns(
|
|
107
|
+
left_container,
|
|
108
|
+
disambiguated_right_container,
|
|
109
|
+
join_info,
|
|
130
110
|
)
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
import pyspark
|
|
135
|
-
|
|
136
|
-
raise pyspark.errors.AnalysisException(
|
|
137
|
-
USING_COLUMN_NOT_FOUND_ERROR.format(
|
|
138
|
-
next(
|
|
139
|
-
c
|
|
140
|
-
for c in using_columns
|
|
141
|
-
if left_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
142
|
-
c, allow_non_exists=True, return_first=True
|
|
143
|
-
)
|
|
144
|
-
is None
|
|
145
|
-
),
|
|
146
|
-
"left",
|
|
147
|
-
left_container.column_map.get_spark_columns(),
|
|
148
|
-
)
|
|
111
|
+
case _:
|
|
112
|
+
result_container = _join_unconditionally(
|
|
113
|
+
left_container, disambiguated_right_container, join_info
|
|
149
114
|
)
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
115
|
+
|
|
116
|
+
return result_container
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _join_unconditionally(
|
|
120
|
+
left_container: DataFrameContainer,
|
|
121
|
+
right_container: DataFrameContainer,
|
|
122
|
+
info: JoinInfo,
|
|
123
|
+
) -> DataFrameContainer:
|
|
124
|
+
if info.join_type != "cross" and not global_config.spark_sql_crossJoin_enabled:
|
|
125
|
+
exception = SparkException.implicit_cartesian_product("inner")
|
|
126
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
127
|
+
raise exception
|
|
128
|
+
|
|
129
|
+
left_input = left_container.dataframe
|
|
130
|
+
right_input = right_container.dataframe
|
|
131
|
+
join_type = info.join_type
|
|
132
|
+
|
|
133
|
+
# For outer joins without a condition, we need to use a TRUE condition
|
|
134
|
+
# to match Spark's behavior.
|
|
135
|
+
result: snowpark.DataFrame = left_input.join(
|
|
136
|
+
right=right_input,
|
|
137
|
+
on=snowpark_fn.lit(True)
|
|
138
|
+
if join_type in ["left", "right", "full_outer"]
|
|
139
|
+
else None,
|
|
140
|
+
how=join_type,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
columns = left_container.column_map.columns + right_container.column_map.columns
|
|
144
|
+
column_metadata = _combine_metadata(left_container, right_container)
|
|
145
|
+
|
|
146
|
+
if info.just_left_columns:
|
|
147
|
+
columns = left_container.column_map.columns
|
|
148
|
+
column_metadata = left_container.column_map.column_metadata
|
|
149
|
+
result = result.select(*left_container.column_map.get_snowpark_columns())
|
|
150
|
+
|
|
151
|
+
snowpark_columns = [c.snowpark_name for c in columns]
|
|
152
|
+
|
|
153
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
154
|
+
dataframe=result,
|
|
155
|
+
spark_column_names=[c.spark_name for c in columns],
|
|
156
|
+
snowpark_column_names=snowpark_columns,
|
|
157
|
+
column_metadata=column_metadata,
|
|
158
|
+
column_qualifiers=[c.qualifiers for c in columns],
|
|
159
|
+
cached_schema_getter=_build_joined_schema(
|
|
160
|
+
snowpark_columns, left_input, right_input
|
|
161
|
+
),
|
|
162
|
+
equivalent_snowpark_names=[c.equivalent_snowpark_names for c in columns],
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _join_using_columns(
|
|
167
|
+
left_container: DataFrameContainer,
|
|
168
|
+
right_container: DataFrameContainer,
|
|
169
|
+
info: JoinInfo,
|
|
170
|
+
) -> DataFrameContainer:
|
|
171
|
+
join_columns = info.join_columns
|
|
172
|
+
|
|
173
|
+
def _validate_using_column(
|
|
174
|
+
column: str, container: DataFrameContainer, side: str
|
|
175
|
+
) -> None:
|
|
176
|
+
if (
|
|
177
|
+
container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
178
|
+
column, allow_non_exists=True, return_first=True
|
|
153
179
|
)
|
|
154
180
|
is None
|
|
155
|
-
for c in using_columns
|
|
156
181
|
):
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
raise pyspark.errors.AnalysisException(
|
|
182
|
+
exception = AnalysisException(
|
|
160
183
|
USING_COLUMN_NOT_FOUND_ERROR.format(
|
|
161
|
-
|
|
162
|
-
c
|
|
163
|
-
for c in using_columns
|
|
164
|
-
if right_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
165
|
-
c, allow_non_exists=True, return_first=True
|
|
166
|
-
)
|
|
167
|
-
is None
|
|
168
|
-
),
|
|
169
|
-
"right",
|
|
170
|
-
right_container.column_map.get_spark_columns(),
|
|
184
|
+
column, side, container.column_map.get_spark_columns()
|
|
171
185
|
)
|
|
172
186
|
)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
187
|
+
attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
|
|
188
|
+
raise exception
|
|
189
|
+
|
|
190
|
+
for col in join_columns:
|
|
191
|
+
_validate_using_column(col, left_container, "left")
|
|
192
|
+
_validate_using_column(col, right_container, "right")
|
|
193
|
+
|
|
194
|
+
left_input = left_container.dataframe
|
|
195
|
+
right_input = right_container.dataframe
|
|
196
|
+
|
|
197
|
+
# The inputs will have different snowpark names for the same spark name,
|
|
198
|
+
# so we convert ["a", "b"] into (left["a"] == right["a"] & left["b"] == right["b"]),
|
|
199
|
+
# then drop right["a"] and right["b"].
|
|
200
|
+
snowpark_using_columns = [
|
|
201
|
+
(
|
|
202
|
+
snowpark_fn.col(
|
|
203
|
+
left_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
204
|
+
spark_name, return_first=True
|
|
205
|
+
)
|
|
206
|
+
),
|
|
207
|
+
snowpark_fn.col(
|
|
208
|
+
right_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
209
|
+
spark_name, return_first=True
|
|
210
|
+
)
|
|
211
|
+
),
|
|
178
212
|
)
|
|
213
|
+
for spark_name in join_columns
|
|
214
|
+
]
|
|
215
|
+
|
|
216
|
+
# this is a condition join, so it will contain left + right columns
|
|
217
|
+
# we need to postprocess this later to have a correct projection
|
|
218
|
+
joined_df = left_input.join(
|
|
219
|
+
right=right_input,
|
|
220
|
+
on=reduce(
|
|
221
|
+
snowpark.Column.__and__,
|
|
222
|
+
(left == right for left, right in snowpark_using_columns),
|
|
223
|
+
),
|
|
224
|
+
how=info.join_type,
|
|
225
|
+
)
|
|
179
226
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
227
|
+
# figure out default column ordering after the join
|
|
228
|
+
columns = left_container.column_map.get_columns_after_join(
|
|
229
|
+
right_container.column_map, join_columns, info.join_type
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if info.join_type in ["full_outer", "left", "right"]:
|
|
233
|
+
all_columns_for_select = []
|
|
234
|
+
all_column_names = []
|
|
185
235
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
left_container.column_map.get_spark_column_names_from_snowpark_column_names(
|
|
191
|
-
using_columns_snowpark_names
|
|
236
|
+
for column_info in columns[: len(join_columns)]:
|
|
237
|
+
spark_name = column_info.spark_name
|
|
238
|
+
left_sp_name = left_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
239
|
+
spark_name, return_first=True
|
|
192
240
|
)
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
right_container.column_map.get_snowpark_column_names_from_spark_column_names(
|
|
196
|
-
list(using_columns), return_first=True
|
|
241
|
+
right_sp_name = right_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
242
|
+
spark_name, return_first=True
|
|
197
243
|
)
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
244
|
+
|
|
245
|
+
if info.join_type == "full_outer":
|
|
246
|
+
new_sp_name = make_unique_snowpark_name(spark_name)
|
|
247
|
+
all_columns_for_select.append(
|
|
248
|
+
snowpark_fn.coalesce(
|
|
249
|
+
snowpark_fn.col(left_sp_name), snowpark_fn.col(right_sp_name)
|
|
250
|
+
).alias(new_sp_name)
|
|
251
|
+
)
|
|
252
|
+
all_column_names.append(
|
|
253
|
+
ColumnNames(
|
|
254
|
+
spark_name,
|
|
255
|
+
new_sp_name,
|
|
256
|
+
set(),
|
|
257
|
+
equivalent_snowpark_names=set(),
|
|
258
|
+
is_hidden=False,
|
|
213
259
|
)
|
|
214
|
-
|
|
215
|
-
)
|
|
216
|
-
for lft, r in using_columns
|
|
217
|
-
]
|
|
218
|
-
joined_df = left_input.join(
|
|
219
|
-
right=right_input,
|
|
220
|
-
on=reduce(
|
|
221
|
-
snowpark.Column.__and__,
|
|
222
|
-
(left == right for left, right in snowpark_using_columns),
|
|
223
|
-
),
|
|
224
|
-
how=join_type,
|
|
225
|
-
rsuffix=DUPLICATED_JOIN_COL_RSUFFIX,
|
|
226
|
-
)
|
|
227
|
-
# If we disambiguated the snowpark_using_columns during the join, we need to update 'snowpark_using_columns' to
|
|
228
|
-
# use the disambiguated names.
|
|
229
|
-
disambiguated_snowpark_using_columns = []
|
|
260
|
+
)
|
|
230
261
|
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
)
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
262
|
+
for sp_name, container in [
|
|
263
|
+
(left_sp_name, left_container),
|
|
264
|
+
(right_sp_name, right_container),
|
|
265
|
+
]:
|
|
266
|
+
all_columns_for_select.append(snowpark_fn.col(sp_name))
|
|
267
|
+
all_column_names.append(
|
|
268
|
+
ColumnNames(
|
|
269
|
+
spark_name,
|
|
270
|
+
sp_name,
|
|
271
|
+
container.column_map.get_qualifiers_for_snowpark_column(
|
|
272
|
+
sp_name
|
|
273
|
+
),
|
|
274
|
+
equivalent_snowpark_names=container.column_map.get_equivalent_snowpark_names_for_snowpark_name(
|
|
275
|
+
sp_name
|
|
276
|
+
),
|
|
277
|
+
is_hidden=True,
|
|
278
|
+
)
|
|
279
|
+
)
|
|
280
|
+
else:
|
|
281
|
+
for sp_name, container, side in [
|
|
282
|
+
(left_sp_name, left_container, "left"),
|
|
283
|
+
(right_sp_name, right_container, "right"),
|
|
284
|
+
]:
|
|
285
|
+
all_columns_for_select.append(snowpark_fn.col(sp_name))
|
|
286
|
+
qualifiers = (
|
|
287
|
+
container.column_map.get_qualifiers_for_snowpark_column(sp_name)
|
|
288
|
+
)
|
|
289
|
+
equivalent_snowpark_names = set()
|
|
290
|
+
equivalent_snowpark_names.update(
|
|
291
|
+
container.column_map.get_equivalent_snowpark_names_for_snowpark_name(
|
|
292
|
+
sp_name
|
|
293
|
+
)
|
|
294
|
+
)
|
|
295
|
+
is_visible = info.join_type == side
|
|
296
|
+
if is_visible:
|
|
297
|
+
qualifiers = qualifiers | {ColumnQualifier(())}
|
|
298
|
+
all_column_names.append(
|
|
299
|
+
ColumnNames(
|
|
300
|
+
spark_name,
|
|
301
|
+
sp_name,
|
|
302
|
+
qualifiers,
|
|
303
|
+
equivalent_snowpark_names=equivalent_snowpark_names,
|
|
304
|
+
is_hidden=not is_visible,
|
|
274
305
|
)
|
|
275
|
-
disambiguated_snowpark_using_columns.append(
|
|
276
|
-
(disambiguated_left, disambiguated_right)
|
|
277
306
|
)
|
|
278
307
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
- IF CASE
|
|
283
|
-
- Need to drop the using columns
|
|
284
|
-
- Need to create the hidden_columns DF with the using columns from right and left
|
|
285
|
-
- ELSE CASE
|
|
286
|
-
- Need to drop the right side using columns
|
|
287
|
-
- Need to create the hidden_columns DF with the using columns from right
|
|
288
|
-
"""
|
|
289
|
-
if join_type == "full_outer":
|
|
290
|
-
coalesced_columns = []
|
|
291
|
-
for i, (left_col, _right_col) in enumerate(snowpark_using_columns):
|
|
292
|
-
# Use the original user-specified column name to preserve case sensitivity
|
|
293
|
-
# Use the disambiguated columns for coalescing
|
|
294
|
-
disambiguated_left_col = disambiguated_snowpark_using_columns[i][0]
|
|
295
|
-
disambiguated_right_col = disambiguated_snowpark_using_columns[i][1]
|
|
296
|
-
|
|
297
|
-
coalesced_col = snowpark_fn.coalesce(
|
|
298
|
-
disambiguated_left_col, disambiguated_right_col
|
|
299
|
-
).alias(left_col.get_name())
|
|
300
|
-
coalesced_columns.append(coalesced_col)
|
|
301
|
-
|
|
302
|
-
# Create HiddenColumn objects for each hidden column
|
|
303
|
-
hidden_left = HiddenColumn(
|
|
304
|
-
hidden_snowpark_name=disambiguated_left_col.getName(),
|
|
305
|
-
spark_name=case_corrected_left_columns[i],
|
|
306
|
-
visible_snowpark_name=left_col.get_name(),
|
|
307
|
-
qualifiers=left_container.column_map.get_qualifier_for_spark_column(
|
|
308
|
-
case_corrected_left_columns[i]
|
|
309
|
-
),
|
|
310
|
-
original_position=left_container.column_map.get_spark_columns().index(
|
|
311
|
-
case_corrected_left_columns[i]
|
|
312
|
-
),
|
|
313
|
-
)
|
|
314
|
-
|
|
315
|
-
hidden_right = HiddenColumn(
|
|
316
|
-
hidden_snowpark_name=disambiguated_right_col.getName(),
|
|
317
|
-
spark_name=case_corrected_right_columns[i],
|
|
318
|
-
visible_snowpark_name=left_col.get_name(),
|
|
319
|
-
qualifiers=right_container.column_map.get_qualifier_for_spark_column(
|
|
320
|
-
case_corrected_right_columns[i]
|
|
321
|
-
),
|
|
322
|
-
original_position=right_container.column_map.get_spark_columns().index(
|
|
323
|
-
case_corrected_right_columns[i]
|
|
324
|
-
),
|
|
325
|
-
)
|
|
326
|
-
hidden_columns.update(
|
|
327
|
-
[
|
|
328
|
-
hidden_left,
|
|
329
|
-
hidden_right,
|
|
330
|
-
]
|
|
331
|
-
)
|
|
308
|
+
for c in columns[len(join_columns) :]:
|
|
309
|
+
all_columns_for_select.append(snowpark_fn.col(c.snowpark_name))
|
|
310
|
+
all_column_names.append(c)
|
|
332
311
|
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
snowpark_fn.col(col_name)
|
|
336
|
-
for col_name in joined_df.columns
|
|
337
|
-
if col_name not in [col.hidden_snowpark_name for col in hidden_columns]
|
|
338
|
-
]
|
|
339
|
-
result = joined_df.select(coalesced_columns + other_columns)
|
|
312
|
+
result = joined_df.select(all_columns_for_select)
|
|
313
|
+
snowpark_names_for_schema = [c.snowpark_name for c in columns]
|
|
340
314
|
|
|
341
|
-
|
|
342
|
-
result
|
|
343
|
-
|
|
344
|
-
for
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
),
|
|
358
|
-
)
|
|
359
|
-
hidden_columns.add(hidden_col)
|
|
360
|
-
else:
|
|
361
|
-
if join_type != "cross" and not global_config.spark_sql_crossJoin_enabled:
|
|
362
|
-
raise SparkException.implicit_cartesian_product("inner")
|
|
363
|
-
result: snowpark.DataFrame = left_input.join(
|
|
364
|
-
right=right_input,
|
|
365
|
-
how=join_type,
|
|
315
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
316
|
+
dataframe=result,
|
|
317
|
+
spark_column_names=[c.spark_name for c in all_column_names],
|
|
318
|
+
snowpark_column_names=[c.snowpark_name for c in all_column_names],
|
|
319
|
+
column_metadata=_combine_metadata(left_container, right_container),
|
|
320
|
+
column_qualifiers=[c.qualifiers for c in all_column_names],
|
|
321
|
+
column_is_hidden=[c.is_hidden for c in all_column_names],
|
|
322
|
+
cached_schema_getter=_build_joined_schema(
|
|
323
|
+
snowpark_names_for_schema,
|
|
324
|
+
left_input,
|
|
325
|
+
right_input,
|
|
326
|
+
all_column_names,
|
|
327
|
+
),
|
|
328
|
+
equivalent_snowpark_names=[
|
|
329
|
+
c.equivalent_snowpark_names for c in all_column_names
|
|
330
|
+
],
|
|
366
331
|
)
|
|
367
332
|
|
|
368
|
-
if
|
|
369
|
-
#
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
left_container_snowpark_columns = (
|
|
386
|
-
left_container.column_map.get_snowpark_columns()
|
|
333
|
+
if info.just_left_columns:
|
|
334
|
+
# we just need the left columns
|
|
335
|
+
columns = columns[: len(left_container.column_map.columns)]
|
|
336
|
+
snowpark_columns = [c.snowpark_name for c in columns]
|
|
337
|
+
result = joined_df.select(*snowpark_columns)
|
|
338
|
+
|
|
339
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
340
|
+
dataframe=result,
|
|
341
|
+
spark_column_names=[c.spark_name for c in columns],
|
|
342
|
+
snowpark_column_names=snowpark_columns,
|
|
343
|
+
column_metadata=left_container.column_map.column_metadata,
|
|
344
|
+
column_qualifiers=[c.qualifiers for c in columns],
|
|
345
|
+
cached_schema_getter=_build_joined_schema(
|
|
346
|
+
snowpark_columns, left_input, right_input
|
|
347
|
+
),
|
|
348
|
+
equivalent_snowpark_names=[c.equivalent_snowpark_names for c in columns],
|
|
387
349
|
)
|
|
388
|
-
|
|
389
|
-
|
|
350
|
+
|
|
351
|
+
snowpark_columns = [c.snowpark_name for c in columns]
|
|
352
|
+
result = joined_df.select(*snowpark_columns)
|
|
353
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
354
|
+
dataframe=result,
|
|
355
|
+
spark_column_names=[c.spark_name for c in columns],
|
|
356
|
+
snowpark_column_names=snowpark_columns,
|
|
357
|
+
column_metadata=_combine_metadata(left_container, right_container),
|
|
358
|
+
column_qualifiers=[c.qualifiers for c in columns],
|
|
359
|
+
cached_schema_getter=_build_joined_schema(
|
|
360
|
+
snowpark_columns, left_input, right_input
|
|
361
|
+
),
|
|
362
|
+
equivalent_snowpark_names=[c.equivalent_snowpark_names for c in columns],
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def _join_using_condition(
|
|
367
|
+
left_container: DataFrameContainer,
|
|
368
|
+
right_container: DataFrameContainer,
|
|
369
|
+
info: JoinInfo,
|
|
370
|
+
rel: relation_proto.Relation,
|
|
371
|
+
original_right_container: Optional[DataFrameContainer],
|
|
372
|
+
) -> DataFrameContainer:
|
|
373
|
+
left_columns = left_container.column_map.get_spark_columns()
|
|
374
|
+
right_columns = right_container.column_map.get_spark_columns()
|
|
375
|
+
|
|
376
|
+
left_input = left_container.dataframe
|
|
377
|
+
right_input = right_container.dataframe
|
|
378
|
+
|
|
379
|
+
# All PySpark join types are in the format of JOIN_TYPE_XXX.
|
|
380
|
+
# We remove the first 10 characters (JOIN_TYPE_) and replace all underscores with spaces to match the exception.
|
|
381
|
+
pyspark_join_type = relation_proto.Join.JoinType.Name(rel.join.join_type)[
|
|
382
|
+
10:
|
|
383
|
+
].replace("_", " ")
|
|
384
|
+
with push_sql_scope(), push_evaluating_join_condition(
|
|
385
|
+
pyspark_join_type, left_columns, right_columns
|
|
386
|
+
):
|
|
387
|
+
if left_container.alias is not None:
|
|
388
|
+
set_sql_plan_name(left_container.alias, rel.join.left.common.plan_id)
|
|
389
|
+
if right_container.alias is not None:
|
|
390
|
+
set_sql_plan_name(right_container.alias, rel.join.right.common.plan_id)
|
|
391
|
+
# resolve join condition expression
|
|
392
|
+
_, join_expression = map_single_column_expression(
|
|
393
|
+
rel.join.join_condition,
|
|
394
|
+
column_mapping=JoinColumnNameMap(
|
|
395
|
+
left_container.column_map,
|
|
396
|
+
# using the original (not disambiguated) right container is intended to break
|
|
397
|
+
# self join cases like a.join(a, a.id == a.id), since SAS can't handle them correctly
|
|
398
|
+
# and they fail in Spark Connect
|
|
399
|
+
(
|
|
400
|
+
original_right_container
|
|
401
|
+
if original_right_container
|
|
402
|
+
else right_container
|
|
403
|
+
).column_map,
|
|
404
|
+
),
|
|
405
|
+
typer=JoinExpressionTyper(left_input, right_input),
|
|
390
406
|
)
|
|
391
407
|
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
qualifiers.append([])
|
|
398
|
-
|
|
399
|
-
# Handle adding left and right columns, excluding the using columns
|
|
400
|
-
for i, spark_col in enumerate(left_container.column_map.get_spark_columns()):
|
|
401
|
-
if (
|
|
402
|
-
spark_col not in case_corrected_left_columns
|
|
403
|
-
or spark_col in left_container.column_map.get_spark_columns()[:i]
|
|
404
|
-
):
|
|
405
|
-
spark_cols_after_join.append(spark_col)
|
|
406
|
-
snowpark_cols_after_join.append(left_container_snowpark_columns[i])
|
|
407
|
-
qualifiers.append(
|
|
408
|
-
left_container.column_map.get_qualifier_for_spark_column(spark_col)
|
|
409
|
-
)
|
|
408
|
+
result: snowpark.DataFrame = left_input.join(
|
|
409
|
+
right=right_input,
|
|
410
|
+
on=join_expression.col,
|
|
411
|
+
how=info.join_type,
|
|
412
|
+
)
|
|
410
413
|
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
+
# early return for joinWith
|
|
415
|
+
if info.is_join_with:
|
|
416
|
+
return _join_with(left_container, right_container, result, info)
|
|
414
417
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
or spark_col in right_container.column_map.get_spark_columns()[:i]
|
|
419
|
-
):
|
|
420
|
-
spark_cols_after_join.append(spark_col)
|
|
421
|
-
snowpark_cols_after_join.append(right_container_snowpark_columns[i])
|
|
422
|
-
qualifiers.append(
|
|
423
|
-
right_container.column_map.get_qualifier_for_spark_column(spark_col)
|
|
424
|
-
)
|
|
418
|
+
# column order is already correct, so we just take the left + right side list
|
|
419
|
+
columns = left_container.column_map.columns + right_container.column_map.columns
|
|
420
|
+
column_metadata = _combine_metadata(left_container, right_container)
|
|
425
421
|
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
422
|
+
if info.just_left_columns:
|
|
423
|
+
# we just need left-side columns
|
|
424
|
+
columns = left_container.column_map.columns
|
|
425
|
+
result = result.select(*[c.snowpark_name for c in columns])
|
|
426
|
+
column_metadata = left_container.column_map.column_metadata
|
|
429
427
|
|
|
430
|
-
|
|
431
|
-
spark_cols_after_join = left_container.column_map.get_spark_columns()
|
|
432
|
-
snowpark_cols_after_join = left_container.column_map.get_snowpark_columns()
|
|
433
|
-
snowpark_col_types = [
|
|
434
|
-
f.datatype for f in left_container.dataframe.schema.fields
|
|
435
|
-
]
|
|
428
|
+
snowpark_columns = [c.snowpark_name for c in columns]
|
|
436
429
|
|
|
437
|
-
|
|
430
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
431
|
+
dataframe=result,
|
|
432
|
+
spark_column_names=[c.spark_name for c in columns],
|
|
433
|
+
snowpark_column_names=snowpark_columns,
|
|
434
|
+
column_metadata=column_metadata,
|
|
435
|
+
column_qualifiers=[c.qualifiers for c in columns],
|
|
436
|
+
cached_schema_getter=_build_joined_schema(
|
|
437
|
+
snowpark_columns, left_input, right_input
|
|
438
|
+
),
|
|
439
|
+
equivalent_snowpark_names=[c.equivalent_snowpark_names for c in columns],
|
|
440
|
+
)
|
|
438
441
|
|
|
439
|
-
right_df_snowpark_columns = right_container.column_map.get_snowpark_columns()
|
|
440
442
|
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
443
|
+
def _join_with(
|
|
444
|
+
left_container: DataFrameContainer,
|
|
445
|
+
right_container: DataFrameContainer,
|
|
446
|
+
joined_df: DataFrame,
|
|
447
|
+
info: JoinInfo,
|
|
448
|
+
) -> DataFrameContainer:
|
|
449
|
+
# joinWith always returns 2 columns
|
|
450
|
+
left_column = "_1"
|
|
451
|
+
right_column = "_2"
|
|
452
|
+
left_snowpark_name: str = make_unique_snowpark_name(left_column)
|
|
453
|
+
right_snowpark_name: str = make_unique_snowpark_name(right_column)
|
|
451
454
|
|
|
452
|
-
|
|
453
|
-
right_container.column_map.get_qualifier_for_spark_column(spark_col)
|
|
454
|
-
)
|
|
455
|
+
left_nullable, right_nullable = _join_with_nullability(info.join_type)
|
|
455
456
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
457
|
+
left_col, left_col_type = _construct_join_with_column(
|
|
458
|
+
left_container, left_snowpark_name, info.is_left_struct
|
|
459
|
+
)
|
|
460
|
+
right_col, right_col_type = _construct_join_with_column(
|
|
461
|
+
right_container, right_snowpark_name, info.is_right_struct
|
|
462
|
+
)
|
|
459
463
|
|
|
460
|
-
|
|
461
|
-
if snowpark_cols_after_join_counter[col] == 2:
|
|
462
|
-
# This means that the same column exists twice in the joined df, likely due to a self-join and
|
|
463
|
-
# we need to lsuffix and rsuffix to the names of both columns, similar to what Snowpark did under the hood.
|
|
464
|
+
result = joined_df.select(left_col, right_col)
|
|
464
465
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
466
|
+
def _schema_getter() -> StructType:
|
|
467
|
+
return StructType(
|
|
468
|
+
[
|
|
469
|
+
StructField(left_snowpark_name, left_col_type, left_nullable),
|
|
470
|
+
StructField(right_snowpark_name, right_col_type, right_nullable),
|
|
471
|
+
]
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
475
|
+
dataframe=result,
|
|
476
|
+
spark_column_names=[left_column, right_column],
|
|
477
|
+
snowpark_column_names=[left_snowpark_name, right_snowpark_name],
|
|
478
|
+
cached_schema_getter=_schema_getter,
|
|
479
|
+
column_metadata={}, # no top-level metadata for struct columns
|
|
480
|
+
# no qualifiers or equivalent snowpark names
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def _get_join_info(
|
|
485
|
+
rel: relation_proto.Relation, left: DataFrameContainer, right: DataFrameContainer
|
|
486
|
+
) -> JoinInfo:
|
|
487
|
+
"""
|
|
488
|
+
Gathers basic information about the join, and performs basic assertions
|
|
489
|
+
"""
|
|
490
|
+
|
|
491
|
+
is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
|
|
492
|
+
join_columns = rel.join.using_columns
|
|
493
|
+
if is_natural_join:
|
|
494
|
+
rel.join.join_type -= NATURAL_JOIN_TYPE_BASE
|
|
495
|
+
left_spark_columns = left.column_map.get_spark_columns()
|
|
496
|
+
right_spark_columns = right.column_map.get_spark_columns()
|
|
497
|
+
common_spark_columns = [
|
|
498
|
+
x for x in left_spark_columns if x in right_spark_columns
|
|
499
|
+
]
|
|
500
|
+
join_columns = common_spark_columns
|
|
473
501
|
|
|
474
|
-
|
|
502
|
+
match rel.join.join_type:
|
|
503
|
+
case relation_proto.Join.JOIN_TYPE_UNSPECIFIED:
|
|
504
|
+
# TODO: Understand what UNSPECIFIED Join type is
|
|
505
|
+
exception = SnowparkConnectNotImplementedError("Unspecified Join Type")
|
|
506
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
507
|
+
raise exception
|
|
508
|
+
case relation_proto.Join.JOIN_TYPE_INNER:
|
|
509
|
+
join_type = "inner"
|
|
510
|
+
case relation_proto.Join.JOIN_TYPE_FULL_OUTER:
|
|
511
|
+
join_type = "full_outer"
|
|
512
|
+
case relation_proto.Join.JOIN_TYPE_LEFT_OUTER:
|
|
513
|
+
join_type = "left"
|
|
514
|
+
case relation_proto.Join.JOIN_TYPE_RIGHT_OUTER:
|
|
515
|
+
join_type = "right"
|
|
516
|
+
case relation_proto.Join.JOIN_TYPE_LEFT_ANTI:
|
|
517
|
+
join_type = "leftanti"
|
|
518
|
+
case relation_proto.Join.JOIN_TYPE_LEFT_SEMI:
|
|
519
|
+
join_type = "leftsemi"
|
|
520
|
+
case relation_proto.Join.JOIN_TYPE_CROSS:
|
|
521
|
+
join_type = "cross"
|
|
522
|
+
case other:
|
|
523
|
+
exception = SnowparkConnectNotImplementedError(f"Other Join Type: {other}")
|
|
524
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
525
|
+
raise exception
|
|
526
|
+
|
|
527
|
+
has_join_condition = rel.join.HasField("join_condition")
|
|
528
|
+
is_using_columns = bool(join_columns)
|
|
529
|
+
|
|
530
|
+
if join_type == "cross" and has_join_condition:
|
|
531
|
+
# if the user provided any condition, it's no longer a cross join
|
|
532
|
+
join_type = "inner"
|
|
533
|
+
|
|
534
|
+
if has_join_condition:
|
|
535
|
+
assert not is_using_columns
|
|
536
|
+
|
|
537
|
+
condition_type = ConditionType.NO_CONDITION
|
|
538
|
+
if has_join_condition:
|
|
539
|
+
condition_type = ConditionType.JOIN_CONDITION
|
|
540
|
+
elif is_using_columns:
|
|
541
|
+
condition_type = ConditionType.USING_COLUMNS
|
|
542
|
+
|
|
543
|
+
# Join types that only return columns from the left side:
|
|
544
|
+
# - LEFT SEMI JOIN: Returns left rows that have matches in right table (no right columns)
|
|
545
|
+
# - LEFT ANTI JOIN: Returns left rows that have NO matches in right table (no right columns)
|
|
546
|
+
# Both preserve only the columns from the left DataFrame without adding any columns from the right.
|
|
547
|
+
just_left_columns = join_type in ["leftanti", "leftsemi"]
|
|
548
|
+
|
|
549
|
+
# joinWith
|
|
550
|
+
is_join_with = rel.join.HasField("join_data_type")
|
|
551
|
+
is_left_struct = False
|
|
552
|
+
is_right_struct = False
|
|
553
|
+
if is_join_with:
|
|
554
|
+
is_left_struct = rel.join.join_data_type.is_left_struct
|
|
555
|
+
is_right_struct = rel.join.join_data_type.is_right_struct
|
|
556
|
+
|
|
557
|
+
return JoinInfo(
|
|
558
|
+
join_type,
|
|
559
|
+
condition_type,
|
|
560
|
+
join_columns,
|
|
561
|
+
just_left_columns,
|
|
562
|
+
is_join_with,
|
|
563
|
+
is_left_struct,
|
|
564
|
+
is_right_struct,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def _disambiguate_snowpark_columns(
|
|
569
|
+
left: DataFrameContainer, right: DataFrameContainer, right_plan: int
|
|
570
|
+
) -> DataFrameContainer:
|
|
571
|
+
conflicting_snowpark_columns = left.column_map.get_conflicting_snowpark_columns(
|
|
572
|
+
right.column_map
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
if not conflicting_snowpark_columns:
|
|
576
|
+
return right
|
|
577
|
+
|
|
578
|
+
# rename and create new right container
|
|
579
|
+
column_map = right.column_map
|
|
580
|
+
disambiguated_columns: list[Column] = []
|
|
581
|
+
disambiguated_snowpark_names: list[str] = []
|
|
582
|
+
# retain old snowpark names in column map
|
|
583
|
+
equivalent_snowpark_names: list[set[str]] = []
|
|
584
|
+
for c in column_map.columns:
|
|
585
|
+
col_equivalent_snowpark_names = copy(c.equivalent_snowpark_names)
|
|
586
|
+
if c.snowpark_name in conflicting_snowpark_columns:
|
|
587
|
+
# alias snowpark column with a new unique name
|
|
588
|
+
new_name = make_unique_snowpark_name(c.spark_name)
|
|
589
|
+
disambiguated_snowpark_names.append(new_name)
|
|
590
|
+
disambiguated_columns.append(
|
|
591
|
+
snowpark_fn.col(c.snowpark_name).alias(new_name)
|
|
592
|
+
)
|
|
475
593
|
else:
|
|
476
|
-
|
|
594
|
+
disambiguated_snowpark_names.append(c.snowpark_name)
|
|
595
|
+
disambiguated_columns.append(snowpark_fn.col(c.snowpark_name))
|
|
596
|
+
|
|
597
|
+
equivalent_snowpark_names.append(col_equivalent_snowpark_names)
|
|
598
|
+
|
|
599
|
+
disambiguated_df = right.dataframe.select(*disambiguated_columns)
|
|
600
|
+
|
|
601
|
+
def _schema_getter() -> StructType:
|
|
602
|
+
fields = right.dataframe.schema.fields
|
|
603
|
+
return StructType(
|
|
604
|
+
[
|
|
605
|
+
StructField(name, fields[i].datatype, fields[i].nullable)
|
|
606
|
+
for i, name in enumerate(disambiguated_snowpark_names)
|
|
607
|
+
]
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
disambiguated_right = DataFrameContainer.create_with_column_mapping(
|
|
611
|
+
dataframe=disambiguated_df,
|
|
612
|
+
spark_column_names=column_map.get_spark_columns(),
|
|
613
|
+
snowpark_column_names=disambiguated_snowpark_names,
|
|
614
|
+
column_metadata=column_map.column_metadata,
|
|
615
|
+
column_qualifiers=column_map.get_qualifiers(),
|
|
616
|
+
table_name=right.table_name,
|
|
617
|
+
cached_schema_getter=_schema_getter,
|
|
618
|
+
equivalent_snowpark_names=equivalent_snowpark_names,
|
|
619
|
+
)
|
|
477
620
|
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
621
|
+
# since we just renamed some snowpark columns, we need to update the dataframe container for the given plan_id
|
|
622
|
+
# TODO: is there a better way to do this?
|
|
623
|
+
if right_plan:
|
|
624
|
+
set_plan_id_map(right_plan, disambiguated_right)
|
|
481
625
|
|
|
626
|
+
return disambiguated_right
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
def _combine_metadata(
|
|
630
|
+
left_container: DataFrameContainer, right_container: DataFrameContainer
|
|
631
|
+
) -> dict:
|
|
632
|
+
column_metadata = dict(left_container.column_map.column_metadata or {})
|
|
482
633
|
if right_container.column_map.column_metadata:
|
|
483
634
|
for key, value in right_container.column_map.column_metadata.items():
|
|
484
635
|
if key not in column_metadata:
|
|
@@ -490,7 +641,9 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
490
641
|
snowpark_name = right_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
491
642
|
key
|
|
492
643
|
)
|
|
493
|
-
expr_id =
|
|
644
|
+
expr_id = right_container.dataframe[
|
|
645
|
+
snowpark_name
|
|
646
|
+
]._expression.expr_id
|
|
494
647
|
updated_key = COLUMN_METADATA_COLLISION_KEY.format(
|
|
495
648
|
expr_id=expr_id, key=snowpark_name
|
|
496
649
|
)
|
|
@@ -498,49 +651,137 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
498
651
|
except Exception:
|
|
499
652
|
# ignore any errors that happens while fetching the metadata
|
|
500
653
|
pass
|
|
654
|
+
return column_metadata
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
def _build_joined_schema(
|
|
658
|
+
snowpark_columns: list[str],
|
|
659
|
+
left_input: DataFrame,
|
|
660
|
+
right_input: DataFrame,
|
|
661
|
+
outer_join_columns: Optional[list[ColumnNames]] = None,
|
|
662
|
+
) -> Callable[[], StructType]:
|
|
663
|
+
"""
|
|
664
|
+
Builds a lazy schema for the joined dataframe, based on the given snowpark_columns and input dataframes.
|
|
665
|
+
In case of full outer joins, we need a separate target_snowpark_columns, since join columns will have different
|
|
666
|
+
names in the output than in any input.
|
|
667
|
+
"""
|
|
668
|
+
|
|
669
|
+
def _schema_getter() -> StructType:
|
|
670
|
+
all_fields = left_input.schema.fields + right_input.schema.fields
|
|
671
|
+
fields: dict[str, StructField] = {f.name: f for f in all_fields}
|
|
672
|
+
|
|
673
|
+
if outer_join_columns:
|
|
674
|
+
visible_columns = [c for c in outer_join_columns if not c.is_hidden]
|
|
675
|
+
assert len(snowpark_columns) == len(visible_columns)
|
|
676
|
+
|
|
677
|
+
result_fields = []
|
|
678
|
+
visible_idx = 0
|
|
679
|
+
for col in outer_join_columns:
|
|
680
|
+
if col.is_hidden:
|
|
681
|
+
source_field = fields[col.snowpark_name]
|
|
682
|
+
result_fields.append(
|
|
683
|
+
StructField(
|
|
684
|
+
col.snowpark_name,
|
|
685
|
+
source_field.datatype,
|
|
686
|
+
source_field.nullable,
|
|
687
|
+
)
|
|
688
|
+
)
|
|
689
|
+
else:
|
|
690
|
+
source_field = fields[snowpark_columns[visible_idx]]
|
|
691
|
+
result_fields.append(
|
|
692
|
+
StructField(
|
|
693
|
+
col.snowpark_name,
|
|
694
|
+
source_field.datatype,
|
|
695
|
+
source_field.nullable,
|
|
696
|
+
)
|
|
697
|
+
)
|
|
698
|
+
visible_idx += 1
|
|
501
699
|
|
|
502
|
-
|
|
503
|
-
dataframe=result,
|
|
504
|
-
spark_column_names=spark_cols_after_join,
|
|
505
|
-
snowpark_column_names=snowpark_cols_after_join_deduplicated,
|
|
506
|
-
column_metadata=column_metadata,
|
|
507
|
-
column_qualifiers=qualifiers,
|
|
508
|
-
hidden_columns=hidden_columns,
|
|
509
|
-
snowpark_column_types=snowpark_col_types,
|
|
510
|
-
)
|
|
511
|
-
|
|
512
|
-
if rel.join.using_columns:
|
|
513
|
-
# When join 'using_columns', the 'join columns' should go first in result DF.
|
|
514
|
-
idxs_to_shift = [
|
|
515
|
-
spark_cols_after_join.index(left_col_name)
|
|
516
|
-
for left_col_name in case_corrected_left_columns
|
|
517
|
-
]
|
|
518
|
-
|
|
519
|
-
def reorder(lst: list) -> list:
|
|
520
|
-
to_move = [lst[i] for i in idxs_to_shift]
|
|
521
|
-
remaining = [el for i, el in enumerate(lst) if i not in idxs_to_shift]
|
|
522
|
-
return to_move + remaining
|
|
700
|
+
return StructType(result_fields)
|
|
523
701
|
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
702
|
+
return StructType(
|
|
703
|
+
[
|
|
704
|
+
StructField(name, fields[name].datatype, fields[name].nullable)
|
|
705
|
+
for name in snowpark_columns
|
|
706
|
+
]
|
|
527
707
|
)
|
|
528
708
|
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
709
|
+
return _schema_getter
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
def _make_struct_column(
|
|
713
|
+
container: DataFrameContainer, snowpark_name: str
|
|
714
|
+
) -> tuple[snowpark.Column, StructType]:
|
|
715
|
+
column_metadata: dict = {}
|
|
716
|
+
for c in container.column_map.columns:
|
|
717
|
+
column_metadata[c.snowpark_name] = c
|
|
718
|
+
|
|
719
|
+
args: list[Column] = []
|
|
720
|
+
struct_fields: list[StructField] = []
|
|
721
|
+
for f in container.dataframe.schema.fields:
|
|
722
|
+
c = column_metadata[f.name]
|
|
723
|
+
if c.is_hidden:
|
|
724
|
+
continue
|
|
725
|
+
args.append(snowpark_fn.lit(c.spark_name))
|
|
726
|
+
args.append(snowpark_fn.col(c.snowpark_name))
|
|
727
|
+
struct_fields.append(
|
|
728
|
+
StructField(c.spark_name, f.datatype, f.nullable, _is_column=False)
|
|
544
729
|
)
|
|
545
730
|
|
|
546
|
-
|
|
731
|
+
struct_type = StructType(struct_fields, structured=True)
|
|
732
|
+
struct_col: snowpark.Column = (
|
|
733
|
+
snowpark_fn.object_construct_keep_null(*args)
|
|
734
|
+
.cast(struct_type)
|
|
735
|
+
.alias(snowpark_name)
|
|
736
|
+
)
|
|
737
|
+
return struct_col, struct_type
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
def _construct_join_with_column(
|
|
741
|
+
container: DataFrameContainer, snowpark_name: str, is_struct: bool
|
|
742
|
+
) -> tuple[Column, StructType]:
|
|
743
|
+
if is_struct:
|
|
744
|
+
return _make_struct_column(container, snowpark_name)
|
|
745
|
+
else:
|
|
746
|
+
# the dataframe must have a single field
|
|
747
|
+
cols = [
|
|
748
|
+
c.snowpark_name for c in container.column_map.columns if not c.is_hidden
|
|
749
|
+
]
|
|
750
|
+
assert (
|
|
751
|
+
len(cols) == 1
|
|
752
|
+
), "A non-struct dataframe must have a single column in joinWith"
|
|
753
|
+
field = None
|
|
754
|
+
for f in container.dataframe.schema.fields:
|
|
755
|
+
if f.name == cols[0]:
|
|
756
|
+
field = f
|
|
757
|
+
break
|
|
758
|
+
assert field is not None
|
|
759
|
+
col = snowpark_fn.col(field.name).alias(snowpark_name)
|
|
760
|
+
col_type = field.datatype
|
|
761
|
+
return col, col_type
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
def _join_with_nullability(join_type: str) -> tuple[bool, bool]:
|
|
765
|
+
"""
|
|
766
|
+
Returns the nullability for the left and right result columns of a joinWith operation.
|
|
767
|
+
|
|
768
|
+
The tuple corresponds to (left_nullable, right_nullable) and depends on the join type:
|
|
769
|
+
- "inner" or "cross": both columns are non-nullable
|
|
770
|
+
- "left": left is non-nullable, right is nullable
|
|
771
|
+
- "right": left is nullable, right is non-nullable
|
|
772
|
+
- "full_outer": both columns are nullable
|
|
773
|
+
|
|
774
|
+
Raises:
|
|
775
|
+
IllegalArgumentException: If the provided join type is unsupported.
|
|
776
|
+
"""
|
|
777
|
+
match join_type:
|
|
778
|
+
case "inner" | "cross":
|
|
779
|
+
return False, False
|
|
780
|
+
case "left":
|
|
781
|
+
return False, True
|
|
782
|
+
case "right":
|
|
783
|
+
return True, False
|
|
784
|
+
case "full_outer":
|
|
785
|
+
return True, True
|
|
786
|
+
case _:
|
|
787
|
+
raise IllegalArgumentException(f"Unsupported join type '{join_type}'.")
|