snowpark-connect 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +19 -14
- snowflake/snowpark_connect/error/error_utils.py +32 -0
- snowflake/snowpark_connect/error/exceptions.py +4 -0
- snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
- snowflake/snowpark_connect/expression/literal.py +9 -12
- snowflake/snowpark_connect/expression/map_cast.py +20 -4
- snowflake/snowpark_connect/expression/map_expression.py +8 -1
- snowflake/snowpark_connect/expression/map_udf.py +4 -4
- snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
- snowflake/snowpark_connect/expression/map_unresolved_function.py +269 -134
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
- snowflake/snowpark_connect/relation/map_aggregate.py +154 -18
- snowflake/snowpark_connect/relation/map_column_ops.py +59 -8
- snowflake/snowpark_connect/relation/map_extension.py +58 -24
- snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +3 -1
- snowflake/snowpark_connect/relation/map_row_ops.py +30 -1
- snowflake/snowpark_connect/relation/map_sql.py +40 -196
- snowflake/snowpark_connect/relation/map_udtf.py +4 -4
- snowflake/snowpark_connect/relation/read/map_read.py +2 -1
- snowflake/snowpark_connect/relation/read/map_read_json.py +12 -1
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
- snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
- snowflake/snowpark_connect/relation/read/utils.py +7 -6
- snowflake/snowpark_connect/relation/utils.py +170 -1
- snowflake/snowpark_connect/relation/write/map_write.py +306 -87
- snowflake/snowpark_connect/server.py +34 -5
- snowflake/snowpark_connect/type_mapping.py +6 -2
- snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
- snowflake/snowpark_connect/utils/env_utils.py +55 -0
- snowflake/snowpark_connect/utils/session.py +21 -4
- snowflake/snowpark_connect/utils/telemetry.py +213 -61
- snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/__init__.py +0 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
- snowflake/snowpark_decoder/dp_session.py +111 -0
- snowflake/snowpark_decoder/spark_decoder.py +76 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/METADATA +2 -2
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/RECORD +55 -44
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/top_level.txt +1 -0
- spark/__init__.py +0 -0
- spark/connect/__init__.py +0 -0
- spark/connect/envelope_pb2.py +31 -0
- spark/connect/envelope_pb2.pyi +46 -0
- snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/NOTICE-binary +0 -0
|
@@ -16,7 +16,7 @@ from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_rel
|
|
|
16
16
|
from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1csnowflake_relation_ext.proto\x12\rsnowflake.ext\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\"\xe3\x02\n\tExtension\x12(\n\x07rdd_map\x18\x01 \x01(\x0b\x32\x15.snowflake.ext.RddMapH\x00\x12.\n\nrdd_reduce\x18\x02 \x01(\x0b\x32\x18.snowflake.ext.RddReduceH\x00\x12G\n\x17subquery_column_aliases\x18\x03 \x01(\x0b\x32$.snowflake.ext.SubqueryColumnAliasesH\x00\x12\x32\n\x0clateral_join\x18\x04 \x01(\x0b\x32\x1a.snowflake.ext.LateralJoinH\x00\x12J\n\x19udtf_with_table_arguments\x18\x05 \x01(\x0b\x32%.snowflake.ext.UDTFWithTableArgumentsH\x00\x12-\n\taggregate\x18\x06 \x01(\x0b\x32\x18.snowflake.ext.AggregateH\x00\x42\x04\n\x02op\">\n\x06RddMap\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"A\n\tRddReduce\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"P\n\x15SubqueryColumnAliases\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0f\n\x07\x61liases\x18\x02 \x03(\t\"\\\n\x0bLateralJoin\x12%\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12&\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.Relation\"\x98\x01\n\x16UDTFWithTableArguments\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12,\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x39\n\x0ftable_arguments\x18\x03 \x03(\x0b\x32 .snowflake.ext.TableArgumentInfo\"`\n\x11TableArgumentInfo\x12/\n\x0etable_argument\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x1a\n\x12table_argument_idx\x18\x02 \x01(\x05\"\
|
|
19
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1csnowflake_relation_ext.proto\x12\rsnowflake.ext\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\"\xe3\x02\n\tExtension\x12(\n\x07rdd_map\x18\x01 \x01(\x0b\x32\x15.snowflake.ext.RddMapH\x00\x12.\n\nrdd_reduce\x18\x02 \x01(\x0b\x32\x18.snowflake.ext.RddReduceH\x00\x12G\n\x17subquery_column_aliases\x18\x03 \x01(\x0b\x32$.snowflake.ext.SubqueryColumnAliasesH\x00\x12\x32\n\x0clateral_join\x18\x04 \x01(\x0b\x32\x1a.snowflake.ext.LateralJoinH\x00\x12J\n\x19udtf_with_table_arguments\x18\x05 \x01(\x0b\x32%.snowflake.ext.UDTFWithTableArgumentsH\x00\x12-\n\taggregate\x18\x06 \x01(\x0b\x32\x18.snowflake.ext.AggregateH\x00\x42\x04\n\x02op\">\n\x06RddMap\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"A\n\tRddReduce\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"P\n\x15SubqueryColumnAliases\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0f\n\x07\x61liases\x18\x02 \x03(\t\"\\\n\x0bLateralJoin\x12%\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12&\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.Relation\"\x98\x01\n\x16UDTFWithTableArguments\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12,\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x39\n\x0ftable_arguments\x18\x03 \x03(\x0b\x32 .snowflake.ext.TableArgumentInfo\"`\n\x11TableArgumentInfo\x12/\n\x0etable_argument\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x1a\n\x12table_argument_idx\x18\x02 \x01(\x05\"\xc7\x05\n\tAggregate\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x36\n\ngroup_type\x18\x02 \x01(\x0e\x32\".snowflake.ext.Aggregate.GroupType\x12\x37\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x38\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.Expression\x12-\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.snowflake.ext.Aggregate.Pivot\x12<\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.snowflake.ext.Aggregate.GroupingSets\x12\x33\n\x10having_condition\x18\x07 \x01(\x0b\x32\x19.spark.connect.Expression\x1a\x62\n\x05Pivot\x12&\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.Expression\x12\x31\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.Literal\x1a?\n\x0cGroupingSets\x12/\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.Expression\"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05\x62\x06proto3')
|
|
20
20
|
|
|
21
21
|
_globals = globals()
|
|
22
22
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -38,11 +38,11 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
38
38
|
_globals['_TABLEARGUMENTINFO']._serialized_start=931
|
|
39
39
|
_globals['_TABLEARGUMENTINFO']._serialized_end=1027
|
|
40
40
|
_globals['_AGGREGATE']._serialized_start=1030
|
|
41
|
-
_globals['_AGGREGATE']._serialized_end=
|
|
42
|
-
_globals['_AGGREGATE_PIVOT']._serialized_start=
|
|
43
|
-
_globals['_AGGREGATE_PIVOT']._serialized_end=
|
|
44
|
-
_globals['_AGGREGATE_GROUPINGSETS']._serialized_start=
|
|
45
|
-
_globals['_AGGREGATE_GROUPINGSETS']._serialized_end=
|
|
46
|
-
_globals['_AGGREGATE_GROUPTYPE']._serialized_start=
|
|
47
|
-
_globals['_AGGREGATE_GROUPTYPE']._serialized_end=
|
|
41
|
+
_globals['_AGGREGATE']._serialized_end=1741
|
|
42
|
+
_globals['_AGGREGATE_PIVOT']._serialized_start=1416
|
|
43
|
+
_globals['_AGGREGATE_PIVOT']._serialized_end=1514
|
|
44
|
+
_globals['_AGGREGATE_GROUPINGSETS']._serialized_start=1516
|
|
45
|
+
_globals['_AGGREGATE_GROUPINGSETS']._serialized_end=1579
|
|
46
|
+
_globals['_AGGREGATE_GROUPTYPE']._serialized_start=1582
|
|
47
|
+
_globals['_AGGREGATE_GROUPTYPE']._serialized_end=1741
|
|
48
48
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -75,7 +75,7 @@ class TableArgumentInfo(_message.Message):
|
|
|
75
75
|
def __init__(self, table_argument: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., table_argument_idx: _Optional[int] = ...) -> None: ...
|
|
76
76
|
|
|
77
77
|
class Aggregate(_message.Message):
|
|
78
|
-
__slots__ = ("input", "group_type", "grouping_expressions", "aggregate_expressions", "pivot", "grouping_sets")
|
|
78
|
+
__slots__ = ("input", "group_type", "grouping_expressions", "aggregate_expressions", "pivot", "grouping_sets", "having_condition")
|
|
79
79
|
class GroupType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
|
80
80
|
__slots__ = ()
|
|
81
81
|
GROUP_TYPE_UNSPECIFIED: _ClassVar[Aggregate.GroupType]
|
|
@@ -108,10 +108,12 @@ class Aggregate(_message.Message):
|
|
|
108
108
|
AGGREGATE_EXPRESSIONS_FIELD_NUMBER: _ClassVar[int]
|
|
109
109
|
PIVOT_FIELD_NUMBER: _ClassVar[int]
|
|
110
110
|
GROUPING_SETS_FIELD_NUMBER: _ClassVar[int]
|
|
111
|
+
HAVING_CONDITION_FIELD_NUMBER: _ClassVar[int]
|
|
111
112
|
input: _relations_pb2.Relation
|
|
112
113
|
group_type: Aggregate.GroupType
|
|
113
114
|
grouping_expressions: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
|
|
114
115
|
aggregate_expressions: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
|
|
115
116
|
pivot: Aggregate.Pivot
|
|
116
117
|
grouping_sets: _containers.RepeatedCompositeFieldContainer[Aggregate.GroupingSets]
|
|
117
|
-
|
|
118
|
+
having_condition: _expressions_pb2.Expression
|
|
119
|
+
def __init__(self, input: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., group_type: _Optional[_Union[Aggregate.GroupType, str]] = ..., grouping_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., aggregate_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., pivot: _Optional[_Union[Aggregate.Pivot, _Mapping]] = ..., grouping_sets: _Optional[_Iterable[_Union[Aggregate.GroupingSets, _Mapping]]] = ..., having_condition: _Optional[_Union[_expressions_pb2.Expression, _Mapping]] = ...) -> None: ...
|
|
@@ -8,7 +8,10 @@ import typing
|
|
|
8
8
|
import pandas
|
|
9
9
|
import pyspark.sql.connect.proto.common_pb2 as common_proto
|
|
10
10
|
import pyspark.sql.connect.proto.types_pb2 as types_proto
|
|
11
|
-
from
|
|
11
|
+
from pyspark.sql.connect.client.core import Retrying
|
|
12
|
+
from snowflake.core.exceptions import APIError, NotFoundError
|
|
13
|
+
from snowflake.core.schema import Schema
|
|
14
|
+
from snowflake.core.table import Table, TableColumn
|
|
12
15
|
|
|
13
16
|
from snowflake.snowpark import functions
|
|
14
17
|
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
@@ -22,6 +25,7 @@ from snowflake.snowpark_connect.config import (
|
|
|
22
25
|
global_config,
|
|
23
26
|
)
|
|
24
27
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
28
|
+
from snowflake.snowpark_connect.error.exceptions import MaxRetryExceeded
|
|
25
29
|
from snowflake.snowpark_connect.relation.catalogs.abstract_spark_catalog import (
|
|
26
30
|
AbstractSparkCatalog,
|
|
27
31
|
_get_current_snowflake_schema,
|
|
@@ -39,6 +43,37 @@ from snowflake.snowpark_connect.utils.telemetry import (
|
|
|
39
43
|
from snowflake.snowpark_connect.utils.udf_cache import cached_udf
|
|
40
44
|
|
|
41
45
|
|
|
46
|
+
def _is_retryable_api_error(e: Exception) -> bool:
|
|
47
|
+
"""
|
|
48
|
+
Determine if an APIError should be retried.
|
|
49
|
+
|
|
50
|
+
Only retry on server errors, rate limiting, and transient network issues.
|
|
51
|
+
Don't retry on client errors like authentication, authorization, or validation failures.
|
|
52
|
+
"""
|
|
53
|
+
if not isinstance(e, APIError):
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
# Check if the error has a status_code attribute
|
|
57
|
+
if hasattr(e, "status_code"):
|
|
58
|
+
# Retry on server errors (5xx), rate limiting (429), and some client errors (400)
|
|
59
|
+
# 400 can be transient in some cases (like the original error trace shows)
|
|
60
|
+
return e.status_code in [400, 429, 500, 502, 503, 504]
|
|
61
|
+
|
|
62
|
+
# For APIErrors without explicit status codes, check the message
|
|
63
|
+
error_msg = str(e).lower()
|
|
64
|
+
retryable_patterns = [
|
|
65
|
+
"timeout",
|
|
66
|
+
"connection",
|
|
67
|
+
"network",
|
|
68
|
+
"unavailable",
|
|
69
|
+
"temporary",
|
|
70
|
+
"rate limit",
|
|
71
|
+
"throttle",
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
return any(pattern in error_msg for pattern in retryable_patterns)
|
|
75
|
+
|
|
76
|
+
|
|
42
77
|
def _normalize_identifier(identifier: str | None) -> str | None:
|
|
43
78
|
if identifier is None:
|
|
44
79
|
return None
|
|
@@ -73,10 +108,25 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
73
108
|
)
|
|
74
109
|
sp_catalog = get_or_create_snowpark_session().catalog
|
|
75
110
|
|
|
76
|
-
dbs =
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
111
|
+
dbs: list[Schema] | None = None
|
|
112
|
+
for attempt in Retrying(
|
|
113
|
+
max_retries=5,
|
|
114
|
+
initial_backoff=100, # 100ms
|
|
115
|
+
max_backoff=5000, # 5 s
|
|
116
|
+
backoff_multiplier=2.0,
|
|
117
|
+
jitter=100,
|
|
118
|
+
min_jitter_threshold=200,
|
|
119
|
+
can_retry=_is_retryable_api_error,
|
|
120
|
+
):
|
|
121
|
+
with attempt:
|
|
122
|
+
dbs = sp_catalog.list_schemas(
|
|
123
|
+
database=sf_quote(sf_database),
|
|
124
|
+
pattern=_normalize_identifier(sf_schema),
|
|
125
|
+
)
|
|
126
|
+
if dbs is None:
|
|
127
|
+
raise MaxRetryExceeded(
|
|
128
|
+
f"Failed to fetch databases {f'with pattern {pattern} ' if pattern is not None else ''}after all retry attempts"
|
|
129
|
+
)
|
|
80
130
|
names: list[str] = list()
|
|
81
131
|
catalogs: list[str] = list()
|
|
82
132
|
descriptions: list[str | None] = list()
|
|
@@ -112,9 +162,24 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
112
162
|
)
|
|
113
163
|
sp_catalog = get_or_create_snowpark_session().catalog
|
|
114
164
|
|
|
115
|
-
db =
|
|
116
|
-
|
|
117
|
-
|
|
165
|
+
db: Schema | None = None
|
|
166
|
+
for attempt in Retrying(
|
|
167
|
+
max_retries=5,
|
|
168
|
+
initial_backoff=100, # 100ms
|
|
169
|
+
max_backoff=5000, # 5 s
|
|
170
|
+
backoff_multiplier=2.0,
|
|
171
|
+
jitter=100,
|
|
172
|
+
min_jitter_threshold=200,
|
|
173
|
+
can_retry=_is_retryable_api_error,
|
|
174
|
+
):
|
|
175
|
+
with attempt:
|
|
176
|
+
db = sp_catalog.get_schema(
|
|
177
|
+
schema=sf_quote(sf_schema), database=sf_quote(sf_database)
|
|
178
|
+
)
|
|
179
|
+
if db is None:
|
|
180
|
+
raise MaxRetryExceeded(
|
|
181
|
+
f"Failed to fetch database {spark_dbName} after all retry attempts"
|
|
182
|
+
)
|
|
118
183
|
|
|
119
184
|
name = unquote_if_quoted(db.name)
|
|
120
185
|
return pandas.DataFrame(
|
|
@@ -241,11 +306,27 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
241
306
|
"Calling into another catalog is not currently supported"
|
|
242
307
|
)
|
|
243
308
|
|
|
244
|
-
table =
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
309
|
+
table: Table | None = None
|
|
310
|
+
for attempt in Retrying(
|
|
311
|
+
max_retries=5,
|
|
312
|
+
initial_backoff=100, # 100ms
|
|
313
|
+
max_backoff=5000, # 5 s
|
|
314
|
+
backoff_multiplier=2.0,
|
|
315
|
+
jitter=100,
|
|
316
|
+
min_jitter_threshold=200,
|
|
317
|
+
can_retry=_is_retryable_api_error,
|
|
318
|
+
):
|
|
319
|
+
with attempt:
|
|
320
|
+
table = sp_catalog.get_table(
|
|
321
|
+
database=sf_quote(sf_database),
|
|
322
|
+
schema=sf_quote(sf_schema),
|
|
323
|
+
table_name=sf_quote(table_name),
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
if table is None:
|
|
327
|
+
raise MaxRetryExceeded(
|
|
328
|
+
f"Failed to fetch table {spark_tableName} after all retry attempts"
|
|
329
|
+
)
|
|
249
330
|
|
|
250
331
|
return pandas.DataFrame(
|
|
251
332
|
{
|
|
@@ -286,6 +367,7 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
286
367
|
) -> pandas.DataFrame:
|
|
287
368
|
"""List all columns in a table/view, optionally database name filter can be provided."""
|
|
288
369
|
sp_catalog = get_or_create_snowpark_session().catalog
|
|
370
|
+
columns: list[TableColumn] | None = None
|
|
289
371
|
if spark_dbName is None:
|
|
290
372
|
catalog, sf_database, sf_schema, sf_table = _process_multi_layer_identifier(
|
|
291
373
|
spark_tableName
|
|
@@ -294,15 +376,39 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
294
376
|
raise SnowparkConnectNotImplementedError(
|
|
295
377
|
"Calling into another catalog is not currently supported"
|
|
296
378
|
)
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
379
|
+
for attempt in Retrying(
|
|
380
|
+
max_retries=5,
|
|
381
|
+
initial_backoff=100, # 100ms
|
|
382
|
+
max_backoff=5000, # 5 s
|
|
383
|
+
backoff_multiplier=2.0,
|
|
384
|
+
jitter=100,
|
|
385
|
+
min_jitter_threshold=200,
|
|
386
|
+
can_retry=_is_retryable_api_error,
|
|
387
|
+
):
|
|
388
|
+
with attempt:
|
|
389
|
+
columns = sp_catalog.list_columns(
|
|
390
|
+
database=sf_quote(sf_database),
|
|
391
|
+
schema=sf_quote(sf_schema),
|
|
392
|
+
table_name=sf_quote(sf_table),
|
|
393
|
+
)
|
|
302
394
|
else:
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
395
|
+
for attempt in Retrying(
|
|
396
|
+
max_retries=5,
|
|
397
|
+
initial_backoff=100, # 100ms
|
|
398
|
+
max_backoff=5000, # 5 s
|
|
399
|
+
backoff_multiplier=2.0,
|
|
400
|
+
jitter=100,
|
|
401
|
+
min_jitter_threshold=200,
|
|
402
|
+
can_retry=_is_retryable_api_error,
|
|
403
|
+
):
|
|
404
|
+
with attempt:
|
|
405
|
+
columns = sp_catalog.list_columns(
|
|
406
|
+
schema=sf_quote(spark_dbName),
|
|
407
|
+
table_name=sf_quote(spark_tableName),
|
|
408
|
+
)
|
|
409
|
+
if columns is None:
|
|
410
|
+
raise MaxRetryExceeded(
|
|
411
|
+
f"Failed to fetch columns of {spark_tableName} after all retry attempts"
|
|
306
412
|
)
|
|
307
413
|
names: list[str] = list()
|
|
308
414
|
descriptions: list[str | None] = list()
|
|
@@ -4,10 +4,14 @@
|
|
|
4
4
|
|
|
5
5
|
import re
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
+
from typing import Optional
|
|
7
8
|
|
|
8
9
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
9
10
|
|
|
11
|
+
import snowflake.snowpark.functions as snowpark_fn
|
|
10
12
|
from snowflake import snowpark
|
|
13
|
+
from snowflake.snowpark import Column
|
|
14
|
+
from snowflake.snowpark._internal.analyzer.unary_expression import Alias
|
|
11
15
|
from snowflake.snowpark.types import DataType
|
|
12
16
|
from snowflake.snowpark_connect.column_name_handler import (
|
|
13
17
|
make_column_names_snowpark_compatible,
|
|
@@ -21,6 +25,7 @@ from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
|
21
25
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
22
26
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
23
27
|
from snowflake.snowpark_connect.utils.context import (
|
|
28
|
+
get_is_evaluating_sql,
|
|
24
29
|
set_current_grouping_columns,
|
|
25
30
|
temporary_pivot_expression,
|
|
26
31
|
)
|
|
@@ -131,19 +136,109 @@ def map_pivot_aggregate(
|
|
|
131
136
|
get_literal_field_and_name(lit)[0] for lit in rel.aggregate.pivot.values
|
|
132
137
|
]
|
|
133
138
|
|
|
139
|
+
used_columns = {pivot_column[1].col._expression.name}
|
|
140
|
+
if get_is_evaluating_sql():
|
|
141
|
+
# When evaluating SQL spark doesn't trim columns from the result
|
|
142
|
+
used_columns = {"*"}
|
|
143
|
+
else:
|
|
144
|
+
for expression in rel.aggregate.aggregate_expressions:
|
|
145
|
+
matched_identifiers = re.findall(
|
|
146
|
+
r'unparsed_identifier: "(.*)"', expression.__str__()
|
|
147
|
+
)
|
|
148
|
+
for identifier in matched_identifiers:
|
|
149
|
+
mapped_col = input_container.column_map.spark_to_col.get(
|
|
150
|
+
identifier, None
|
|
151
|
+
)
|
|
152
|
+
if mapped_col:
|
|
153
|
+
used_columns.add(mapped_col[0].snowpark_name)
|
|
154
|
+
|
|
134
155
|
if len(columns.grouping_expressions()) == 0:
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
156
|
+
# Snowpark doesn't support multiple aggregations in pivot without groupBy
|
|
157
|
+
# So we need to perform each aggregation separately and then combine results
|
|
158
|
+
if len(columns.aggregation_expressions(unalias=True)) > 1:
|
|
159
|
+
agg_expressions = columns.aggregation_expressions(unalias=True)
|
|
160
|
+
agg_metadata = columns.aggregation_columns
|
|
161
|
+
num_agg_functions = len(agg_expressions)
|
|
162
|
+
|
|
163
|
+
spark_names = []
|
|
164
|
+
pivot_results = []
|
|
165
|
+
for i, agg_expr in enumerate(agg_expressions):
|
|
166
|
+
pivot_result = (
|
|
167
|
+
input_df_actual.select(*used_columns)
|
|
168
|
+
.pivot(pivot_column[1].col, pivot_values if pivot_values else None)
|
|
169
|
+
.agg(agg_expr)
|
|
170
|
+
)
|
|
171
|
+
for col_name in pivot_result.columns:
|
|
172
|
+
spark_names.append(
|
|
173
|
+
f"{pivot_column_name(col_name)}_{agg_metadata[i].spark_name}"
|
|
174
|
+
)
|
|
175
|
+
pivot_results.append(pivot_result)
|
|
176
|
+
|
|
177
|
+
result = pivot_results[0]
|
|
178
|
+
for pivot_result in pivot_results[1:]:
|
|
179
|
+
result = result.cross_join(pivot_result)
|
|
180
|
+
|
|
181
|
+
pivot_columns_per_agg = len(pivot_results[0].columns)
|
|
182
|
+
reordered_spark_names = []
|
|
183
|
+
reordered_snowpark_names = []
|
|
184
|
+
reordered_types = []
|
|
185
|
+
column_selectors = []
|
|
186
|
+
|
|
187
|
+
for pivot_idx in range(pivot_columns_per_agg):
|
|
188
|
+
for agg_idx in range(num_agg_functions):
|
|
189
|
+
current_pos = agg_idx * pivot_columns_per_agg + pivot_idx
|
|
190
|
+
if current_pos < len(spark_names):
|
|
191
|
+
idx = current_pos + 1 # 1-based indexing for Snowpark
|
|
192
|
+
reordered_spark_names.append(spark_names[current_pos])
|
|
193
|
+
reordered_snowpark_names.append(f"${idx}")
|
|
194
|
+
reordered_types.append(
|
|
195
|
+
result.schema.fields[current_pos].datatype
|
|
196
|
+
)
|
|
197
|
+
column_selectors.append(snowpark_fn.col(f"${idx}"))
|
|
198
|
+
|
|
199
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
200
|
+
dataframe=result.select(*column_selectors),
|
|
201
|
+
spark_column_names=reordered_spark_names,
|
|
202
|
+
snowpark_column_names=reordered_snowpark_names,
|
|
203
|
+
column_qualifiers=[[]] * len(reordered_spark_names),
|
|
204
|
+
parent_column_name_map=input_container.column_map,
|
|
205
|
+
snowpark_column_types=reordered_types,
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
result = (
|
|
209
|
+
input_df_actual.select(*used_columns)
|
|
210
|
+
.pivot(pivot_column[1].col, pivot_values if pivot_values else None)
|
|
211
|
+
.agg(*columns.aggregation_expressions(unalias=True))
|
|
212
|
+
)
|
|
138
213
|
else:
|
|
139
214
|
result = (
|
|
140
215
|
input_df_actual.group_by(*columns.grouping_expressions())
|
|
141
216
|
.pivot(pivot_column[1].col, pivot_values if pivot_values else None)
|
|
142
|
-
.agg(*columns.aggregation_expressions())
|
|
217
|
+
.agg(*columns.aggregation_expressions(unalias=True))
|
|
143
218
|
)
|
|
144
219
|
|
|
220
|
+
agg_name_list = [c.spark_name for c in columns.grouping_columns]
|
|
221
|
+
|
|
222
|
+
# Calculate number of pivot values for proper Spark-compatible indexing
|
|
223
|
+
total_pivot_columns = len(result.columns) - len(agg_name_list)
|
|
224
|
+
num_pivot_values = (
|
|
225
|
+
total_pivot_columns // len(columns.aggregation_columns)
|
|
226
|
+
if len(columns.aggregation_columns) > 0
|
|
227
|
+
else 1
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def _get_agg_exp_alias_for_col(col_index: int) -> Optional[str]:
|
|
231
|
+
if col_index < len(agg_name_list) or len(columns.aggregation_columns) <= 1:
|
|
232
|
+
return None
|
|
233
|
+
else:
|
|
234
|
+
index = (col_index - len(agg_name_list)) // num_pivot_values
|
|
235
|
+
return columns.aggregation_columns[index].spark_name
|
|
236
|
+
|
|
145
237
|
spark_columns = []
|
|
146
|
-
for col in [
|
|
238
|
+
for col in [
|
|
239
|
+
pivot_column_name(c, _get_agg_exp_alias_for_col(i))
|
|
240
|
+
for i, c in enumerate(result.columns)
|
|
241
|
+
]:
|
|
147
242
|
spark_col = (
|
|
148
243
|
input_container.column_map.get_spark_column_name_from_snowpark_column_name(
|
|
149
244
|
col, allow_non_exists=True
|
|
@@ -153,22 +248,57 @@ def map_pivot_aggregate(
|
|
|
153
248
|
if spark_col is not None:
|
|
154
249
|
spark_columns.append(spark_col)
|
|
155
250
|
else:
|
|
156
|
-
|
|
251
|
+
# Handle NULL column names to match Spark behavior (lowercase 'null')
|
|
252
|
+
if col == "NULL":
|
|
253
|
+
spark_columns.append(col.lower())
|
|
254
|
+
else:
|
|
255
|
+
spark_columns.append(col)
|
|
256
|
+
|
|
257
|
+
grouping_cols_count = len(agg_name_list)
|
|
258
|
+
pivot_cols = result.columns[grouping_cols_count:]
|
|
259
|
+
spark_pivot_cols = spark_columns[grouping_cols_count:]
|
|
260
|
+
|
|
261
|
+
num_agg_functions = len(columns.aggregation_columns)
|
|
262
|
+
num_pivot_values = len(pivot_cols) // num_agg_functions
|
|
263
|
+
|
|
264
|
+
reordered_snowpark_cols = []
|
|
265
|
+
reordered_spark_cols = []
|
|
266
|
+
column_indices = [] # 1-based indexing
|
|
267
|
+
|
|
268
|
+
for i in range(grouping_cols_count):
|
|
269
|
+
reordered_snowpark_cols.append(result.columns[i])
|
|
270
|
+
reordered_spark_cols.append(spark_columns[i])
|
|
271
|
+
column_indices.append(i + 1)
|
|
272
|
+
|
|
273
|
+
for pivot_idx in range(num_pivot_values):
|
|
274
|
+
for agg_idx in range(num_agg_functions):
|
|
275
|
+
current_pos = agg_idx * num_pivot_values + pivot_idx
|
|
276
|
+
if current_pos < len(pivot_cols):
|
|
277
|
+
reordered_snowpark_cols.append(pivot_cols[current_pos])
|
|
278
|
+
reordered_spark_cols.append(spark_pivot_cols[current_pos])
|
|
279
|
+
original_index = grouping_cols_count + current_pos
|
|
280
|
+
column_indices.append(original_index + 1)
|
|
281
|
+
|
|
282
|
+
reordered_result = result.select(
|
|
283
|
+
*[snowpark_fn.col(f"${idx}") for idx in column_indices]
|
|
284
|
+
)
|
|
157
285
|
|
|
158
|
-
agg_name_list = [c.spark_name for c in columns.grouping_columns]
|
|
159
286
|
return DataFrameContainer.create_with_column_mapping(
|
|
160
|
-
dataframe=
|
|
161
|
-
spark_column_names=
|
|
162
|
-
snowpark_column_names=
|
|
287
|
+
dataframe=reordered_result,
|
|
288
|
+
spark_column_names=reordered_spark_cols,
|
|
289
|
+
snowpark_column_names=[f"${idx}" for idx in column_indices],
|
|
163
290
|
column_qualifiers=(
|
|
164
291
|
columns.get_qualifiers()[: len(agg_name_list)]
|
|
165
|
-
+ [[]] * (len(
|
|
292
|
+
+ [[]] * (len(reordered_spark_cols) - len(agg_name_list))
|
|
166
293
|
),
|
|
167
294
|
parent_column_name_map=input_container.column_map,
|
|
295
|
+
snowpark_column_types=[
|
|
296
|
+
result.schema.fields[idx - 1].datatype for idx in column_indices
|
|
297
|
+
],
|
|
168
298
|
)
|
|
169
299
|
|
|
170
300
|
|
|
171
|
-
def
|
|
301
|
+
def pivot_column_name(snowpark_cname, opt_alias: Optional[str] = None) -> Optional[str]:
|
|
172
302
|
# For values that are used as pivoted columns, the input and output are in the following format (outermost double quotes are part of the input):
|
|
173
303
|
|
|
174
304
|
# 1. "'Java'" -> Java
|
|
@@ -183,7 +313,7 @@ def string_parser(s):
|
|
|
183
313
|
|
|
184
314
|
try:
|
|
185
315
|
# handling values that are used as pivoted columns
|
|
186
|
-
match = re.match(r'^"\'(.*)\'"$',
|
|
316
|
+
match = re.match(r'^"\'(.*)\'"$', snowpark_cname)
|
|
187
317
|
# extract the content between the outermost double quote followed by a single quote "'
|
|
188
318
|
content = match.group(1)
|
|
189
319
|
# convert the escaped double quote to the actual double quote
|
|
@@ -195,10 +325,10 @@ def string_parser(s):
|
|
|
195
325
|
content = re.sub(r"'", "", content)
|
|
196
326
|
# replace the placeholder with the single quote which we want to preserve
|
|
197
327
|
result = content.replace(escape_single_quote_placeholder, "'")
|
|
198
|
-
return result
|
|
328
|
+
return f"{result}_{opt_alias}" if opt_alias else result
|
|
199
329
|
except Exception:
|
|
200
330
|
# fallback to the original logic, handling aliased column names
|
|
201
|
-
double_quote_list = re.findall(r'"(.*?)"',
|
|
331
|
+
double_quote_list = re.findall(r'"(.*?)"', snowpark_cname)
|
|
202
332
|
spark_string = ""
|
|
203
333
|
for entry in list(filter(None, double_quote_list)):
|
|
204
334
|
if "'" in entry:
|
|
@@ -210,7 +340,7 @@ def string_parser(s):
|
|
|
210
340
|
spark_string += entry
|
|
211
341
|
else:
|
|
212
342
|
spark_string += '"' + entry + '"'
|
|
213
|
-
return
|
|
343
|
+
return snowpark_cname if spark_string == "" else spark_string
|
|
214
344
|
|
|
215
345
|
|
|
216
346
|
@dataclass(frozen=True)
|
|
@@ -231,8 +361,14 @@ class _Columns:
|
|
|
231
361
|
def grouping_expressions(self) -> list[snowpark.Column]:
|
|
232
362
|
return [col.expression for col in self.grouping_columns]
|
|
233
363
|
|
|
234
|
-
def aggregation_expressions(self) -> list[snowpark.Column]:
|
|
235
|
-
|
|
364
|
+
def aggregation_expressions(self, unalias: bool = False) -> list[snowpark.Column]:
|
|
365
|
+
def _unalias(col: snowpark.Column) -> snowpark.Column:
|
|
366
|
+
if unalias and hasattr(col, "_expr1") and isinstance(col._expr1, Alias):
|
|
367
|
+
return _unalias(Column(col._expr1.child))
|
|
368
|
+
else:
|
|
369
|
+
return col
|
|
370
|
+
|
|
371
|
+
return [_unalias(col.expression) for col in self.aggregation_columns]
|
|
236
372
|
|
|
237
373
|
def expressions(self) -> list[snowpark.Column]:
|
|
238
374
|
return self.grouping_expressions() + self.aggregation_expressions()
|
|
@@ -6,10 +6,12 @@ import ast
|
|
|
6
6
|
import json
|
|
7
7
|
import sys
|
|
8
8
|
from collections import defaultdict
|
|
9
|
+
from copy import copy
|
|
9
10
|
|
|
10
11
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
11
12
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
12
13
|
import pyspark.sql.connect.proto.types_pb2 as types_proto
|
|
14
|
+
from pyspark.errors import PySparkValueError
|
|
13
15
|
from pyspark.errors.exceptions.base import AnalysisException
|
|
14
16
|
from pyspark.serializers import CloudPickleSerializer
|
|
15
17
|
|
|
@@ -44,6 +46,7 @@ from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
|
44
46
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
45
47
|
from snowflake.snowpark_connect.relation.utils import (
|
|
46
48
|
TYPE_MAP_FOR_TO_SCHEMA,
|
|
49
|
+
can_sort_be_flattened,
|
|
47
50
|
snowpark_functions_col,
|
|
48
51
|
)
|
|
49
52
|
from snowflake.snowpark_connect.type_mapping import (
|
|
@@ -266,6 +269,7 @@ def map_project(
|
|
|
266
269
|
|
|
267
270
|
aliased_col = mapper.col.alias(snowpark_column)
|
|
268
271
|
select_list.append(aliased_col)
|
|
272
|
+
|
|
269
273
|
new_snowpark_columns.append(snowpark_column)
|
|
270
274
|
new_spark_columns.append(spark_name)
|
|
271
275
|
column_types.extend(mapper.types)
|
|
@@ -342,6 +346,12 @@ def map_sort(
|
|
|
342
346
|
|
|
343
347
|
sort_order = sort.order
|
|
344
348
|
|
|
349
|
+
if not sort_order:
|
|
350
|
+
raise PySparkValueError(
|
|
351
|
+
error_class="CANNOT_BE_EMPTY",
|
|
352
|
+
message="At least one column must be specified.",
|
|
353
|
+
)
|
|
354
|
+
|
|
345
355
|
if len(sort_order) == 1:
|
|
346
356
|
parsed_col_name = split_fully_qualified_spark_name(
|
|
347
357
|
sort_order[0].child.unresolved_attribute.unparsed_identifier
|
|
@@ -422,7 +432,30 @@ def map_sort(
|
|
|
422
432
|
# TODO: sort.isglobal.
|
|
423
433
|
if not order_specified:
|
|
424
434
|
ascending = None
|
|
425
|
-
|
|
435
|
+
|
|
436
|
+
select_statement = getattr(input_df, "_select_statement", None)
|
|
437
|
+
sort_expressions = [c._expression for c in cols]
|
|
438
|
+
if (
|
|
439
|
+
can_sort_be_flattened(select_statement, *sort_expressions)
|
|
440
|
+
and input_df._ops_after_agg is None
|
|
441
|
+
):
|
|
442
|
+
# "flattened" order by that will allow using dropped columns
|
|
443
|
+
new = copy(select_statement)
|
|
444
|
+
new.from_ = select_statement.from_.to_subqueryable()
|
|
445
|
+
new.pre_actions = new.from_.pre_actions
|
|
446
|
+
new.post_actions = new.from_.post_actions
|
|
447
|
+
new.order_by = sort_expressions + (select_statement.order_by or [])
|
|
448
|
+
new.column_states = select_statement.column_states
|
|
449
|
+
new._merge_projection_complexity_with_subquery = False
|
|
450
|
+
new.df_ast_ids = (
|
|
451
|
+
select_statement.df_ast_ids.copy()
|
|
452
|
+
if select_statement.df_ast_ids is not None
|
|
453
|
+
else None
|
|
454
|
+
)
|
|
455
|
+
new.attributes = select_statement.attributes
|
|
456
|
+
result = input_df._with_plan(new)
|
|
457
|
+
else:
|
|
458
|
+
result = input_df.sort(cols, ascending=ascending)
|
|
426
459
|
|
|
427
460
|
return DataFrameContainer(
|
|
428
461
|
result,
|
|
@@ -666,10 +699,29 @@ def map_with_columns_renamed(
|
|
|
666
699
|
)
|
|
667
700
|
|
|
668
701
|
# Validate for naming conflicts
|
|
669
|
-
|
|
702
|
+
rename_map = dict(rel.with_columns_renamed.rename_columns_map)
|
|
703
|
+
new_names_list = list(rename_map.values())
|
|
670
704
|
seen = set()
|
|
671
705
|
for new_name in new_names_list:
|
|
672
|
-
if
|
|
706
|
+
# Check if this new name conflicts with existing columns
|
|
707
|
+
# But allow renaming a column to a different case version of itself
|
|
708
|
+
is_case_insensitive_self_rename = False
|
|
709
|
+
if not global_config.spark_sql_caseSensitive:
|
|
710
|
+
# Find the source column(s) that map to this new name
|
|
711
|
+
source_columns = [
|
|
712
|
+
old_name
|
|
713
|
+
for old_name, new_name_candidate in rename_map.items()
|
|
714
|
+
if new_name_candidate == new_name
|
|
715
|
+
]
|
|
716
|
+
# Check if any source column is the same as new name when case-insensitive
|
|
717
|
+
is_case_insensitive_self_rename = any(
|
|
718
|
+
source_col.lower() == new_name.lower() for source_col in source_columns
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
if (
|
|
722
|
+
column_map.has_spark_column(new_name)
|
|
723
|
+
and not is_case_insensitive_self_rename
|
|
724
|
+
):
|
|
673
725
|
# Spark doesn't allow reusing existing names, even if the result df will not contain duplicate columns
|
|
674
726
|
raise _column_exists_error(new_name)
|
|
675
727
|
if (global_config.spark_sql_caseSensitive and new_name in seen) or (
|
|
@@ -1056,14 +1108,12 @@ def map_group_map(
|
|
|
1056
1108
|
snowpark_grouping_expressions: list[snowpark.Column] = []
|
|
1057
1109
|
typer = ExpressionTyper(input_df)
|
|
1058
1110
|
group_name_list: list[str] = []
|
|
1059
|
-
qualifiers = []
|
|
1060
1111
|
for exp in grouping_expressions:
|
|
1061
1112
|
new_name, snowpark_column = map_single_column_expression(
|
|
1062
1113
|
exp, input_container.column_map, typer
|
|
1063
1114
|
)
|
|
1064
1115
|
snowpark_grouping_expressions.append(snowpark_column.col)
|
|
1065
1116
|
group_name_list.append(new_name)
|
|
1066
|
-
qualifiers.append(snowpark_column.get_qualifiers())
|
|
1067
1117
|
if rel.group_map.func.python_udf is None:
|
|
1068
1118
|
raise ValueError("group_map relation without python udf is not supported")
|
|
1069
1119
|
|
|
@@ -1105,13 +1155,14 @@ def map_group_map(
|
|
|
1105
1155
|
result = input_df.group_by(*snowpark_grouping_expressions).apply_in_pandas(
|
|
1106
1156
|
callable_func, output_type
|
|
1107
1157
|
)
|
|
1108
|
-
|
|
1109
|
-
|
|
1158
|
+
# The UDTF `apply_in_pandas` generates a new table whose output schema
|
|
1159
|
+
# can be entirely different from that of the input Snowpark DataFrame.
|
|
1160
|
+
# As a result, the output DataFrame should not use qualifiers based on the input group by columns.
|
|
1110
1161
|
return DataFrameContainer.create_with_column_mapping(
|
|
1111
1162
|
dataframe=result,
|
|
1112
1163
|
spark_column_names=[field.name for field in output_type],
|
|
1113
1164
|
snowpark_column_names=result.columns,
|
|
1114
|
-
column_qualifiers=
|
|
1165
|
+
column_qualifiers=None,
|
|
1115
1166
|
parent_column_name_map=input_container.column_map,
|
|
1116
1167
|
)
|
|
1117
1168
|
|