snowpark-connect 0.27.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  3. snowflake/snowpark_connect/client/__init__.py +15 -0
  4. snowflake/snowpark_connect/client/error_utils.py +30 -0
  5. snowflake/snowpark_connect/client/exceptions.py +36 -0
  6. snowflake/snowpark_connect/client/query_results.py +90 -0
  7. snowflake/snowpark_connect/client/server.py +680 -0
  8. snowflake/snowpark_connect/client/utils/__init__.py +10 -0
  9. snowflake/snowpark_connect/client/utils/session.py +85 -0
  10. snowflake/snowpark_connect/column_name_handler.py +404 -243
  11. snowflake/snowpark_connect/column_qualifier.py +43 -0
  12. snowflake/snowpark_connect/config.py +237 -23
  13. snowflake/snowpark_connect/constants.py +2 -0
  14. snowflake/snowpark_connect/dataframe_container.py +102 -8
  15. snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
  16. snowflake/snowpark_connect/error/error_codes.py +50 -0
  17. snowflake/snowpark_connect/error/error_utils.py +172 -23
  18. snowflake/snowpark_connect/error/exceptions.py +13 -4
  19. snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
  20. snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
  21. snowflake/snowpark_connect/execute_plan/utils.py +5 -1
  22. snowflake/snowpark_connect/expression/function_defaults.py +9 -2
  23. snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
  24. snowflake/snowpark_connect/expression/literal.py +37 -13
  25. snowflake/snowpark_connect/expression/map_cast.py +123 -5
  26. snowflake/snowpark_connect/expression/map_expression.py +80 -27
  27. snowflake/snowpark_connect/expression/map_extension.py +322 -12
  28. snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
  29. snowflake/snowpark_connect/expression/map_udf.py +85 -20
  30. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
  31. snowflake/snowpark_connect/expression/map_unresolved_function.py +2748 -746
  32. snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
  33. snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
  34. snowflake/snowpark_connect/expression/map_window_function.py +18 -3
  35. snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
  36. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
  37. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
  38. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  39. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  40. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
  41. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
  42. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
  43. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
  44. snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
  45. snowflake/snowpark_connect/relation/io_utils.py +110 -10
  46. snowflake/snowpark_connect/relation/map_aggregate.py +196 -255
  47. snowflake/snowpark_connect/relation/map_catalog.py +5 -1
  48. snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
  49. snowflake/snowpark_connect/relation/map_extension.py +263 -29
  50. snowflake/snowpark_connect/relation/map_join.py +683 -442
  51. snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
  52. snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
  53. snowflake/snowpark_connect/relation/map_relation.py +48 -19
  54. snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
  55. snowflake/snowpark_connect/relation/map_show_string.py +13 -6
  56. snowflake/snowpark_connect/relation/map_sql.py +1233 -222
  57. snowflake/snowpark_connect/relation/map_stats.py +48 -9
  58. snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
  59. snowflake/snowpark_connect/relation/map_udtf.py +14 -4
  60. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
  61. snowflake/snowpark_connect/relation/read/map_read.py +134 -43
  62. snowflake/snowpark_connect/relation/read/map_read_csv.py +255 -45
  63. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
  64. snowflake/snowpark_connect/relation/read/map_read_json.py +320 -85
  65. snowflake/snowpark_connect/relation/read/map_read_parquet.py +142 -27
  66. snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
  67. snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
  68. snowflake/snowpark_connect/relation/read/map_read_table.py +82 -5
  69. snowflake/snowpark_connect/relation/read/map_read_text.py +18 -3
  70. snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
  71. snowflake/snowpark_connect/relation/read/reader_config.py +36 -3
  72. snowflake/snowpark_connect/relation/read/utils.py +50 -5
  73. snowflake/snowpark_connect/relation/stage_locator.py +91 -55
  74. snowflake/snowpark_connect/relation/utils.py +128 -5
  75. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
  76. snowflake/snowpark_connect/relation/write/map_write.py +929 -319
  77. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
  78. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  79. snowflake/snowpark_connect/resources_initializer.py +110 -48
  80. snowflake/snowpark_connect/server.py +546 -456
  81. snowflake/snowpark_connect/server_common/__init__.py +500 -0
  82. snowflake/snowpark_connect/snowflake_session.py +65 -0
  83. snowflake/snowpark_connect/start_server.py +53 -5
  84. snowflake/snowpark_connect/type_mapping.py +349 -27
  85. snowflake/snowpark_connect/typed_column.py +9 -7
  86. snowflake/snowpark_connect/utils/artifacts.py +9 -8
  87. snowflake/snowpark_connect/utils/cache.py +49 -27
  88. snowflake/snowpark_connect/utils/concurrent.py +36 -1
  89. snowflake/snowpark_connect/utils/context.py +187 -37
  90. snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
  91. snowflake/snowpark_connect/utils/env_utils.py +5 -1
  92. snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
  93. snowflake/snowpark_connect/utils/identifiers.py +137 -3
  94. snowflake/snowpark_connect/utils/io_utils.py +57 -1
  95. snowflake/snowpark_connect/utils/java_stored_procedure.py +125 -0
  96. snowflake/snowpark_connect/utils/java_udaf_utils.py +303 -0
  97. snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
  98. snowflake/snowpark_connect/utils/jvm_udf_utils.py +248 -0
  99. snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
  100. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
  101. snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
  102. snowflake/snowpark_connect/utils/profiling.py +25 -8
  103. snowflake/snowpark_connect/utils/scala_udf_utils.py +101 -332
  104. snowflake/snowpark_connect/utils/sequence.py +21 -0
  105. snowflake/snowpark_connect/utils/session.py +64 -28
  106. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
  107. snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
  108. snowflake/snowpark_connect/utils/telemetry.py +163 -22
  109. snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
  110. snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
  111. snowflake/snowpark_connect/utils/udf_cache.py +117 -41
  112. snowflake/snowpark_connect/utils/udf_helper.py +39 -37
  113. snowflake/snowpark_connect/utils/udf_utils.py +133 -14
  114. snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
  115. snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
  116. snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
  117. snowflake/snowpark_connect/version.py +1 -1
  118. snowflake/snowpark_decoder/dp_session.py +6 -2
  119. snowflake/snowpark_decoder/spark_decoder.py +12 -0
  120. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-submit +2 -2
  121. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/METADATA +14 -7
  122. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/RECORD +129 -167
  123. snowflake/snowpark_connect/hidden_column.py +0 -39
  124. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  125. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  126. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  127. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  128. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  129. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  130. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  131. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  132. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  133. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  134. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  135. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  136. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  137. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  138. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  139. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  140. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  141. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  142. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  143. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  144. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  145. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  146. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  147. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  148. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  149. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  150. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  151. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  152. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  153. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  154. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  155. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  156. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  157. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  158. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  159. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  160. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  161. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  162. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  163. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  164. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  165. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  166. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  167. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  168. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
  169. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  170. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  171. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  172. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  173. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  174. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  175. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  176. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  177. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  178. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  179. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  180. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  181. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  182. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  183. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  184. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  185. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  186. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-connect +0 -0
  187. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-session +0 -0
  188. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/WHEEL +0 -0
  189. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE-binary +0 -0
  190. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
  191. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/NOTICE-binary +0 -0
  192. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/top_level.txt +0 -0
@@ -2,16 +2,21 @@
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
4
 
5
+ import math
5
6
  import re
7
+ import typing
6
8
  from collections.abc import MutableMapping, MutableSequence
7
- from contextlib import contextmanager
9
+ from contextlib import contextmanager, suppress
8
10
  from contextvars import ContextVar
11
+ from decimal import Decimal
9
12
  from functools import reduce
13
+ from typing import Tuple
10
14
 
11
15
  import jpype
12
16
  import pandas
13
17
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
14
18
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
19
+ import pyspark.sql.connect.proto.types_pb2 as types_proto
15
20
  import sqlglot
16
21
  from google.protobuf.any_pb2 import Any
