snowpark-connect 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (56) hide show
  1. snowflake/snowpark_connect/config.py +19 -14
  2. snowflake/snowpark_connect/error/error_utils.py +32 -0
  3. snowflake/snowpark_connect/error/exceptions.py +4 -0
  4. snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
  5. snowflake/snowpark_connect/expression/literal.py +9 -12
  6. snowflake/snowpark_connect/expression/map_cast.py +20 -4
  7. snowflake/snowpark_connect/expression/map_expression.py +8 -1
  8. snowflake/snowpark_connect/expression/map_udf.py +4 -4
  9. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
  10. snowflake/snowpark_connect/expression/map_unresolved_function.py +269 -134
  11. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
  12. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
  13. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
  14. snowflake/snowpark_connect/relation/map_aggregate.py +154 -18
  15. snowflake/snowpark_connect/relation/map_column_ops.py +59 -8
  16. snowflake/snowpark_connect/relation/map_extension.py +58 -24
  17. snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
  18. snowflake/snowpark_connect/relation/map_map_partitions.py +3 -1
  19. snowflake/snowpark_connect/relation/map_row_ops.py +30 -1
  20. snowflake/snowpark_connect/relation/map_sql.py +40 -196
  21. snowflake/snowpark_connect/relation/map_udtf.py +4 -4
  22. snowflake/snowpark_connect/relation/read/map_read.py +2 -1
  23. snowflake/snowpark_connect/relation/read/map_read_json.py +12 -1
  24. snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
  25. snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
  26. snowflake/snowpark_connect/relation/read/utils.py +7 -6
  27. snowflake/snowpark_connect/relation/utils.py +170 -1
  28. snowflake/snowpark_connect/relation/write/map_write.py +306 -87
  29. snowflake/snowpark_connect/server.py +34 -5
  30. snowflake/snowpark_connect/type_mapping.py +6 -2
  31. snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
  32. snowflake/snowpark_connect/utils/env_utils.py +55 -0
  33. snowflake/snowpark_connect/utils/session.py +21 -4
  34. snowflake/snowpark_connect/utils/telemetry.py +213 -61
  35. snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
  36. snowflake/snowpark_connect/version.py +1 -1
  37. snowflake/snowpark_decoder/__init__.py +0 -0
  38. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
  39. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
  40. snowflake/snowpark_decoder/dp_session.py +111 -0
  41. snowflake/snowpark_decoder/spark_decoder.py +76 -0
  42. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/METADATA +2 -2
  43. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/RECORD +55 -44
  44. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/top_level.txt +1 -0
  45. spark/__init__.py +0 -0
  46. spark/connect/__init__.py +0 -0
  47. spark/connect/envelope_pb2.py +31 -0
  48. spark/connect/envelope_pb2.pyi +46 -0
  49. snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
  50. {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-connect +0 -0
  51. {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-session +0 -0
  52. {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-submit +0 -0
  53. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/WHEEL +0 -0
  54. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE-binary +0 -0
  55. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE.txt +0 -0
  56. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/NOTICE-binary +0 -0
@@ -16,7 +16,7 @@ from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_rel
16
16
  from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2
17
17
 
18
18
 
19
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1csnowflake_relation_ext.proto\x12\rsnowflake.ext\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\"\xe3\x02\n\tExtension\x12(\n\x07rdd_map\x18\x01 \x01(\x0b\x32\x15.snowflake.ext.RddMapH\x00\x12.\n\nrdd_reduce\x18\x02 \x01(\x0b\x32\x18.snowflake.ext.RddReduceH\x00\x12G\n\x17subquery_column_aliases\x18\x03 \x01(\x0b\x32$.snowflake.ext.SubqueryColumnAliasesH\x00\x12\x32\n\x0clateral_join\x18\x04 \x01(\x0b\x32\x1a.snowflake.ext.LateralJoinH\x00\x12J\n\x19udtf_with_table_arguments\x18\x05 \x01(\x0b\x32%.snowflake.ext.UDTFWithTableArgumentsH\x00\x12-\n\taggregate\x18\x06 \x01(\x0b\x32\x18.snowflake.ext.AggregateH\x00\x42\x04\n\x02op\">\n\x06RddMap\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"A\n\tRddReduce\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"P\n\x15SubqueryColumnAliases\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0f\n\x07\x61liases\x18\x02 \x03(\t\"\\\n\x0bLateralJoin\x12%\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12&\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.Relation\"\x98\x01\n\x16UDTFWithTableArguments\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12,\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x39\n\x0ftable_arguments\x18\x03 \x03(\x0b\x32 .snowflake.ext.TableArgumentInfo\"`\n\x11TableArgumentInfo\x12/\n\x0etable_argument\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x1a\n\x12table_argument_idx\x18\x02 \x01(\x05\"\x92\x05\n\tAggregate\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x36\n\ngroup_type\x18\x02 \x01(\x0e\x32\".snowflake.ext.Aggregate.GroupType\x12\x37\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x38\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.Expression\x12-\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.snowflake.ext.Aggregate.Pivot\x12<\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.snowflake.ext.Aggregate.GroupingSets\x1a\x62\n\x05Pivot\x12&\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.Expression\x12\x31\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.Literal\x1a?\n\x0cGroupingSets\x12/\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.Expression\"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05\x62\x06proto3')
19
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1csnowflake_relation_ext.proto\x12\rsnowflake.ext\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\"\xe3\x02\n\tExtension\x12(\n\x07rdd_map\x18\x01 \x01(\x0b\x32\x15.snowflake.ext.RddMapH\x00\x12.\n\nrdd_reduce\x18\x02 \x01(\x0b\x32\x18.snowflake.ext.RddReduceH\x00\x12G\n\x17subquery_column_aliases\x18\x03 \x01(\x0b\x32$.snowflake.ext.SubqueryColumnAliasesH\x00\x12\x32\n\x0clateral_join\x18\x04 \x01(\x0b\x32\x1a.snowflake.ext.LateralJoinH\x00\x12J\n\x19udtf_with_table_arguments\x18\x05 \x01(\x0b\x32%.snowflake.ext.UDTFWithTableArgumentsH\x00\x12-\n\taggregate\x18\x06 \x01(\x0b\x32\x18.snowflake.ext.AggregateH\x00\x42\x04\n\x02op\">\n\x06RddMap\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"A\n\tRddReduce\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"P\n\x15SubqueryColumnAliases\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0f\n\x07\x61liases\x18\x02 \x03(\t\"\\\n\x0bLateralJoin\x12%\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12&\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.Relation\"\x98\x01\n\x16UDTFWithTableArguments\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12,\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x39\n\x0ftable_arguments\x18\x03 \x03(\x0b\x32 .snowflake.ext.TableArgumentInfo\"`\n\x11TableArgumentInfo\x12/\n\x0etable_argument\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x1a\n\x12table_argument_idx\x18\x02 \x01(\x05\"\xc7\x05\n\tAggregate\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x36\n\ngroup_type\x18\x02 \x01(\x0e\x32\".snowflake.ext.Aggregate.GroupType\x12\x37\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x38\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.Expression\x12-\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.snowflake.ext.Aggregate.Pivot\x12<\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.snowflake.ext.Aggregate.GroupingSets\x12\x33\n\x10having_condition\x18\x07 \x01(\x0b\x32\x19.spark.connect.Expression\x1a\x62\n\x05Pivot\x12&\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.Expression\x12\x31\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.Literal\x1a?\n\x0cGroupingSets\x12/\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.Expression\"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05\x62\x06proto3')
20
20
 
21
21
  _globals = globals()
22
22
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -38,11 +38,11 @@ if _descriptor._USE_C_DESCRIPTORS == False:
38
38
  _globals['_TABLEARGUMENTINFO']._serialized_start=931
39
39
  _globals['_TABLEARGUMENTINFO']._serialized_end=1027
40
40
  _globals['_AGGREGATE']._serialized_start=1030
41
- _globals['_AGGREGATE']._serialized_end=1688
42
- _globals['_AGGREGATE_PIVOT']._serialized_start=1363
43
- _globals['_AGGREGATE_PIVOT']._serialized_end=1461
44
- _globals['_AGGREGATE_GROUPINGSETS']._serialized_start=1463
45
- _globals['_AGGREGATE_GROUPINGSETS']._serialized_end=1526
46
- _globals['_AGGREGATE_GROUPTYPE']._serialized_start=1529
47
- _globals['_AGGREGATE_GROUPTYPE']._serialized_end=1688
41
+ _globals['_AGGREGATE']._serialized_end=1741
42
+ _globals['_AGGREGATE_PIVOT']._serialized_start=1416
43
+ _globals['_AGGREGATE_PIVOT']._serialized_end=1514
44
+ _globals['_AGGREGATE_GROUPINGSETS']._serialized_start=1516
45
+ _globals['_AGGREGATE_GROUPINGSETS']._serialized_end=1579
46
+ _globals['_AGGREGATE_GROUPTYPE']._serialized_start=1582
47
+ _globals['_AGGREGATE_GROUPTYPE']._serialized_end=1741
48
48
  # @@protoc_insertion_point(module_scope)
@@ -75,7 +75,7 @@ class TableArgumentInfo(_message.Message):
75
75
  def __init__(self, table_argument: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., table_argument_idx: _Optional[int] = ...) -> None: ...
76
76
 
77
77
  class Aggregate(_message.Message):
78
- __slots__ = ("input", "group_type", "grouping_expressions", "aggregate_expressions", "pivot", "grouping_sets")
78
+ __slots__ = ("input", "group_type", "grouping_expressions", "aggregate_expressions", "pivot", "grouping_sets", "having_condition")
79
79
  class GroupType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
80
80
  __slots__ = ()
81
81
  GROUP_TYPE_UNSPECIFIED: _ClassVar[Aggregate.GroupType]
@@ -108,10 +108,12 @@ class Aggregate(_message.Message):
108
108
  AGGREGATE_EXPRESSIONS_FIELD_NUMBER: _ClassVar[int]
109
109
  PIVOT_FIELD_NUMBER: _ClassVar[int]
110
110
  GROUPING_SETS_FIELD_NUMBER: _ClassVar[int]
111
+ HAVING_CONDITION_FIELD_NUMBER: _ClassVar[int]
111
112
  input: _relations_pb2.Relation
112
113
  group_type: Aggregate.GroupType
113
114
  grouping_expressions: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
114
115
  aggregate_expressions: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
115
116
  pivot: Aggregate.Pivot
116
117
  grouping_sets: _containers.RepeatedCompositeFieldContainer[Aggregate.GroupingSets]
117
- def __init__(self, input: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., group_type: _Optional[_Union[Aggregate.GroupType, str]] = ..., grouping_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., aggregate_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., pivot: _Optional[_Union[Aggregate.Pivot, _Mapping]] = ..., grouping_sets: _Optional[_Iterable[_Union[Aggregate.GroupingSets, _Mapping]]] = ...) -> None: ...
118
+ having_condition: _expressions_pb2.Expression
119
+ def __init__(self, input: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., group_type: _Optional[_Union[Aggregate.GroupType, str]] = ..., grouping_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., aggregate_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., pivot: _Optional[_Union[Aggregate.Pivot, _Mapping]] = ..., grouping_sets: _Optional[_Iterable[_Union[Aggregate.GroupingSets, _Mapping]]] = ..., having_condition: _Optional[_Union[_expressions_pb2.Expression, _Mapping]] = ...) -> None: ...
@@ -8,7 +8,10 @@ import typing
8
8
  import pandas
9
9
  import pyspark.sql.connect.proto.common_pb2 as common_proto
10
10
  import pyspark.sql.connect.proto.types_pb2 as types_proto
11
- from snowflake.core.exceptions import NotFoundError
11
+ from pyspark.sql.connect.client.core import Retrying
12
+ from snowflake.core.exceptions import APIError, NotFoundError
13
+ from snowflake.core.schema import Schema
14
+ from snowflake.core.table import Table, TableColumn
12
15
 
13
16
  from snowflake.snowpark import functions
14
17
  from snowflake.snowpark._internal.analyzer.analyzer_utils import (
@@ -22,6 +25,7 @@ from snowflake.snowpark_connect.config import (
22
25
  global_config,
23
26
  )
24
27
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
28
+ from snowflake.snowpark_connect.error.exceptions import MaxRetryExceeded
25
29
  from snowflake.snowpark_connect.relation.catalogs.abstract_spark_catalog import (
26
30
  AbstractSparkCatalog,
27
31
  _get_current_snowflake_schema,
@@ -39,6 +43,37 @@ from snowflake.snowpark_connect.utils.telemetry import (
39
43
  from snowflake.snowpark_connect.utils.udf_cache import cached_udf
40
44
 
41
45
 
46
+ def _is_retryable_api_error(e: Exception) -> bool:
47
+ """
48
+ Determine if an APIError should be retried.
49
+
50
+ Only retry on server errors, rate limiting, and transient network issues.
51
+ Don't retry on client errors like authentication, authorization, or validation failures.
52
+ """
53
+ if not isinstance(e, APIError):
54
+ return False
55
+
56
+ # Check if the error has a status_code attribute
57
+ if hasattr(e, "status_code"):
58
+ # Retry on server errors (5xx), rate limiting (429), and some client errors (400)
59
+ # 400 can be transient in some cases (like the original error trace shows)
60
+ return e.status_code in [400, 429, 500, 502, 503, 504]
61
+
62
+ # For APIErrors without explicit status codes, check the message
63
+ error_msg = str(e).lower()
64
+ retryable_patterns = [
65
+ "timeout",
66
+ "connection",
67
+ "network",
68
+ "unavailable",
69
+ "temporary",
70
+ "rate limit",
71
+ "throttle",
72
+ ]
73
+
74
+ return any(pattern in error_msg for pattern in retryable_patterns)
75
+
76
+
42
77
  def _normalize_identifier(identifier: str | None) -> str | None:
43
78
  if identifier is None:
44
79
  return None
@@ -73,10 +108,25 @@ class SnowflakeCatalog(AbstractSparkCatalog):
73
108
  )
74
109
  sp_catalog = get_or_create_snowpark_session().catalog
75
110
 
76
- dbs = sp_catalog.list_schemas(
77
- database=sf_quote(sf_database),
78
- pattern=_normalize_identifier(sf_schema),
79
- )
111
+ dbs: list[Schema] | None = None
112
+ for attempt in Retrying(
113
+ max_retries=5,
114
+ initial_backoff=100, # 100ms
115
+ max_backoff=5000, # 5 s
116
+ backoff_multiplier=2.0,
117
+ jitter=100,
118
+ min_jitter_threshold=200,
119
+ can_retry=_is_retryable_api_error,
120
+ ):
121
+ with attempt:
122
+ dbs = sp_catalog.list_schemas(
123
+ database=sf_quote(sf_database),
124
+ pattern=_normalize_identifier(sf_schema),
125
+ )
126
+ if dbs is None:
127
+ raise MaxRetryExceeded(
128
+ f"Failed to fetch databases {f'with pattern {pattern} ' if pattern is not None else ''}after all retry attempts"
129
+ )
80
130
  names: list[str] = list()
81
131
  catalogs: list[str] = list()
82
132
  descriptions: list[str | None] = list()
@@ -112,9 +162,24 @@ class SnowflakeCatalog(AbstractSparkCatalog):
112
162
  )
113
163
  sp_catalog = get_or_create_snowpark_session().catalog
114
164
 
115
- db = sp_catalog.get_schema(
116
- schema=sf_quote(sf_schema), database=sf_quote(sf_database)
117
- )
165
+ db: Schema | None = None
166
+ for attempt in Retrying(
167
+ max_retries=5,
168
+ initial_backoff=100, # 100ms
169
+ max_backoff=5000, # 5 s
170
+ backoff_multiplier=2.0,
171
+ jitter=100,
172
+ min_jitter_threshold=200,
173
+ can_retry=_is_retryable_api_error,
174
+ ):
175
+ with attempt:
176
+ db = sp_catalog.get_schema(
177
+ schema=sf_quote(sf_schema), database=sf_quote(sf_database)
178
+ )
179
+ if db is None:
180
+ raise MaxRetryExceeded(
181
+ f"Failed to fetch database {spark_dbName} after all retry attempts"
182
+ )
118
183
 
119
184
  name = unquote_if_quoted(db.name)
120
185
  return pandas.DataFrame(
@@ -241,11 +306,27 @@ class SnowflakeCatalog(AbstractSparkCatalog):
241
306
  "Calling into another catalog is not currently supported"
242
307
  )
243
308
 
244
- table = sp_catalog.get_table(
245
- database=sf_quote(sf_database),
246
- schema=sf_quote(sf_schema),
247
- table_name=sf_quote(table_name),
248
- )
309
+ table: Table | None = None
310
+ for attempt in Retrying(
311
+ max_retries=5,
312
+ initial_backoff=100, # 100ms
313
+ max_backoff=5000, # 5 s
314
+ backoff_multiplier=2.0,
315
+ jitter=100,
316
+ min_jitter_threshold=200,
317
+ can_retry=_is_retryable_api_error,
318
+ ):
319
+ with attempt:
320
+ table = sp_catalog.get_table(
321
+ database=sf_quote(sf_database),
322
+ schema=sf_quote(sf_schema),
323
+ table_name=sf_quote(table_name),
324
+ )
325
+
326
+ if table is None:
327
+ raise MaxRetryExceeded(
328
+ f"Failed to fetch table {spark_tableName} after all retry attempts"
329
+ )
249
330
 
250
331
  return pandas.DataFrame(
251
332
  {
@@ -286,6 +367,7 @@ class SnowflakeCatalog(AbstractSparkCatalog):
286
367
  ) -> pandas.DataFrame:
287
368
  """List all columns in a table/view, optionally database name filter can be provided."""
288
369
  sp_catalog = get_or_create_snowpark_session().catalog
370
+ columns: list[TableColumn] | None = None
289
371
  if spark_dbName is None:
290
372
  catalog, sf_database, sf_schema, sf_table = _process_multi_layer_identifier(
291
373
  spark_tableName
@@ -294,15 +376,39 @@ class SnowflakeCatalog(AbstractSparkCatalog):
294
376
  raise SnowparkConnectNotImplementedError(
295
377
  "Calling into another catalog is not currently supported"
296
378
  )
297
- columns = sp_catalog.list_columns(
298
- database=sf_quote(sf_database),
299
- schema=sf_quote(sf_schema),
300
- table_name=sf_quote(sf_table),
301
- )
379
+ for attempt in Retrying(
380
+ max_retries=5,
381
+ initial_backoff=100, # 100ms
382
+ max_backoff=5000, # 5 s
383
+ backoff_multiplier=2.0,
384
+ jitter=100,
385
+ min_jitter_threshold=200,
386
+ can_retry=_is_retryable_api_error,
387
+ ):
388
+ with attempt:
389
+ columns = sp_catalog.list_columns(
390
+ database=sf_quote(sf_database),
391
+ schema=sf_quote(sf_schema),
392
+ table_name=sf_quote(sf_table),
393
+ )
302
394
  else:
303
- columns = sp_catalog.list_columns(
304
- schema=sf_quote(spark_dbName),
305
- table_name=sf_quote(spark_tableName),
395
+ for attempt in Retrying(
396
+ max_retries=5,
397
+ initial_backoff=100, # 100ms
398
+ max_backoff=5000, # 5 s
399
+ backoff_multiplier=2.0,
400
+ jitter=100,
401
+ min_jitter_threshold=200,
402
+ can_retry=_is_retryable_api_error,
403
+ ):
404
+ with attempt:
405
+ columns = sp_catalog.list_columns(
406
+ schema=sf_quote(spark_dbName),
407
+ table_name=sf_quote(spark_tableName),
408
+ )
409
+ if columns is None:
410
+ raise MaxRetryExceeded(
411
+ f"Failed to fetch columns of {spark_tableName} after all retry attempts"
306
412
  )
307
413
  names: list[str] = list()
308
414
  descriptions: list[str | None] = list()
@@ -4,10 +4,14 @@
4
4
 
5
5
  import re
6
6
  from dataclasses import dataclass
7
+ from typing import Optional
7
8
 
8
9
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
9
10
 
11
+ import snowflake.snowpark.functions as snowpark_fn
10
12
  from snowflake import snowpark
13
+ from snowflake.snowpark import Column
14
+ from snowflake.snowpark._internal.analyzer.unary_expression import Alias
11
15
  from snowflake.snowpark.types import DataType
12
16
  from snowflake.snowpark_connect.column_name_handler import (
13
17
  make_column_names_snowpark_compatible,
@@ -21,6 +25,7 @@ from snowflake.snowpark_connect.expression.typer import ExpressionTyper
21
25
  from snowflake.snowpark_connect.relation.map_relation import map_relation
22
26
  from snowflake.snowpark_connect.typed_column import TypedColumn
23
27
  from snowflake.snowpark_connect.utils.context import (
28
+ get_is_evaluating_sql,
24
29
  set_current_grouping_columns,
25
30
  temporary_pivot_expression,
26
31
  )
@@ -131,19 +136,109 @@ def map_pivot_aggregate(
131
136
  get_literal_field_and_name(lit)[0] for lit in rel.aggregate.pivot.values
132
137
  ]
133
138
 
139
+ used_columns = {pivot_column[1].col._expression.name}
140
+ if get_is_evaluating_sql():
141
+ # When evaluating SQL spark doesn't trim columns from the result
142
+ used_columns = {"*"}
143
+ else:
144
+ for expression in rel.aggregate.aggregate_expressions:
145
+ matched_identifiers = re.findall(
146
+ r'unparsed_identifier: "(.*)"', expression.__str__()
147
+ )
148
+ for identifier in matched_identifiers:
149
+ mapped_col = input_container.column_map.spark_to_col.get(
150
+ identifier, None
151
+ )
152
+ if mapped_col:
153
+ used_columns.add(mapped_col[0].snowpark_name)
154
+
134
155
  if len(columns.grouping_expressions()) == 0:
135
- result = input_df_actual.pivot(
136
- pivot_column[1].col, pivot_values if pivot_values else None
137
- ).agg(*columns.aggregation_expressions())
156
+ # Snowpark doesn't support multiple aggregations in pivot without groupBy
157
+ # So we need to perform each aggregation separately and then combine results
158
+ if len(columns.aggregation_expressions(unalias=True)) > 1:
159
+ agg_expressions = columns.aggregation_expressions(unalias=True)
160
+ agg_metadata = columns.aggregation_columns
161
+ num_agg_functions = len(agg_expressions)
162
+
163
+ spark_names = []
164
+ pivot_results = []
165
+ for i, agg_expr in enumerate(agg_expressions):
166
+ pivot_result = (
167
+ input_df_actual.select(*used_columns)
168
+ .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
169
+ .agg(agg_expr)
170
+ )
171
+ for col_name in pivot_result.columns:
172
+ spark_names.append(
173
+ f"{pivot_column_name(col_name)}_{agg_metadata[i].spark_name}"
174
+ )
175
+ pivot_results.append(pivot_result)
176
+
177
+ result = pivot_results[0]
178
+ for pivot_result in pivot_results[1:]:
179
+ result = result.cross_join(pivot_result)
180
+
181
+ pivot_columns_per_agg = len(pivot_results[0].columns)
182
+ reordered_spark_names = []
183
+ reordered_snowpark_names = []
184
+ reordered_types = []
185
+ column_selectors = []
186
+
187
+ for pivot_idx in range(pivot_columns_per_agg):
188
+ for agg_idx in range(num_agg_functions):
189
+ current_pos = agg_idx * pivot_columns_per_agg + pivot_idx
190
+ if current_pos < len(spark_names):
191
+ idx = current_pos + 1 # 1-based indexing for Snowpark
192
+ reordered_spark_names.append(spark_names[current_pos])
193
+ reordered_snowpark_names.append(f"${idx}")
194
+ reordered_types.append(
195
+ result.schema.fields[current_pos].datatype
196
+ )
197
+ column_selectors.append(snowpark_fn.col(f"${idx}"))
198
+
199
+ return DataFrameContainer.create_with_column_mapping(
200
+ dataframe=result.select(*column_selectors),
201
+ spark_column_names=reordered_spark_names,
202
+ snowpark_column_names=reordered_snowpark_names,
203
+ column_qualifiers=[[]] * len(reordered_spark_names),
204
+ parent_column_name_map=input_container.column_map,
205
+ snowpark_column_types=reordered_types,
206
+ )
207
+ else:
208
+ result = (
209
+ input_df_actual.select(*used_columns)
210
+ .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
211
+ .agg(*columns.aggregation_expressions(unalias=True))
212
+ )
138
213
  else:
139
214
  result = (
140
215
  input_df_actual.group_by(*columns.grouping_expressions())
141
216
  .pivot(pivot_column[1].col, pivot_values if pivot_values else None)
142
- .agg(*columns.aggregation_expressions())
217
+ .agg(*columns.aggregation_expressions(unalias=True))
143
218
  )
144
219
 
220
+ agg_name_list = [c.spark_name for c in columns.grouping_columns]
221
+
222
+ # Calculate number of pivot values for proper Spark-compatible indexing
223
+ total_pivot_columns = len(result.columns) - len(agg_name_list)
224
+ num_pivot_values = (
225
+ total_pivot_columns // len(columns.aggregation_columns)
226
+ if len(columns.aggregation_columns) > 0
227
+ else 1
228
+ )
229
+
230
+ def _get_agg_exp_alias_for_col(col_index: int) -> Optional[str]:
231
+ if col_index < len(agg_name_list) or len(columns.aggregation_columns) <= 1:
232
+ return None
233
+ else:
234
+ index = (col_index - len(agg_name_list)) // num_pivot_values
235
+ return columns.aggregation_columns[index].spark_name
236
+
145
237
  spark_columns = []
146
- for col in [string_parser(s) for s in result.columns]:
238
+ for col in [
239
+ pivot_column_name(c, _get_agg_exp_alias_for_col(i))
240
+ for i, c in enumerate(result.columns)
241
+ ]:
147
242
  spark_col = (
148
243
  input_container.column_map.get_spark_column_name_from_snowpark_column_name(
149
244
  col, allow_non_exists=True
@@ -153,22 +248,57 @@ def map_pivot_aggregate(
153
248
  if spark_col is not None:
154
249
  spark_columns.append(spark_col)
155
250
  else:
156
- spark_columns.append(col)
251
+ # Handle NULL column names to match Spark behavior (lowercase 'null')
252
+ if col == "NULL":
253
+ spark_columns.append(col.lower())
254
+ else:
255
+ spark_columns.append(col)
256
+
257
+ grouping_cols_count = len(agg_name_list)
258
+ pivot_cols = result.columns[grouping_cols_count:]
259
+ spark_pivot_cols = spark_columns[grouping_cols_count:]
260
+
261
+ num_agg_functions = len(columns.aggregation_columns)
262
+ num_pivot_values = len(pivot_cols) // num_agg_functions
263
+
264
+ reordered_snowpark_cols = []
265
+ reordered_spark_cols = []
266
+ column_indices = [] # 1-based indexing
267
+
268
+ for i in range(grouping_cols_count):
269
+ reordered_snowpark_cols.append(result.columns[i])
270
+ reordered_spark_cols.append(spark_columns[i])
271
+ column_indices.append(i + 1)
272
+
273
+ for pivot_idx in range(num_pivot_values):
274
+ for agg_idx in range(num_agg_functions):
275
+ current_pos = agg_idx * num_pivot_values + pivot_idx
276
+ if current_pos < len(pivot_cols):
277
+ reordered_snowpark_cols.append(pivot_cols[current_pos])
278
+ reordered_spark_cols.append(spark_pivot_cols[current_pos])
279
+ original_index = grouping_cols_count + current_pos
280
+ column_indices.append(original_index + 1)
281
+
282
+ reordered_result = result.select(
283
+ *[snowpark_fn.col(f"${idx}") for idx in column_indices]
284
+ )
157
285
 
158
- agg_name_list = [c.spark_name for c in columns.grouping_columns]
159
286
  return DataFrameContainer.create_with_column_mapping(
160
- dataframe=result,
161
- spark_column_names=agg_name_list + spark_columns[len(agg_name_list) :],
162
- snowpark_column_names=result.columns,
287
+ dataframe=reordered_result,
288
+ spark_column_names=reordered_spark_cols,
289
+ snowpark_column_names=[f"${idx}" for idx in column_indices],
163
290
  column_qualifiers=(
164
291
  columns.get_qualifiers()[: len(agg_name_list)]
165
- + [[]] * (len(spark_columns) - len(agg_name_list))
292
+ + [[]] * (len(reordered_spark_cols) - len(agg_name_list))
166
293
  ),
167
294
  parent_column_name_map=input_container.column_map,
295
+ snowpark_column_types=[
296
+ result.schema.fields[idx - 1].datatype for idx in column_indices
297
+ ],
168
298
  )
169
299
 
170
300
 
171
- def string_parser(s):
301
+ def pivot_column_name(snowpark_cname, opt_alias: Optional[str] = None) -> Optional[str]:
172
302
  # For values that are used as pivoted columns, the input and output are in the following format (outermost double quotes are part of the input):
173
303
 
174
304
  # 1. "'Java'" -> Java
@@ -183,7 +313,7 @@ def string_parser(s):
183
313
 
184
314
  try:
185
315
  # handling values that are used as pivoted columns
186
- match = re.match(r'^"\'(.*)\'"$', s)
316
+ match = re.match(r'^"\'(.*)\'"$', snowpark_cname)
187
317
  # extract the content between the outermost double quote followed by a single quote "'
188
318
  content = match.group(1)
189
319
  # convert the escaped double quote to the actual double quote
@@ -195,10 +325,10 @@ def string_parser(s):
195
325
  content = re.sub(r"'", "", content)
196
326
  # replace the placeholder with the single quote which we want to preserve
197
327
  result = content.replace(escape_single_quote_placeholder, "'")
198
- return result
328
+ return f"{result}_{opt_alias}" if opt_alias else result
199
329
  except Exception:
200
330
  # fallback to the original logic, handling aliased column names
201
- double_quote_list = re.findall(r'"(.*?)"', s)
331
+ double_quote_list = re.findall(r'"(.*?)"', snowpark_cname)
202
332
  spark_string = ""
203
333
  for entry in list(filter(None, double_quote_list)):
204
334
  if "'" in entry:
@@ -210,7 +340,7 @@ def string_parser(s):
210
340
  spark_string += entry
211
341
  else:
212
342
  spark_string += '"' + entry + '"'
213
- return s if spark_string == "" else spark_string
343
+ return snowpark_cname if spark_string == "" else spark_string
214
344
 
215
345
 
216
346
  @dataclass(frozen=True)
@@ -231,8 +361,14 @@ class _Columns:
231
361
  def grouping_expressions(self) -> list[snowpark.Column]:
232
362
  return [col.expression for col in self.grouping_columns]
233
363
 
234
- def aggregation_expressions(self) -> list[snowpark.Column]:
235
- return [col.expression for col in self.aggregation_columns]
364
+ def aggregation_expressions(self, unalias: bool = False) -> list[snowpark.Column]:
365
+ def _unalias(col: snowpark.Column) -> snowpark.Column:
366
+ if unalias and hasattr(col, "_expr1") and isinstance(col._expr1, Alias):
367
+ return _unalias(Column(col._expr1.child))
368
+ else:
369
+ return col
370
+
371
+ return [_unalias(col.expression) for col in self.aggregation_columns]
236
372
 
237
373
  def expressions(self) -> list[snowpark.Column]:
238
374
  return self.grouping_expressions() + self.aggregation_expressions()
@@ -6,10 +6,12 @@ import ast
6
6
  import json
7
7
  import sys
8
8
  from collections import defaultdict
9
+ from copy import copy
9
10
 
10
11
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
11
12
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
12
13
  import pyspark.sql.connect.proto.types_pb2 as types_proto
14
+ from pyspark.errors import PySparkValueError
13
15
  from pyspark.errors.exceptions.base import AnalysisException
14
16
  from pyspark.serializers import CloudPickleSerializer
15
17
 
@@ -44,6 +46,7 @@ from snowflake.snowpark_connect.expression.typer import ExpressionTyper
44
46
  from snowflake.snowpark_connect.relation.map_relation import map_relation
45
47
  from snowflake.snowpark_connect.relation.utils import (
46
48
  TYPE_MAP_FOR_TO_SCHEMA,
49
+ can_sort_be_flattened,
47
50
  snowpark_functions_col,
48
51
  )
49
52
  from snowflake.snowpark_connect.type_mapping import (
@@ -266,6 +269,7 @@ def map_project(
266
269
 
267
270
  aliased_col = mapper.col.alias(snowpark_column)
268
271
  select_list.append(aliased_col)
272
+
269
273
  new_snowpark_columns.append(snowpark_column)
270
274
  new_spark_columns.append(spark_name)
271
275
  column_types.extend(mapper.types)
@@ -342,6 +346,12 @@ def map_sort(
342
346
 
343
347
  sort_order = sort.order
344
348
 
349
+ if not sort_order:
350
+ raise PySparkValueError(
351
+ error_class="CANNOT_BE_EMPTY",
352
+ message="At least one column must be specified.",
353
+ )
354
+
345
355
  if len(sort_order) == 1:
346
356
  parsed_col_name = split_fully_qualified_spark_name(
347
357
  sort_order[0].child.unresolved_attribute.unparsed_identifier
@@ -422,7 +432,30 @@ def map_sort(
422
432
  # TODO: sort.isglobal.
423
433
  if not order_specified:
424
434
  ascending = None
425
- result = input_df.sort(cols, ascending=ascending)
435
+
436
+ select_statement = getattr(input_df, "_select_statement", None)
437
+ sort_expressions = [c._expression for c in cols]
438
+ if (
439
+ can_sort_be_flattened(select_statement, *sort_expressions)
440
+ and input_df._ops_after_agg is None
441
+ ):
442
+ # "flattened" order by that will allow using dropped columns
443
+ new = copy(select_statement)
444
+ new.from_ = select_statement.from_.to_subqueryable()
445
+ new.pre_actions = new.from_.pre_actions
446
+ new.post_actions = new.from_.post_actions
447
+ new.order_by = sort_expressions + (select_statement.order_by or [])
448
+ new.column_states = select_statement.column_states
449
+ new._merge_projection_complexity_with_subquery = False
450
+ new.df_ast_ids = (
451
+ select_statement.df_ast_ids.copy()
452
+ if select_statement.df_ast_ids is not None
453
+ else None
454
+ )
455
+ new.attributes = select_statement.attributes
456
+ result = input_df._with_plan(new)
457
+ else:
458
+ result = input_df.sort(cols, ascending=ascending)
426
459
 
427
460
  return DataFrameContainer(
428
461
  result,
@@ -666,10 +699,29 @@ def map_with_columns_renamed(
666
699
  )
667
700
 
668
701
  # Validate for naming conflicts
669
- new_names_list = list(dict(rel.with_columns_renamed.rename_columns_map).values())
702
+ rename_map = dict(rel.with_columns_renamed.rename_columns_map)
703
+ new_names_list = list(rename_map.values())
670
704
  seen = set()
671
705
  for new_name in new_names_list:
672
- if column_map.has_spark_column(new_name):
706
+ # Check if this new name conflicts with existing columns
707
+ # But allow renaming a column to a different case version of itself
708
+ is_case_insensitive_self_rename = False
709
+ if not global_config.spark_sql_caseSensitive:
710
+ # Find the source column(s) that map to this new name
711
+ source_columns = [
712
+ old_name
713
+ for old_name, new_name_candidate in rename_map.items()
714
+ if new_name_candidate == new_name
715
+ ]
716
+ # Check if any source column is the same as new name when case-insensitive
717
+ is_case_insensitive_self_rename = any(
718
+ source_col.lower() == new_name.lower() for source_col in source_columns
719
+ )
720
+
721
+ if (
722
+ column_map.has_spark_column(new_name)
723
+ and not is_case_insensitive_self_rename
724
+ ):
673
725
  # Spark doesn't allow reusing existing names, even if the result df will not contain duplicate columns
674
726
  raise _column_exists_error(new_name)
675
727
  if (global_config.spark_sql_caseSensitive and new_name in seen) or (
@@ -1056,14 +1108,12 @@ def map_group_map(
1056
1108
  snowpark_grouping_expressions: list[snowpark.Column] = []
1057
1109
  typer = ExpressionTyper(input_df)
1058
1110
  group_name_list: list[str] = []
1059
- qualifiers = []
1060
1111
  for exp in grouping_expressions:
1061
1112
  new_name, snowpark_column = map_single_column_expression(
1062
1113
  exp, input_container.column_map, typer
1063
1114
  )
1064
1115
  snowpark_grouping_expressions.append(snowpark_column.col)
1065
1116
  group_name_list.append(new_name)
1066
- qualifiers.append(snowpark_column.get_qualifiers())
1067
1117
  if rel.group_map.func.python_udf is None:
1068
1118
  raise ValueError("group_map relation without python udf is not supported")
1069
1119
 
@@ -1105,13 +1155,14 @@ def map_group_map(
1105
1155
  result = input_df.group_by(*snowpark_grouping_expressions).apply_in_pandas(
1106
1156
  callable_func, output_type
1107
1157
  )
1108
-
1109
- qualifiers.extend([[]] * (len(result.columns) - len(group_name_list)))
1158
+ # The UDTF `apply_in_pandas` generates a new table whose output schema
1159
+ # can be entirely different from that of the input Snowpark DataFrame.
1160
+ # As a result, the output DataFrame should not use qualifiers based on the input group by columns.
1110
1161
  return DataFrameContainer.create_with_column_mapping(
1111
1162
  dataframe=result,
1112
1163
  spark_column_names=[field.name for field in output_type],
1113
1164
  snowpark_column_names=result.columns,
1114
- column_qualifiers=qualifiers,
1165
+ column_qualifiers=None,
1115
1166
  parent_column_name_map=input_container.column_map,
1116
1167
  )
1117
1168