snowpark-connect 0.27.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  3. snowflake/snowpark_connect/client/__init__.py +15 -0
  4. snowflake/snowpark_connect/client/error_utils.py +30 -0
  5. snowflake/snowpark_connect/client/exceptions.py +36 -0
  6. snowflake/snowpark_connect/client/query_results.py +90 -0
  7. snowflake/snowpark_connect/client/server.py +717 -0
  8. snowflake/snowpark_connect/client/utils/__init__.py +10 -0
  9. snowflake/snowpark_connect/client/utils/session.py +85 -0
  10. snowflake/snowpark_connect/column_name_handler.py +404 -243
  11. snowflake/snowpark_connect/column_qualifier.py +43 -0
  12. snowflake/snowpark_connect/config.py +309 -26
  13. snowflake/snowpark_connect/constants.py +2 -0
  14. snowflake/snowpark_connect/dataframe_container.py +102 -8
  15. snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
  16. snowflake/snowpark_connect/error/error_codes.py +50 -0
  17. snowflake/snowpark_connect/error/error_utils.py +172 -23
  18. snowflake/snowpark_connect/error/exceptions.py +13 -4
  19. snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
  20. snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
  21. snowflake/snowpark_connect/execute_plan/utils.py +5 -1
  22. snowflake/snowpark_connect/expression/error_utils.py +28 -0
  23. snowflake/snowpark_connect/expression/function_defaults.py +9 -2
  24. snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
  25. snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
  26. snowflake/snowpark_connect/expression/literal.py +37 -13
  27. snowflake/snowpark_connect/expression/map_cast.py +224 -15
  28. snowflake/snowpark_connect/expression/map_expression.py +80 -27
  29. snowflake/snowpark_connect/expression/map_extension.py +322 -12
  30. snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
  31. snowflake/snowpark_connect/expression/map_udf.py +86 -20
  32. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
  33. snowflake/snowpark_connect/expression/map_unresolved_function.py +2964 -829
  34. snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
  35. snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
  36. snowflake/snowpark_connect/expression/map_window_function.py +18 -3
  37. snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
  38. snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
  39. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/{spark-connect-client-jvm_2.12-3.5.6.jar → spark-connect-client-jvm_2.13-3.5.6.jar} +0 -0
  43. snowflake/snowpark_connect/includes/jars/{spark-sql_2.12-3.5.6.jar → spark-sql_2.13-3.5.6.jar} +0 -0
  44. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
  45. snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
  46. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  47. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  48. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
  49. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
  50. snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
  51. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
  52. snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
  53. snowflake/snowpark_connect/relation/io_utils.py +110 -10
  54. snowflake/snowpark_connect/relation/map_aggregate.py +239 -256
  55. snowflake/snowpark_connect/relation/map_catalog.py +5 -1
  56. snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
  57. snowflake/snowpark_connect/relation/map_extension.py +263 -29
  58. snowflake/snowpark_connect/relation/map_join.py +683 -442
  59. snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
  60. snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
  61. snowflake/snowpark_connect/relation/map_relation.py +48 -19
  62. snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
  63. snowflake/snowpark_connect/relation/map_show_string.py +13 -6
  64. snowflake/snowpark_connect/relation/map_sql.py +1233 -222
  65. snowflake/snowpark_connect/relation/map_stats.py +48 -9
  66. snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
  67. snowflake/snowpark_connect/relation/map_udtf.py +14 -4
  68. snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
  69. snowflake/snowpark_connect/relation/read/map_read.py +134 -43
  70. snowflake/snowpark_connect/relation/read/map_read_csv.py +326 -47
  71. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +21 -6
  72. snowflake/snowpark_connect/relation/read/map_read_json.py +324 -86
  73. snowflake/snowpark_connect/relation/read/map_read_parquet.py +146 -28
  74. snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
  75. snowflake/snowpark_connect/relation/read/map_read_socket.py +15 -3
  76. snowflake/snowpark_connect/relation/read/map_read_table.py +86 -6
  77. snowflake/snowpark_connect/relation/read/map_read_text.py +22 -4
  78. snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
  79. snowflake/snowpark_connect/relation/read/reader_config.py +42 -3
  80. snowflake/snowpark_connect/relation/read/utils.py +50 -5
  81. snowflake/snowpark_connect/relation/stage_locator.py +91 -55
  82. snowflake/snowpark_connect/relation/utils.py +128 -5
  83. snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
  84. snowflake/snowpark_connect/relation/write/map_write.py +929 -319
  85. snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
  86. snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
  87. snowflake/snowpark_connect/resources_initializer.py +171 -48
  88. snowflake/snowpark_connect/server.py +528 -473
  89. snowflake/snowpark_connect/server_common/__init__.py +503 -0
  90. snowflake/snowpark_connect/snowflake_session.py +65 -0
  91. snowflake/snowpark_connect/start_server.py +53 -5
  92. snowflake/snowpark_connect/type_mapping.py +349 -27
  93. snowflake/snowpark_connect/type_support.py +130 -0
  94. snowflake/snowpark_connect/typed_column.py +9 -7
  95. snowflake/snowpark_connect/utils/artifacts.py +9 -8
  96. snowflake/snowpark_connect/utils/cache.py +49 -27
  97. snowflake/snowpark_connect/utils/concurrent.py +36 -1
  98. snowflake/snowpark_connect/utils/context.py +195 -37
  99. snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
  100. snowflake/snowpark_connect/utils/env_utils.py +5 -1
  101. snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
  102. snowflake/snowpark_connect/utils/identifiers.py +137 -3
  103. snowflake/snowpark_connect/utils/io_utils.py +57 -1
  104. snowflake/snowpark_connect/utils/java_stored_procedure.py +151 -0
  105. snowflake/snowpark_connect/utils/java_udaf_utils.py +321 -0
  106. snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
  107. snowflake/snowpark_connect/utils/jvm_udf_utils.py +281 -0
  108. snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
  109. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
  110. snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
  111. snowflake/snowpark_connect/utils/profiling.py +25 -8
  112. snowflake/snowpark_connect/utils/scala_udf_utils.py +185 -340
  113. snowflake/snowpark_connect/utils/sequence.py +21 -0
  114. snowflake/snowpark_connect/utils/session.py +64 -28
  115. snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
  116. snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
  117. snowflake/snowpark_connect/utils/telemetry.py +192 -40
  118. snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
  119. snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
  120. snowflake/snowpark_connect/utils/udf_cache.py +117 -41
  121. snowflake/snowpark_connect/utils/udf_helper.py +39 -37
  122. snowflake/snowpark_connect/utils/udf_utils.py +133 -14
  123. snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
  124. snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
  125. snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
  126. snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
  127. snowflake/snowpark_connect/version.py +1 -1
  128. snowflake/snowpark_decoder/dp_session.py +6 -2
  129. snowflake/snowpark_decoder/spark_decoder.py +12 -0
  130. {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +14 -4
  131. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +16 -7
  132. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +139 -168
  133. snowflake/snowpark_connect/hidden_column.py +0 -39
  134. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  135. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  136. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  137. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  138. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  139. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  140. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  141. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  142. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  143. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  144. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  145. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  146. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  147. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  148. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  149. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  150. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  151. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  152. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  153. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  154. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  155. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  156. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  157. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  158. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  159. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  160. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  161. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  162. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  163. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  164. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  165. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  166. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  167. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  168. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  169. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  170. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  171. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  172. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  173. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  174. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  175. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  176. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  177. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  178. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  179. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  180. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  181. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  182. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  183. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  184. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  185. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  186. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  187. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  188. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  189. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  190. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  191. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  192. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  193. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  194. {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
  195. {snowpark_connect-0.27.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
  196. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
  197. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
  198. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  199. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
  200. {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
@@ -2,9 +2,8 @@
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
4
 
5
- import re
5
+ import copy
6
6
  from dataclasses import dataclass
7
- from typing import Optional
8
7
 
9
8
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
10
9
 
@@ -12,22 +11,31 @@ import snowflake.snowpark.functions as snowpark_fn
12
11
  from snowflake import snowpark
13
12
  from snowflake.snowpark import Column
14
13
  from snowflake.snowpark._internal.analyzer.unary_expression import Alias
15
- from snowflake.snowpark.types import DataType
14
+ from snowflake.snowpark.types import DataType, StructType
16
15
  from snowflake.snowpark_connect.column_name_handler import (
17
16
  make_column_names_snowpark_compatible,
17
+ make_unique_snowpark_name,
18
+ )
19
+ from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
20
+ from snowflake.snowpark_connect.dataframe_container import (
21
+ AggregateMetadata,
22
+ DataFrameContainer,
18
23
  )
19
- from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
20
24
  from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
21
25
  from snowflake.snowpark_connect.expression.map_expression import (
22
26
  map_single_column_expression,
23
27
  )
24
28
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
25
29
  from snowflake.snowpark_connect.relation.map_relation import map_relation
30
+ from snowflake.snowpark_connect.relation.utils import (
31
+ create_pivot_column_condition,
32
+ map_pivot_value_to_spark_column_name,
33
+ )
26
34
  from snowflake.snowpark_connect.typed_column import TypedColumn
35
+ from snowflake.snowpark_connect.utils import expression_transformer
27
36
  from snowflake.snowpark_connect.utils.context import (
28
- get_is_evaluating_sql,
37
+ grouping_by_scala_udf_key,
29
38
  set_current_grouping_columns,
30
- temporary_pivot_expression,
31
39
  )
32
40
 
33
41
 
@@ -49,6 +57,61 @@ def map_group_by_aggregate(
49
57
  result = input_df_actual.group_by(*columns.grouping_expressions()).agg(
50
58
  *columns.aggregation_expressions()
51
59
  )
60
+
61
+ for rel_aggregate_expression, aggregate_original_column in zip(
62
+ rel.aggregate.aggregate_expressions, columns.aggregation_columns
63
+ ):
64
+ aggregate_original_data_type = aggregate_original_column.data_type
65
+
66
+ if not (
67
+ rel_aggregate_expression.HasField("unresolved_function")
68
+ and rel_aggregate_expression.unresolved_function.function_name == "reduce"
69
+ ) or not isinstance(aggregate_original_data_type, StructType):
70
+ continue
71
+
72
+ cols = []
73
+ new_snowpark_column_names = []
74
+ new_snowpark_column_types = [
75
+ field.datatype for field in aggregate_original_data_type.fields
76
+ ]
77
+
78
+ if not result.columns or len(result.columns) != 1:
79
+ raise ValueError(
80
+ "Expected result DataFrame to have exactly one column for reduce(StructType)"
81
+ )
82
+ aggregate_col = snowpark_fn.col(result.columns[0])
83
+
84
+ # Extract each field from the StructType result after aggregation to create separate columns
85
+ for spark_col_name in input_df_container.column_map.get_spark_columns():
86
+ unique_snowpark_name = make_unique_snowpark_name(spark_col_name)
87
+ cols.append(
88
+ snowpark_fn.get(aggregate_col, snowpark_fn.lit(spark_col_name)).alias(
89
+ unique_snowpark_name
90
+ )
91
+ )
92
+ new_snowpark_column_names.append(unique_snowpark_name)
93
+
94
+ result = result.select(*cols)
95
+ return DataFrameContainer.create_with_column_mapping(
96
+ dataframe=result,
97
+ spark_column_names=input_df_container.column_map.get_spark_columns(),
98
+ snowpark_column_names=new_snowpark_column_names,
99
+ snowpark_column_types=new_snowpark_column_types,
100
+ )
101
+
102
+ # Store aggregate metadata for ORDER BY resolution
103
+ aggregate_metadata = AggregateMetadata(
104
+ input_column_map=input_df_container.column_map,
105
+ input_dataframe=input_df_actual,
106
+ grouping_expressions=list(rel.aggregate.grouping_expressions),
107
+ aggregate_expressions=list(rel.aggregate.aggregate_expressions),
108
+ spark_columns=columns.spark_names(),
109
+ raw_aggregations=[
110
+ (col.spark_name, TypedColumn(col.expression, col.data_type))
111
+ for col in columns.aggregation_columns
112
+ ],
113
+ )
114
+
52
115
  return DataFrameContainer.create_with_column_mapping(
53
116
  dataframe=result,
54
117
  spark_column_names=columns.spark_names(),
@@ -56,6 +119,8 @@ def map_group_by_aggregate(
56
119
  snowpark_column_types=columns.data_types(),
57
120
  column_qualifiers=columns.get_qualifiers(),
58
121
  parent_column_name_map=input_df_container.column_map,
122
+ equivalent_snowpark_names=columns.get_equivalent_snowpark_names(),
123
+ aggregate_metadata=aggregate_metadata,
59
124
  )
60
125
 
61
126
 
@@ -77,6 +142,11 @@ def map_rollup_aggregate(
77
142
  result = input_df_actual.rollup(*columns.grouping_expressions()).agg(
78
143
  *columns.aggregation_expressions()
79
144
  )
145
+
146
+ # NOTE: Do NOT attach aggregate_metadata for ROLLUP
147
+ # Spark does not allow ORDER BY to reference pre-aggregation columns for ROLLUP
148
+ # Only regular GROUP BY supports this
149
+
80
150
  return DataFrameContainer.create_with_column_mapping(
81
151
  dataframe=result,
82
152
  spark_column_names=columns.spark_names(),
@@ -84,6 +154,7 @@ def map_rollup_aggregate(
84
154
  snowpark_column_types=columns.data_types(),
85
155
  column_qualifiers=columns.get_qualifiers(),
86
156
  parent_column_name_map=input_container.column_map,
157
+ equivalent_snowpark_names=columns.get_equivalent_snowpark_names(),
87
158
  )
88
159
 
89
160
 
@@ -105,6 +176,11 @@ def map_cube_aggregate(
105
176
  result = input_df_actual.cube(*columns.grouping_expressions()).agg(
106
177
  *columns.aggregation_expressions()
107
178
  )
179
+
180
+ # NOTE: Do NOT attach aggregate_metadata for CUBE
181
+ # Spark does not allow ORDER BY to reference pre-aggregation columns for CUBE
182
+ # Only regular GROUP BY supports this
183
+
108
184
  return DataFrameContainer.create_with_column_mapping(
109
185
  dataframe=result,
110
186
  spark_column_names=columns.spark_names(),
@@ -112,6 +188,7 @@ def map_cube_aggregate(
112
188
  snowpark_column_types=columns.data_types(),
113
189
  column_qualifiers=columns.get_qualifiers(),
114
190
  parent_column_name_map=input_container.column_map,
191
+ equivalent_snowpark_names=columns.get_equivalent_snowpark_names(),
115
192
  )
116
193
 
117
194
 
@@ -136,220 +213,111 @@ def map_pivot_aggregate(
136
213
  get_literal_field_and_name(lit)[0] for lit in rel.aggregate.pivot.values
137
214
  ]
138
215
 
139
- used_columns = {pivot_column[1].col._expression.name}
140
- if get_is_evaluating_sql():
141
- # When evaluating SQL spark doesn't trim columns from the result
142
- used_columns = {"*"}
143
- else:
144
- for expression in rel.aggregate.aggregate_expressions:
145
- matched_identifiers = re.findall(
146
- r'unparsed_identifier: "(.*)"', expression.__str__()
147
- )
148
- for identifier in matched_identifiers:
149
- mapped_col = input_container.column_map.spark_to_col.get(
150
- identifier, None
151
- )
152
- if mapped_col:
153
- used_columns.add(mapped_col[0].snowpark_name)
216
+ if not pivot_values:
217
+ distinct_col_values = (
218
+ input_df_actual.select(pivot_column[1].col)
219
+ .distinct()
220
+ .sort(snowpark_fn.asc_nulls_first(pivot_column[1].col))
221
+ .collect()
222
+ )
223
+ pivot_values = [
224
+ row[0].as_dict() if isinstance(row[0], snowpark.Row) else row[0]
225
+ for row in distinct_col_values
226
+ ]
154
227
 
155
- if len(columns.grouping_expressions()) == 0:
156
- # Snowpark doesn't support multiple aggregations in pivot without groupBy
157
- # So we need to perform each aggregation separately and then combine results
158
- if len(columns.aggregation_expressions(unalias=True)) > 1:
159
- agg_expressions = columns.aggregation_expressions(unalias=True)
160
- agg_metadata = columns.aggregation_columns
161
- num_agg_functions = len(agg_expressions)
162
-
163
- spark_names = []
164
- pivot_results = []
165
- for i, agg_expr in enumerate(agg_expressions):
166
- pivot_result = (
167
- input_df_actual.select(*used_columns)
168
- .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
169
- .agg(agg_expr)
228
+ agg_expressions = columns.aggregation_expressions(unalias=True)
229
+
230
+ spark_col_names = []
231
+ aggregations = []
232
+ final_pivot_names = []
233
+ grouping_columns_qualifiers = []
234
+ grouping_eq_snowpark_names = []
235
+
236
+ grouping_columns = columns.grouping_expressions()
237
+ if grouping_columns:
238
+ for col in grouping_columns:
239
+ snowpark_name = col.get_name()
240
+ spark_col_name = input_container.column_map.get_spark_column_name_from_snowpark_column_name(
241
+ snowpark_name
242
+ )
243
+ qualifiers = input_container.column_map.get_qualifiers_for_snowpark_column(
244
+ snowpark_name
245
+ )
246
+ grouping_columns_qualifiers.append(qualifiers)
247
+ spark_col_names.append(spark_col_name)
248
+ grouping_eq_snowpark_names.append(
249
+ input_container.column_map.get_equivalent_snowpark_names_for_snowpark_name(
250
+ snowpark_name
170
251
  )
171
- for col_name in pivot_result.columns:
172
- spark_names.append(
173
- f"{pivot_column_name(col_name)}_{agg_metadata[i].spark_name}"
174
- )
175
- pivot_results.append(pivot_result)
176
-
177
- result = pivot_results[0]
178
- for pivot_result in pivot_results[1:]:
179
- result = result.cross_join(pivot_result)
180
-
181
- pivot_columns_per_agg = len(pivot_results[0].columns)
182
- reordered_spark_names = []
183
- reordered_snowpark_names = []
184
- reordered_types = []
185
- column_selectors = []
186
-
187
- for pivot_idx in range(pivot_columns_per_agg):
188
- for agg_idx in range(num_agg_functions):
189
- current_pos = agg_idx * pivot_columns_per_agg + pivot_idx
190
- if current_pos < len(spark_names):
191
- idx = current_pos + 1 # 1-based indexing for Snowpark
192
- reordered_spark_names.append(spark_names[current_pos])
193
- reordered_snowpark_names.append(f"${idx}")
194
- reordered_types.append(
195
- result.schema.fields[current_pos].datatype
196
- )
197
- column_selectors.append(snowpark_fn.col(f"${idx}"))
198
-
199
- return DataFrameContainer.create_with_column_mapping(
200
- dataframe=result.select(*column_selectors),
201
- spark_column_names=reordered_spark_names,
202
- snowpark_column_names=reordered_snowpark_names,
203
- column_qualifiers=[[]] * len(reordered_spark_names),
204
- parent_column_name_map=input_container.column_map,
205
- snowpark_column_types=reordered_types,
206
252
  )
207
- else:
208
- result = (
209
- input_df_actual.select(*used_columns)
210
- .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
211
- .agg(*columns.aggregation_expressions(unalias=True))
253
+
254
+ for pv_value in pivot_values:
255
+ pv_value_spark, pv_is_null = map_pivot_value_to_spark_column_name(pv_value)
256
+
257
+ for i, agg_expression in enumerate(agg_expressions):
258
+ agg_fun_expr = copy.deepcopy(agg_expression._expr1)
259
+
260
+ condition = create_pivot_column_condition(
261
+ pivot_column[1].col,
262
+ pv_value,
263
+ pv_is_null,
264
+ pivot_column[1].typ if isinstance(pv_value, (list, dict)) else None,
212
265
  )
213
- else:
214
- result = (
215
- input_df_actual.group_by(*columns.grouping_expressions())
216
- .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
217
- .agg(*columns.aggregation_expressions(unalias=True))
218
- )
219
266
 
220
- agg_name_list = [c.spark_name for c in columns.grouping_columns]
267
+ expression_transformer.inject_condition_to_all_agg_functions(
268
+ agg_fun_expr, condition
269
+ )
221
270
 
222
- # Calculate number of pivot values for proper Spark-compatible indexing
223
- total_pivot_columns = len(result.columns) - len(agg_name_list)
224
- num_pivot_values = (
225
- total_pivot_columns // len(columns.aggregation_columns)
226
- if len(columns.aggregation_columns) > 0
227
- else 1
228
- )
271
+ curr_expression = Column(agg_fun_expr)
229
272
 
230
- def _get_agg_exp_alias_for_col(col_index: int) -> Optional[str]:
231
- if col_index < len(agg_name_list) or len(columns.aggregation_columns) <= 1:
232
- return None
233
- else:
234
- index = (col_index - len(agg_name_list)) // num_pivot_values
235
- return columns.aggregation_columns[index].spark_name
236
-
237
- spark_columns = []
238
- for col in [
239
- pivot_column_name(c, _get_agg_exp_alias_for_col(i))
240
- for i, c in enumerate(result.columns)
241
- ]:
242
- spark_col = (
243
- input_container.column_map.get_spark_column_name_from_snowpark_column_name(
244
- col, allow_non_exists=True
273
+ spark_col_name = (
274
+ f"{pv_value_spark}_{columns.aggregation_columns[i].spark_name}"
275
+ if len(agg_expressions) > 1
276
+ else f"{pv_value_spark}"
245
277
  )
246
- )
247
278
 
248
- if spark_col is not None:
249
- spark_columns.append(spark_col)
250
- else:
251
- # Handle NULL column names to match Spark behavior (lowercase 'null')
252
- if col == "NULL":
253
- spark_columns.append(col.lower())
254
- else:
255
- spark_columns.append(col)
256
-
257
- grouping_cols_count = len(agg_name_list)
258
- pivot_cols = result.columns[grouping_cols_count:]
259
- spark_pivot_cols = spark_columns[grouping_cols_count:]
260
-
261
- num_agg_functions = len(columns.aggregation_columns)
262
- num_pivot_values = len(pivot_cols) // num_agg_functions
263
-
264
- reordered_snowpark_cols = []
265
- reordered_spark_cols = []
266
- column_indices = [] # 1-based indexing
267
-
268
- for i in range(grouping_cols_count):
269
- reordered_snowpark_cols.append(result.columns[i])
270
- reordered_spark_cols.append(spark_columns[i])
271
- column_indices.append(i + 1)
272
-
273
- for pivot_idx in range(num_pivot_values):
274
- for agg_idx in range(num_agg_functions):
275
- current_pos = agg_idx * num_pivot_values + pivot_idx
276
- if current_pos < len(pivot_cols):
277
- reordered_snowpark_cols.append(pivot_cols[current_pos])
278
- reordered_spark_cols.append(spark_pivot_cols[current_pos])
279
- original_index = grouping_cols_count + current_pos
280
- column_indices.append(original_index + 1)
281
-
282
- reordered_result = result.select(
283
- *[snowpark_fn.col(f"${idx}") for idx in column_indices]
279
+ snowpark_col_name = make_column_names_snowpark_compatible(
280
+ [spark_col_name],
281
+ rel.common.plan_id,
282
+ len(grouping_columns) + len(agg_expressions),
283
+ )[0]
284
+
285
+ curr_expression = curr_expression.alias(snowpark_col_name)
286
+
287
+ aggregations.append(curr_expression)
288
+ spark_col_names.append(spark_col_name)
289
+ final_pivot_names.append(snowpark_col_name)
290
+
291
+ result_df = (
292
+ input_df_actual.group_by(*grouping_columns)
293
+ .agg(*aggregations)
294
+ .select(*grouping_columns, *final_pivot_names)
284
295
  )
285
296
 
286
297
  return DataFrameContainer.create_with_column_mapping(
287
- dataframe=reordered_result,
288
- spark_column_names=reordered_spark_cols,
289
- snowpark_column_names=[f"${idx}" for idx in column_indices],
290
- column_qualifiers=(
291
- columns.get_qualifiers()[: len(agg_name_list)]
292
- + [[]] * (len(reordered_spark_cols) - len(agg_name_list))
293
- ),
294
- parent_column_name_map=input_container.column_map,
298
+ dataframe=result_df,
299
+ spark_column_names=spark_col_names,
300
+ snowpark_column_names=result_df.columns,
295
301
  snowpark_column_types=[
296
- result.schema.fields[idx - 1].datatype for idx in column_indices
302
+ result_df.schema.fields[idx].datatype
303
+ for idx, _ in enumerate(result_df.columns)
297
304
  ],
305
+ column_qualifiers=grouping_columns_qualifiers
306
+ + [set() for _ in final_pivot_names],
307
+ parent_column_name_map=input_container.column_map,
308
+ equivalent_snowpark_names=grouping_eq_snowpark_names
309
+ + [set() for _ in final_pivot_names],
298
310
  )
299
311
 
300
312
 
301
- def pivot_column_name(snowpark_cname, opt_alias: Optional[str] = None) -> Optional[str]:
302
- # For values that are used as pivoted columns, the input and output are in the following format (outermost double quotes are part of the input):
303
-
304
- # 1. "'Java'" -> Java
305
- # 2. "'""C++""'" -> "C++"
306
- # 3. "'""""''Scala''""""'" -> ""'Scala'""
307
-
308
- # As we can see:
309
- # 1. the whole content is always nested in a double quote followed by a single quote ("'<content>'").
310
- # 2. the string content is nested in single quotes ('<string_content>')
311
- # 3. double quote is escased by another double quote, this is snowflake behavior
312
- # 4. if there is a single quote followed by a single quote, the first single quote needs to be preserved in the output
313
-
314
- try:
315
- # handling values that are used as pivoted columns
316
- match = re.match(r'^"\'(.*)\'"$', snowpark_cname)
317
- # extract the content between the outermost double quote followed by a single quote "'
318
- content = match.group(1)
319
- # convert the escaped double quote to the actual double quote
320
- content = content.replace('""', '"')
321
- escape_single_quote_placeholder = "__SAS_PLACEHOLDER_ESCAPE_SINGLE_QUOTE__"
322
- # replace two consecutive single quote in the content with a placeholder, the first single quote needs to be preserved
323
- content = re.sub(r"''", escape_single_quote_placeholder, content)
324
- # remove the solo single quote, they are not part of the string content
325
- content = re.sub(r"'", "", content)
326
- # replace the placeholder with the single quote which we want to preserve
327
- result = content.replace(escape_single_quote_placeholder, "'")
328
- return f"{result}_{opt_alias}" if opt_alias else result
329
- except Exception:
330
- # fallback to the original logic, handling aliased column names
331
- double_quote_list = re.findall(r'"(.*?)"', snowpark_cname)
332
- spark_string = ""
333
- for entry in list(filter(None, double_quote_list)):
334
- if "'" in entry:
335
- entry = entry.replace("'", "")
336
- if len(entry) > 0:
337
- spark_string += entry
338
- elif entry.isdigit() or re.compile(r"^\d+?\.\d+?$").match(entry):
339
- # skip quoting digits or decimal numbers as column names.
340
- spark_string += entry
341
- else:
342
- spark_string += '"' + entry + '"'
343
- return snowpark_cname if spark_string == "" else spark_string
344
-
345
-
346
313
  @dataclass(frozen=True)
347
314
  class _ColumnMetadata:
348
315
  expression: snowpark.Column
349
316
  spark_name: str
350
317
  snowpark_name: str
351
318
  data_type: DataType
352
- qualifiers: list[str]
319
+ qualifiers: set[ColumnQualifier]
320
+ equivalent_snowpark_names: set[str]
353
321
 
354
322
 
355
323
  @dataclass(frozen=True)
@@ -385,7 +353,7 @@ class _Columns:
385
353
  col.spark_name for col in self.grouping_columns + self.aggregation_columns
386
354
  ]
387
355
 
388
- def get_qualifiers(self) -> list[list[str]]:
356
+ def get_qualifiers(self) -> list[set[ColumnQualifier]]:
389
357
  return [
390
358
  col.qualifiers for col in self.grouping_columns + self.aggregation_columns
391
359
  ]
@@ -399,6 +367,12 @@ class _Columns:
399
367
  if col.data_type is not None
400
368
  ]
401
369
 
370
+ def get_equivalent_snowpark_names(self) -> list[set[str]]:
371
+ return [
372
+ col.equivalent_snowpark_names
373
+ for col in self.grouping_columns + self.aggregation_columns
374
+ ]
375
+
402
376
 
403
377
  def map_aggregate_helper(
404
378
  rel: relation_proto.Relation, pivot: bool = False, skip_alias: bool = False
@@ -413,71 +387,80 @@ def map_aggregate_helper(
413
387
  typer = ExpressionTyper(input_df)
414
388
  schema_inferrable = True
415
389
 
416
- with temporary_pivot_expression(pivot):
417
- for exp in grouping_expressions:
390
+ for exp in grouping_expressions:
391
+ with grouping_by_scala_udf_key(
392
+ exp.WhichOneof("expr_type") == "common_inline_user_defined_function"
393
+ and exp.common_inline_user_defined_function.scalar_scala_udf is not None
394
+ ):
418
395
  new_name, snowpark_column = map_single_column_expression(
419
396
  exp, input_container.column_map, typer
420
397
  )
421
- alias = make_column_names_snowpark_compatible(
422
- [new_name], rel.common.plan_id, len(groupings)
423
- )[0]
424
- groupings.append(
425
- _ColumnMetadata(
426
- snowpark_column.col
427
- if skip_alias
428
- else snowpark_column.col.alias(alias),
429
- new_name,
430
- None if skip_alias else alias,
431
- None if pivot else snowpark_column.typ,
432
- snowpark_column.get_qualifiers(),
433
- )
434
- )
435
398
 
436
- grouping_cols = [g.spark_name for g in groupings]
437
- set_current_grouping_columns(grouping_cols)
399
+ alias = make_column_names_snowpark_compatible(
400
+ [new_name], rel.common.plan_id, len(groupings)
401
+ )[0]
438
402
 
439
- for exp in expressions:
440
- new_name, snowpark_column = map_single_column_expression(
441
- exp, input_container.column_map, typer
403
+ equivalent_snowpark_names = (
404
+ input_container.column_map.get_equivalent_snowpark_names_for_snowpark_name(
405
+ snowpark_column.col.get_name()
442
406
  )
443
- alias = make_column_names_snowpark_compatible(
444
- [new_name], rel.common.plan_id, len(groupings) + len(aggregations)
445
- )[0]
407
+ )
446
408
 
447
- def type_agg_expr(
448
- agg_exp: TypedColumn, schema_inferrable: bool
449
- ) -> DataType | None:
450
- if pivot or not schema_inferrable:
451
- return None
452
- try:
453
- return agg_exp.typ
454
- except Exception:
455
- # This type used for schema inference optimization purposes.
456
- # typer may not be able to infer the type of some expressions
457
- # in that case we return None, and the optimization will not be applied.
458
- return None
459
-
460
- agg_col_typ = type_agg_expr(snowpark_column, schema_inferrable)
461
- if agg_col_typ is None:
462
- schema_inferrable = False
463
-
464
- aggregations.append(
465
- _ColumnMetadata(
466
- snowpark_column.col
467
- if skip_alias
468
- else snowpark_column.col.alias(alias),
469
- new_name,
470
- None if skip_alias else alias,
471
- agg_col_typ,
472
- [],
473
- )
409
+ groupings.append(
410
+ _ColumnMetadata(
411
+ snowpark_column.col if skip_alias else snowpark_column.col.alias(alias),
412
+ new_name,
413
+ None if skip_alias else alias,
414
+ None if pivot else snowpark_column.typ,
415
+ qualifiers=snowpark_column.get_qualifiers(),
416
+ equivalent_snowpark_names=equivalent_snowpark_names,
474
417
  )
418
+ )
475
419
 
476
- return (
477
- input_container,
478
- _Columns(
479
- grouping_columns=groupings,
480
- aggregation_columns=aggregations,
481
- can_infer_schema=schema_inferrable,
482
- ),
420
+ grouping_cols = [g.spark_name for g in groupings]
421
+ set_current_grouping_columns(grouping_cols)
422
+
423
+ for exp in expressions:
424
+ new_name, snowpark_column = map_single_column_expression(
425
+ exp, input_container.column_map, typer
426
+ )
427
+ alias = make_column_names_snowpark_compatible(
428
+ [new_name], rel.common.plan_id, len(groupings) + len(aggregations)
429
+ )[0]
430
+
431
+ def type_agg_expr(
432
+ agg_exp: TypedColumn, schema_inferrable: bool
433
+ ) -> DataType | None:
434
+ if pivot or not schema_inferrable:
435
+ return None
436
+ try:
437
+ return agg_exp.typ
438
+ except Exception:
439
+ # This type used for schema inference optimization purposes.
440
+ # typer may not be able to infer the type of some expressions
441
+ # in that case we return None, and the optimization will not be applied.
442
+ return None
443
+
444
+ agg_col_typ = type_agg_expr(snowpark_column, schema_inferrable)
445
+ if agg_col_typ is None:
446
+ schema_inferrable = False
447
+
448
+ aggregations.append(
449
+ _ColumnMetadata(
450
+ snowpark_column.col if skip_alias else snowpark_column.col.alias(alias),
451
+ new_name,
452
+ None if skip_alias else alias,
453
+ agg_col_typ,
454
+ qualifiers=set(),
455
+ equivalent_snowpark_names=set(),
456
+ )
483
457
  )
458
+
459
+ return (
460
+ input_container,
461
+ _Columns(
462
+ grouping_columns=groupings,
463
+ aggregation_columns=aggregations,
464
+ can_infer_schema=schema_inferrable,
465
+ ),
466
+ )
@@ -8,6 +8,8 @@ import pandas
8
8
  import pyspark.sql.connect.proto.catalog_pb2 as catalog_proto
9
9
 
10
10
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
11
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
12
+ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
11
13
  from snowflake.snowpark_connect.relation.catalogs import CATALOGS
12
14
  from snowflake.snowpark_connect.relation.catalogs.utils import (
13
15
  CURRENT_CATALOG_NAME,
@@ -148,4 +150,6 @@ def map_catalog(
148
150
  return get_current_catalog().uncacheTable(rel.uncache_table.table_name)
149
151
  case other:
150
152
  # TODO: list_function implementation is blocked on SNOW-1787268
151
- raise SnowparkConnectNotImplementedError(f"Other Relation {other}")
153
+ exception = SnowparkConnectNotImplementedError(f"Other Relation {other}")
154
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
155
+ raise exception