17
22
  from pyspark.errors.exceptions.base import (
@@ -24,20 +29,28 @@ import snowflake.snowpark.functions as snowpark_fn
24
29
  import snowflake.snowpark_connect.proto.snowflake_expression_ext_pb2 as snowflake_exp_proto
25
30
  import snowflake.snowpark_connect.proto.snowflake_relation_ext_pb2 as snowflake_proto
26
31
  from snowflake import snowpark
32
+ from snowflake.snowpark import Session, types as snowpark_types
27
33
  from snowflake.snowpark._internal.analyzer.analyzer_utils import (
28
34
  quote_name_without_upper_casing,
29
35
  unquote_if_quoted,
30
36
  )
31
37
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
32
38
  from snowflake.snowpark._internal.utils import is_sql_select_statement, quote_name
39
+ from snowflake.snowpark.functions import when_matched, when_not_matched
33
40
  from snowflake.snowpark_connect.config import (
41
+ auto_uppercase_column_identifiers,
34
42
  auto_uppercase_non_column_identifiers,
43
+ check_table_supports_operation,
35
44
  get_boolean_session_config_param,
36
45
  global_config,
46
+ record_table_metadata,
37
47
  set_config_param,
48
+ should_create_temporary_view_in_snowflake,
38
49
  unset_config_param,
39
50
  )
40
51
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
52
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
53
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
41
54
  from snowflake.snowpark_connect.expression.map_expression import (
42
55
  ColumnNameMap,
43
56
  map_single_column_expression,
@@ -51,14 +64,28 @@ from snowflake.snowpark_connect.relation.map_relation import (
51
64
  NATURAL_JOIN_TYPE_BASE,
52
65
  map_relation,
53
66
  )
54
- from snowflake.snowpark_connect.type_mapping import map_snowpark_to_pyspark_types
67
+
68
+ # Import from utils for consistency
69
+ from snowflake.snowpark_connect.relation.utils import is_aggregate_function
70
+ from snowflake.snowpark_connect.snowflake_session import (
71
+ SQL_PASS_THROUGH_MARKER,
72
+ calculate_checksum,
73
+ )
74
+ from snowflake.snowpark_connect.type_mapping import (
75
+ map_snowpark_to_pyspark_types,
76
+ snowpark_to_proto_type,
77
+ )
55
78
  from snowflake.snowpark_connect.utils.context import (
56
79
  _accessing_temp_object,
57
80
  gen_sql_plan_id,
58
- get_session_id,
81
+ get_is_processing_aliased_relation,
82
+ get_spark_session_id,
59
83
  get_sql_plan,
60
84
  push_evaluating_sql_scope,
85
+ push_processed_view,
86
+ push_processing_aliased_relation_scope,
61
87
  push_sql_scope,
88
+ set_plan_id_map,
62
89
  set_sql_args,
63
90
  set_sql_plan_name,
64
91
  )
@@ -68,6 +95,7 @@ from snowflake.snowpark_connect.utils.telemetry import (
68
95
  telemetry,
69
96
  )
70
97
 
98
+ from .. import column_name_handler
71
99
  from ..expression.map_sql_expression import (
72
100
  _window_specs,
73
101
  as_java_list,
@@ -75,7 +103,18 @@ from ..expression.map_sql_expression import (
75
103
  map_logical_plan_expression,
76
104
  sql_parser,
77
105
  )
78
- from ..utils.identifiers import spark_to_sf_single_id
106
+ from ..typed_column import TypedColumn
107
+ from ..utils.identifiers import (
108
+ spark_to_sf_single_id,
109
+ spark_to_sf_single_id_with_unquoting,
110
+ )
111
+ from ..utils.temporary_view_helper import (
112
+ create_snowflake_temporary_view,
113
+ get_temp_view,
114
+ store_temporary_view_as_dataframe,
115
+ unregister_temp_view,
116
+ )
117
+ from .catalogs import SNOWFLAKE_CATALOG
79
118
 
80
119
  _ctes = ContextVar[dict[str, relation_proto.Relation]]("_ctes", default={})
81
120
  _cte_definitions = ContextVar[dict[str, any]]("_cte_definitions", default={})
@@ -84,6 +123,65 @@ _having_condition = ContextVar[expressions_proto.Expression | None](
84
123
  )
85
124
 
86
125
 
126
+ def _map_value_to_literal_proto(
127
+ value: typing.Any, typ: snowpark_types.DataType
128
+ ) -> expressions_proto.Expression.Literal:
129
+ if isinstance(typ, snowpark_types.NullType):
130
+ return expressions_proto.Expression.Literal(null=value)
131
+ if isinstance(typ, snowpark_types.BinaryType):
132
+ return expressions_proto.Expression.Literal(binary=value)
133
+ if isinstance(typ, snowpark_types.BooleanType):
134
+ return expressions_proto.Expression.Literal(boolean=value)
135
+ if isinstance(typ, snowpark_types.ByteType):
136
+ return expressions_proto.Expression.Literal(byte=value)
137
+ if isinstance(typ, snowpark_types.ShortType):
138
+ return expressions_proto.Expression.Literal(short=value)
139
+ if isinstance(typ, snowpark_types.IntegerType):
140
+ return expressions_proto.Expression.Literal(integer=value)
141
+ if isinstance(typ, snowpark_types.LongType):
142
+ return expressions_proto.Expression.Literal(long=value)
143
+ if isinstance(typ, snowpark_types.FloatType):
144
+ return expressions_proto.Expression.Literal(float=value)
145
+ if isinstance(typ, snowpark_types.DoubleType):
146
+ return expressions_proto.Expression.Literal(double=value)
147
+ if isinstance(typ, snowpark_types.DecimalType):
148
+ return expressions_proto.Expression.Literal(
149
+ decimal=expressions_proto.Expression.Literal.Decimal(
150
+ value=value,
151
+ precision=typ.precision,
152
+ scale=typ.scale,
153
+ )
154
+ )
155
+ if isinstance(typ, snowpark_types.ArrayType):
156
+ element_type_proto = types_proto.DataType(
157
+ **snowpark_to_proto_type(typ.element_type)
158
+ )
159
+
160
+ return expressions_proto.Expression.Literal(
161
+ array=expressions_proto.Expression.Literal.Array(
162
+ element_type=element_type_proto,
163
+ elements=[
164
+ _map_value_to_literal_proto(el, typ.element_type) for el in value
165
+ ],
166
+ )
167
+ )
168
+
169
+ if isinstance(typ, snowpark_types.StructType):
170
+ struct_type_proto = types_proto.DataType(**snowpark_to_proto_type(typ))
171
+
172
+ return expressions_proto.Expression.Literal(
173
+ struct=expressions_proto.Expression.Literal.Struct(
174
+ struct_type=struct_type_proto,
175
+ elements=[
176
+ _map_value_to_literal_proto(v, typ.fields[i].datatype)
177
+ for i, v in enumerate(value.values())
178
+ ],
179
+ )
180
+ )
181
+
182
+ return expressions_proto.Expression.Literal(string=str(value))
183
+
184
+
87
185
  def _is_sql_select_statement_helper(sql_string: str) -> bool:
88
186
  """
89
187
  Determine if a SQL string is a SELECT or CTE query statement, even when it starts with comments or whitespace.
@@ -130,6 +228,48 @@ def _push_cte_scope():
130
228
  _cte_definitions.reset(def_token)
131
229
 
132
230
 
231
+ def _process_cte_relations(cte_relations):
232
+ """
233
+ Process CTE relations and register them in the current CTE scope.
234
+
235
+ This function extracts CTE definitions from CTE relations,
236
+ maps them to protobuf representations, and stores them for later reference.
237
+
238
+ Args:
239
+ cte_relations: Java list of CTE relations (tuples of name and SubqueryAlias)
240
+ """
241
+ for cte in as_java_list(cte_relations):
242
+ name = str(cte._1())
243
+ # Store the original CTE definition for re-evaluation
244
+ _cte_definitions.get()[name] = cte._2()
245
+ # Process CTE definition with a unique plan_id to ensure proper column naming
246
+ # Clear HAVING condition before processing each CTE to prevent leakage between CTEs
247
+ saved_having = _having_condition.get()
248
+ _having_condition.set(None)
249
+ try:
250
+ cte_plan_id = gen_sql_plan_id()
251
+ cte_proto = map_logical_plan_relation(cte._2(), cte_plan_id)
252
+ _ctes.get()[name] = cte_proto
253
+ finally:
254
+ _having_condition.set(saved_having)
255
+
256
+
257
+ @contextmanager
258
+ def _with_cte_scope(cte_relations):
259
+ """
260
+ Context manager that creates a CTE scope and processes CTE relations.
261
+
262
+ This combines _push_cte_scope() and _process_cte_relations() to handle
263
+ the common pattern of processing CTEs within a new scope.
264
+
265
+ Args:
266
+ cte_relations: Java list of CTE relations (tuples of name and SubqueryAlias)
267
+ """
268
+ with (_push_cte_scope()):
269
+ _process_cte_relations(cte_relations)
270
+ yield
271
+
272
+
133
273
  @contextmanager
134
274
  def _push_window_specs_scope():
135
275
  """
@@ -203,6 +343,9 @@ def _rename_columns(
203
343
  def _create_table_as_select(logical_plan, mode: str) -> None:
204
344
  # TODO: for as select create tables we'd map multi layer identifier here
205
345
  name = get_relation_identifier_name(logical_plan.name())
346
+ full_table_identifier = get_relation_identifier_name(
347
+ logical_plan.name(), is_multi_part=True
348
+ )
206
349
  comment = logical_plan.tableSpec().comment()
207
350
 
208
351
  container = execute_logical_plan(logical_plan.query())
@@ -223,9 +366,158 @@ def _create_table_as_select(logical_plan, mode: str) -> None:
223
366
  mode=mode,
224
367
  )
225
368
 
369
+ # Record table metadata for CREATE TABLE AS SELECT
370
+ # These are typically considered v2 tables and support RENAME COLUMN
371
+ record_table_metadata(
372
+ table_identifier=full_table_identifier,
373
+ table_type="v2",
374
+ data_source="default",
375
+ supports_column_rename=True,
376
+ )
377
+
378
+
379
+ def _insert_into_table(logical_plan, session: Session) -> None:
380
+ df_container = execute_logical_plan(logical_plan.query())
381
+ df = df_container.dataframe
382
+ queries = df.queries["queries"]
383
+ if len(queries) != 1:
384
+ exception = SnowparkConnectNotImplementedError(
385
+ f"Unexpected number of queries: {len(queries)}"
386
+ )
387
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
388
+ raise exception
389
+
390
+ name = get_relation_identifier_name(logical_plan.table(), True)
391
+
392
+ user_columns = [
393
+ spark_to_sf_single_id(str(col), is_column=True)
394
+ for col in as_java_list(logical_plan.userSpecifiedCols())
395
+ ]
396
+ overwrite_str = "OVERWRITE" if logical_plan.overwrite() else ""
397
+ cols_str = "(" + ", ".join(user_columns) + ")" if user_columns else ""
398
+
399
+ # Extract partition spec if any
400
+ partition_spec = logical_plan.partitionSpec()
401
+ partition_map = as_java_map(partition_spec)
402
+
403
+ partition_columns = {}
404
+ for entry in partition_map.entrySet():
405
+ col_name = str(entry.getKey())
406
+ value_option = entry.getValue()
407
+ if value_option.isDefined():
408
+ partition_columns[col_name] = value_option.get()
409
+
410
+ target_table = session.table(name)
411
+ target_schema = target_table.schema
412
+
413
+ # Add partition columns to the dataframe
414
+ if partition_columns:
415
+ """
416
+ Spark sends them in the partition spec and the values won't be present in the values array.
417
+ As snowflake does not support static partitions in INSERT INTO statements,
418
+ we need to add the partition columns to the dataframe as literal columns.
419
+
420
+ ex: INSERT INTO TABLE test_table PARTITION (ds='2021-01-01', hr=10) VALUES ('k1', 100), ('k2', 200), ('k3', 300)
421
+
422
+ Spark sends: VALUES ('k1', 100), ('k2', 200), ('k3', 300) with partition spec (ds='2021-01-01', hr=10)
423
+ Snowflake expects: VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
424
+
425
+ We need to add the partition columns to the dataframe as literal columns.
426
+
427
+ ex: df = df.withColumn('ds', snowpark_fn.lit('2021-01-01'))
428
+ df = df.withColumn('hr', snowpark_fn.lit(10))
429
+
430
+ Then the final query will be:
431
+ INSERT INTO TABLE test_table VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
432
+ """
433
+ for partition_col, partition_value in partition_columns.items():
434
+
435
+ def _comparable_col_name(col: str) -> str:
436
+ name = col.upper() if auto_uppercase_column_identifiers() else col
437
+ return unquote_if_quoted(name)
438
+
439
+ comparable_target_schema = [
440
+ _comparable_col_name(col.name) for col in target_schema.fields
441
+ ]
442
+
443
+ if _comparable_col_name(partition_col) not in comparable_target_schema:
444
+ exception = AnalysisException(
445
+ f"{partition_col} is not a valid partition column in table {name}."
446
+ )
447
+ attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
448
+ raise exception
449
+ df = df.withColumn(partition_col, snowpark_fn.lit(partition_value))
450
+
451
+ expected_number_of_columns = (
452
+ len(user_columns) if user_columns else len(target_schema.fields)
453
+ )
454
+ if expected_number_of_columns != len(df.schema.fields):
455
+ reason = (
456
+ "too many data columns"
457
+ if len(df.schema.fields) > expected_number_of_columns
458
+ else "not enough data columns"
459
+ )
460
+ exception = AnalysisException(
461
+ f'[INSERT_COLUMN_ARITY_MISMATCH.{reason.replace(" ", "_").upper()}] Cannot write to {name}, the reason is {reason}:\n'
462
+ f'Table columns: {", ".join(target_schema.names)}.\n'
463
+ f'Data columns: {", ".join(df.schema.names)}.'
464
+ )
465
+ attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
466
+ raise exception
467
+
468
+ try:
469
+ # Modify df with type conversions and struct field name mapping
470
+ modified_columns = []
471
+ for source_field, target_field in zip(df.schema.fields, target_schema.fields):
472
+ col_name = source_field.name
473
+
474
+ # Handle different type conversions
475
+ if isinstance(
476
+ target_field.datatype, snowpark.types.DecimalType
477
+ ) and isinstance(
478
+ source_field.datatype,
479
+ (snowpark.types.FloatType, snowpark.types.DoubleType),
480
+ ):
481
+ # Add CASE WHEN to convert NaN to NULL for DECIMAL targets
482
+ # Only apply this to floating-point source columns
483
+ modified_col = (
484
+ snowpark_fn.when(
485
+ snowpark_fn.equal_nan(snowpark_fn.col(col_name)),
486
+ snowpark_fn.lit(None),
487
+ )
488
+ .otherwise(snowpark_fn.col(col_name))
489
+ .alias(col_name)
490
+ )
491
+ modified_columns.append(modified_col)
492
+ elif (
493
+ isinstance(target_field.datatype, snowpark.types.StructType)
494
+ and source_field.datatype != target_field.datatype
495
+ ):
496
+ # Cast struct with field name mapping (e.g., col1,col2 -> i1,i2)
497
+ # This fixes INSERT INTO table with struct literals like (2, 3)
498
+ modified_col = (
499
+ snowpark_fn.col(col_name)
500
+ .cast(target_field.datatype, rename_fields=True)
501
+ .alias(col_name)
502
+ )
503
+ modified_columns.append(modified_col)
504
+ else:
505
+ modified_columns.append(snowpark_fn.col(col_name))
506
+
507
+ df = df.select(modified_columns)
508
+ except Exception:
509
+ pass
510
+
511
+ queries = df.queries["queries"]
512
+ final_query = queries[0]
513
+ session.sql(
514
+ f"INSERT {overwrite_str} INTO {name} {cols_str} {final_query}",
515
+ ).collect()
516
+
226
517
 
227
518
  def _spark_field_to_sql(field: jpype.JObject, is_column: bool) -> str:
228
- # Column names will be uppercased according to "snowpark.connect.sql.identifiers.auto-uppercase",
519
+ # Column names will be uppercased according to "snowpark.connect.sql.identifiers.auto-uppercase"
520
+ # if present, or to "spark.sql.caseSensitive".
229
521
  # and struct fields will be left as is. This should allow users to use the same names
230
522
  # in spark and Snowflake in most cases.
231
523
  if is_column:
@@ -300,6 +592,69 @@ def _remove_column_data_type(node):
300
592
  return node
301
593
 
302
594
 
595
+ def _get_condition_from_action(action, column_mapping, typer):
596
+ condition = None
597
+ if action.condition().isDefined():
598
+ (_, condition_typed_col,) = map_single_column_expression(
599
+ map_logical_plan_expression(action.condition().get()),
600
+ column_mapping,
601
+ typer,
602
+ )
603
+ condition = condition_typed_col.col
604
+ return condition
605
+
606
+
607
+ def _get_assignments_from_action(
608
+ action,
609
+ column_mapping_source,
610
+ column_mapping_target,
611
+ typer_source,
612
+ typer_target,
613
+ ):
614
+ assignments = dict()
615
+ if (
616
+ action.getClass().getSimpleName() == "InsertAction"
617
+ or action.getClass().getSimpleName() == "UpdateAction"
618
+ ):
619
+ incoming_assignments = as_java_list(action.assignments())
620
+ for assignment in incoming_assignments:
621
+ (_, key_typ_col) = map_single_column_expression(
622
+ map_logical_plan_expression(assignment.key()),
623
+ column_mapping=column_mapping_target,
624
+ typer=typer_target,
625
+ )
626
+ key_name = typer_target.df.select(key_typ_col.col).columns[0]
627
+
628
+ (_, val_typ_col) = map_single_column_expression(
629
+ map_logical_plan_expression(assignment.value()),
630
+ column_mapping=column_mapping_source,
631
+ typer=typer_source,
632
+ )
633
+
634
+ assignments[key_name] = val_typ_col.col
635
+ elif (
636
+ action.getClass().getSimpleName() == "InsertStarAction"
637
+ or action.getClass().getSimpleName() == "UpdateStarAction"
638
+ ):
639
+ if len(column_mapping_source.columns) != len(column_mapping_target.columns):
640
+ exception = ValueError(
641
+ "source and target must have the same number of columns for InsertStarAction or UpdateStarAction"
642
+ )
643
+ attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
644
+ raise exception
645
+ for i, col in enumerate(column_mapping_target.columns):
646
+ if assignments.get(col.snowpark_name) is not None:
647
+ exception = SnowparkConnectNotImplementedError(
648
+ "UpdateStarAction or InsertStarAction is not supported with duplicate columns."
649
+ )
650
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
651
+ raise exception
652
+ assignments[col.snowpark_name] = snowpark_fn.col(
653
+ column_mapping_source.columns[i].snowpark_name
654
+ )
655
+ return assignments
656
+
657
+
303
658
  def map_sql_to_pandas_df(
304
659
  sql_string: str,
305
660
  named_args: MutableMapping[str, expressions_proto.Expression.Literal],
@@ -311,7 +666,7 @@ def map_sql_to_pandas_df(
311
666
  returns a tuple of None for SELECT queries to enable lazy evaluation
312
667
  """
313
668
 
314
- snowpark_connect_sql_passthrough = get_sql_passthrough()
669
+ snowpark_connect_sql_passthrough, sql_string = is_valid_passthrough_sql(sql_string)
315
670
 
316
671
  if not snowpark_connect_sql_passthrough:
317
672
  logical_plan = sql_parser().parsePlan(sql_string)
@@ -327,6 +682,7 @@ def map_sql_to_pandas_df(
327
682
  ) == "UnresolvedHint":
328
683
  logical_plan = logical_plan.child()
329
684
 
685
+ # TODO: Add support for temporary views for SQL cases such as ShowViews, ShowColumns ect. (Currently the cases are not compatible with Spark, returning raw Snowflake rows)
330
686
  match class_name:
331
687
  case "AddColumns":
332
688
  # Handle ALTER TABLE ... ADD COLUMNS (col_name data_type) -> ADD COLUMN col_name data_type
@@ -397,9 +753,11 @@ def map_sql_to_pandas_df(
397
753
  snowflake_sql = f"ALTER TABLE {table_name} ALTER COLUMN {column_name} {alter_clause}"
398
754
  session.sql(snowflake_sql).collect()
399
755
  else:
400
- raise ValueError(
756
+ exception = ValueError(
401
757
  f"No alter operations found in AlterColumn logical plan for table {table_name}, column {column_name}"
402
758
  )
759
+ attach_custom_error_code(exception, ErrorCodes.INVALID_SQL_SYNTAX)
760
+ raise exception
403
761
  case "CreateNamespace":
404
762
  name = get_relation_identifier_name(logical_plan.name(), True)
405
763
  previous_name = session.connection.schema
@@ -421,6 +779,9 @@ def map_sql_to_pandas_df(
421
779
  )
422
780
 
423
781
  name = get_relation_identifier_name(logical_plan.name())
782
+ full_table_identifier = get_relation_identifier_name(
783
+ logical_plan.name(), is_multi_part=True
784
+ )
424
785
  columns = ", ".join(
425
786
  _spark_field_to_sql(f, True)
426
787
  for f in logical_plan.tableSchema().fields()
@@ -431,10 +792,48 @@ def map_sql_to_pandas_df(
431
792
  if comment_opt.isDefined()
432
793
  else ""
433
794
  )
795
+
796
+ # Extract data source for metadata tracking
797
+ data_source = "default"
798
+
799
+ with suppress(Exception):
800
+ # Get data source from tableSpec.provider() (for USING clause)
801
+ if hasattr(logical_plan, "tableSpec"):
802
+ table_spec = logical_plan.tableSpec()
803
+ if hasattr(table_spec, "provider"):
804
+ provider_opt = table_spec.provider()
805
+ if provider_opt.isDefined():
806
+ data_source = str(provider_opt.get()).lower()
807
+ else:
808
+ # Fall back to checking properties for FORMAT
809
+ table_properties = table_spec.properties()
810
+ if not table_properties.isEmpty():
811
+ for prop in table_properties.get():
812
+ if str(prop.key()) == "FORMAT":
813
+ data_source = str(prop.value()).lower()
814
+ break
815
+
434
816
  # NOTE: We are intentionally ignoring any FORMAT=... parameters here.
435
817
  session.sql(
436
818
  f"CREATE {replace_table} TABLE {if_not_exists}{name} ({columns}) {comment}"
437
819
  ).collect()
820
+
821
+ # Record table metadata for Spark compatibility
822
+ # Tables created with explicit schema are considered v1 tables
823
+ # v1 tables with certain data sources don't support RENAME COLUMN in OSS Spark
824
+ supports_rename = data_source not in (
825
+ "parquet",
826
+ "csv",
827
+ "json",
828
+ "orc",
829
+ "avro",
830
+ )
831
+ record_table_metadata(
832
+ table_identifier=full_table_identifier,
833
+ table_type="v1",
834
+ data_source=data_source,
835
+ supports_column_rename=supports_rename,
836
+ )
438
837
  case "CreateTableAsSelect":
439
838
  mode = "ignore" if logical_plan.ignoreIfExists() else "errorifexists"
440
839
  _create_table_as_select(logical_plan, mode=mode)
@@ -446,20 +845,62 @@ def map_sql_to_pandas_df(
446
845
  f"CREATE TABLE {if_not_exists}{name} LIKE {source}"
447
846
  ).collect()
448
847
  case "CreateTempViewUsing":
848
+ parsed_sql = sqlglot.parse_one(sql_string, dialect="spark")
849
+
850
+ spark_view_name = next(parsed_sql.find_all(sqlglot.exp.Table)).name
851
+
852
+ # extract ONLY top-level column definitions (not nested struct fields)
853
+ column_defs = []
854
+ schema_node = next(parsed_sql.find_all(sqlglot.exp.Schema), None)
855
+ if schema_node:
856
+ for expr in schema_node.expressions:
857
+ if isinstance(expr, sqlglot.exp.ColumnDef):
858
+ column_defs.append(expr)
859
+
860
+ num_columns = len(column_defs)
861
+ if num_columns > 0:
862
+ null_list_parts = []
863
+ for col_def in column_defs:
864
+ col_name = spark_to_sf_single_id(col_def.name, is_column=True)
865
+ col_type = col_def.kind
866
+ if col_type:
867
+ null_list_parts.append(
868
+ f"CAST(NULL AS {col_type.sql(dialect='snowflake')}) AS {col_name}"
869
+ )
870
+ else:
871
+ null_list_parts.append(f"NULL AS {col_name}")
872
+ null_list = ", ".join(null_list_parts)
873
+ else:
874
+ null_list = "*"
875
+
449
876
  empty_select = (
450
- " AS SELECT * WHERE 1 = 0"
877
+ f" AS SELECT {null_list} WHERE 1 = 0"
451
878
  if logical_plan.options().isEmpty()
452
879
  and logical_plan.children().isEmpty()
453
880
  else ""
454
881
  )
455
- parsed_sql = (
456
- sqlglot.parse_one(sql_string, dialect="spark")
457
- .transform(_normalize_identifiers)
882
+
883
+ transformed_sql = (
884
+ parsed_sql.transform(_normalize_identifiers)
458
885
  .transform(_remove_column_data_type)
459
886
  .transform(_remove_file_format_property)
460
887
  )
461
- snowflake_sql = parsed_sql.sql(dialect="snowflake")
888
+ snowflake_sql = transformed_sql.sql(dialect="snowflake")
462
889
  session.sql(f"{snowflake_sql}{empty_select}").collect()
890
+ snowflake_view_name = spark_to_sf_single_id_with_unquoting(
891
+ spark_view_name
892
+ )
893
+ temp_view = get_temp_view(snowflake_view_name)
894
+ if temp_view is not None and not logical_plan.replace():
895
+ exception = AnalysisException(
896
+ f"[TEMP_TABLE_OR_VIEW_ALREADY_EXISTS] Cannot create the temporary view `{spark_view_name}` because it already exists."
897
+ )
898
+ attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
899
+ raise exception
900
+ else:
901
+ unregister_temp_view(
902
+ spark_to_sf_single_id_with_unquoting(spark_view_name)
903
+ )
463
904
  case "CreateView":
464
905
  current_schema = session.connection.schema
465
906
  if (
@@ -475,11 +916,13 @@ def map_sql_to_pandas_df(
475
916
  df_container = execute_logical_plan(logical_plan.query())
476
917
  df = df_container.dataframe
477
918
  if _accessing_temp_object.get():
478
- raise AnalysisException(
919
+ exception = AnalysisException(
479
920
  f"[INVALID_TEMP_OBJ_REFERENCE] Cannot create the persistent object `{CURRENT_CATALOG_NAME}`.`{current_schema}`.`{object_name}` "
480
921
  "of the type VIEW because it references to a temporary object of the type VIEW. Please "
481
922
  f"make the temporary object persistent, or make the persistent object `{CURRENT_CATALOG_NAME}`.`{current_schema}`.`{object_name}` temporary."
482
923
  )
924
+ attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
925
+ raise exception
483
926
 
484
927
  name = get_relation_identifier_name(logical_plan.child())
485
928
  comment = logical_plan.comment()
@@ -496,58 +939,143 @@ def map_sql_to_pandas_df(
496
939
  else None,
497
940
  )
498
941
  case "CreateViewCommand":
499
- df_container = execute_logical_plan(logical_plan.plan())
500
- df = df_container.dataframe
501
- tmp_views = _get_current_temp_objects()
502
- tmp_views.add(
503
- (
504
- CURRENT_CATALOG_NAME,
505
- session.connection.schema,
506
- str(logical_plan.name().identifier()),
942
+ with push_processed_view(logical_plan.name().identifier()):
943
+ df_container = execute_logical_plan(logical_plan.plan())
944
+ df = df_container.dataframe
945
+ user_specified_spark_column_names = [
946
+ str(col._1())
947
+ for col in as_java_list(logical_plan.userSpecifiedColumns())
948
+ ]
949
+ df_container = DataFrameContainer.create_with_column_mapping(
950
+ dataframe=df,
951
+ spark_column_names=user_specified_spark_column_names
952
+ if user_specified_spark_column_names
953
+ else df_container.column_map.get_spark_columns(),
954
+ snowpark_column_names=df_container.column_map.get_snowpark_columns(),
955
+ parent_column_name_map=df_container.column_map,
507
956
  )
508
- )
509
957
 
510
- name = str(logical_plan.name().identifier())
511
- name = spark_to_sf_single_id(name)
512
- if isinstance(
513
- logical_plan.viewType(),
514
- jpype.JClass(
515
- "org.apache.spark.sql.catalyst.analysis.GlobalTempView$"
516
- ),
517
- ):
518
- name = f"{global_config.spark_sql_globalTempDatabase}.{name}"
519
- comment = logical_plan.comment()
520
- maybe_comment = (
521
- _escape_sql_comment(str(comment.get()))
522
- if comment.isDefined()
523
- else None
524
- )
958
+ is_global = isinstance(
959
+ logical_plan.viewType(),
960
+ jpype.JClass(
961
+ "org.apache.spark.sql.catalyst.analysis.GlobalTempView$"
962
+ ),
963
+ )
525
964
 
526
- df = _rename_columns(
527
- df, logical_plan.userSpecifiedColumns(), df_container.column_map
528
- )
965
+ def get_cached_view_name() -> str:
966
+ if is_global:
967
+ view_name = [
968
+ global_config.spark_sql_globalTempDatabase,
969
+ logical_plan.name().quotedString(),
970
+ ]
971
+ else:
972
+ view_name = [logical_plan.name().quotedString()]
973
+ view_name = [
974
+ spark_to_sf_single_id_with_unquoting(part)
975
+ for part in view_name
976
+ ]
977
+ return ".".join(view_name)
978
+
979
+ def get_snowflake_view_name() -> list[str]:
980
+ snowpark_view_name = str(logical_plan.name().identifier())
981
+ snowpark_view_name = spark_to_sf_single_id(snowpark_view_name)
982
+ return (
983
+ [
984
+ global_config.spark_sql_globalTempDatabase,
985
+ snowpark_view_name,
986
+ ]
987
+ if is_global
988
+ else [snowpark_view_name]
989
+ )
529
990
 
530
- if logical_plan.replace():
531
- df.create_or_replace_temp_view(
532
- name,
533
- comment=maybe_comment,
534
- )
535
- else:
536
- df.create_temp_view(
537
- name,
538
- comment=maybe_comment,
991
+ snowflake_view_name = get_snowflake_view_name()
992
+ cached_view_name = get_cached_view_name()
993
+
994
+ tmp_views = _get_current_temp_objects()
995
+ tmp_views.add(
996
+ (
997
+ CURRENT_CATALOG_NAME,
998
+ session.connection.schema,
999
+ str(logical_plan.name().identifier()),
1000
+ )
539
1001
  )
1002
+
1003
+ def _create_snowflake_temporary_view():
1004
+ comment = logical_plan.comment()
1005
+ maybe_comment = (
1006
+ _escape_sql_comment(str(comment.get()))
1007
+ if comment.isDefined()
1008
+ else None
1009
+ )
1010
+
1011
+ renamed_df = _rename_columns(
1012
+ df,
1013
+ logical_plan.userSpecifiedColumns(),
1014
+ df_container.column_map,
1015
+ )
1016
+
1017
+ create_snowflake_temporary_view(
1018
+ renamed_df,
1019
+ snowflake_view_name,
1020
+ cached_view_name,
1021
+ logical_plan.replace(),
1022
+ maybe_comment,
1023
+ )
1024
+
1025
+ if should_create_temporary_view_in_snowflake():
1026
+ _create_snowflake_temporary_view()
1027
+ else:
1028
+ user_specified_spark_column_names = [
1029
+ str(col._1())
1030
+ for col in as_java_list(logical_plan.userSpecifiedColumns())
1031
+ ]
1032
+ spark_column_names = (
1033
+ user_specified_spark_column_names
1034
+ if user_specified_spark_column_names
1035
+ else df_container.column_map.get_spark_columns()
1036
+ )
1037
+ store_temporary_view_as_dataframe(
1038
+ df,
1039
+ df_container.column_map,
1040
+ spark_column_names,
1041
+ df_container.column_map.get_snowpark_columns(),
1042
+ cached_view_name,
1043
+ snowflake_view_name,
1044
+ logical_plan.replace(),
1045
+ )
540
1046
  case "DescribeColumn":
541
- name = get_relation_identifier_name(logical_plan.column())
1047
+ name = get_relation_identifier_name_without_uppercasing(
1048
+ logical_plan.column()
1049
+ )
1050
+ stored_temp_view = get_temp_view(name)
1051
+ if stored_temp_view:
1052
+ return (
1053
+ SNOWFLAKE_CATALOG._list_columns_from_dataframe_container(
1054
+ stored_temp_view
1055
+ ),
1056
+ "",
1057
+ )
542
1058
  # todo double check if this is correct
1059
+ name = get_relation_identifier_name(logical_plan.column())
543
1060
  rows = session.sql(f"DESCRIBE TABLE {name}").collect()
544
1061
  case "DescribeNamespace":
545
1062
  name = get_relation_identifier_name(logical_plan.namespace(), True)
546
- name = change_default_to_public(name)
547
1063
  rows = session.sql(f"DESCRIBE SCHEMA {name}").collect()
548
1064
  if not rows:
549
1065
  rows = None
550
1066
  case "DescribeRelation":
1067
+ name = get_relation_identifier_name_without_uppercasing(
1068
+ logical_plan.relation(), True
1069
+ )
1070
+ stored_temp_view = get_temp_view(name)
1071
+ if stored_temp_view:
1072
+ return (
1073
+ SNOWFLAKE_CATALOG._list_columns_from_dataframe_container(
1074
+ stored_temp_view
1075
+ ),
1076
+ "",
1077
+ )
1078
+
551
1079
  name = get_relation_identifier_name(logical_plan.relation(), True)
552
1080
  rows = session.sql(f"DESCRIBE TABLE {name}").collect()
553
1081
  if not rows:
@@ -598,9 +1126,11 @@ def map_sql_to_pandas_df(
598
1126
  del session._udtfs[func_name]
599
1127
  else:
600
1128
  if not logical_plan.ifExists():
601
- raise ValueError(
1129
+ exception = ValueError(
602
1130
  f"Function {func_name} not found among registered UDFs or UDTFs."
603
1131
  )
1132
+ attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
1133
+ raise exception
604
1134
  if snowpark_name != "":
605
1135
  argument_string = f"({', '.join(convert_sp_to_sf_type(arg) for arg in input_types)})"
606
1136
  session.sql(
@@ -615,9 +1145,13 @@ def map_sql_to_pandas_df(
615
1145
  if_exists = "IF EXISTS " if logical_plan.ifExists() else ""
616
1146
  session.sql(f"DROP TABLE {if_exists}{name}").collect()
617
1147
  case "DropView":
618
- name = get_relation_identifier_name(logical_plan.child())
619
- if_exists = "IF EXISTS " if logical_plan.ifExists() else ""
620
- session.sql(f"DROP VIEW {if_exists}{name}").collect()
1148
+ temporary_view_name = get_relation_identifier_name_without_uppercasing(
1149
+ logical_plan.child()
1150
+ )
1151
+ if not unregister_temp_view(temporary_view_name):
1152
+ name = get_relation_identifier_name(logical_plan.child())
1153
+ if_exists = "IF EXISTS " if logical_plan.ifExists() else ""
1154
+ session.sql(f"DROP VIEW {if_exists}{name}").collect()
621
1155
  case "ExplainCommand":
622
1156
  inner_plan = logical_plan.logicalPlan()
623
1157
  logical_plan_name = inner_plan.nodeName()
@@ -669,84 +1203,189 @@ def map_sql_to_pandas_df(
669
1203
  rows = session.sql(final_sql).collect()
670
1204
  else:
671
1205
  # TODO: Support other logical plans
672
- raise SnowparkConnectNotImplementedError(
1206
+ exception = SnowparkConnectNotImplementedError(
673
1207
  f"{logical_plan_name} is not supported yet with EXPLAIN."
674
1208
  )
1209
+ attach_custom_error_code(
1210
+ exception, ErrorCodes.UNSUPPORTED_OPERATION
1211
+ )
1212
+ raise exception
675
1213
  case "InsertIntoStatement":
676
- df_container = execute_logical_plan(logical_plan.query())
677
- df = df_container.dataframe
678
- queries = df.queries["queries"]
679
- if len(queries) != 1:
680
- raise SnowparkConnectNotImplementedError(
681
- f"Unexpected number of queries: {len(queries)}"
1214
+ _insert_into_table(logical_plan, session)
1215
+ case "MergeIntoTable":
1216
+ source_df_container = map_relation(
1217
+ map_logical_plan_relation(logical_plan.sourceTable())
1218
+ )
1219
+ source_df = source_df_container.dataframe
1220
+ plan_id = gen_sql_plan_id()
1221
+ target_df_container = map_relation(
1222
+ map_logical_plan_relation(logical_plan.targetTable(), plan_id)
1223
+ )
1224
+ target_df = target_df_container.dataframe
1225
+
1226
+ if (
1227
+ logical_plan.targetTable().getClass().getSimpleName()
1228
+ == "UnresolvedRelation"
1229
+ ):
1230
+ target_table_name = _spark_to_snowflake(
1231
+ logical_plan.targetTable().multipartIdentifier()
1232
+ )
1233
+ else:
1234
+ target_table_name = _spark_to_snowflake(
1235
+ logical_plan.targetTable().child().multipartIdentifier()
682
1236
  )
683
1237
 
684
- name = get_relation_identifier_name(logical_plan.table(), True)
1238
+ target_table = session.table(target_table_name)
1239
+ target_table_columns = target_table.columns
1240
+ target_df_spark_names = []
1241
+ for target_table_col, target_df_col in zip(
1242
+ target_table_columns, target_df_container.column_map.columns
1243
+ ):
1244
+ target_df = target_df.with_column_renamed(
1245
+ target_df_col.snowpark_name,
1246
+ target_table_col,
1247
+ )
1248
+ target_df_spark_names.append(target_df_col.spark_name)
1249
+ target_df_container = DataFrameContainer.create_with_column_mapping(
1250
+ dataframe=target_df,
1251
+ spark_column_names=target_df_spark_names,
1252
+ snowpark_column_names=target_table_columns,
1253
+ )
685
1254
 
686
- user_columns = [
687
- spark_to_sf_single_id(str(col), is_column=True)
688
- for col in as_java_list(logical_plan.userSpecifiedCols())
689
- ]
690
- overwrite_str = "OVERWRITE" if logical_plan.overwrite() else ""
691
- cols_str = "(" + ", ".join(user_columns) + ")" if user_columns else ""
1255
+ set_plan_id_map(plan_id, target_df_container)
692
1256
 
693
- try:
694
- target_table = session.table(name)
695
- target_schema = target_table.schema
1257
+ joined_df_before_condition: snowpark.DataFrame = source_df.join(
1258
+ target_df
1259
+ )
696
1260
 
697
- # Modify df with NaN → NULL conversion for DECIMAL columns
698
- modified_columns = []
699
- for source_field, target_field in zip(
700
- df.schema.fields, target_schema.fields
701
- ):
702
- col_name = source_field.name
703
- if isinstance(
704
- target_field.datatype, snowpark.types.DecimalType
705
- ) and isinstance(
706
- source_field.datatype,
707
- (snowpark.types.FloatType, snowpark.types.DoubleType),
708
- ):
709
- # Add CASE WHEN to convert NaN to NULL for DECIMAL targets
710
- # Only apply this to floating-point source columns
711
- modified_col = (
712
- snowpark_fn.when(
713
- snowpark_fn.equal_nan(snowpark_fn.col(col_name)),
714
- snowpark_fn.lit(None),
715
- )
716
- .otherwise(snowpark_fn.col(col_name))
717
- .alias(col_name)
718
- )
719
- modified_columns.append(modified_col)
720
- else:
721
- modified_columns.append(snowpark_fn.col(col_name))
1261
+ column_mapping_for_conditions = column_name_handler.JoinColumnNameMap(
1262
+ source_df_container.column_map,
1263
+ target_df_container.column_map,
1264
+ )
1265
+ typer_for_expressions = ExpressionTyper(joined_df_before_condition)
722
1266
 
723
- df = df.select(modified_columns)
724
- except Exception:
725
- pass
726
- queries = df.queries["queries"]
727
- final_query = queries[0]
728
- session.sql(
729
- f"INSERT {overwrite_str} INTO {name} {cols_str} {final_query}",
730
- ).collect()
731
- case "MergeIntoTable":
732
- raise UnsupportedOperationException(
733
- "[UNSUPPORTED_SQL_EXTENSION] The MERGE INTO command failed.\n"
734
- + "Reason: This command is a platform-specific SQL extension and is not part of the standard Apache Spark specification that this interface uses."
1267
+ (_, merge_condition_typed_col,) = map_single_column_expression(
1268
+ map_logical_plan_expression(logical_plan.mergeCondition()),
1269
+ column_mapping=column_mapping_for_conditions,
1270
+ typer=typer_for_expressions,
735
1271
  )
1272
+
1273
+ clauses = []
1274
+
1275
+ for matched_action in as_java_list(logical_plan.matchedActions()):
1276
+ condition = _get_condition_from_action(
1277
+ matched_action,
1278
+ column_mapping_for_conditions,
1279
+ typer_for_expressions,
1280
+ )
1281
+ if matched_action.getClass().getSimpleName() == "DeleteAction":
1282
+ clauses.append(when_matched(condition).delete())
1283
+ elif (
1284
+ matched_action.getClass().getSimpleName() == "UpdateAction"
1285
+ or matched_action.getClass().getSimpleName()
1286
+ == "UpdateStarAction"
1287
+ ):
1288
+ assignments = _get_assignments_from_action(
1289
+ matched_action,
1290
+ source_df_container.column_map,
1291
+ target_df_container.column_map,
1292
+ ExpressionTyper(source_df),
1293
+ ExpressionTyper(target_df),
1294
+ )
1295
+ clauses.append(when_matched(condition).update(assignments))
1296
+
1297
+ for not_matched_action in as_java_list(
1298
+ logical_plan.notMatchedActions()
1299
+ ):
1300
+ condition = _get_condition_from_action(
1301
+ not_matched_action,
1302
+ column_mapping_for_conditions,
1303
+ typer_for_expressions,
1304
+ )
1305
+ if (
1306
+ not_matched_action.getClass().getSimpleName() == "InsertAction"
1307
+ or not_matched_action.getClass().getSimpleName()
1308
+ == "InsertStarAction"
1309
+ ):
1310
+ assignments = _get_assignments_from_action(
1311
+ not_matched_action,
1312
+ source_df_container.column_map,
1313
+ target_df_container.column_map,
1314
+ ExpressionTyper(source_df),
1315
+ ExpressionTyper(target_df),
1316
+ )
1317
+ clauses.append(when_not_matched(condition).insert(assignments))
1318
+
1319
+ if not as_java_list(logical_plan.notMatchedBySourceActions()).isEmpty():
1320
+ exception = SnowparkConnectNotImplementedError(
1321
+ "Snowflake does not support 'not matched by source' actions in MERGE statements."
1322
+ )
1323
+ attach_custom_error_code(
1324
+ exception, ErrorCodes.UNSUPPORTED_OPERATION
1325
+ )
1326
+ raise exception
1327
+
1328
+ target_table.merge(source_df, merge_condition_typed_col.col, clauses)
736
1329
  case "DeleteFromTable":
737
- raise UnsupportedOperationException(
738
- "[UNSUPPORTED_SQL_EXTENSION] The DELETE FROM command failed.\n"
739
- + "Reason: This command is a platform-specific SQL extension and is not part of the standard Apache Spark specification that this interface uses."
1330
+ df_container = map_relation(
1331
+ map_logical_plan_relation(logical_plan.table())
740
1332
  )
1333
+ name = get_relation_identifier_name(logical_plan.table(), True)
1334
+ table = session.table(name)
1335
+ table_columns = table.columns
1336
+ df = df_container.dataframe
1337
+ spark_names = []
1338
+ for table_col, df_col in zip(
1339
+ table_columns, df_container.column_map.columns
1340
+ ):
1341
+ df = df.with_column_renamed(
1342
+ df_col.snowpark_name,
1343
+ table_col,
1344
+ )
1345
+ spark_names.append(df_col.spark_name)
1346
+ df_container = DataFrameContainer.create_with_column_mapping(
1347
+ dataframe=df,
1348
+ spark_column_names=spark_names,
1349
+ snowpark_column_names=table_columns,
1350
+ )
1351
+ df = df_container.dataframe
1352
+ (
1353
+ condition_column_name,
1354
+ condition_typed_col,
1355
+ ) = map_single_column_expression(
1356
+ map_logical_plan_expression(logical_plan.condition()),
1357
+ df_container.column_map,
1358
+ ExpressionTyper(df),
1359
+ )
1360
+ table.delete(condition_typed_col.col)
741
1361
  case "UpdateTable":
742
1362
  # Databricks/Delta-specific extension not supported by SAS.
743
1363
  # Provide an actionable, clear error.
744
- raise UnsupportedOperationException(
1364
+ exception = UnsupportedOperationException(
745
1365
  "[UNSUPPORTED_SQL_EXTENSION] The UPDATE TABLE command failed.\n"
746
1366
  + "Reason: This command is a platform-specific SQL extension and is not part of the standard Apache Spark specification that this interface uses."
747
1367
  )
1368
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
1369
+ raise exception
748
1370
  case "RenameColumn":
749
- table_name = get_relation_identifier_name(logical_plan.table(), True)
1371
+ full_table_identifier = get_relation_identifier_name(
1372
+ logical_plan.table(), True
1373
+ )
1374
+
1375
+ # Check Spark compatibility for RENAME COLUMN operation
1376
+ if not check_table_supports_operation(
1377
+ full_table_identifier, "rename_column"
1378
+ ):
1379
+ exception = AnalysisException(
1380
+ f"ALTER TABLE RENAME COLUMN is not supported for table '{full_table_identifier}'. "
1381
+ f"This table was created as a v1 table with a data source that doesn't support column renaming. "
1382
+ f"To enable this operation, set 'snowpark.connect.enable_snowflake_extension_behavior' to 'true'."
1383
+ )
1384
+ attach_custom_error_code(
1385
+ exception, ErrorCodes.UNSUPPORTED_OPERATION
1386
+ )
1387
+ raise exception
1388
+
750
1389
  column_obj = logical_plan.column()
751
1390
  old_column_name = ".".join(
752
1391
  spark_to_sf_single_id(str(part), is_column=True)
@@ -756,7 +1395,7 @@ def map_sql_to_pandas_df(
756
1395
  case_insensitive_name = next(
757
1396
  (
758
1397
  f.name
759
- for f in session.table(table_name).schema.fields
1398
+ for f in session.table(full_table_identifier).schema.fields
760
1399
  if f.name.lower() == old_column_name.lower()
761
1400
  ),
762
1401
  None,
@@ -768,7 +1407,7 @@ def map_sql_to_pandas_df(
768
1407
  )
769
1408
 
770
1409
  # Pass through to Snowflake
771
- snowflake_sql = f"ALTER TABLE {table_name} RENAME COLUMN {old_column_name} TO {new_column_name}"
1410
+ snowflake_sql = f"ALTER TABLE {full_table_identifier} RENAME COLUMN {old_column_name} TO {new_column_name}"
772
1411
  session.sql(snowflake_sql).collect()
773
1412
  case "RenameTable":
774
1413
  name = get_relation_identifier_name(logical_plan.child(), True)
@@ -786,30 +1425,31 @@ def map_sql_to_pandas_df(
786
1425
  f"ALTER ICEBERG TABLE {name} RENAME TO {new_name}"
787
1426
  ).collect()
788
1427
  else:
1428
+ attach_custom_error_code(e, ErrorCodes.INTERNAL_ERROR)
789
1429
  raise e
790
1430
  case "ReplaceTableAsSelect":
791
1431
  _create_table_as_select(logical_plan, mode="overwrite")
792
1432
  case "ResetCommand":
793
1433
  key = logical_plan.config().get()
794
- unset_config_param(get_session_id(), key, session)
1434
+ unset_config_param(get_spark_session_id(), key, session)
795
1435
  case "SetCatalogAndNamespace":
796
1436
  # TODO: add catalog setting here
797
1437
  name = get_relation_identifier_name(logical_plan.child(), True)
798
- name = change_default_to_public(name)
799
1438
  session.sql(f"USE SCHEMA {name}").collect()
800
1439
  case "SetCommand":
801
1440
  kv_result_tuple = logical_plan.kv().get()
802
1441
  key = kv_result_tuple._1()
803
1442
  val = kv_result_tuple._2().get()
804
- set_config_param(get_session_id(), key, val, session)
1443
+ set_config_param(get_spark_session_id(), key, val, session)
805
1444
  case "SetNamespaceCommand":
806
1445
  name = _spark_to_snowflake(logical_plan.namespace())
807
- name = change_default_to_public(name)
808
1446
  session.sql(f"USE SCHEMA {name}").collect()
809
1447
  case "SetNamespaceLocation" | "SetNamespaceProperties":
810
- raise SnowparkConnectNotImplementedError(
1448
+ exception = SnowparkConnectNotImplementedError(
811
1449
  "Altering databases is not currently supported."
812
1450
  )
1451
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
1452
+ raise exception
813
1453
  case "ShowCreateTable":
814
1454
  # Handle SHOW CREATE TABLE command
815
1455
  # Spark: SHOW CREATE TABLE table_name
@@ -831,16 +1471,24 @@ def map_sql_to_pandas_df(
831
1471
  case "ShowNamespaces":
832
1472
  name = get_relation_identifier_name(logical_plan.namespace(), True)
833
1473
  if name:
834
- raise SnowparkConnectNotImplementedError(
1474
+ exception = SnowparkConnectNotImplementedError(
835
1475
  "'IN' clause is not supported while listing databases"
836
1476
  )
1477
+ attach_custom_error_code(
1478
+ exception, ErrorCodes.UNSUPPORTED_OPERATION
1479
+ )
1480
+ raise exception
837
1481
  if logical_plan.pattern().isDefined():
838
1482
  # Snowflake SQL requires a "%" pattern.
839
1483
  # Snowpark catalog requires a regex and does client-side filtering.
840
1484
  # Spark, however, uses a regex-like pattern that treats '*' and '|' differently.
841
- raise SnowparkConnectNotImplementedError(
1485
+ exception = SnowparkConnectNotImplementedError(
842
1486
  "'LIKE' clause is not supported while listing databases"
843
1487
  )
1488
+ attach_custom_error_code(
1489
+ exception, ErrorCodes.UNSUPPORTED_OPERATION
1490
+ )
1491
+ raise exception
844
1492
  rows = session.sql("SHOW SCHEMAS").collect()
845
1493
  if not rows:
846
1494
  rows = None
@@ -913,6 +1561,18 @@ def map_sql_to_pandas_df(
913
1561
  if pattern and rows:
914
1562
  rows = _filter_tables_by_pattern(rows, pattern)
915
1563
  case "ShowColumns":
1564
+ name = get_relation_identifier_name_without_uppercasing(
1565
+ logical_plan.child(), True
1566
+ )
1567
+ stored_temp_view = get_temp_view(name)
1568
+ if stored_temp_view:
1569
+ return (
1570
+ SNOWFLAKE_CATALOG._list_columns_from_dataframe_container(
1571
+ stored_temp_view
1572
+ ),
1573
+ "",
1574
+ )
1575
+
916
1576
  # Handle Spark SQL: SHOW COLUMNS IN table_name FROM database_name
917
1577
  # Convert to Snowflake SQL: SHOW COLUMNS IN TABLE database_name.table_name
918
1578
 
@@ -941,9 +1601,13 @@ def map_sql_to_pandas_df(
941
1601
  spark_to_sf_single_id(str(db_and_table_name[0])).casefold()
942
1602
  != db_name.casefold()
943
1603
  ):
944
- raise AnalysisException(
1604
+ exception = AnalysisException(
945
1605
  f"database name is not matching:{db_name} and {db_and_table_name[0]}"
946
1606
  )
1607
+ attach_custom_error_code(
1608
+ exception, ErrorCodes.INVALID_OPERATION
1609
+ )
1610
+ raise exception
947
1611
 
948
1612
  # Just table name
949
1613
  snowflake_cmd = f"SHOW COLUMNS IN TABLE {table_name}"
@@ -981,6 +1645,51 @@ def map_sql_to_pandas_df(
981
1645
  return pandas.DataFrame({"": [""]}), ""
982
1646
 
983
1647
  rows = session.sql(snowflake_sql).collect()
1648
+ case "RefreshTable":
1649
+ table_name_unquoted = ".".join(
1650
+ str(part)
1651
+ for part in as_java_list(logical_plan.child().multipartIdentifier())
1652
+ )
1653
+ SNOWFLAKE_CATALOG.refreshTable(table_name_unquoted)
1654
+
1655
+ return pandas.DataFrame({"": [""]}), ""
1656
+ case "RepairTable":
1657
+ # No-Op: Snowflake doesn't have explicit partitions to repair.
1658
+ table_relation = logical_plan.child()
1659
+ db_and_table_name = as_java_list(table_relation.multipartIdentifier())
1660
+ multi_part_len = len(db_and_table_name)
1661
+
1662
+ if multi_part_len == 1:
1663
+ table_name = db_and_table_name[0]
1664
+ db_name = None
1665
+ full_table_name = table_name
1666
+ else:
1667
+ db_name = db_and_table_name[0]
1668
+ table_name = db_and_table_name[1]
1669
+ full_table_name = db_name + "." + table_name
1670
+
1671
+ df = SNOWFLAKE_CATALOG.tableExists(table_name, db_name)
1672
+
1673
+ table_exist = df.iloc[0, 0]
1674
+
1675
+ if not table_exist:
1676
+ exception = AnalysisException(
1677
+ f"[TABLE_OR_VIEW_NOT_FOUND] Table not found `{full_table_name}`."
1678
+ )
1679
+ attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
1680
+ raise exception
1681
+
1682
+ return pandas.DataFrame({"": [""]}), ""
1683
+ case "UnresolvedWith":
1684
+ child = logical_plan.child()
1685
+ child_class = str(child.getClass().getSimpleName())
1686
+ match child_class:
1687
+ case "InsertIntoStatement":
1688
+ with _with_cte_scope(logical_plan.cteRelations()):
1689
+ _insert_into_table(child, get_or_create_snowpark_session())
1690
+ case _:
1691
+ execute_logical_plan(logical_plan)
1692
+ return None, None
984
1693
  case _:
985
1694
  execute_logical_plan(logical_plan)
986
1695
  return None, None
@@ -1001,6 +1710,27 @@ def get_sql_passthrough() -> bool:
1001
1710
  return get_boolean_session_config_param("snowpark.connect.sql.passthrough")
1002
1711
 
1003
1712
 
1713
+ def is_valid_passthrough_sql(sql_stmt: str) -> Tuple[bool, str]:
1714
+ """
1715
+ Checks if :param sql_stmt: should be executed as SQL pass-through. SQL pass-through can be detected in 1 of 2 ways:
1716
+ 1) Either Spark config parameter "snowpark.connect.sql.passthrough" is set (legacy mode, to be deprecated)
1717
+ 2) If :param sql_stmt: is created through SnowflakeSession and has correct marker + checksum
1718
+ """
1719
+ if get_sql_passthrough():
1720
+ # legacy style pass-through, sql_stmt should be a whole, valid SQL statement
1721
+ return True, sql_stmt
1722
+
1723
+ # check for new style, SnowflakeSession based SQL pass-through
1724
+ sql_parts = sql_stmt.split(" ", 2)
1725
+ if len(sql_parts) == 3:
1726
+ marker, checksum, sql = sql_parts
1727
+ if marker == SQL_PASS_THROUGH_MARKER and checksum == calculate_checksum(sql):
1728
+ return True, sql
1729
+
1730
+ # Not a SQL pass-through
1731
+ return False, sql_stmt
1732
+
1733
+
1004
1734
  def change_default_to_public(name: str) -> str:
1005
1735
  """
1006
1736
  Change the namespace to PUBLIC when given name is DEFAULT
@@ -1015,6 +1745,76 @@ def change_default_to_public(name: str) -> str:
1015
1745
  return name
1016
1746
 
1017
1747
 
1748
+ def _preprocess_identifier_calls(sql_query: str) -> str:
1749
+ """
1750
+ Pre-process SQL query to resolve IDENTIFIER() calls before Spark parsing.
1751
+
1752
+ Transforms: IDENTIFIER('abs')(c2) -> abs(c2)
1753
+ Transforms: IDENTIFIER('COAL' || 'ESCE')(NULL, 1) -> COALESCE(NULL, 1)
1754
+
1755
+ This preserves all function arguments in their original positions, eliminating
1756
+ the need to reconstruct them at the expression level.
1757
+ """
1758
+ import re
1759
+
1760
+ # Pattern to match IDENTIFIER(...) followed by optional function call arguments
1761
+ # This captures both the identifier expression and any trailing arguments
1762
+ # Note: We need to be careful about whitespace preservation
1763
+ identifier_pattern = r"IDENTIFIER\s*\(\s*([^)]+)\s*\)(\s*)(\([^)]*\))?"
1764
+
1765
+ def resolve_identifier_match(match):
1766
+ identifier_expr_str = match.group(1).strip()
1767
+ whitespace = match.group(2) if match.group(2) else ""
1768
+ function_args = match.group(3) if match.group(3) else ""
1769
+
1770
+ try:
1771
+ # Handle string concatenation FIRST: IDENTIFIER('COAL' || 'ESCE')
1772
+ # (Must check this before simple strings since it also starts/ends with quotes)
1773
+ if "||" in identifier_expr_str:
1774
+ # Parse basic string concatenation with proper quote handling
1775
+ parts = []
1776
+ split_parts = identifier_expr_str.split("||")
1777
+ for part in split_parts:
1778
+ part = part.strip()
1779
+ if part.startswith("'") and part.endswith("'"):
1780
+ unquoted = part[1:-1] # Remove quotes from each part
1781
+ parts.append(unquoted)
1782
+ else:
1783
+ # Non-string parts - return original for safety
1784
+ return match.group(0)
1785
+ resolved_name = "".join(parts) # Concatenate the unquoted parts
1786
+
1787
+ # Handle simple string literals: IDENTIFIER('abs')
1788
+ elif identifier_expr_str.startswith("'") and identifier_expr_str.endswith(
1789
+ "'"
1790
+ ):
1791
+ resolved_name = identifier_expr_str[1:-1] # Remove quotes
1792
+
1793
+ else:
1794
+ # Complex expressions not supported yet - return original
1795
+ return match.group(0)
1796
+
1797
+ # Return resolved function call with preserved arguments and whitespace
1798
+ if function_args:
1799
+ # Function call case: IDENTIFIER('abs')(c1) -> abs(c1)
1800
+ result = f"{resolved_name}{function_args}"
1801
+ else:
1802
+ # Column reference case: IDENTIFIER('c1') FROM -> c1 FROM (preserve whitespace)
1803
+ result = f"{resolved_name}{whitespace}"
1804
+ return result
1805
+
1806
+ except Exception:
1807
+ # Return original to avoid breaking the query
1808
+ return match.group(0)
1809
+
1810
+ # Apply the transformation
1811
+ processed_query = re.sub(
1812
+ identifier_pattern, resolve_identifier_match, sql_query, flags=re.IGNORECASE
1813
+ )
1814
+
1815
+ return processed_query
1816
+
1817
+
1018
1818
  def map_sql(
1019
1819
  rel: relation_proto.Relation,
1020
1820
  ) -> DataFrameContainer:
@@ -1026,10 +1826,15 @@ def map_sql(
1026
1826
  In passthough mode as True, SAS calls session.sql() and not calling Spark Parser.
1027
1827
  This is to mitigate any issue not covered by spark logical plan to protobuf conversion.
1028
1828
  """
1029
- snowpark_connect_sql_passthrough = get_sql_passthrough()
1829
+ snowpark_connect_sql_passthrough, sql_stmt = is_valid_passthrough_sql(rel.sql.query)
1030
1830
 
1031
1831
  if not snowpark_connect_sql_passthrough:
1032
- logical_plan = sql_parser().parseQuery(rel.sql.query)
1832
+ # Changed from parseQuery to parsePlan as Spark parseQuery() call generating wrong logical plan for
1833
+ # query like this: SELECT cast('3.4' as decimal(38, 18)) UNION SELECT 'foo'
1834
+ # As such other place in this file we use parsePlan.
1835
+ # Main difference between parsePlan() and parseQuery() is, parsePlan() can be called for any SQL statement, while
1836
+ # parseQuery() can only be called for query statements.
1837
+ logical_plan = sql_parser().parsePlan(sql_stmt)
1033
1838
 
1034
1839
  parsed_pos_args = parse_pos_args(logical_plan, rel.sql.pos_args)
1035
1840
  set_sql_args(rel.sql.args, parsed_pos_args)
@@ -1037,7 +1842,7 @@ def map_sql(
1037
1842
  return execute_logical_plan(logical_plan)
1038
1843
  else:
1039
1844
  session = snowpark.Session.get_active_session()
1040
- sql_df = session.sql(rel.sql.query)
1845
+ sql_df = session.sql(sql_stmt)
1041
1846
  columns = sql_df.columns
1042
1847
  return DataFrameContainer.create_with_column_mapping(
1043
1848
  dataframe=sql_df,
@@ -1112,7 +1917,19 @@ def map_logical_plan_relation(
1112
1917
  attr_parts = as_java_list(expr.nameParts())
1113
1918
  if len(attr_parts) == 1:
1114
1919
  attr_name = str(attr_parts[0])
1115
- return alias_map.get(attr_name, expr)
1920
+ if attr_name in alias_map:
1921
+ # Check if the alias references an aggregate function
1922
+ # If so, don't substitute because you can't GROUP BY an aggregate
1923
+ aliased_expr = alias_map[attr_name]
1924
+ aliased_expr_class = str(
1925
+ aliased_expr.getClass().getSimpleName()
1926
+ )
1927
+ if aliased_expr_class == "UnresolvedFunction":
1928
+ func_name = str(aliased_expr.nameParts().head())
1929
+ if is_aggregate_function(func_name):
1930
+ return expr
1931
+ return aliased_expr
1932
+ return expr
1116
1933
 
1117
1934
  return expr
1118
1935
 
@@ -1129,9 +1946,13 @@ def map_logical_plan_relation(
1129
1946
  group_type = snowflake_proto.Aggregate.GROUP_TYPE_CUBE
1130
1947
  case "GroupingSets":
1131
1948
  if not exp.userGivenGroupByExprs().isEmpty():
1132
- raise SnowparkConnectNotImplementedError(
1949
+ exception = SnowparkConnectNotImplementedError(
1133
1950
  "User-defined group by expressions are not supported"
1134
1951
  )
1952
+ attach_custom_error_code(
1953
+ exception, ErrorCodes.UNSUPPORTED_OPERATION
1954
+ )
1955
+ raise exception
1135
1956
  group_type = (
1136
1957
  snowflake_proto.Aggregate.GROUP_TYPE_GROUPING_SETS
1137
1958
  )
@@ -1147,9 +1968,13 @@ def map_logical_plan_relation(
1147
1968
 
1148
1969
  if group_type != snowflake_proto.Aggregate.GROUP_TYPE_GROUPBY:
1149
1970
  if len(group_expression_list) != 1:
1150
- raise SnowparkConnectNotImplementedError(
1971
+ exception = SnowparkConnectNotImplementedError(
1151
1972
  "Multiple grouping expressions are not supported"
1152
1973
  )
1974
+ attach_custom_error_code(
1975
+ exception, ErrorCodes.UNSUPPORTED_OPERATION
1976
+ )
1977
+ raise exception
1153
1978
  if group_type == snowflake_proto.Aggregate.GROUP_TYPE_GROUPING_SETS:
1154
1979
  group_expression_list = [] # TODO: exp.userGivenGroupByExprs()?
1155
1980
  else:
@@ -1281,38 +2106,89 @@ def map_logical_plan_relation(
1281
2106
  case "Pivot":
1282
2107
  pivot_column = map_logical_plan_expression(rel.pivotColumn())
1283
2108
  session = snowpark.Session.get_active_session()
1284
- m = ColumnNameMap([], [], None)
2109
+ m = ColumnNameMap([], [])
1285
2110
 
1286
- pivot_values = [
1287
- map_logical_plan_expression(e) for e in as_java_list(rel.pivotValues())
1288
- ]
2111
+ pivot_columns = (
2112
+ [
2113
+ col
2114
+ for col in pivot_column.unresolved_function.arguments
2115
+ if col.HasField("unresolved_attribute")
2116
+ ]
2117
+ if pivot_column.HasField("unresolved_function")
2118
+ else [pivot_column]
2119
+ )
1289
2120
 
1290
- pivot_literals = []
2121
+ typer = ExpressionTyper.dummy_typer(session)
1291
2122
 
1292
- for expr_proto in pivot_values:
1293
- expr = map_single_column_expression(
1294
- expr_proto, m, ExpressionTyper.dummy_typer(session)
2123
+ expression_protos: list[expressions_proto.Expression] = []
2124
+ expressions: list[TypedColumn] = []
2125
+ aliases: list[str] = []
2126
+
2127
+ for pivot_value in as_java_list(rel.pivotValues()):
2128
+ expr_proto = map_logical_plan_expression(pivot_value)
2129
+ alias, expr = map_single_column_expression(expr_proto, m, typer)
2130
+
2131
+ expression_protos.append(expr_proto)
2132
+ expressions.append(expr)
2133
+ aliases.append(alias)
2134
+
2135
+ resolved_pivot_values_row = (
2136
+ session.range(1)
2137
+ .select(*[expr.col for expr in expressions])
2138
+ .collect()[0]
2139
+ )
2140
+ resolved_pivot_values = [value for value in resolved_pivot_values_row]
2141
+
2142
+ pivot_values = []
2143
+ for expr_proto, expr, alias, value in zip(
2144
+ expression_protos, expressions, aliases, resolved_pivot_values
2145
+ ):
2146
+ literals_proto = (
2147
+ [
2148
+ _map_value_to_literal_proto(v, expr.typ.fields[i].datatype)
2149
+ for i, v in enumerate(value)
2150
+ ]
2151
+ if isinstance(expr.typ, snowpark.types.StructType)
2152
+ else [_map_value_to_literal_proto(value, expr.typ)]
1295
2153
  )
1296
- value = session.range(1).select(expr[1].col).collect()[0][0]
1297
- pivot_literals.append(
1298
- expressions_proto.Expression.Literal(string=str(value))
2154
+
2155
+ if len(pivot_columns) != len(literals_proto):
2156
+ raise AnalysisException(
2157
+ f"[PIVOT_VALUE_DATA_TYPE_MISMATCH] Number of pivot columns ({len(pivot_columns)}) does not match number of values ({len(literals_proto)})"
2158
+ )
2159
+
2160
+ current_pivot_value_proto = (
2161
+ snowflake_proto.Aggregate.Pivot.PivotValue(
2162
+ values=literals_proto, alias=alias
2163
+ )
2164
+ if expr_proto.HasField("alias")
2165
+ else snowflake_proto.Aggregate.Pivot.PivotValue(
2166
+ values=literals_proto
2167
+ )
1299
2168
  )
1300
2169
 
2170
+ pivot_values.append(current_pivot_value_proto)
2171
+
1301
2172
  aggregate_expressions = [
1302
2173
  map_logical_plan_expression(e) for e in as_java_list(rel.aggregates())
1303
2174
  ]
1304
2175
 
1305
- proto = relation_proto.Relation(
1306
- aggregate=relation_proto.Aggregate(
1307
- input=map_logical_plan_relation(rel.child()),
1308
- aggregate_expressions=aggregate_expressions,
1309
- group_type=relation_proto.Aggregate.GroupType.GROUP_TYPE_PIVOT,
1310
- pivot=relation_proto.Aggregate.Pivot(
1311
- col=pivot_column, values=pivot_literals
1312
- ),
2176
+ any_proto = Any()
2177
+ any_proto.Pack(
2178
+ snowflake_proto.Extension(
2179
+ aggregate=snowflake_proto.Aggregate(
2180
+ input=map_logical_plan_relation(rel.child()),
2181
+ group_type=relation_proto.Aggregate.GroupType.GROUP_TYPE_PIVOT,
2182
+ aggregate_expressions=aggregate_expressions,
2183
+ having_condition=_having_condition.get(),
2184
+ pivot=snowflake_proto.Aggregate.Pivot(
2185
+ pivot_columns=pivot_columns,
2186
+ pivot_values=pivot_values,
2187
+ ),
2188
+ )
1313
2189
  )
1314
2190
  )
1315
-
2191
+ proto = relation_proto.Relation(extension=any_proto)
1316
2192
  case "PlanWithUnresolvedIdentifier":
1317
2193
  expr_proto = map_logical_plan_expression(rel.identifierExpr())
1318
2194
  session = snowpark.Session.get_active_session()
@@ -1343,23 +2219,119 @@ def map_logical_plan_relation(
1343
2219
  )
1344
2220
  )
1345
2221
  case "Sort":
2222
+ # Process the input first
2223
+ input_proto = map_logical_plan_relation(rel.child())
2224
+
2225
+ # Check if child is a Project - if so, build an alias map for ORDER BY resolution
2226
+ # This handles: SELECT o.date AS order_date ... ORDER BY o.date
2227
+ child_class = str(rel.child().getClass().getSimpleName())
2228
+ alias_map = {}
2229
+
2230
+ if child_class == "Project":
2231
+ # Extract aliases from SELECT clause
2232
+ for proj_expr in list(as_java_list(rel.child().projectList())):
2233
+ if str(proj_expr.getClass().getSimpleName()) == "Alias":
2234
+ alias_name = str(proj_expr.name())
2235
+ child_expr = proj_expr.child()
2236
+
2237
+ # Store mapping from original expression to alias name
2238
+ # Use string representation for matching
2239
+ expr_str = str(child_expr)
2240
+ alias_map[expr_str] = alias_name
2241
+
2242
+ # Also handle UnresolvedAttribute specifically to get the qualified name
2243
+ if (
2244
+ str(child_expr.getClass().getSimpleName())
2245
+ == "UnresolvedAttribute"
2246
+ ):
2247
+ # Get the qualified name like "o.date"
2248
+ name_parts = list(as_java_list(child_expr.nameParts()))
2249
+ qualified_name = ".".join(str(part) for part in name_parts)
2250
+ if qualified_name not in alias_map:
2251
+ alias_map[qualified_name] = alias_name
2252
+
2253
+ # Process ORDER BY expressions, substituting aliases where needed
2254
+ order_list = []
2255
+ for order_expr in as_java_list(rel.order()):
2256
+ # Get the child expression from the SortOrder
2257
+ child_expr = order_expr.child()
2258
+ expr_class = str(child_expr.getClass().getSimpleName())
2259
+
2260
+ # Check if this expression matches any aliased expression
2261
+ expr_str = str(child_expr)
2262
+ substituted = False
2263
+
2264
+ if expr_str in alias_map:
2265
+ # Found a match - substitute with alias reference
2266
+ alias_name = alias_map[expr_str]
2267
+ # Create new UnresolvedAttribute for the alias
2268
+ UnresolvedAttribute = jpype.JClass(
2269
+ "org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute"
2270
+ )
2271
+ new_attr = UnresolvedAttribute.quoted(alias_name)
2272
+
2273
+ # Create new SortOrder with substituted expression
2274
+ SortOrder = jpype.JClass(
2275
+ "org.apache.spark.sql.catalyst.expressions.SortOrder"
2276
+ )
2277
+ new_order = SortOrder(
2278
+ new_attr,
2279
+ order_expr.direction(),
2280
+ order_expr.nullOrdering(),
2281
+ order_expr.sameOrderExpressions(),
2282
+ )
2283
+ order_list.append(map_logical_plan_expression(new_order).sort_order)
2284
+ substituted = True
2285
+ elif expr_class == "UnresolvedAttribute":
2286
+ # Try matching on qualified name
2287
+ name_parts = list(as_java_list(child_expr.nameParts()))
2288
+ qualified_name = ".".join(str(part) for part in name_parts)
2289
+ if qualified_name in alias_map:
2290
+ alias_name = alias_map[qualified_name]
2291
+ UnresolvedAttribute = jpype.JClass(
2292
+ "org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute"
2293
+ )
2294
+ new_attr = UnresolvedAttribute.quoted(alias_name)
2295
+
2296
+ SortOrder = jpype.JClass(
2297
+ "org.apache.spark.sql.catalyst.expressions.SortOrder"
2298
+ )
2299
+ new_order = SortOrder(
2300
+ new_attr,
2301
+ order_expr.direction(),
2302
+ order_expr.nullOrdering(),
2303
+ order_expr.sameOrderExpressions(),
2304
+ )
2305
+ order_list.append(
2306
+ map_logical_plan_expression(new_order).sort_order
2307
+ )
2308
+ substituted = True
2309
+
2310
+ if not substituted:
2311
+ # No substitution needed - use original
2312
+ order_list.append(
2313
+ map_logical_plan_expression(order_expr).sort_order
2314
+ )
2315
+
1346
2316
  proto = relation_proto.Relation(
1347
2317
  sort=relation_proto.Sort(
1348
- input=map_logical_plan_relation(rel.child()),
1349
- order=[
1350
- map_logical_plan_expression(e).sort_order
1351
- for e in as_java_list(rel.order())
1352
- ],
2318
+ input=input_proto,
2319
+ order=order_list,
1353
2320
  )
1354
2321
  )
1355
2322
  case "SubqueryAlias":
1356
2323
  alias = str(rel.alias())
1357
- proto = relation_proto.Relation(
1358
- subquery_alias=relation_proto.SubqueryAlias(
1359
- input=map_logical_plan_relation(rel.child()),
1360
- alias=alias,
1361
- )
2324
+ # If the child is an UnresolvedRelation, we want to preserve the original plan id and save only aliased one
2325
+ process_aliased_relation = (
2326
+ str(rel.child().getClass().getSimpleName()) == "UnresolvedRelation"
1362
2327
  )
2328
+ with push_processing_aliased_relation_scope(process_aliased_relation):
2329
+ proto = relation_proto.Relation(
2330
+ subquery_alias=relation_proto.SubqueryAlias(
2331
+ input=map_logical_plan_relation(rel.child()),
2332
+ alias=alias,
2333
+ )
2334
+ )
1363
2335
  set_sql_plan_name(alias, plan_id)
1364
2336
  case "Union":
1365
2337
  children = as_java_list(rel.children())
@@ -1381,12 +2353,14 @@ def map_logical_plan_relation(
1381
2353
 
1382
2354
  # Check for multi-column UNPIVOT which Snowflake doesn't support
1383
2355
  if len(value_column_names) > 1:
1384
- raise UnsupportedOperationException(
2356
+ exception = UnsupportedOperationException(
1385
2357
  f"Multi-column UNPIVOT is not supported. Snowflake SQL does not support unpivoting "
1386
2358
  f"multiple value columns ({', '.join(value_column_names)}) in a single operation. "
1387
2359
  f"Workaround: Use separate UNPIVOT operations for each value column and join the results, "
1388
2360
  f"or restructure your query to unpivot columns individually."
1389
2361
  )
2362
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
2363
+ raise exception
1390
2364
 
1391
2365
  values = []
1392
2366
  values_groups = as_java_list(rel.values().get())
@@ -1394,11 +2368,13 @@ def map_logical_plan_relation(
1394
2368
  # Check if we have multi-column groups in the IN clause
1395
2369
  if values_groups and len(as_java_list(values_groups[0])) > 1:
1396
2370
  group_sizes = [len(as_java_list(group)) for group in values_groups]
1397
- raise UnsupportedOperationException(
2371
+ exception = UnsupportedOperationException(
1398
2372
  f"Multi-column UNPIVOT is not supported. Snowflake SQL does not support unpivoting "
1399
2373
  f"multiple columns together in groups. Found groups with {max(group_sizes)} columns. "
1400
2374
  f"Workaround: Unpivot each column separately and then join/union the results as needed."
1401
2375
  )
2376
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
2377
+ raise exception
1402
2378
 
1403
2379
  for e1 in values_groups:
1404
2380
  for e in as_java_list(e1):
@@ -1444,9 +2420,11 @@ def map_logical_plan_relation(
1444
2420
  # Store the having condition in context and process the child aggregate
1445
2421
  child_relation = rel.child()
1446
2422
  if str(child_relation.getClass().getSimpleName()) != "Aggregate":
1447
- raise SnowparkConnectNotImplementedError(
2423
+ exception = SnowparkConnectNotImplementedError(
1448
2424
  "UnresolvedHaving can only be applied to Aggregate relations"
1449
2425
  )
2426
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
2427
+ raise exception
1450
2428
 
1451
2429
  # Store having condition in a context variable for the Aggregate case to pick up
1452
2430
  having_condition = map_logical_plan_expression(rel.havingCondition())
@@ -1509,7 +2487,8 @@ def map_logical_plan_relation(
1509
2487
  )
1510
2488
  case "UnresolvedRelation":
1511
2489
  name = str(rel.name())
1512
- set_sql_plan_name(name, plan_id)
2490
+ if not get_is_processing_aliased_relation():
2491
+ set_sql_plan_name(name, plan_id)
1513
2492
 
1514
2493
  cte_proto = _ctes.get().get(name)
1515
2494
  if cte_proto is not None:
@@ -1530,10 +2509,16 @@ def map_logical_plan_relation(
1530
2509
  )
1531
2510
 
1532
2511
  # Re-evaluate the CTE definition with a fresh plan_id
1533
- fresh_plan_id = gen_sql_plan_id()
1534
- fresh_cte_proto = map_logical_plan_relation(
1535
- cte_definition, fresh_plan_id
1536
- )
2512
+ # Clear HAVING condition to prevent leakage from outer CTEs
2513
+ saved_having = _having_condition.get()
2514
+ _having_condition.set(None)
2515
+ try:
2516
+ fresh_plan_id = gen_sql_plan_id()
2517
+ fresh_cte_proto = map_logical_plan_relation(
2518
+ cte_definition, fresh_plan_id
2519
+ )
2520
+ finally:
2521
+ _having_condition.set(saved_having)
1537
2522
 
1538
2523
  # Use SubqueryColumnAliases to ensure consistent column names across CTE references
1539
2524
  # This is crucial for CTEs that reference other CTEs
@@ -1612,14 +2597,35 @@ def map_logical_plan_relation(
1612
2597
  .collect()[0]
1613
2598
  )
1614
2599
 
2600
+ def _parse_value(argument, place):
2601
+ if isinstance(argument, (Decimal, float)):
2602
+ return int(argument)
2603
+ elif isinstance(argument, str):
2604
+ try:
2605
+ value = float(argument)
2606
+ if value < 0:
2607
+ return math.ceil(value)
2608
+ return math.floor(float(argument))
2609
+ except ValueError:
2610
+ raise AnalysisException(
2611
+ f'[UNEXPECTED_INPUT_TYPE] Parameter {place} of function `range` requires the "BIGINT" type, however "{argument}" has the type "STRING"'
2612
+ )
2613
+ return argument
2614
+
1615
2615
  start, step = 0, 1
1616
2616
  match args:
1617
2617
  case [_]:
1618
2618
  [end] = args
2619
+ end = _parse_value(end, 1)
1619
2620
  case [_, _]:
1620
2621
  [start, end] = args
2622
+ start = _parse_value(start, 1)
2623
+ end = _parse_value(end, 2)
1621
2624
  case [_, _, _]:
1622
2625
  [start, end, step] = args
2626
+ start = _parse_value(start, 1)
2627
+ end = _parse_value(end, 2)
2628
+ step = _parse_value(step, 3)
1623
2629
 
1624
2630
  proto = relation_proto.Relation(
1625
2631
  range=relation_proto.Range(
@@ -1688,16 +2694,7 @@ def map_logical_plan_relation(
1688
2694
  ),
1689
2695
  )
1690
2696
  case "UnresolvedWith":
1691
- with _push_cte_scope():
1692
- for cte in as_java_list(rel.cteRelations()):
1693
- name = str(cte._1())
1694
- # Store the original CTE definition for re-evaluation
1695
- _cte_definitions.get()[name] = cte._2()
1696
- # Process CTE definition with a unique plan_id to ensure proper column naming
1697
- cte_plan_id = gen_sql_plan_id()
1698
- cte_proto = map_logical_plan_relation(cte._2(), cte_plan_id)
1699
- _ctes.get()[name] = cte_proto
1700
-
2697
+ with _with_cte_scope(rel.cteRelations()):
1701
2698
  proto = map_logical_plan_relation(rel.child())
1702
2699
  case "LateralJoin":
1703
2700
  left = map_logical_plan_relation(rel.left())
@@ -1719,41 +2716,16 @@ def map_logical_plan_relation(
1719
2716
  _window_specs.get()[key] = window_spec
1720
2717
  proto = map_logical_plan_relation(rel.child())
1721
2718
  case "Generate":
1722
- # Generate creates a nested Project relation (see lines 1785-1790) without
1723
- # setting its plan_id field. When this Project is later processed by map_project
1724
- # (map_column_ops.py), it uses rel.common.plan_id which defaults to 0 for unset
1725
- # protobuf fields. This means all columns from the Generate operation (both exploded
1726
- # columns and passthrough columns) will have plan_id=0 in their names.
1727
- #
1728
- # If Generate's child is a SubqueryAlias whose inner relation was processed
1729
- # with a non-zero plan_id, there will be a mismatch between:
1730
- # - The columns referenced in the Project (expecting plan_id from SubqueryAlias's child)
1731
- # - The actual column names created by Generate's Project (using plan_id=0)
1732
-
1733
- # Therefore, when Generate has a SubqueryAlias child, we explicitly process the inner
1734
- # relation with plan_id=0 to match what Generate's Project will use. This only applies when
1735
- # the immediate child of Generate is a SubqueryAlias and preserves existing registrations (like CTEs),
1736
- # so it won't affect other patterns.
1737
-
1738
2719
  child_class = str(rel.child().getClass().getSimpleName())
1739
2720
 
1740
2721
  if child_class == "SubqueryAlias":
1741
2722
  alias = str(rel.child().alias())
1742
2723
 
1743
- # Check if this alias was already registered during initial SQL parsing
1744
2724
  existing_plan_id = get_sql_plan(alias)
1745
2725
 
1746
- if existing_plan_id is not None:
1747
- # Use the existing plan_id to maintain consistency with prior registration
1748
- used_plan_id = existing_plan_id
1749
- else:
1750
- # Use plan_id=0 to match what the nested Project will use (protobuf default)
1751
- used_plan_id = 0
1752
- set_sql_plan_name(alias, used_plan_id)
1753
-
1754
2726
  # Process the inner child with the determined plan_id
1755
2727
  inner_child = map_logical_plan_relation(
1756
- rel.child().child(), plan_id=used_plan_id
2728
+ rel.child().child(), plan_id=existing_plan_id
1757
2729
  )
1758
2730
  input_relation = relation_proto.Relation(
1759
2731
  subquery_alias=relation_proto.SubqueryAlias(
@@ -1771,19 +2743,19 @@ def map_logical_plan_relation(
1771
2743
  function_name = rel.generator().name().toString()
1772
2744
  func_arguments = [
1773
2745
  map_logical_plan_expression(e)
1774
- for e in as_java_list(rel.generator().children())
2746
+ for e in list(as_java_list(rel.generator().children()))
1775
2747
  ]
1776
2748
  unresolved_fun_proto = expressions_proto.Expression.UnresolvedFunction(
1777
2749
  function_name=function_name, arguments=func_arguments
1778
2750
  )
1779
2751
 
1780
- aliased_proto = unresolved_fun_proto
2752
+ aliased_proto = expressions_proto.Expression(
2753
+ unresolved_function=unresolved_fun_proto,
2754
+ )
1781
2755
  if generator_output_list.size() > 0:
1782
2756
  aliased_proto = expressions_proto.Expression(
1783
2757
  alias=expressions_proto.Expression.Alias(
1784
- expr=expressions_proto.Expression(
1785
- unresolved_function=unresolved_fun_proto,
1786
- ),
2758
+ expr=aliased_proto,
1787
2759
  name=[attribute.name() for attribute in generator_output_list],
1788
2760
  )
1789
2761
  )
@@ -1837,28 +2809,67 @@ def map_logical_plan_relation(
1837
2809
  )
1838
2810
  proto = generator_dataframe_proto
1839
2811
  case other:
1840
- raise SnowparkConnectNotImplementedError(f"Unimplemented relation: {other}")
2812
+ exception = SnowparkConnectNotImplementedError(
2813
+ f"Unimplemented relation: {other}"
2814
+ )
2815
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
2816
+ raise exception
1841
2817
 
1842
2818
  proto.common.plan_id = plan_id
1843
2819
 
1844
2820
  return proto
1845
2821
 
1846
2822
 
2823
+ def _get_relation_identifier(name_obj) -> str:
2824
+ # IDENTIFIER(<table_name>), or IDENTIFIER(<method name>)
2825
+ expr_proto = map_logical_plan_expression(name_obj.identifierExpr())
2826
+ session = snowpark.Session.get_active_session()
2827
+ m = ColumnNameMap([], [], None)
2828
+ expr = map_single_column_expression(
2829
+ expr_proto, m, ExpressionTyper.dummy_typer(session)
2830
+ )
2831
+ return spark_to_sf_single_id(session.range(1).select(expr[1].col).collect()[0][0])
2832
+
2833
+
2834
+ def _create_temp_view_name(parts) -> str:
2835
+ return ".".join(
2836
+ quote_name_without_upper_casing(str(part)) for part in as_java_list(parts)
2837
+ )
2838
+
2839
+
2840
+ def get_relation_identifier_name_without_uppercasing(
2841
+ name_obj, is_multi_part: bool = False
2842
+ ) -> str:
2843
+ if name_obj.getClass().getSimpleName() in (
2844
+ "PlanWithUnresolvedIdentifier",
2845
+ "ExpressionWithUnresolvedIdentifier",
2846
+ ):
2847
+ return _get_relation_identifier(name_obj)
2848
+ elif is_multi_part:
2849
+ try:
2850
+ # Try multipartIdentifier first for full catalog.database.table
2851
+ return _create_temp_view_name(name_obj.multipartIdentifier())
2852
+ except AttributeError:
2853
+ # Fallback to nameParts if multipartIdentifier not available
2854
+ return _create_temp_view_name(name_obj.nameParts())
2855
+ else:
2856
+ return _create_temp_view_name(name_obj.nameParts())
2857
+
2858
+
1847
2859
  def get_relation_identifier_name(name_obj, is_multi_part: bool = False) -> str:
1848
- if name_obj.getClass().getSimpleName() == "PlanWithUnresolvedIdentifier":
1849
- # IDENTIFIER(<table_name>)
1850
- expr_proto = map_logical_plan_expression(name_obj.identifierExpr())
1851
- session = snowpark.Session.get_active_session()
1852
- m = ColumnNameMap([], [], None)
1853
- expr = map_single_column_expression(
1854
- expr_proto, m, ExpressionTyper.dummy_typer(session)
1855
- )
1856
- name = spark_to_sf_single_id(
1857
- session.range(1).select(expr[1].col).collect()[0][0]
1858
- )
2860
+ if name_obj.getClass().getSimpleName() in (
2861
+ "PlanWithUnresolvedIdentifier",
2862
+ "ExpressionWithUnresolvedIdentifier",
2863
+ ):
2864
+ return _get_relation_identifier(name_obj)
1859
2865
  else:
1860
2866
  if is_multi_part:
1861
- name = _spark_to_snowflake(name_obj.multipartIdentifier())
2867
+ try:
2868
+ # Try multipartIdentifier first for full catalog.database.table
2869
+ name = _spark_to_snowflake(name_obj.multipartIdentifier())
2870
+ except AttributeError:
2871
+ # Fallback to nameParts if multipartIdentifier not available
2872
+ name = _spark_to_snowflake(name_obj.nameParts())
1862
2873
  else:
1863
2874
  name = _spark_to_snowflake(name_obj.nameParts())
1864
2875