snowpark-connect 0.27.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +680 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +237 -23
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +123 -5
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +85 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2748 -746
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +196 -255
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +255 -45
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
- snowflake/snowpark_connect/relation/read/map_read_json.py +320 -85
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +142 -27
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +82 -5
- snowflake/snowpark_connect/relation/read/map_read_text.py +18 -3
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +36 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +110 -48
- snowflake/snowpark_connect/server.py +546 -456
- snowflake/snowpark_connect/server_common/__init__.py +500 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +187 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +125 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +303 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +248 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +101 -332
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +163 -22
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-submit +2 -2
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/METADATA +14 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/RECORD +129 -167
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,9 @@
|
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
|
|
5
|
+
import copy
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
5
8
|
import cloudpickle as pkl
|
|
6
9
|
import pyspark.sql.connect.proto.expressions_pb2 as expression_proto
|
|
7
10
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
@@ -11,22 +14,40 @@ import snowflake.snowpark.functions as snowpark_fn
|
|
|
11
14
|
import snowflake.snowpark.types as snowpark_types
|
|
12
15
|
import snowflake.snowpark_connect.proto.snowflake_relation_ext_pb2 as snowflake_proto
|
|
13
16
|
from snowflake import snowpark
|
|
17
|
+
from snowflake.snowpark import Column
|
|
14
18
|
from snowflake.snowpark_connect.column_name_handler import (
|
|
15
19
|
ColumnNameMap,
|
|
16
20
|
make_column_names_snowpark_compatible,
|
|
17
21
|
)
|
|
22
|
+
from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
|
|
18
23
|
from snowflake.snowpark_connect.config import get_boolean_session_config_param
|
|
19
|
-
from snowflake.snowpark_connect.dataframe_container import
|
|
20
|
-
|
|
24
|
+
from snowflake.snowpark_connect.dataframe_container import (
|
|
25
|
+
AggregateMetadata,
|
|
26
|
+
DataFrameContainer,
|
|
27
|
+
)
|
|
28
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
29
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
30
|
+
from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
|
|
31
|
+
from snowflake.snowpark_connect.expression.map_expression import (
|
|
32
|
+
map_expression,
|
|
33
|
+
map_single_column_expression,
|
|
34
|
+
)
|
|
21
35
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
22
36
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
37
|
+
from snowflake.snowpark_connect.relation.utils import (
|
|
38
|
+
create_pivot_column_condition,
|
|
39
|
+
get_all_dependent_column_names,
|
|
40
|
+
map_pivot_value_to_spark_column_name,
|
|
41
|
+
)
|
|
23
42
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
24
43
|
from snowflake.snowpark_connect.utils.context import (
|
|
25
44
|
get_sql_aggregate_function_count,
|
|
26
|
-
not_resolving_fun_args,
|
|
27
45
|
push_outer_dataframe,
|
|
28
46
|
set_current_grouping_columns,
|
|
29
47
|
)
|
|
48
|
+
from snowflake.snowpark_connect.utils.expression_transformer import (
|
|
49
|
+
inject_condition_to_all_agg_functions,
|
|
50
|
+
)
|
|
30
51
|
from snowflake.snowpark_connect.utils.identifiers import (
|
|
31
52
|
split_fully_qualified_spark_name,
|
|
32
53
|
)
|
|
@@ -85,16 +106,19 @@ def map_extension(
|
|
|
85
106
|
input_df = result.dataframe
|
|
86
107
|
snowpark_col_names = result.column_map.get_snowpark_columns()
|
|
87
108
|
if len(subquery_aliases.aliases) != len(snowpark_col_names):
|
|
88
|
-
|
|
109
|
+
exception = AnalysisException(
|
|
89
110
|
"Number of column aliases does not match number of columns. "
|
|
90
111
|
f"Number of column aliases: {len(subquery_aliases.aliases)}; "
|
|
91
112
|
f"number of columns: {len(snowpark_col_names)}."
|
|
92
113
|
)
|
|
114
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
|
|
115
|
+
raise exception
|
|
93
116
|
return DataFrameContainer.create_with_column_mapping(
|
|
94
117
|
dataframe=input_df,
|
|
95
118
|
spark_column_names=subquery_aliases.aliases,
|
|
96
119
|
snowpark_column_names=snowpark_col_names,
|
|
97
120
|
column_qualifiers=result.column_map.get_qualifiers(),
|
|
121
|
+
equivalent_snowpark_names=result.column_map.get_equivalent_snowpark_names(),
|
|
98
122
|
)
|
|
99
123
|
case "lateral_join":
|
|
100
124
|
lateral_join = extension.lateral_join
|
|
@@ -109,18 +133,22 @@ def map_extension(
|
|
|
109
133
|
|
|
110
134
|
left_queries = left_df.queries["queries"]
|
|
111
135
|
if len(left_queries) != 1:
|
|
112
|
-
|
|
136
|
+
exception = SnowparkConnectNotImplementedError(
|
|
113
137
|
f"Unexpected number of queries: {len(left_queries)}"
|
|
114
138
|
)
|
|
139
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
140
|
+
raise exception
|
|
115
141
|
left_query = left_queries[0]
|
|
116
142
|
with push_outer_dataframe(left_result):
|
|
117
143
|
right_result = map_relation(lateral_join.right)
|
|
118
144
|
right_df = right_result.dataframe
|
|
119
145
|
right_queries = right_df.queries["queries"]
|
|
120
146
|
if len(right_queries) != 1:
|
|
121
|
-
|
|
147
|
+
exception = SnowparkConnectNotImplementedError(
|
|
122
148
|
f"Unexpected number of queries: {len(right_queries)}"
|
|
123
149
|
)
|
|
150
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
151
|
+
raise exception
|
|
124
152
|
right_query = right_queries[0]
|
|
125
153
|
input_df_sql = f"WITH __left AS ({left_query}) SELECT * FROM __left INNER JOIN LATERAL ({right_query})"
|
|
126
154
|
session = snowpark.Session.get_active_session()
|
|
@@ -133,6 +161,8 @@ def map_extension(
|
|
|
133
161
|
+ right_result.column_map.get_snowpark_columns(),
|
|
134
162
|
column_qualifiers=left_result.column_map.get_qualifiers()
|
|
135
163
|
+ right_result.column_map.get_qualifiers(),
|
|
164
|
+
equivalent_snowpark_names=left_result.column_map.get_equivalent_snowpark_names()
|
|
165
|
+
+ right_result.column_map.get_equivalent_snowpark_names(),
|
|
136
166
|
)
|
|
137
167
|
|
|
138
168
|
case "udtf_with_table_arguments":
|
|
@@ -140,7 +170,11 @@ def map_extension(
|
|
|
140
170
|
case "aggregate":
|
|
141
171
|
return map_aggregate(extension.aggregate, rel.common.plan_id)
|
|
142
172
|
case other:
|
|
143
|
-
|
|
173
|
+
exception = SnowparkConnectNotImplementedError(
|
|
174
|
+
f"Unexpected extension {other}"
|
|
175
|
+
)
|
|
176
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
177
|
+
raise exception
|
|
144
178
|
|
|
145
179
|
|
|
146
180
|
def get_udtf_project(relation: relation_proto.Relation) -> bool:
|
|
@@ -167,7 +201,7 @@ def get_udtf_project(relation: relation_proto.Relation) -> bool:
|
|
|
167
201
|
|
|
168
202
|
def handle_udtf_with_table_arguments(
|
|
169
203
|
udtf_info: snowflake_proto.UDTFWithTableArguments,
|
|
170
|
-
) ->
|
|
204
|
+
) -> DataFrameContainer:
|
|
171
205
|
"""
|
|
172
206
|
Handle UDTF with one or more table arguments using Snowpark's join_table_function.
|
|
173
207
|
For multiple table arguments, this creates a Cartesian product of all input tables.
|
|
@@ -175,7 +209,9 @@ def handle_udtf_with_table_arguments(
|
|
|
175
209
|
session = snowpark.Session.get_active_session()
|
|
176
210
|
udtf_name_lower = udtf_info.function_name.lower()
|
|
177
211
|
if udtf_name_lower not in session._udtfs:
|
|
178
|
-
|
|
212
|
+
exception = ValueError(f"UDTF '{udtf_info.function_name}' not found.")
|
|
213
|
+
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
214
|
+
raise exception
|
|
179
215
|
_udtf_obj, udtf_spark_output_names = session._udtfs[udtf_name_lower]
|
|
180
216
|
|
|
181
217
|
table_containers = []
|
|
@@ -189,10 +225,12 @@ def handle_udtf_with_table_arguments(
|
|
|
189
225
|
if not get_boolean_session_config_param(
|
|
190
226
|
"spark.sql.tvf.allowMultipleTableArguments.enabled"
|
|
191
227
|
):
|
|
192
|
-
|
|
228
|
+
exception = AnalysisException(
|
|
193
229
|
"[TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS] Multiple table arguments are not enabled. "
|
|
194
230
|
"Please set `spark.sql.tvf.allowMultipleTableArguments.enabled` to `true`"
|
|
195
231
|
)
|
|
232
|
+
attach_custom_error_code(exception, ErrorCodes.CONFIG_NOT_ENABLED)
|
|
233
|
+
raise exception
|
|
196
234
|
|
|
197
235
|
base_df = table_containers[0][0].dataframe
|
|
198
236
|
first_table_col_count = len(base_df.columns)
|
|
@@ -271,7 +309,7 @@ def handle_lateral_join_with_udtf(
|
|
|
271
309
|
left_result: DataFrameContainer,
|
|
272
310
|
udtf_relation: relation_proto.Relation,
|
|
273
311
|
udtf_info: tuple[snowpark.udtf.UserDefinedTableFunction, list],
|
|
274
|
-
) ->
|
|
312
|
+
) -> DataFrameContainer:
|
|
275
313
|
"""
|
|
276
314
|
Handle lateral join with UDTF on the right side using join_table_function.
|
|
277
315
|
"""
|
|
@@ -298,13 +336,15 @@ def handle_lateral_join_with_udtf(
|
|
|
298
336
|
+ udtf_spark_output_names,
|
|
299
337
|
snowpark_column_names=result_df.columns,
|
|
300
338
|
column_qualifiers=left_result.column_map.get_qualifiers()
|
|
301
|
-
+ [
|
|
339
|
+
+ [set() for _ in udtf_spark_output_names],
|
|
340
|
+
equivalent_snowpark_names=left_result.column_map.get_equivalent_snowpark_names()
|
|
341
|
+
+ [set() for _ in udtf_spark_output_names],
|
|
302
342
|
)
|
|
303
343
|
|
|
304
344
|
|
|
305
345
|
def map_aggregate(
|
|
306
346
|
aggregate: snowflake_proto.Aggregate, plan_id: int
|
|
307
|
-
) ->
|
|
347
|
+
) -> DataFrameContainer:
|
|
308
348
|
input_container = map_relation(aggregate.input)
|
|
309
349
|
input_df: snowpark.DataFrame = input_container.dataframe
|
|
310
350
|
|
|
@@ -336,18 +376,19 @@ def map_aggregate(
|
|
|
336
376
|
typer = ExpressionTyper(input_df)
|
|
337
377
|
|
|
338
378
|
def _map_column(exp: expression_proto.Expression) -> tuple[str, TypedColumn]:
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
379
|
+
new_names, snowpark_column = map_expression(
|
|
380
|
+
exp, input_container.column_map, typer
|
|
381
|
+
)
|
|
382
|
+
if len(new_names) != 1:
|
|
383
|
+
exception = SnowparkConnectNotImplementedError(
|
|
384
|
+
"Multi-column aggregate expressions are not supported"
|
|
342
385
|
)
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
)
|
|
347
|
-
return new_names[0], snowpark_column
|
|
386
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
387
|
+
raise exception
|
|
388
|
+
return new_names[0], snowpark_column
|
|
348
389
|
|
|
349
390
|
raw_groupings: list[tuple[str, TypedColumn]] = []
|
|
350
|
-
raw_aggregations: list[tuple[str, TypedColumn]] = []
|
|
391
|
+
raw_aggregations: list[tuple[str, TypedColumn, set[ColumnQualifier]]] = []
|
|
351
392
|
|
|
352
393
|
if not is_group_by_all:
|
|
353
394
|
raw_groupings = [_map_column(exp) for exp in aggregate.grouping_expressions]
|
|
@@ -377,10 +418,22 @@ def map_aggregate(
|
|
|
377
418
|
# Note: We don't clear the map here to preserve any parent context aliases
|
|
378
419
|
from snowflake.snowpark_connect.utils.context import register_lca_alias
|
|
379
420
|
|
|
421
|
+
# If it's an unresolved attribute when its in aggregate.aggregate_expressions, we know it came from the parent map straight away
|
|
422
|
+
# in this case, we should see if the parent map has a qualifier for it and propagate that here, in case the order by references it in
|
|
423
|
+
# a qualified way later.
|
|
380
424
|
agg_count = get_sql_aggregate_function_count()
|
|
381
425
|
for exp in aggregate.aggregate_expressions:
|
|
382
426
|
col = _map_column(exp)
|
|
383
|
-
|
|
427
|
+
if exp.WhichOneof("expr_type") == "unresolved_attribute":
|
|
428
|
+
qualifiers: set[
|
|
429
|
+
ColumnQualifier
|
|
430
|
+
] = input_container.column_map.get_qualifiers_for_snowpark_column(
|
|
431
|
+
col[1].col.get_name()
|
|
432
|
+
)
|
|
433
|
+
else:
|
|
434
|
+
qualifiers = set()
|
|
435
|
+
|
|
436
|
+
raw_aggregations.append((col[0], col[1], qualifiers))
|
|
384
437
|
|
|
385
438
|
# If this is an alias, register it in the LCA map for subsequent expressions
|
|
386
439
|
if (
|
|
@@ -411,18 +464,20 @@ def map_aggregate(
|
|
|
411
464
|
spark_columns: list[str] = []
|
|
412
465
|
snowpark_columns: list[str] = []
|
|
413
466
|
snowpark_column_types: list[snowpark_types.DataType] = []
|
|
467
|
+
all_qualifiers: list[set[ColumnQualifier]] = []
|
|
414
468
|
|
|
415
469
|
# Use grouping columns directly without aliases
|
|
416
|
-
groupings = [
|
|
470
|
+
groupings: list[Column] = [tc.col for _, tc in raw_groupings]
|
|
417
471
|
|
|
418
472
|
# Create aliases only for aggregation columns
|
|
419
473
|
aggregations = []
|
|
420
|
-
for i, (spark_name, snowpark_column) in enumerate(raw_aggregations):
|
|
474
|
+
for i, (spark_name, snowpark_column, qualifiers) in enumerate(raw_aggregations):
|
|
421
475
|
alias = make_column_names_snowpark_compatible([spark_name], plan_id, i)[0]
|
|
422
476
|
|
|
423
477
|
spark_columns.append(spark_name)
|
|
424
478
|
snowpark_columns.append(alias)
|
|
425
479
|
snowpark_column_types.append(snowpark_column.typ)
|
|
480
|
+
all_qualifiers.append(qualifiers)
|
|
426
481
|
|
|
427
482
|
aggregations.append(snowpark_column.col.alias(alias))
|
|
428
483
|
|
|
@@ -431,12 +486,18 @@ def map_aggregate(
|
|
|
431
486
|
if groupings:
|
|
432
487
|
# Normal GROUP BY with explicit grouping columns
|
|
433
488
|
result = input_df.group_by(groupings)
|
|
434
|
-
|
|
489
|
+
elif not is_group_by_all:
|
|
435
490
|
# No explicit GROUP BY - this is an aggregate over the entire table
|
|
436
491
|
# Use a dummy constant that will be excluded from the final result
|
|
437
492
|
result = input_df.with_column(
|
|
438
493
|
"__dummy_group__", snowpark_fn.lit(1)
|
|
439
494
|
).group_by("__dummy_group__")
|
|
495
|
+
else:
|
|
496
|
+
# GROUP BY ALL with only one aggregate column
|
|
497
|
+
# Snowpark doesn't support GROUP BY ALL
|
|
498
|
+
# TODO: Change in future with Snowpark Supported arguments or API for GROUP BY ALL
|
|
499
|
+
result = input_df.group_by()
|
|
500
|
+
|
|
440
501
|
case snowflake_proto.Aggregate.GROUP_TYPE_ROLLUP:
|
|
441
502
|
result = input_df.rollup(groupings)
|
|
442
503
|
case snowflake_proto.Aggregate.GROUP_TYPE_CUBE:
|
|
@@ -456,10 +517,148 @@ def map_aggregate(
|
|
|
456
517
|
result = input_df.group_by_grouping_sets(
|
|
457
518
|
snowpark.GroupingSets(*sets_mapped)
|
|
458
519
|
)
|
|
520
|
+
case snowflake_proto.Aggregate.GROUP_TYPE_PIVOT:
|
|
521
|
+
pivot_typed_columns: list[TypedColumn] = [
|
|
522
|
+
map_single_column_expression(
|
|
523
|
+
pivot_col,
|
|
524
|
+
input_container.column_map,
|
|
525
|
+
ExpressionTyper(input_df),
|
|
526
|
+
)[1]
|
|
527
|
+
for pivot_col in aggregate.pivot.pivot_columns
|
|
528
|
+
]
|
|
529
|
+
|
|
530
|
+
pivot_columns = [col.col for col in pivot_typed_columns]
|
|
531
|
+
pivot_column_types = [col.typ for col in pivot_typed_columns]
|
|
532
|
+
|
|
533
|
+
pivot_values: list[list[Any]] = []
|
|
534
|
+
pivot_aliases: list[str] = []
|
|
535
|
+
|
|
536
|
+
for pivot_value in aggregate.pivot.pivot_values:
|
|
537
|
+
current_values = [
|
|
538
|
+
get_literal_field_and_name(val)[0] for val in pivot_value.values
|
|
539
|
+
]
|
|
540
|
+
pivot_values.append(current_values)
|
|
541
|
+
|
|
542
|
+
if pivot_value.alias:
|
|
543
|
+
pivot_aliases.append(pivot_value.alias)
|
|
544
|
+
|
|
545
|
+
spark_col_names = []
|
|
546
|
+
final_pivot_names = []
|
|
547
|
+
grouping_columns_qualifiers = []
|
|
548
|
+
aggregations_pivot = []
|
|
549
|
+
|
|
550
|
+
pivot_col_names: set[str] = {col.get_name() for col in pivot_columns}
|
|
551
|
+
|
|
552
|
+
agg_columns = get_all_dependent_column_names(aggregations)
|
|
553
|
+
|
|
554
|
+
if groupings:
|
|
555
|
+
for col in groupings:
|
|
556
|
+
snowpark_name = col.get_name()
|
|
557
|
+
spark_col_name = input_container.column_map.get_spark_column_name_from_snowpark_column_name(
|
|
558
|
+
snowpark_name
|
|
559
|
+
)
|
|
560
|
+
qualifiers = (
|
|
561
|
+
input_container.column_map.get_qualifiers_for_snowpark_column(
|
|
562
|
+
snowpark_name
|
|
563
|
+
)
|
|
564
|
+
)
|
|
565
|
+
grouping_columns_qualifiers.append(qualifiers)
|
|
566
|
+
spark_col_names.append(spark_col_name)
|
|
567
|
+
else:
|
|
568
|
+
for col in input_container.column_map.columns:
|
|
569
|
+
if (
|
|
570
|
+
col.snowpark_name not in pivot_col_names
|
|
571
|
+
and col.snowpark_name not in agg_columns
|
|
572
|
+
):
|
|
573
|
+
groupings.append(snowpark_fn.col(col.snowpark_name))
|
|
574
|
+
grouping_columns_qualifiers.append(col.qualifiers)
|
|
575
|
+
spark_col_names.append(col.spark_name)
|
|
576
|
+
|
|
577
|
+
for pivot_value_idx, pivot_value_group in enumerate(pivot_values):
|
|
578
|
+
pivot_values_spark_names = []
|
|
579
|
+
pivot_value_is_null = []
|
|
580
|
+
|
|
581
|
+
for val in pivot_value_group:
|
|
582
|
+
spark_name, is_null = map_pivot_value_to_spark_column_name(val)
|
|
583
|
+
|
|
584
|
+
pivot_values_spark_names.append(spark_name)
|
|
585
|
+
pivot_value_is_null.append(is_null)
|
|
586
|
+
|
|
587
|
+
for agg_idx, agg_expression in enumerate(aggregations):
|
|
588
|
+
agg_fun_expr = copy.deepcopy(agg_expression._expr1)
|
|
589
|
+
|
|
590
|
+
condition = None
|
|
591
|
+
for pivot_col_idx, (pivot_col, pivot_val) in enumerate(
|
|
592
|
+
zip(pivot_columns, pivot_value_group)
|
|
593
|
+
):
|
|
594
|
+
current_condition = create_pivot_column_condition(
|
|
595
|
+
pivot_col,
|
|
596
|
+
pivot_val,
|
|
597
|
+
pivot_value_is_null[pivot_col_idx],
|
|
598
|
+
pivot_column_types[pivot_col_idx]
|
|
599
|
+
if isinstance(pivot_val, (list, dict))
|
|
600
|
+
else None,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
condition = (
|
|
604
|
+
current_condition
|
|
605
|
+
if condition is None
|
|
606
|
+
else condition & current_condition
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
inject_condition_to_all_agg_functions(agg_fun_expr, condition)
|
|
610
|
+
curr_expression = Column(agg_fun_expr)
|
|
611
|
+
|
|
612
|
+
if pivot_aliases and not any(pivot_value_is_null):
|
|
613
|
+
aliased_pivoted_column_spark_name = pivot_aliases[
|
|
614
|
+
pivot_value_idx
|
|
615
|
+
]
|
|
616
|
+
elif len(pivot_values_spark_names) > 1:
|
|
617
|
+
aliased_pivoted_column_spark_name = (
|
|
618
|
+
"{" + ", ".join(pivot_values_spark_names) + "}"
|
|
619
|
+
)
|
|
620
|
+
else:
|
|
621
|
+
aliased_pivoted_column_spark_name = pivot_values_spark_names[0]
|
|
622
|
+
|
|
623
|
+
spark_col_name = (
|
|
624
|
+
f"{aliased_pivoted_column_spark_name}_{raw_aggregations[agg_idx][0]}"
|
|
625
|
+
if len(aggregations) > 1
|
|
626
|
+
else f"{aliased_pivoted_column_spark_name}"
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
snowpark_col_name = make_column_names_snowpark_compatible(
|
|
630
|
+
[spark_col_name],
|
|
631
|
+
plan_id,
|
|
632
|
+
len(aggregations) + len(groupings),
|
|
633
|
+
)[0]
|
|
634
|
+
|
|
635
|
+
curr_expression = curr_expression.alias(snowpark_col_name)
|
|
636
|
+
|
|
637
|
+
aggregations_pivot.append(curr_expression)
|
|
638
|
+
spark_col_names.append(spark_col_name)
|
|
639
|
+
final_pivot_names.append(snowpark_col_name)
|
|
640
|
+
|
|
641
|
+
result_df = input_df.group_by(*groupings).agg(*aggregations_pivot)
|
|
642
|
+
|
|
643
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
644
|
+
dataframe=result_df,
|
|
645
|
+
spark_column_names=spark_col_names,
|
|
646
|
+
snowpark_column_names=result_df.columns,
|
|
647
|
+
snowpark_column_types=[
|
|
648
|
+
result_df.schema.fields[idx].datatype
|
|
649
|
+
for idx, _ in enumerate(result_df.columns)
|
|
650
|
+
],
|
|
651
|
+
column_qualifiers=grouping_columns_qualifiers
|
|
652
|
+
+ [set() for _ in final_pivot_names],
|
|
653
|
+
parent_column_name_map=input_container.column_map,
|
|
654
|
+
)
|
|
655
|
+
|
|
459
656
|
case other:
|
|
460
|
-
|
|
657
|
+
exception = SnowparkConnectNotImplementedError(
|
|
461
658
|
f"Unsupported GROUP BY type: {other}"
|
|
462
659
|
)
|
|
660
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
661
|
+
raise exception
|
|
463
662
|
|
|
464
663
|
result = result.agg(*aggregations, exclude_grouping_columns=True)
|
|
465
664
|
|
|
@@ -479,6 +678,13 @@ def map_aggregate(
|
|
|
479
678
|
spark_column_names=spark_columns,
|
|
480
679
|
snowpark_column_names=snowpark_columns,
|
|
481
680
|
snowpark_column_types=snowpark_column_types,
|
|
681
|
+
column_qualifiers=all_qualifiers,
|
|
682
|
+
equivalent_snowpark_names=[
|
|
683
|
+
input_container.column_map.get_equivalent_snowpark_names_for_snowpark_name(
|
|
684
|
+
new_name
|
|
685
|
+
)
|
|
686
|
+
for new_name in snowpark_columns
|
|
687
|
+
],
|
|
482
688
|
).column_map
|
|
483
689
|
|
|
484
690
|
# Create hybrid column map that can resolve both input and aggregate contexts
|
|
@@ -490,7 +696,9 @@ def map_aggregate(
|
|
|
490
696
|
aggregate_expressions=list(aggregate.aggregate_expressions),
|
|
491
697
|
grouping_expressions=list(aggregate.grouping_expressions),
|
|
492
698
|
spark_columns=spark_columns,
|
|
493
|
-
raw_aggregations=
|
|
699
|
+
raw_aggregations=[
|
|
700
|
+
(spark_name, col) for spark_name, col, _ in raw_aggregations
|
|
701
|
+
],
|
|
494
702
|
)
|
|
495
703
|
|
|
496
704
|
# Map the HAVING condition using hybrid resolution
|
|
@@ -504,11 +712,37 @@ def map_aggregate(
|
|
|
504
712
|
# grouping sets don't allow ORDER BY with columns that aren't in the aggregate list.
|
|
505
713
|
result = result.select(result.columns[-len(aggregations) :])
|
|
506
714
|
|
|
715
|
+
# Store aggregate metadata for ORDER BY resolution
|
|
716
|
+
# Only for regular GROUP BY - ROLLUP, CUBE, and GROUPING_SETS should NOT allow
|
|
717
|
+
# ORDER BY to reference pre-aggregation columns (Spark compatibility)
|
|
718
|
+
# This enables ORDER BY to resolve expressions that reference pre-aggregation columns
|
|
719
|
+
# (e.g., ORDER BY year(date) when only 'year' alias exists in aggregated result)
|
|
720
|
+
aggregate_metadata = None
|
|
721
|
+
if aggregate.group_type == snowflake_proto.Aggregate.GROUP_TYPE_GROUPBY:
|
|
722
|
+
aggregate_metadata = AggregateMetadata(
|
|
723
|
+
input_column_map=input_container.column_map,
|
|
724
|
+
input_dataframe=input_df,
|
|
725
|
+
grouping_expressions=list(aggregate.grouping_expressions),
|
|
726
|
+
aggregate_expressions=list(aggregate.aggregate_expressions),
|
|
727
|
+
spark_columns=spark_columns,
|
|
728
|
+
raw_aggregations=[
|
|
729
|
+
(spark_name, col) for spark_name, col, _ in raw_aggregations
|
|
730
|
+
],
|
|
731
|
+
)
|
|
732
|
+
|
|
507
733
|
# Return only aggregation columns in the column map
|
|
508
734
|
return DataFrameContainer.create_with_column_mapping(
|
|
509
735
|
dataframe=result,
|
|
510
736
|
spark_column_names=spark_columns,
|
|
511
737
|
snowpark_column_names=snowpark_columns,
|
|
512
738
|
snowpark_column_types=snowpark_column_types,
|
|
513
|
-
parent_column_name_map=
|
|
739
|
+
parent_column_name_map=input_container.column_map,
|
|
740
|
+
column_qualifiers=all_qualifiers,
|
|
741
|
+
equivalent_snowpark_names=[
|
|
742
|
+
input_container.column_map.get_equivalent_snowpark_names_for_snowpark_name(
|
|
743
|
+
new_name
|
|
744
|
+
)
|
|
745
|
+
for new_name in snowpark_columns
|
|
746
|
+
],
|
|
747
|
+
aggregate_metadata=aggregate_metadata,
|
|
514
748
|
)
|