snowpark-connect 0.29.0__py3-none-any.whl → 0.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client.py +65 -0
- snowflake/snowpark_connect/column_name_handler.py +6 -0
- snowflake/snowpark_connect/config.py +25 -3
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +21 -19
- snowflake/snowpark_connect/expression/map_extension.py +277 -1
- snowflake/snowpark_connect/expression/map_sql_expression.py +107 -2
- snowflake/snowpark_connect/expression/map_unresolved_function.py +253 -59
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/relation/io_utils.py +61 -4
- snowflake/snowpark_connect/relation/map_column_ops.py +9 -4
- snowflake/snowpark_connect/relation/map_join.py +8 -0
- snowflake/snowpark_connect/relation/map_row_ops.py +129 -17
- snowflake/snowpark_connect/relation/map_show_string.py +14 -6
- snowflake/snowpark_connect/relation/map_sql.py +39 -5
- snowflake/snowpark_connect/relation/map_stats.py +21 -6
- snowflake/snowpark_connect/relation/read/map_read.py +9 -0
- snowflake/snowpark_connect/relation/read/map_read_csv.py +17 -6
- snowflake/snowpark_connect/relation/read/map_read_json.py +12 -2
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -1
- snowflake/snowpark_connect/relation/read/metadata_utils.py +159 -0
- snowflake/snowpark_connect/relation/utils.py +19 -2
- snowflake/snowpark_connect/relation/write/map_write.py +44 -29
- snowflake/snowpark_connect/server.py +11 -3
- snowflake/snowpark_connect/type_mapping.py +75 -3
- snowflake/snowpark_connect/utils/describe_query_cache.py +6 -3
- snowflake/snowpark_connect/utils/telemetry.py +105 -23
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/METADATA +1 -1
- {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/RECORD +41 -37
- {snowpark_connect-0.29.0.data → snowpark_connect-0.30.1.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.29.0.data → snowpark_connect-0.30.1.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.29.0.data → snowpark_connect-0.30.1.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/top_level.txt +0 -0
|
@@ -16,7 +16,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e
|
|
|
16
16
|
from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_relations__pb2
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1esnowflake_expression_ext.proto\x12\rsnowflake.ext\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\"\
|
|
19
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1esnowflake_expression_ext.proto\x12\rsnowflake.ext\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\"\xde\x01\n\x0c\x45xpExtension\x12@\n\x0enamed_argument\x18\x01 \x01(\x0b\x32&.snowflake.ext.NamedArgumentExpressionH\x00\x12@\n\x13subquery_expression\x18\x02 \x01(\x0b\x32!.snowflake.ext.SubqueryExpressionH\x00\x12\x44\n\x10interval_literal\x18\x03 \x01(\x0b\x32(.snowflake.ext.IntervalLiteralExpressionH\x00\x42\x04\n\x02op\"P\n\x17NamedArgumentExpression\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.Expression\"\xf4\x04\n\x12SubqueryExpression\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x45\n\rsubquery_type\x18\x02 \x01(\x0e\x32..snowflake.ext.SubqueryExpression.SubqueryType\x12Q\n\x11table_arg_options\x18\x03 \x01(\x0b\x32\x31.snowflake.ext.SubqueryExpression.TableArgOptionsH\x00\x88\x01\x01\x12\x35\n\x12in_subquery_values\x18\x04 \x03(\x0b\x32\x19.spark.connect.Expression\x1a\xbb\x01\n\x0fTableArgOptions\x12\x31\n\x0epartition_spec\x18\x01 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x37\n\norder_spec\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrder\x12\"\n\x15with_single_partition\x18\x03 \x01(\x08H\x00\x88\x01\x01\x42\x18\n\x16_with_single_partition\"\x90\x01\n\x0cSubqueryType\x12\x19\n\x15SUBQUERY_TYPE_UNKNOWN\x10\x00\x12\x18\n\x14SUBQUERY_TYPE_SCALAR\x10\x01\x12\x18\n\x14SUBQUERY_TYPE_EXISTS\x10\x02\x12\x1b\n\x17SUBQUERY_TYPE_TABLE_ARG\x10\x03\x12\x14\n\x10SUBQUERY_TYPE_IN\x10\x04\x42\x14\n\x12_table_arg_options\"\x9f\x01\n\x19IntervalLiteralExpression\x12\x32\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.Literal\x12\x18\n\x0bstart_field\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12\x16\n\tend_field\x18\x03 \x01(\x05H\x01\x88\x01\x01\x42\x0e\n\x0c_start_fieldB\x0c\n\n_end_fieldb\x06proto3')
|
|
20
20
|
|
|
21
21
|
_globals = globals()
|
|
22
22
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -24,13 +24,15 @@ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'snowflake_expression_ext_pb
|
|
|
24
24
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
|
25
25
|
DESCRIPTOR._options = None
|
|
26
26
|
_globals['_EXPEXTENSION']._serialized_start=114
|
|
27
|
-
_globals['_EXPEXTENSION']._serialized_end=
|
|
28
|
-
_globals['_NAMEDARGUMENTEXPRESSION']._serialized_start=
|
|
29
|
-
_globals['_NAMEDARGUMENTEXPRESSION']._serialized_end=
|
|
30
|
-
_globals['_SUBQUERYEXPRESSION']._serialized_start=
|
|
31
|
-
_globals['_SUBQUERYEXPRESSION']._serialized_end=
|
|
32
|
-
_globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_start=
|
|
33
|
-
_globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_end=
|
|
34
|
-
_globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_start=
|
|
35
|
-
_globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_end=
|
|
27
|
+
_globals['_EXPEXTENSION']._serialized_end=336
|
|
28
|
+
_globals['_NAMEDARGUMENTEXPRESSION']._serialized_start=338
|
|
29
|
+
_globals['_NAMEDARGUMENTEXPRESSION']._serialized_end=418
|
|
30
|
+
_globals['_SUBQUERYEXPRESSION']._serialized_start=421
|
|
31
|
+
_globals['_SUBQUERYEXPRESSION']._serialized_end=1049
|
|
32
|
+
_globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_start=693
|
|
33
|
+
_globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_end=880
|
|
34
|
+
_globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_start=883
|
|
35
|
+
_globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_end=1027
|
|
36
|
+
_globals['_INTERVALLITERALEXPRESSION']._serialized_start=1052
|
|
37
|
+
_globals['_INTERVALLITERALEXPRESSION']._serialized_end=1211
|
|
36
38
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -9,12 +9,14 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map
|
|
|
9
9
|
DESCRIPTOR: _descriptor.FileDescriptor
|
|
10
10
|
|
|
11
11
|
class ExpExtension(_message.Message):
|
|
12
|
-
__slots__ = ("named_argument", "subquery_expression")
|
|
12
|
+
__slots__ = ("named_argument", "subquery_expression", "interval_literal")
|
|
13
13
|
NAMED_ARGUMENT_FIELD_NUMBER: _ClassVar[int]
|
|
14
14
|
SUBQUERY_EXPRESSION_FIELD_NUMBER: _ClassVar[int]
|
|
15
|
+
INTERVAL_LITERAL_FIELD_NUMBER: _ClassVar[int]
|
|
15
16
|
named_argument: NamedArgumentExpression
|
|
16
17
|
subquery_expression: SubqueryExpression
|
|
17
|
-
|
|
18
|
+
interval_literal: IntervalLiteralExpression
|
|
19
|
+
def __init__(self, named_argument: _Optional[_Union[NamedArgumentExpression, _Mapping]] = ..., subquery_expression: _Optional[_Union[SubqueryExpression, _Mapping]] = ..., interval_literal: _Optional[_Union[IntervalLiteralExpression, _Mapping]] = ...) -> None: ...
|
|
18
20
|
|
|
19
21
|
class NamedArgumentExpression(_message.Message):
|
|
20
22
|
__slots__ = ("key", "value")
|
|
@@ -56,3 +58,13 @@ class SubqueryExpression(_message.Message):
|
|
|
56
58
|
table_arg_options: SubqueryExpression.TableArgOptions
|
|
57
59
|
in_subquery_values: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
|
|
58
60
|
def __init__(self, input: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., subquery_type: _Optional[_Union[SubqueryExpression.SubqueryType, str]] = ..., table_arg_options: _Optional[_Union[SubqueryExpression.TableArgOptions, _Mapping]] = ..., in_subquery_values: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ...) -> None: ...
|
|
61
|
+
|
|
62
|
+
class IntervalLiteralExpression(_message.Message):
|
|
63
|
+
__slots__ = ("literal", "start_field", "end_field")
|
|
64
|
+
LITERAL_FIELD_NUMBER: _ClassVar[int]
|
|
65
|
+
START_FIELD_FIELD_NUMBER: _ClassVar[int]
|
|
66
|
+
END_FIELD_FIELD_NUMBER: _ClassVar[int]
|
|
67
|
+
literal: _expressions_pb2.Expression.Literal
|
|
68
|
+
start_field: int
|
|
69
|
+
end_field: int
|
|
70
|
+
def __init__(self, literal: _Optional[_Union[_expressions_pb2.Expression.Literal, _Mapping]] = ..., start_field: _Optional[int] = ..., end_field: _Optional[int] = ...) -> None: ...
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
|
|
5
5
|
from urllib.parse import urlparse
|
|
6
6
|
|
|
7
|
+
from pyspark.errors.exceptions.base import AnalysisException
|
|
8
|
+
|
|
7
9
|
CLOUD_PREFIX_TO_CLOUD = {
|
|
8
10
|
"abfss": "azure",
|
|
9
11
|
"wasbs": "azure",
|
|
@@ -12,10 +14,28 @@ CLOUD_PREFIX_TO_CLOUD = {
|
|
|
12
14
|
}
|
|
13
15
|
|
|
14
16
|
SUPPORTED_COMPRESSION_PER_FORMAT = {
|
|
15
|
-
"csv": {
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
"csv": {
|
|
18
|
+
"GZIP",
|
|
19
|
+
"BZ2",
|
|
20
|
+
"BROTLI",
|
|
21
|
+
"ZSTD",
|
|
22
|
+
"DEFLATE",
|
|
23
|
+
"RAW_DEFLATE",
|
|
24
|
+
"NONE",
|
|
25
|
+
"UNCOMPRESSED",
|
|
26
|
+
},
|
|
27
|
+
"json": {
|
|
28
|
+
"GZIP",
|
|
29
|
+
"BZ2",
|
|
30
|
+
"BROTLI",
|
|
31
|
+
"ZSTD",
|
|
32
|
+
"DEFLATE",
|
|
33
|
+
"RAW_DEFLATE",
|
|
34
|
+
"NONE",
|
|
35
|
+
"UNCOMPRESSED",
|
|
36
|
+
},
|
|
37
|
+
"parquet": {"LZO", "SNAPPY", "NONE", "UNCOMPRESSED"},
|
|
38
|
+
"text": {"NONE", "UNCOMPRESSED"},
|
|
19
39
|
}
|
|
20
40
|
|
|
21
41
|
|
|
@@ -29,6 +49,43 @@ def is_supported_compression(format: str, compression: str | None) -> bool:
|
|
|
29
49
|
return compression in supported_compressions_for_format(format)
|
|
30
50
|
|
|
31
51
|
|
|
52
|
+
def get_compression_for_source_and_options(
|
|
53
|
+
source: str, options: dict[str, str], from_read: bool = False
|
|
54
|
+
) -> str | None:
|
|
55
|
+
"""
|
|
56
|
+
Determines the compression type to use for a given data source and options.
|
|
57
|
+
Args:
|
|
58
|
+
source (str): The data source format (e.g., "csv", "json", "parquet", "text").
|
|
59
|
+
options (dict[str, str]): A dictionary of options that may include a "compression" key.
|
|
60
|
+
Returns:
|
|
61
|
+
str: The compression type to use (e.g., "GZIP", "SNAPPY", "NONE").
|
|
62
|
+
Raises:
|
|
63
|
+
AnalysisException: If the specified compression is not supported for the given source format.
|
|
64
|
+
"""
|
|
65
|
+
# From read, we don't have a default compression
|
|
66
|
+
if from_read and "compression" not in options:
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
# Get compression from options for proper filename generation
|
|
70
|
+
default_compression = "NONE" if source != "parquet" else "snappy"
|
|
71
|
+
compression = options.get("compression", default_compression).upper()
|
|
72
|
+
if compression == "UNCOMPRESSED":
|
|
73
|
+
compression = "NONE"
|
|
74
|
+
|
|
75
|
+
if not is_supported_compression(source, compression):
|
|
76
|
+
supported_compressions = supported_compressions_for_format(source)
|
|
77
|
+
raise AnalysisException(
|
|
78
|
+
f"Compression {compression} is not supported for {source} format. "
|
|
79
|
+
+ (
|
|
80
|
+
f"Supported compressions: {sorted(supported_compressions)}"
|
|
81
|
+
if supported_compressions
|
|
82
|
+
else "None compression supported for this format."
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return compression
|
|
87
|
+
|
|
88
|
+
|
|
32
89
|
def get_cloud_from_url(
|
|
33
90
|
url: str,
|
|
34
91
|
):
|
|
@@ -1124,10 +1124,15 @@ def map_group_map(
|
|
|
1124
1124
|
group_by_df = input_df.group_by(*snowpark_grouping_expressions)
|
|
1125
1125
|
inner_df = group_by_df._dataframe
|
|
1126
1126
|
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1127
|
+
renamed_columns = [f"snowflake_jtf_{column}" for column in input_df.columns]
|
|
1128
|
+
tfc = snowpark_fn.call_table_function(
|
|
1129
|
+
apply_udtf_temp_name, *renamed_columns
|
|
1130
|
+
).over(partition_by=snowpark_grouping_expressions)
|
|
1131
|
+
|
|
1132
|
+
result = (
|
|
1133
|
+
inner_df.to_df(renamed_columns)
|
|
1134
|
+
.join_table_function(tfc)
|
|
1135
|
+
.drop(*renamed_columns)
|
|
1131
1136
|
)
|
|
1132
1137
|
else:
|
|
1133
1138
|
(
|
|
@@ -21,6 +21,9 @@ from snowflake.snowpark_connect.relation.map_relation import (
|
|
|
21
21
|
NATURAL_JOIN_TYPE_BASE,
|
|
22
22
|
map_relation,
|
|
23
23
|
)
|
|
24
|
+
from snowflake.snowpark_connect.relation.read.metadata_utils import (
|
|
25
|
+
filter_metadata_columns,
|
|
26
|
+
)
|
|
24
27
|
from snowflake.snowpark_connect.utils.context import (
|
|
25
28
|
push_evaluating_join_condition,
|
|
26
29
|
push_sql_scope,
|
|
@@ -38,6 +41,11 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
38
41
|
left_container: DataFrameContainer = map_relation(rel.join.left)
|
|
39
42
|
right_container: DataFrameContainer = map_relation(rel.join.right)
|
|
40
43
|
|
|
44
|
+
# Remove any metadata columns(like metada$filename) present in the dataframes.
|
|
45
|
+
# We cannot support inputfilename for multisources as each dataframe has it's own source.
|
|
46
|
+
left_container = filter_metadata_columns(left_container)
|
|
47
|
+
right_container = filter_metadata_columns(right_container)
|
|
48
|
+
|
|
41
49
|
left_input: snowpark.DataFrame = left_container.dataframe
|
|
42
50
|
right_input: snowpark.DataFrame = right_container.dataframe
|
|
43
51
|
is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
|
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
#
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
|
|
5
|
-
|
|
6
4
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
7
5
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
8
6
|
from pyspark.errors.exceptions.base import AnalysisException, IllegalArgumentException
|
|
9
7
|
|
|
10
8
|
import snowflake.snowpark_connect.relation.utils as utils
|
|
11
9
|
from snowflake import snowpark
|
|
12
|
-
from snowflake.snowpark.
|
|
10
|
+
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
|
|
11
|
+
from snowflake.snowpark.functions import col, expr as snowpark_expr, lit
|
|
13
12
|
from snowflake.snowpark.types import (
|
|
14
13
|
BooleanType,
|
|
15
14
|
ByteType,
|
|
@@ -20,8 +19,14 @@ from snowflake.snowpark.types import (
|
|
|
20
19
|
LongType,
|
|
21
20
|
NullType,
|
|
22
21
|
ShortType,
|
|
22
|
+
StructField,
|
|
23
|
+
StructType,
|
|
24
|
+
)
|
|
25
|
+
from snowflake.snowpark_connect.column_name_handler import (
|
|
26
|
+
ColumnNameMap,
|
|
27
|
+
schema_getter,
|
|
28
|
+
set_schema_getter,
|
|
23
29
|
)
|
|
24
|
-
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap, schema_getter
|
|
25
30
|
from snowflake.snowpark_connect.config import global_config
|
|
26
31
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
27
32
|
from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
|
|
@@ -30,6 +35,9 @@ from snowflake.snowpark_connect.expression.map_expression import (
|
|
|
30
35
|
)
|
|
31
36
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
32
37
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
38
|
+
from snowflake.snowpark_connect.utils.identifiers import (
|
|
39
|
+
split_fully_qualified_spark_name,
|
|
40
|
+
)
|
|
33
41
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
34
42
|
SnowparkConnectNotImplementedError,
|
|
35
43
|
)
|
|
@@ -126,11 +134,19 @@ def map_fillna(
|
|
|
126
134
|
input_df = input_container.dataframe
|
|
127
135
|
|
|
128
136
|
if len(rel.fill_na.cols) > 0:
|
|
137
|
+
if rel.fill_na.cols == ["*"]:
|
|
138
|
+
# Expand "*" to all columns
|
|
139
|
+
spark_col_names = input_container.column_map.get_spark_columns()
|
|
140
|
+
else:
|
|
141
|
+
spark_col_names = list(rel.fill_na.cols)
|
|
142
|
+
|
|
143
|
+
# We don't validate the fully qualified spark name here as fillNa is no-op for structured type colums.
|
|
144
|
+
# It only works for scalar type columns like float, int, string or bool.
|
|
129
145
|
columns: list[str] = [
|
|
130
146
|
input_container.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
131
|
-
c
|
|
147
|
+
split_fully_qualified_spark_name(c)[0]
|
|
132
148
|
)
|
|
133
|
-
for c in
|
|
149
|
+
for c in spark_col_names
|
|
134
150
|
]
|
|
135
151
|
values = [get_literal_field_and_name(v)[0] for v in rel.fill_na.values]
|
|
136
152
|
if len(values) == 1:
|
|
@@ -318,23 +334,37 @@ def map_union(
|
|
|
318
334
|
right_column_map = right_result.column_map
|
|
319
335
|
columns_to_restore: dict[str, tuple[str, str]] = {}
|
|
320
336
|
|
|
321
|
-
|
|
337
|
+
original_right_schema = right_df.schema
|
|
338
|
+
right_renamed_fields = []
|
|
339
|
+
for field in original_right_schema.fields:
|
|
322
340
|
spark_name = (
|
|
323
|
-
right_column_map.get_spark_column_name_from_snowpark_column_name(
|
|
341
|
+
right_column_map.get_spark_column_name_from_snowpark_column_name(
|
|
342
|
+
field.name
|
|
343
|
+
)
|
|
344
|
+
)
|
|
345
|
+
right_df = right_df.withColumnRenamed(field.name, spark_name)
|
|
346
|
+
columns_to_restore[spark_name.upper()] = (spark_name, field.name)
|
|
347
|
+
right_renamed_fields.append(
|
|
348
|
+
StructField(spark_name, field.datatype, field.nullable)
|
|
324
349
|
)
|
|
325
|
-
|
|
326
|
-
columns_to_restore[spark_name.upper()] = (spark_name, column)
|
|
350
|
+
set_schema_getter(right_df, lambda: StructType(right_renamed_fields))
|
|
327
351
|
|
|
328
|
-
|
|
352
|
+
original_left_schema = left_df.schema
|
|
353
|
+
left_renamed_fields = []
|
|
354
|
+
for field in original_left_schema.fields:
|
|
329
355
|
spark_name = (
|
|
330
|
-
left_column_map.get_spark_column_name_from_snowpark_column_name(
|
|
356
|
+
left_column_map.get_spark_column_name_from_snowpark_column_name(
|
|
357
|
+
field.name
|
|
358
|
+
)
|
|
359
|
+
)
|
|
360
|
+
left_df = left_df.withColumnRenamed(field.name, spark_name)
|
|
361
|
+
columns_to_restore[spark_name.upper()] = (spark_name, field.name)
|
|
362
|
+
left_renamed_fields.append(
|
|
363
|
+
StructField(spark_name, field.datatype, field.nullable)
|
|
331
364
|
)
|
|
332
|
-
|
|
333
|
-
columns_to_restore[spark_name.upper()] = (spark_name, column)
|
|
365
|
+
set_schema_getter(left_df, lambda: StructType(left_renamed_fields))
|
|
334
366
|
|
|
335
|
-
result = left_df
|
|
336
|
-
right_df, allow_missing_columns=allow_missing_columns
|
|
337
|
-
)
|
|
367
|
+
result = _union_by_name_optimized(left_df, right_df, allow_missing_columns)
|
|
338
368
|
|
|
339
369
|
if allow_missing_columns:
|
|
340
370
|
spark_columns = []
|
|
@@ -809,3 +839,85 @@ def map_tail(
|
|
|
809
839
|
alias=input_container.alias,
|
|
810
840
|
cached_schema_getter=lambda: input_df.schema,
|
|
811
841
|
)
|
|
842
|
+
|
|
843
|
+
|
|
844
|
+
def _union_by_name_optimized(
|
|
845
|
+
left_df: snowpark.DataFrame,
|
|
846
|
+
right_df: snowpark.DataFrame,
|
|
847
|
+
allow_missing_columns: bool = False,
|
|
848
|
+
) -> snowpark.DataFrame:
|
|
849
|
+
"""
|
|
850
|
+
This implementation is an optimized version of Snowpark's Dataframe::_union_by_name_internal.
|
|
851
|
+
The only change is, that it avoids redundant schema queries that occur in the standard Snowpark,
|
|
852
|
+
by reusing already-fetched/calculated schemas.
|
|
853
|
+
"""
|
|
854
|
+
|
|
855
|
+
left_schema = left_df.schema
|
|
856
|
+
right_schema = right_df.schema
|
|
857
|
+
|
|
858
|
+
left_cols = {field.name for field in left_schema.fields}
|
|
859
|
+
right_cols = {field.name for field in right_schema.fields}
|
|
860
|
+
right_field_map = {field.name: field for field in right_schema.fields}
|
|
861
|
+
|
|
862
|
+
missing_left = right_cols - left_cols
|
|
863
|
+
missing_right = left_cols - right_cols
|
|
864
|
+
|
|
865
|
+
def add_nulls(
|
|
866
|
+
missing_cols: set[str], to_df: snowpark.DataFrame, from_df: snowpark.DataFrame
|
|
867
|
+
) -> snowpark.DataFrame:
|
|
868
|
+
dt_map = {field.name: field.datatype for field in from_df.schema.fields}
|
|
869
|
+
result = to_df.select(
|
|
870
|
+
"*",
|
|
871
|
+
*[lit(None).cast(dt_map[col]).alias(col) for col in missing_cols],
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
result_fields = []
|
|
875
|
+
for field in to_df.schema.fields:
|
|
876
|
+
result_fields.append(
|
|
877
|
+
StructField(field.name, field.datatype, field.nullable)
|
|
878
|
+
)
|
|
879
|
+
for col_name in missing_cols:
|
|
880
|
+
from_field = next(
|
|
881
|
+
field for field in from_df.schema.fields if field.name == col_name
|
|
882
|
+
)
|
|
883
|
+
result_fields.append(
|
|
884
|
+
StructField(col_name, from_field.datatype, from_field.nullable)
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
set_schema_getter(result, lambda: StructType(result_fields))
|
|
888
|
+
|
|
889
|
+
return result
|
|
890
|
+
|
|
891
|
+
if missing_left or missing_right:
|
|
892
|
+
if allow_missing_columns:
|
|
893
|
+
left = left_df
|
|
894
|
+
right = right_df
|
|
895
|
+
if missing_left:
|
|
896
|
+
left = add_nulls(missing_left, left, right)
|
|
897
|
+
if missing_right:
|
|
898
|
+
right = add_nulls(missing_right, right, left)
|
|
899
|
+
result = left._union_by_name_internal(right, is_all=True)
|
|
900
|
+
|
|
901
|
+
result_fields = []
|
|
902
|
+
for field in left_schema.fields:
|
|
903
|
+
result_fields.append(
|
|
904
|
+
StructField(field.name, field.datatype, field.nullable)
|
|
905
|
+
)
|
|
906
|
+
for col_name in missing_left:
|
|
907
|
+
right_field = right_field_map[col_name]
|
|
908
|
+
result_fields.append(
|
|
909
|
+
StructField(col_name, right_field.datatype, right_field.nullable)
|
|
910
|
+
)
|
|
911
|
+
|
|
912
|
+
set_schema_getter(result, lambda: StructType(result_fields))
|
|
913
|
+
return result
|
|
914
|
+
else:
|
|
915
|
+
raise SnowparkClientExceptionMessages.DF_CANNOT_RESOLVE_COLUMN_NAME_AMONG(
|
|
916
|
+
missing_left, missing_right
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
result = left_df.unionAllByName(
|
|
920
|
+
right_df, allow_missing_columns=allow_missing_columns
|
|
921
|
+
)
|
|
922
|
+
set_schema_getter(result, lambda: left_df.schema)
|
|
923
|
+
return result
|
|
@@ -15,6 +15,9 @@ from snowflake.snowpark_connect.column_name_handler import set_schema_getter
|
|
|
15
15
|
from snowflake.snowpark_connect.config import global_config
|
|
16
16
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
17
17
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
18
|
+
from snowflake.snowpark_connect.relation.read.metadata_utils import (
|
|
19
|
+
filter_metadata_columns,
|
|
20
|
+
)
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
def map_show_string(rel: relation_proto.Relation) -> pandas.DataFrame:
|
|
@@ -26,14 +29,17 @@ def map_show_string(rel: relation_proto.Relation) -> pandas.DataFrame:
|
|
|
26
29
|
Buffer object as a single cell.
|
|
27
30
|
"""
|
|
28
31
|
input_df_container: DataFrameContainer = map_relation(rel.show_string.input)
|
|
29
|
-
|
|
30
|
-
|
|
32
|
+
filtered_container = filter_metadata_columns(input_df_container)
|
|
33
|
+
display_df = filtered_container.dataframe
|
|
34
|
+
display_spark_columns = filtered_container.column_map.get_spark_columns()
|
|
35
|
+
|
|
36
|
+
input_df = _handle_datetype_columns(display_df)
|
|
31
37
|
|
|
32
38
|
show_string = input_df._show_string_spark(
|
|
33
39
|
num_rows=rel.show_string.num_rows,
|
|
34
40
|
truncate=rel.show_string.truncate,
|
|
35
41
|
vertical=rel.show_string.vertical,
|
|
36
|
-
_spark_column_names=
|
|
42
|
+
_spark_column_names=display_spark_columns,
|
|
37
43
|
_spark_session_tz=global_config.spark_sql_session_timeZone,
|
|
38
44
|
)
|
|
39
45
|
return pandas.DataFrame({"show_string": [show_string]})
|
|
@@ -44,14 +50,16 @@ def map_repr_html(rel: relation_proto.Relation) -> pandas.DataFrame:
|
|
|
44
50
|
Generate the html string representation of the input dataframe.
|
|
45
51
|
"""
|
|
46
52
|
input_df_container: DataFrameContainer = map_relation(rel.html_string.input)
|
|
47
|
-
|
|
53
|
+
|
|
54
|
+
filtered_container = filter_metadata_columns(input_df_container)
|
|
55
|
+
input_df = filtered_container.dataframe
|
|
48
56
|
|
|
49
57
|
input_panda = input_df.toPandas()
|
|
50
58
|
input_panda.rename(
|
|
51
59
|
columns={
|
|
52
60
|
analyzer_utils.unquote_if_quoted(
|
|
53
|
-
|
|
54
|
-
):
|
|
61
|
+
filtered_container.column_map.get_snowpark_columns()[i]
|
|
62
|
+
): filtered_container.column_map.get_spark_columns()[i]
|
|
55
63
|
for i in range(len(input_panda.columns))
|
|
56
64
|
},
|
|
57
65
|
inplace=True,
|
|
@@ -7,6 +7,7 @@ from collections.abc import MutableMapping, MutableSequence
|
|
|
7
7
|
from contextlib import contextmanager, suppress
|
|
8
8
|
from contextvars import ContextVar
|
|
9
9
|
from functools import reduce
|
|
10
|
+
from typing import Tuple
|
|
10
11
|
|
|
11
12
|
import jpype
|
|
12
13
|
import pandas
|
|
@@ -31,6 +32,10 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
|
31
32
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
|
32
33
|
from snowflake.snowpark._internal.utils import is_sql_select_statement, quote_name
|
|
33
34
|
from snowflake.snowpark.functions import when_matched, when_not_matched
|
|
35
|
+
from snowflake.snowpark_connect.client import (
|
|
36
|
+
SQL_PASS_THROUGH_MARKER,
|
|
37
|
+
calculate_checksum,
|
|
38
|
+
)
|
|
34
39
|
from snowflake.snowpark_connect.config import (
|
|
35
40
|
auto_uppercase_non_column_identifiers,
|
|
36
41
|
check_table_supports_operation,
|
|
@@ -397,7 +402,7 @@ def map_sql_to_pandas_df(
|
|
|
397
402
|
returns a tuple of None for SELECT queries to enable lazy evaluation
|
|
398
403
|
"""
|
|
399
404
|
|
|
400
|
-
snowpark_connect_sql_passthrough =
|
|
405
|
+
snowpark_connect_sql_passthrough, sql_string = is_valid_passthrough_sql(sql_string)
|
|
401
406
|
|
|
402
407
|
if not snowpark_connect_sql_passthrough:
|
|
403
408
|
logical_plan = sql_parser().parsePlan(sql_string)
|
|
@@ -1047,7 +1052,7 @@ def map_sql_to_pandas_df(
|
|
|
1047
1052
|
raise AnalysisException(
|
|
1048
1053
|
f"ALTER TABLE RENAME COLUMN is not supported for table '{full_table_identifier}'. "
|
|
1049
1054
|
f"This table was created as a v1 table with a data source that doesn't support column renaming. "
|
|
1050
|
-
f"To enable this operation, set 'enable_snowflake_extension_behavior' to 'true'."
|
|
1055
|
+
f"To enable this operation, set 'snowpark.connect.enable_snowflake_extension_behavior' to 'true'."
|
|
1051
1056
|
)
|
|
1052
1057
|
|
|
1053
1058
|
column_obj = logical_plan.column()
|
|
@@ -1282,6 +1287,14 @@ def map_sql_to_pandas_df(
|
|
|
1282
1287
|
return pandas.DataFrame({"": [""]}), ""
|
|
1283
1288
|
|
|
1284
1289
|
rows = session.sql(snowflake_sql).collect()
|
|
1290
|
+
case "RefreshTable":
|
|
1291
|
+
table_name_unquoted = ".".join(
|
|
1292
|
+
str(part)
|
|
1293
|
+
for part in as_java_list(logical_plan.child().multipartIdentifier())
|
|
1294
|
+
)
|
|
1295
|
+
SNOWFLAKE_CATALOG.refreshTable(table_name_unquoted)
|
|
1296
|
+
|
|
1297
|
+
return pandas.DataFrame({"": [""]}), ""
|
|
1285
1298
|
case _:
|
|
1286
1299
|
execute_logical_plan(logical_plan)
|
|
1287
1300
|
return None, None
|
|
@@ -1302,6 +1315,27 @@ def get_sql_passthrough() -> bool:
|
|
|
1302
1315
|
return get_boolean_session_config_param("snowpark.connect.sql.passthrough")
|
|
1303
1316
|
|
|
1304
1317
|
|
|
1318
|
+
def is_valid_passthrough_sql(sql_stmt: str) -> Tuple[bool, str]:
|
|
1319
|
+
"""
|
|
1320
|
+
Checks if :param sql_stmt: should be executed as SQL pass-through. SQL pass-through can be detected in 1 of 2 ways:
|
|
1321
|
+
1) Either Spark config parameter "snowpark.connect.sql.passthrough" is set (legacy mode, to be deprecated)
|
|
1322
|
+
2) If :param sql_stmt: is created through SnowflakeSession and has correct marker + checksum
|
|
1323
|
+
"""
|
|
1324
|
+
if get_sql_passthrough():
|
|
1325
|
+
# legacy style pass-through, sql_stmt should be a whole, valid SQL statement
|
|
1326
|
+
return True, sql_stmt
|
|
1327
|
+
|
|
1328
|
+
# check for new style, SnowflakeSession based SQL pass-through
|
|
1329
|
+
sql_parts = sql_stmt.split(" ", 2)
|
|
1330
|
+
if len(sql_parts) == 3:
|
|
1331
|
+
marker, checksum, sql = sql_parts
|
|
1332
|
+
if marker == SQL_PASS_THROUGH_MARKER and checksum == calculate_checksum(sql):
|
|
1333
|
+
return True, sql
|
|
1334
|
+
|
|
1335
|
+
# Not a SQL pass-through
|
|
1336
|
+
return False, sql_stmt
|
|
1337
|
+
|
|
1338
|
+
|
|
1305
1339
|
def change_default_to_public(name: str) -> str:
|
|
1306
1340
|
"""
|
|
1307
1341
|
Change the namespace to PUBLIC when given name is DEFAULT
|
|
@@ -1397,10 +1431,10 @@ def map_sql(
|
|
|
1397
1431
|
In passthough mode as True, SAS calls session.sql() and not calling Spark Parser.
|
|
1398
1432
|
This is to mitigate any issue not covered by spark logical plan to protobuf conversion.
|
|
1399
1433
|
"""
|
|
1400
|
-
snowpark_connect_sql_passthrough =
|
|
1434
|
+
snowpark_connect_sql_passthrough, sql_stmt = is_valid_passthrough_sql(rel.sql.query)
|
|
1401
1435
|
|
|
1402
1436
|
if not snowpark_connect_sql_passthrough:
|
|
1403
|
-
logical_plan = sql_parser().parseQuery(
|
|
1437
|
+
logical_plan = sql_parser().parseQuery(sql_stmt)
|
|
1404
1438
|
|
|
1405
1439
|
parsed_pos_args = parse_pos_args(logical_plan, rel.sql.pos_args)
|
|
1406
1440
|
set_sql_args(rel.sql.args, parsed_pos_args)
|
|
@@ -1408,7 +1442,7 @@ def map_sql(
|
|
|
1408
1442
|
return execute_logical_plan(logical_plan)
|
|
1409
1443
|
else:
|
|
1410
1444
|
session = snowpark.Session.get_active_session()
|
|
1411
|
-
sql_df = session.sql(
|
|
1445
|
+
sql_df = session.sql(sql_stmt)
|
|
1412
1446
|
columns = sql_df.columns
|
|
1413
1447
|
return DataFrameContainer.create_with_column_mapping(
|
|
1414
1448
|
dataframe=sql_df,
|
|
@@ -81,7 +81,7 @@ def map_approx_quantile(
|
|
|
81
81
|
input_df = input_container.dataframe
|
|
82
82
|
|
|
83
83
|
snowflake_compatible = get_boolean_session_config_param(
|
|
84
|
-
"enable_snowflake_extension_behavior"
|
|
84
|
+
"snowpark.connect.enable_snowflake_extension_behavior"
|
|
85
85
|
)
|
|
86
86
|
|
|
87
87
|
if not snowflake_compatible:
|
|
@@ -309,9 +309,28 @@ def map_freq_items(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
309
309
|
cols = input_container.column_map.get_snowpark_column_names_from_spark_column_names(
|
|
310
310
|
list(rel.freq_items.cols)
|
|
311
311
|
)
|
|
312
|
+
|
|
313
|
+
# handle empty DataFrame case
|
|
314
|
+
row_count = input_df.count()
|
|
315
|
+
|
|
316
|
+
for sp_col_name in cols:
|
|
317
|
+
spark_col_names.append(
|
|
318
|
+
f"{input_container.column_map.get_spark_column_name_from_snowpark_column_name(sp_col_name)}_freqItems"
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if row_count == 0:
|
|
322
|
+
# If DataFrame is empty, return empty arrays for each column
|
|
323
|
+
empty_values = [[] for _ in cols]
|
|
324
|
+
approx_top_k_df = session.createDataFrame([empty_values], spark_col_names)
|
|
325
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
326
|
+
dataframe=approx_top_k_df,
|
|
327
|
+
spark_column_names=spark_col_names,
|
|
328
|
+
snowpark_column_names=spark_col_names,
|
|
329
|
+
)
|
|
330
|
+
|
|
312
331
|
approx_top_k_df = input_df.select(
|
|
313
332
|
*[
|
|
314
|
-
fn.function("approx_top_k")(fn.col(col), round(
|
|
333
|
+
fn.function("approx_top_k")(fn.col(col), round(row_count / support))
|
|
315
334
|
for col in cols
|
|
316
335
|
]
|
|
317
336
|
)
|
|
@@ -330,10 +349,6 @@ def map_freq_items(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
330
349
|
for value in approx_top_k_values
|
|
331
350
|
]
|
|
332
351
|
|
|
333
|
-
for sp_col_name in cols:
|
|
334
|
-
spark_col_names.append(
|
|
335
|
-
f"{input_container.column_map.get_spark_column_name_from_snowpark_column_name(sp_col_name)}_freqItems"
|
|
336
|
-
)
|
|
337
352
|
approx_top_k_df = session.createDataFrame([filtered_values], spark_col_names)
|
|
338
353
|
|
|
339
354
|
return DataFrameContainer.create_with_column_mapping(
|
|
@@ -17,6 +17,7 @@ from snowflake.snowpark_connect.config import global_config
|
|
|
17
17
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
18
18
|
from snowflake.snowpark_connect.relation.io_utils import (
|
|
19
19
|
convert_file_prefix_path,
|
|
20
|
+
get_compression_for_source_and_options,
|
|
20
21
|
is_cloud_path,
|
|
21
22
|
)
|
|
22
23
|
from snowflake.snowpark_connect.relation.read.map_read_table import map_read_table
|
|
@@ -237,6 +238,14 @@ def _read_file(
|
|
|
237
238
|
)
|
|
238
239
|
upload_files_if_needed(paths, clean_source_paths, session, read_format)
|
|
239
240
|
paths = [_quote_stage_path(path) for path in paths]
|
|
241
|
+
|
|
242
|
+
if read_format in ("csv", "text", "json", "parquet"):
|
|
243
|
+
compression = get_compression_for_source_and_options(
|
|
244
|
+
read_format, options, from_read=True
|
|
245
|
+
)
|
|
246
|
+
if compression is not None:
|
|
247
|
+
options["compression"] = compression
|
|
248
|
+
|
|
240
249
|
match read_format:
|
|
241
250
|
case "csv":
|
|
242
251
|
from snowflake.snowpark_connect.relation.read.map_read_csv import (
|