snowpark-connect 0.27.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  3. snowflake/snowpark_connect/client/__init__.py +15 -0
  4. snowflake/snowpark_connect/client/error_utils.py +30 -0
  5. snowflake/snowpark_connect/client/exceptions.py +36 -0
  6. snowflake/snowpark_connect/client/query_results.py +90 -0
  7. snowflake/snowpark_connect/client/server.py +680 -0
  8. snowflake/snowpark_connect/client/utils/__init__.py +10 -0
  9. snowflake/snowpark_connect/client/utils/session.py +85 -0
  10. snowflake/snowpark_connect/column_name_handler.py +404 -243
  11. snowflake/snowpark_connect/column_qualifier.py +43 -0
  12. snowflake/snowpark_connect/config.py +237 -23
  13. snowflake/snowpark_connect/constants.py +2 -0
  14. snowflake/snowpark_connect/dataframe_container.py +102 -8
  15. snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
  16. snowflake/snowpark_connect/error/error_codes.py +50 -0
  17. snowflake/snowpark_connect/error/error_utils.py +172 -23
  18. snowflake/snowpark_connect/error/exceptions.py +13 -4
  19. snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
  20. snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
  21. snowflake/snowpark_connect/execute_plan/utils.py +5 -1
  22. snowflake/snowpark_connect/expression/function_defaults.py +9 -2
  23. snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
  24. snowflake/snowpark_connect/expression/literal.py +37 -13
  25. snowflake/snowpark_connect/expression/map_cast.py +123 -5
  26. snowflake/snowpark_connect/expression/map_expression.py +80 -27
  27. snowflake/snowpark_connect/expression/map_extension.py +322 -12
  28. snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
  29. snowflake/snowpark_connect/expression/map_udf.py +85 -20
  30. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
  31. snowflake/snowpark_connect/expression/map_unresolved_function.py +2748 -746
  32. snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
  33. snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
  34. snowflake/snowpark_connect/expression/map_window_function.py +18 -3
  35. snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
  36. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
  37. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
  38. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  39. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  40. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
  41. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
  42. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
  43. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
  44. snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
  45. snowflake/snowpark_connect/relation/io_utils.py +110 -10
  46. snowflake/snowpark_connect/relation/map_aggregate.py +196 -255
  47. snowflake/snowpark_connect/relation/map_catalog.py +5 -1
  48. snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
  49. snowflake/snowpark_connect/relation/map_extension.py +263 -29
  50. snowflake/snowpark_connect/relation/map_join.py +683 -442
  51. snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
  52. snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
  53. snowflake/snowpark_connect/relation/map_relation.py +48 -19
  54. snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
  55. snowflake/snowpark_connect/relation/map_show_string.py +13 -6
  56. snowflake/snowpark_connect/relation/map_sql.py +1233 -222
  57. snowflake/snowpark_connect/relation/map_stats.py +48 -9
  58. snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
  59. snowflake/snowpark_connect/relation/map_udtf.py +14 -4
  60. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
  61. snowflake/snowpark_connect/relation/read/map_read.py +134 -43
  62. snowflake/snowpark_connect/relation/read/map_read_csv.py +255 -45
  63. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
  64. snowflake/snowpark_connect/relation/read/map_read_json.py +320 -85
  65. snowflake/snowpark_connect/relation/read/map_read_parquet.py +142 -27
  66. snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
  67. snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
  68. snowflake/snowpark_connect/relation/read/map_read_table.py +82 -5
  69. snowflake/snowpark_connect/relation/read/map_read_text.py +18 -3
  70. snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
  71. snowflake/snowpark_connect/relation/read/reader_config.py +36 -3
  72. snowflake/snowpark_connect/relation/read/utils.py +50 -5
  73. snowflake/snowpark_connect/relation/stage_locator.py +91 -55
  74. snowflake/snowpark_connect/relation/utils.py +128 -5
  75. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
  76. snowflake/snowpark_connect/relation/write/map_write.py +929 -319
  77. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
  78. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  79. snowflake/snowpark_connect/resources_initializer.py +110 -48
  80. snowflake/snowpark_connect/server.py +546 -456
  81. snowflake/snowpark_connect/server_common/__init__.py +500 -0
  82. snowflake/snowpark_connect/snowflake_session.py +65 -0
  83. snowflake/snowpark_connect/start_server.py +53 -5
  84. snowflake/snowpark_connect/type_mapping.py +349 -27
  85. snowflake/snowpark_connect/typed_column.py +9 -7
  86. snowflake/snowpark_connect/utils/artifacts.py +9 -8
  87. snowflake/snowpark_connect/utils/cache.py +49 -27
  88. snowflake/snowpark_connect/utils/concurrent.py +36 -1
  89. snowflake/snowpark_connect/utils/context.py +187 -37
  90. snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
  91. snowflake/snowpark_connect/utils/env_utils.py +5 -1
  92. snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
  93. snowflake/snowpark_connect/utils/identifiers.py +137 -3
  94. snowflake/snowpark_connect/utils/io_utils.py +57 -1
  95. snowflake/snowpark_connect/utils/java_stored_procedure.py +125 -0
  96. snowflake/snowpark_connect/utils/java_udaf_utils.py +303 -0
  97. snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
  98. snowflake/snowpark_connect/utils/jvm_udf_utils.py +248 -0
  99. snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
  100. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
  101. snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
  102. snowflake/snowpark_connect/utils/profiling.py +25 -8
  103. snowflake/snowpark_connect/utils/scala_udf_utils.py +101 -332
  104. snowflake/snowpark_connect/utils/sequence.py +21 -0
  105. snowflake/snowpark_connect/utils/session.py +64 -28
  106. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
  107. snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
  108. snowflake/snowpark_connect/utils/telemetry.py +163 -22
  109. snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
  110. snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
  111. snowflake/snowpark_connect/utils/udf_cache.py +117 -41
  112. snowflake/snowpark_connect/utils/udf_helper.py +39 -37
  113. snowflake/snowpark_connect/utils/udf_utils.py +133 -14
  114. snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
  115. snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
  116. snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
  117. snowflake/snowpark_connect/version.py +1 -1
  118. snowflake/snowpark_decoder/dp_session.py +6 -2
  119. snowflake/snowpark_decoder/spark_decoder.py +12 -0
  120. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-submit +2 -2
  121. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/METADATA +14 -7
  122. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/RECORD +129 -167
  123. snowflake/snowpark_connect/hidden_column.py +0 -39
  124. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  125. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  126. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  127. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  128. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  129. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  130. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  131. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  132. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  133. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  134. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  135. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  136. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  137. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  138. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  139. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  140. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  141. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  142. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  143. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  144. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  145. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  146. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  147. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  148. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  149. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  150. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  151. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  152. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  153. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  154. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  155. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  156. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  157. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  158. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  159. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  160. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  161. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  162. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  163. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  164. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  165. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  166. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  167. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  168. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
  169. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  170. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  171. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  172. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  173. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  174. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  175. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  176. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  177. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  178. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  179. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  180. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  181. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  182. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  183. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  184. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  185. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  186. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-connect +0 -0
  187. {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-session +0 -0
  188. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/WHEEL +0 -0
  189. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE-binary +0 -0
  190. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
  191. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/NOTICE-binary +0 -0
  192. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/top_level.txt +0 -0
@@ -2,9 +2,8 @@
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
4
 
5
- import re
5
+ import copy
6
6
  from dataclasses import dataclass
7
- from typing import Optional
8
7
 
9
8
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
10
9
 
@@ -16,18 +15,26 @@ from snowflake.snowpark.types import DataType
16
15
  from snowflake.snowpark_connect.column_name_handler import (
17
16
  make_column_names_snowpark_compatible,
18
17
  )
19
- from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
18
+ from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
19
+ from snowflake.snowpark_connect.dataframe_container import (
20
+ AggregateMetadata,
21
+ DataFrameContainer,
22
+ )
20
23
  from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
21
24
  from snowflake.snowpark_connect.expression.map_expression import (
22
25
  map_single_column_expression,
23
26
  )
24
27
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
25
28
  from snowflake.snowpark_connect.relation.map_relation import map_relation
29
+ from snowflake.snowpark_connect.relation.utils import (
30
+ create_pivot_column_condition,
31
+ map_pivot_value_to_spark_column_name,
32
+ )
26
33
  from snowflake.snowpark_connect.typed_column import TypedColumn
34
+ from snowflake.snowpark_connect.utils import expression_transformer
27
35
  from snowflake.snowpark_connect.utils.context import (
28
- get_is_evaluating_sql,
36
+ grouping_by_scala_udf_key,
29
37
  set_current_grouping_columns,
30
- temporary_pivot_expression,
31
38
  )
32
39
 
33
40
 
@@ -49,6 +56,20 @@ def map_group_by_aggregate(
49
56
  result = input_df_actual.group_by(*columns.grouping_expressions()).agg(
50
57
  *columns.aggregation_expressions()
51
58
  )
59
+
60
+ # Store aggregate metadata for ORDER BY resolution
61
+ aggregate_metadata = AggregateMetadata(
62
+ input_column_map=input_df_container.column_map,
63
+ input_dataframe=input_df_actual,
64
+ grouping_expressions=list(rel.aggregate.grouping_expressions),
65
+ aggregate_expressions=list(rel.aggregate.aggregate_expressions),
66
+ spark_columns=columns.spark_names(),
67
+ raw_aggregations=[
68
+ (col.spark_name, TypedColumn(col.expression, col.data_type))
69
+ for col in columns.aggregation_columns
70
+ ],
71
+ )
72
+
52
73
  return DataFrameContainer.create_with_column_mapping(
53
74
  dataframe=result,
54
75
  spark_column_names=columns.spark_names(),
@@ -56,6 +77,8 @@ def map_group_by_aggregate(
56
77
  snowpark_column_types=columns.data_types(),
57
78
  column_qualifiers=columns.get_qualifiers(),
58
79
  parent_column_name_map=input_df_container.column_map,
80
+ equivalent_snowpark_names=columns.get_equivalent_snowpark_names(),
81
+ aggregate_metadata=aggregate_metadata,
59
82
  )
60
83
 
61
84
 
@@ -77,6 +100,11 @@ def map_rollup_aggregate(
77
100
  result = input_df_actual.rollup(*columns.grouping_expressions()).agg(
78
101
  *columns.aggregation_expressions()
79
102
  )
103
+
104
+ # NOTE: Do NOT attach aggregate_metadata for ROLLUP
105
+ # Spark does not allow ORDER BY to reference pre-aggregation columns for ROLLUP
106
+ # Only regular GROUP BY supports this
107
+
80
108
  return DataFrameContainer.create_with_column_mapping(
81
109
  dataframe=result,
82
110
  spark_column_names=columns.spark_names(),
@@ -84,6 +112,7 @@ def map_rollup_aggregate(
84
112
  snowpark_column_types=columns.data_types(),
85
113
  column_qualifiers=columns.get_qualifiers(),
86
114
  parent_column_name_map=input_container.column_map,
115
+ equivalent_snowpark_names=columns.get_equivalent_snowpark_names(),
87
116
  )
88
117
 
89
118
 
@@ -105,6 +134,11 @@ def map_cube_aggregate(
105
134
  result = input_df_actual.cube(*columns.grouping_expressions()).agg(
106
135
  *columns.aggregation_expressions()
107
136
  )
137
+
138
+ # NOTE: Do NOT attach aggregate_metadata for CUBE
139
+ # Spark does not allow ORDER BY to reference pre-aggregation columns for CUBE
140
+ # Only regular GROUP BY supports this
141
+
108
142
  return DataFrameContainer.create_with_column_mapping(
109
143
  dataframe=result,
110
144
  spark_column_names=columns.spark_names(),
@@ -112,6 +146,7 @@ def map_cube_aggregate(
112
146
  snowpark_column_types=columns.data_types(),
113
147
  column_qualifiers=columns.get_qualifiers(),
114
148
  parent_column_name_map=input_container.column_map,
149
+ equivalent_snowpark_names=columns.get_equivalent_snowpark_names(),
115
150
  )
116
151
 
117
152
 
@@ -136,220 +171,111 @@ def map_pivot_aggregate(
136
171
  get_literal_field_and_name(lit)[0] for lit in rel.aggregate.pivot.values
137
172
  ]
138
173
 
139
- used_columns = {pivot_column[1].col._expression.name}
140
- if get_is_evaluating_sql():
141
- # When evaluating SQL spark doesn't trim columns from the result
142
- used_columns = {"*"}
143
- else:
144
- for expression in rel.aggregate.aggregate_expressions:
145
- matched_identifiers = re.findall(
146
- r'unparsed_identifier: "(.*)"', expression.__str__()
147
- )
148
- for identifier in matched_identifiers:
149
- mapped_col = input_container.column_map.spark_to_col.get(
150
- identifier, None
151
- )
152
- if mapped_col:
153
- used_columns.add(mapped_col[0].snowpark_name)
174
+ if not pivot_values:
175
+ distinct_col_values = (
176
+ input_df_actual.select(pivot_column[1].col)
177
+ .distinct()
178
+ .sort(snowpark_fn.asc_nulls_first(pivot_column[1].col))
179
+ .collect()
180
+ )
181
+ pivot_values = [
182
+ row[0].as_dict() if isinstance(row[0], snowpark.Row) else row[0]
183
+ for row in distinct_col_values
184
+ ]
154
185
 
155
- if len(columns.grouping_expressions()) == 0:
156
- # Snowpark doesn't support multiple aggregations in pivot without groupBy
157
- # So we need to perform each aggregation separately and then combine results
158
- if len(columns.aggregation_expressions(unalias=True)) > 1:
159
- agg_expressions = columns.aggregation_expressions(unalias=True)
160
- agg_metadata = columns.aggregation_columns
161
- num_agg_functions = len(agg_expressions)
162
-
163
- spark_names = []
164
- pivot_results = []
165
- for i, agg_expr in enumerate(agg_expressions):
166
- pivot_result = (
167
- input_df_actual.select(*used_columns)
168
- .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
169
- .agg(agg_expr)
186
+ agg_expressions = columns.aggregation_expressions(unalias=True)
187
+
188
+ spark_col_names = []
189
+ aggregations = []
190
+ final_pivot_names = []
191
+ grouping_columns_qualifiers = []
192
+ grouping_eq_snowpark_names = []
193
+
194
+ grouping_columns = columns.grouping_expressions()
195
+ if grouping_columns:
196
+ for col in grouping_columns:
197
+ snowpark_name = col.get_name()
198
+ spark_col_name = input_container.column_map.get_spark_column_name_from_snowpark_column_name(
199
+ snowpark_name
200
+ )
201
+ qualifiers = input_container.column_map.get_qualifiers_for_snowpark_column(
202
+ snowpark_name
203
+ )
204
+ grouping_columns_qualifiers.append(qualifiers)
205
+ spark_col_names.append(spark_col_name)
206
+ grouping_eq_snowpark_names.append(
207
+ input_container.column_map.get_equivalent_snowpark_names_for_snowpark_name(
208
+ snowpark_name
170
209
  )
171
- for col_name in pivot_result.columns:
172
- spark_names.append(
173
- f"{pivot_column_name(col_name)}_{agg_metadata[i].spark_name}"
174
- )
175
- pivot_results.append(pivot_result)
176
-
177
- result = pivot_results[0]
178
- for pivot_result in pivot_results[1:]:
179
- result = result.cross_join(pivot_result)
180
-
181
- pivot_columns_per_agg = len(pivot_results[0].columns)
182
- reordered_spark_names = []
183
- reordered_snowpark_names = []
184
- reordered_types = []
185
- column_selectors = []
186
-
187
- for pivot_idx in range(pivot_columns_per_agg):
188
- for agg_idx in range(num_agg_functions):
189
- current_pos = agg_idx * pivot_columns_per_agg + pivot_idx
190
- if current_pos < len(spark_names):
191
- idx = current_pos + 1 # 1-based indexing for Snowpark
192
- reordered_spark_names.append(spark_names[current_pos])
193
- reordered_snowpark_names.append(f"${idx}")
194
- reordered_types.append(
195
- result.schema.fields[current_pos].datatype
196
- )
197
- column_selectors.append(snowpark_fn.col(f"${idx}"))
198
-
199
- return DataFrameContainer.create_with_column_mapping(
200
- dataframe=result.select(*column_selectors),
201
- spark_column_names=reordered_spark_names,
202
- snowpark_column_names=reordered_snowpark_names,
203
- column_qualifiers=[[]] * len(reordered_spark_names),
204
- parent_column_name_map=input_container.column_map,
205
- snowpark_column_types=reordered_types,
206
210
  )
207
- else:
208
- result = (
209
- input_df_actual.select(*used_columns)
210
- .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
211
- .agg(*columns.aggregation_expressions(unalias=True))
211
+
212
+ for pv_value in pivot_values:
213
+ pv_value_spark, pv_is_null = map_pivot_value_to_spark_column_name(pv_value)
214
+
215
+ for i, agg_expression in enumerate(agg_expressions):
216
+ agg_fun_expr = copy.deepcopy(agg_expression._expr1)
217
+
218
+ condition = create_pivot_column_condition(
219
+ pivot_column[1].col,
220
+ pv_value,
221
+ pv_is_null,
222
+ pivot_column[1].typ if isinstance(pv_value, (list, dict)) else None,
212
223
  )
213
- else:
214
- result = (
215
- input_df_actual.group_by(*columns.grouping_expressions())
216
- .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
217
- .agg(*columns.aggregation_expressions(unalias=True))
218
- )
219
224
 
220
- agg_name_list = [c.spark_name for c in columns.grouping_columns]
225
+ expression_transformer.inject_condition_to_all_agg_functions(
226
+ agg_fun_expr, condition
227
+ )
221
228
 
222
- # Calculate number of pivot values for proper Spark-compatible indexing
223
- total_pivot_columns = len(result.columns) - len(agg_name_list)
224
- num_pivot_values = (
225
- total_pivot_columns // len(columns.aggregation_columns)
226
- if len(columns.aggregation_columns) > 0
227
- else 1
228
- )
229
+ curr_expression = Column(agg_fun_expr)
229
230
 
230
- def _get_agg_exp_alias_for_col(col_index: int) -> Optional[str]:
231
- if col_index < len(agg_name_list) or len(columns.aggregation_columns) <= 1:
232
- return None
233
- else:
234
- index = (col_index - len(agg_name_list)) // num_pivot_values
235
- return columns.aggregation_columns[index].spark_name
236
-
237
- spark_columns = []
238
- for col in [
239
- pivot_column_name(c, _get_agg_exp_alias_for_col(i))
240
- for i, c in enumerate(result.columns)
241
- ]:
242
- spark_col = (
243
- input_container.column_map.get_spark_column_name_from_snowpark_column_name(
244
- col, allow_non_exists=True
231
+ spark_col_name = (
232
+ f"{pv_value_spark}_{columns.aggregation_columns[i].spark_name}"
233
+ if len(agg_expressions) > 1
234
+ else f"{pv_value_spark}"
245
235
  )
246
- )
247
236
 
248
- if spark_col is not None:
249
- spark_columns.append(spark_col)
250
- else:
251
- # Handle NULL column names to match Spark behavior (lowercase 'null')
252
- if col == "NULL":
253
- spark_columns.append(col.lower())
254
- else:
255
- spark_columns.append(col)
256
-
257
- grouping_cols_count = len(agg_name_list)
258
- pivot_cols = result.columns[grouping_cols_count:]
259
- spark_pivot_cols = spark_columns[grouping_cols_count:]
260
-
261
- num_agg_functions = len(columns.aggregation_columns)
262
- num_pivot_values = len(pivot_cols) // num_agg_functions
263
-
264
- reordered_snowpark_cols = []
265
- reordered_spark_cols = []
266
- column_indices = [] # 1-based indexing
267
-
268
- for i in range(grouping_cols_count):
269
- reordered_snowpark_cols.append(result.columns[i])
270
- reordered_spark_cols.append(spark_columns[i])
271
- column_indices.append(i + 1)
272
-
273
- for pivot_idx in range(num_pivot_values):
274
- for agg_idx in range(num_agg_functions):
275
- current_pos = agg_idx * num_pivot_values + pivot_idx
276
- if current_pos < len(pivot_cols):
277
- reordered_snowpark_cols.append(pivot_cols[current_pos])
278
- reordered_spark_cols.append(spark_pivot_cols[current_pos])
279
- original_index = grouping_cols_count + current_pos
280
- column_indices.append(original_index + 1)
281
-
282
- reordered_result = result.select(
283
- *[snowpark_fn.col(f"${idx}") for idx in column_indices]
237
+ snowpark_col_name = make_column_names_snowpark_compatible(
238
+ [spark_col_name],
239
+ rel.common.plan_id,
240
+ len(grouping_columns) + len(agg_expressions),
241
+ )[0]
242
+
243
+ curr_expression = curr_expression.alias(snowpark_col_name)
244
+
245
+ aggregations.append(curr_expression)
246
+ spark_col_names.append(spark_col_name)
247
+ final_pivot_names.append(snowpark_col_name)
248
+
249
+ result_df = (
250
+ input_df_actual.group_by(*grouping_columns)
251
+ .agg(*aggregations)
252
+ .select(*grouping_columns, *final_pivot_names)
284
253
  )
285
254
 
286
255
  return DataFrameContainer.create_with_column_mapping(
287
- dataframe=reordered_result,
288
- spark_column_names=reordered_spark_cols,
289
- snowpark_column_names=[f"${idx}" for idx in column_indices],
290
- column_qualifiers=(
291
- columns.get_qualifiers()[: len(agg_name_list)]
292
- + [[]] * (len(reordered_spark_cols) - len(agg_name_list))
293
- ),
294
- parent_column_name_map=input_container.column_map,
256
+ dataframe=result_df,
257
+ spark_column_names=spark_col_names,
258
+ snowpark_column_names=result_df.columns,
295
259
  snowpark_column_types=[
296
- result.schema.fields[idx - 1].datatype for idx in column_indices
260
+ result_df.schema.fields[idx].datatype
261
+ for idx, _ in enumerate(result_df.columns)
297
262
  ],
263
+ column_qualifiers=grouping_columns_qualifiers
264
+ + [set() for _ in final_pivot_names],
265
+ parent_column_name_map=input_container.column_map,
266
+ equivalent_snowpark_names=grouping_eq_snowpark_names
267
+ + [set() for _ in final_pivot_names],
298
268
  )
299
269
 
300
270
 
301
- def pivot_column_name(snowpark_cname, opt_alias: Optional[str] = None) -> Optional[str]:
302
- # For values that are used as pivoted columns, the input and output are in the following format (outermost double quotes are part of the input):
303
-
304
- # 1. "'Java'" -> Java
305
- # 2. "'""C++""'" -> "C++"
306
- # 3. "'""""''Scala''""""'" -> ""'Scala'""
307
-
308
- # As we can see:
309
- # 1. the whole content is always nested in a double quote followed by a single quote ("'<content>'").
310
- # 2. the string content is nested in single quotes ('<string_content>')
311
- # 3. double quote is escased by another double quote, this is snowflake behavior
312
- # 4. if there is a single quote followed by a single quote, the first single quote needs to be preserved in the output
313
-
314
- try:
315
- # handling values that are used as pivoted columns
316
- match = re.match(r'^"\'(.*)\'"$', snowpark_cname)
317
- # extract the content between the outermost double quote followed by a single quote "'
318
- content = match.group(1)
319
- # convert the escaped double quote to the actual double quote
320
- content = content.replace('""', '"')
321
- escape_single_quote_placeholder = "__SAS_PLACEHOLDER_ESCAPE_SINGLE_QUOTE__"
322
- # replace two consecutive single quote in the content with a placeholder, the first single quote needs to be preserved
323
- content = re.sub(r"''", escape_single_quote_placeholder, content)
324
- # remove the solo single quote, they are not part of the string content
325
- content = re.sub(r"'", "", content)
326
- # replace the placeholder with the single quote which we want to preserve
327
- result = content.replace(escape_single_quote_placeholder, "'")
328
- return f"{result}_{opt_alias}" if opt_alias else result
329
- except Exception:
330
- # fallback to the original logic, handling aliased column names
331
- double_quote_list = re.findall(r'"(.*?)"', snowpark_cname)
332
- spark_string = ""
333
- for entry in list(filter(None, double_quote_list)):
334
- if "'" in entry:
335
- entry = entry.replace("'", "")
336
- if len(entry) > 0:
337
- spark_string += entry
338
- elif entry.isdigit() or re.compile(r"^\d+?\.\d+?$").match(entry):
339
- # skip quoting digits or decimal numbers as column names.
340
- spark_string += entry
341
- else:
342
- spark_string += '"' + entry + '"'
343
- return snowpark_cname if spark_string == "" else spark_string
344
-
345
-
346
271
  @dataclass(frozen=True)
347
272
  class _ColumnMetadata:
348
273
  expression: snowpark.Column
349
274
  spark_name: str
350
275
  snowpark_name: str
351
276
  data_type: DataType
352
- qualifiers: list[str]
277
+ qualifiers: set[ColumnQualifier]
278
+ equivalent_snowpark_names: set[str]
353
279
 
354
280
 
355
281
  @dataclass(frozen=True)
@@ -385,7 +311,7 @@ class _Columns:
385
311
  col.spark_name for col in self.grouping_columns + self.aggregation_columns
386
312
  ]
387
313
 
388
- def get_qualifiers(self) -> list[list[str]]:
314
+ def get_qualifiers(self) -> list[set[ColumnQualifier]]:
389
315
  return [
390
316
  col.qualifiers for col in self.grouping_columns + self.aggregation_columns
391
317
  ]
@@ -399,6 +325,12 @@ class _Columns:
399
325
  if col.data_type is not None
400
326
  ]
401
327
 
328
+ def get_equivalent_snowpark_names(self) -> list[set[str]]:
329
+ return [
330
+ col.equivalent_snowpark_names
331
+ for col in self.grouping_columns + self.aggregation_columns
332
+ ]
333
+
402
334
 
403
335
  def map_aggregate_helper(
404
336
  rel: relation_proto.Relation, pivot: bool = False, skip_alias: bool = False
@@ -413,71 +345,80 @@ def map_aggregate_helper(
413
345
  typer = ExpressionTyper(input_df)
414
346
  schema_inferrable = True
415
347
 
416
- with temporary_pivot_expression(pivot):
417
- for exp in grouping_expressions:
348
+ for exp in grouping_expressions:
349
+ with grouping_by_scala_udf_key(
350
+ exp.WhichOneof("expr_type") == "common_inline_user_defined_function"
351
+ and exp.common_inline_user_defined_function.scalar_scala_udf is not None
352
+ ):
418
353
  new_name, snowpark_column = map_single_column_expression(
419
354
  exp, input_container.column_map, typer
420
355
  )
421
- alias = make_column_names_snowpark_compatible(
422
- [new_name], rel.common.plan_id, len(groupings)
423
- )[0]
424
- groupings.append(
425
- _ColumnMetadata(
426
- snowpark_column.col
427
- if skip_alias
428
- else snowpark_column.col.alias(alias),
429
- new_name,
430
- None if skip_alias else alias,
431
- None if pivot else snowpark_column.typ,
432
- snowpark_column.get_qualifiers(),
433
- )
434
- )
435
356
 
436
- grouping_cols = [g.spark_name for g in groupings]
437
- set_current_grouping_columns(grouping_cols)
357
+ alias = make_column_names_snowpark_compatible(
358
+ [new_name], rel.common.plan_id, len(groupings)
359
+ )[0]
438
360
 
439
- for exp in expressions:
440
- new_name, snowpark_column = map_single_column_expression(
441
- exp, input_container.column_map, typer
361
+ equivalent_snowpark_names = (
362
+ input_container.column_map.get_equivalent_snowpark_names_for_snowpark_name(
363
+ snowpark_column.col.get_name()
442
364
  )
443
- alias = make_column_names_snowpark_compatible(
444
- [new_name], rel.common.plan_id, len(groupings) + len(aggregations)
445
- )[0]
365
+ )
446
366
 
447
- def type_agg_expr(
448
- agg_exp: TypedColumn, schema_inferrable: bool
449
- ) -> DataType | None:
450
- if pivot or not schema_inferrable:
451
- return None
452
- try:
453
- return agg_exp.typ
454
- except Exception:
455
- # This type used for schema inference optimization purposes.
456
- # typer may not be able to infer the type of some expressions
457
- # in that case we return None, and the optimization will not be applied.
458
- return None
459
-
460
- agg_col_typ = type_agg_expr(snowpark_column, schema_inferrable)
461
- if agg_col_typ is None:
462
- schema_inferrable = False
463
-
464
- aggregations.append(
465
- _ColumnMetadata(
466
- snowpark_column.col
467
- if skip_alias
468
- else snowpark_column.col.alias(alias),
469
- new_name,
470
- None if skip_alias else alias,
471
- agg_col_typ,
472
- [],
473
- )
367
+ groupings.append(
368
+ _ColumnMetadata(
369
+ snowpark_column.col if skip_alias else snowpark_column.col.alias(alias),
370
+ new_name,
371
+ None if skip_alias else alias,
372
+ None if pivot else snowpark_column.typ,
373
+ qualifiers=snowpark_column.get_qualifiers(),
374
+ equivalent_snowpark_names=equivalent_snowpark_names,
474
375
  )
376
+ )
377
+
378
+ grouping_cols = [g.spark_name for g in groupings]
379
+ set_current_grouping_columns(grouping_cols)
475
380
 
476
- return (
477
- input_container,
478
- _Columns(
479
- grouping_columns=groupings,
480
- aggregation_columns=aggregations,
481
- can_infer_schema=schema_inferrable,
482
- ),
381
+ for exp in expressions:
382
+ new_name, snowpark_column = map_single_column_expression(
383
+ exp, input_container.column_map, typer
384
+ )
385
+ alias = make_column_names_snowpark_compatible(
386
+ [new_name], rel.common.plan_id, len(groupings) + len(aggregations)
387
+ )[0]
388
+
389
+ def type_agg_expr(
390
+ agg_exp: TypedColumn, schema_inferrable: bool
391
+ ) -> DataType | None:
392
+ if pivot or not schema_inferrable:
393
+ return None
394
+ try:
395
+ return agg_exp.typ
396
+ except Exception:
397
+ # This type used for schema inference optimization purposes.
398
+ # typer may not be able to infer the type of some expressions
399
+ # in that case we return None, and the optimization will not be applied.
400
+ return None
401
+
402
+ agg_col_typ = type_agg_expr(snowpark_column, schema_inferrable)
403
+ if agg_col_typ is None:
404
+ schema_inferrable = False
405
+
406
+ aggregations.append(
407
+ _ColumnMetadata(
408
+ snowpark_column.col if skip_alias else snowpark_column.col.alias(alias),
409
+ new_name,
410
+ None if skip_alias else alias,
411
+ agg_col_typ,
412
+ qualifiers=set(),
413
+ equivalent_snowpark_names=set(),
414
+ )
483
415
  )
416
+
417
+ return (
418
+ input_container,
419
+ _Columns(
420
+ grouping_columns=groupings,
421
+ aggregation_columns=aggregations,
422
+ can_infer_schema=schema_inferrable,
423
+ ),
424
+ )
@@ -8,6 +8,8 @@ import pandas
8
8
  import pyspark.sql.connect.proto.catalog_pb2 as catalog_proto
9
9
 
10
10
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
11
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
12
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
11
13
  from snowflake.snowpark_connect.relation.catalogs import CATALOGS
12
14
  from snowflake.snowpark_connect.relation.catalogs.utils import (
13
15
  CURRENT_CATALOG_NAME,
@@ -148,4 +150,6 @@ def map_catalog(
148
150
  return get_current_catalog().uncacheTable(rel.uncache_table.table_name)
149
151
  case other:
150
152
  # TODO: list_function implementation is blocked on SNOW-1787268
151
- raise SnowparkConnectNotImplementedError(f"Other Relation {other}")
153
+ exception = SnowparkConnectNotImplementedError(f"Other Relation {other}")
154
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
155
+ raise exception