snowpark-connect 0.28.1__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client.py +65 -0
- snowflake/snowpark_connect/column_name_handler.py +6 -0
- snowflake/snowpark_connect/config.py +33 -5
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +21 -19
- snowflake/snowpark_connect/expression/map_extension.py +277 -1
- snowflake/snowpark_connect/expression/map_sql_expression.py +107 -2
- snowflake/snowpark_connect/expression/map_unresolved_function.py +425 -269
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/relation/io_utils.py +21 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +9 -4
- snowflake/snowpark_connect/relation/map_extension.py +21 -4
- snowflake/snowpark_connect/relation/map_join.py +8 -0
- snowflake/snowpark_connect/relation/map_map_partitions.py +7 -8
- snowflake/snowpark_connect/relation/map_relation.py +1 -3
- snowflake/snowpark_connect/relation/map_row_ops.py +116 -15
- snowflake/snowpark_connect/relation/map_show_string.py +14 -6
- snowflake/snowpark_connect/relation/map_sql.py +39 -5
- snowflake/snowpark_connect/relation/map_stats.py +1 -1
- snowflake/snowpark_connect/relation/read/map_read.py +22 -3
- snowflake/snowpark_connect/relation/read/map_read_csv.py +119 -29
- snowflake/snowpark_connect/relation/read/map_read_json.py +57 -36
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -1
- snowflake/snowpark_connect/relation/read/map_read_text.py +6 -1
- snowflake/snowpark_connect/relation/read/metadata_utils.py +159 -0
- snowflake/snowpark_connect/relation/stage_locator.py +85 -53
- snowflake/snowpark_connect/relation/write/map_write.py +67 -4
- snowflake/snowpark_connect/server.py +29 -16
- snowflake/snowpark_connect/type_mapping.py +75 -3
- snowflake/snowpark_connect/utils/context.py +0 -14
- snowflake/snowpark_connect/utils/describe_query_cache.py +6 -3
- snowflake/snowpark_connect/utils/io_utils.py +36 -0
- snowflake/snowpark_connect/utils/session.py +4 -0
- snowflake/snowpark_connect/utils/telemetry.py +30 -5
- snowflake/snowpark_connect/utils/udf_cache.py +37 -7
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/METADATA +3 -2
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/RECORD +47 -45
- {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/top_level.txt +0 -0
|
@@ -16,7 +16,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e
|
|
|
16
16
|
from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_relations__pb2
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1esnowflake_expression_ext.proto\x12\rsnowflake.ext\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\"\
|
|
19
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1esnowflake_expression_ext.proto\x12\rsnowflake.ext\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\"\xde\x01\n\x0c\x45xpExtension\x12@\n\x0enamed_argument\x18\x01 \x01(\x0b\x32&.snowflake.ext.NamedArgumentExpressionH\x00\x12@\n\x13subquery_expression\x18\x02 \x01(\x0b\x32!.snowflake.ext.SubqueryExpressionH\x00\x12\x44\n\x10interval_literal\x18\x03 \x01(\x0b\x32(.snowflake.ext.IntervalLiteralExpressionH\x00\x42\x04\n\x02op\"P\n\x17NamedArgumentExpression\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.Expression\"\xf4\x04\n\x12SubqueryExpression\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x45\n\rsubquery_type\x18\x02 \x01(\x0e\x32..snowflake.ext.SubqueryExpression.SubqueryType\x12Q\n\x11table_arg_options\x18\x03 \x01(\x0b\x32\x31.snowflake.ext.SubqueryExpression.TableArgOptionsH\x00\x88\x01\x01\x12\x35\n\x12in_subquery_values\x18\x04 \x03(\x0b\x32\x19.spark.connect.Expression\x1a\xbb\x01\n\x0fTableArgOptions\x12\x31\n\x0epartition_spec\x18\x01 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x37\n\norder_spec\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrder\x12\"\n\x15with_single_partition\x18\x03 \x01(\x08H\x00\x88\x01\x01\x42\x18\n\x16_with_single_partition\"\x90\x01\n\x0cSubqueryType\x12\x19\n\x15SUBQUERY_TYPE_UNKNOWN\x10\x00\x12\x18\n\x14SUBQUERY_TYPE_SCALAR\x10\x01\x12\x18\n\x14SUBQUERY_TYPE_EXISTS\x10\x02\x12\x1b\n\x17SUBQUERY_TYPE_TABLE_ARG\x10\x03\x12\x14\n\x10SUBQUERY_TYPE_IN\x10\x04\x42\x14\n\x12_table_arg_options\"\x9f\x01\n\x19IntervalLiteralExpression\x12\x32\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.Literal\x12\x18\n\x0bstart_field\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12\x16\n\tend_field\x18\x03 \x01(\x05H\x01\x88\x01\x01\x42\x0e\n\x0c_start_fieldB\x0c\n\n_end_fieldb\x06proto3')
|
|
20
20
|
|
|
21
21
|
_globals = globals()
|
|
22
22
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -24,13 +24,15 @@ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'snowflake_expression_ext_pb
|
|
|
24
24
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
|
25
25
|
DESCRIPTOR._options = None
|
|
26
26
|
_globals['_EXPEXTENSION']._serialized_start=114
|
|
27
|
-
_globals['_EXPEXTENSION']._serialized_end=
|
|
28
|
-
_globals['_NAMEDARGUMENTEXPRESSION']._serialized_start=
|
|
29
|
-
_globals['_NAMEDARGUMENTEXPRESSION']._serialized_end=
|
|
30
|
-
_globals['_SUBQUERYEXPRESSION']._serialized_start=
|
|
31
|
-
_globals['_SUBQUERYEXPRESSION']._serialized_end=
|
|
32
|
-
_globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_start=
|
|
33
|
-
_globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_end=
|
|
34
|
-
_globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_start=
|
|
35
|
-
_globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_end=
|
|
27
|
+
_globals['_EXPEXTENSION']._serialized_end=336
|
|
28
|
+
_globals['_NAMEDARGUMENTEXPRESSION']._serialized_start=338
|
|
29
|
+
_globals['_NAMEDARGUMENTEXPRESSION']._serialized_end=418
|
|
30
|
+
_globals['_SUBQUERYEXPRESSION']._serialized_start=421
|
|
31
|
+
_globals['_SUBQUERYEXPRESSION']._serialized_end=1049
|
|
32
|
+
_globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_start=693
|
|
33
|
+
_globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_end=880
|
|
34
|
+
_globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_start=883
|
|
35
|
+
_globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_end=1027
|
|
36
|
+
_globals['_INTERVALLITERALEXPRESSION']._serialized_start=1052
|
|
37
|
+
_globals['_INTERVALLITERALEXPRESSION']._serialized_end=1211
|
|
36
38
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -9,12 +9,14 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map
|
|
|
9
9
|
DESCRIPTOR: _descriptor.FileDescriptor
|
|
10
10
|
|
|
11
11
|
class ExpExtension(_message.Message):
|
|
12
|
-
__slots__ = ("named_argument", "subquery_expression")
|
|
12
|
+
__slots__ = ("named_argument", "subquery_expression", "interval_literal")
|
|
13
13
|
NAMED_ARGUMENT_FIELD_NUMBER: _ClassVar[int]
|
|
14
14
|
SUBQUERY_EXPRESSION_FIELD_NUMBER: _ClassVar[int]
|
|
15
|
+
INTERVAL_LITERAL_FIELD_NUMBER: _ClassVar[int]
|
|
15
16
|
named_argument: NamedArgumentExpression
|
|
16
17
|
subquery_expression: SubqueryExpression
|
|
17
|
-
|
|
18
|
+
interval_literal: IntervalLiteralExpression
|
|
19
|
+
def __init__(self, named_argument: _Optional[_Union[NamedArgumentExpression, _Mapping]] = ..., subquery_expression: _Optional[_Union[SubqueryExpression, _Mapping]] = ..., interval_literal: _Optional[_Union[IntervalLiteralExpression, _Mapping]] = ...) -> None: ...
|
|
18
20
|
|
|
19
21
|
class NamedArgumentExpression(_message.Message):
|
|
20
22
|
__slots__ = ("key", "value")
|
|
@@ -56,3 +58,13 @@ class SubqueryExpression(_message.Message):
|
|
|
56
58
|
table_arg_options: SubqueryExpression.TableArgOptions
|
|
57
59
|
in_subquery_values: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
|
|
58
60
|
def __init__(self, input: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., subquery_type: _Optional[_Union[SubqueryExpression.SubqueryType, str]] = ..., table_arg_options: _Optional[_Union[SubqueryExpression.TableArgOptions, _Mapping]] = ..., in_subquery_values: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ...) -> None: ...
|
|
61
|
+
|
|
62
|
+
class IntervalLiteralExpression(_message.Message):
|
|
63
|
+
__slots__ = ("literal", "start_field", "end_field")
|
|
64
|
+
LITERAL_FIELD_NUMBER: _ClassVar[int]
|
|
65
|
+
START_FIELD_FIELD_NUMBER: _ClassVar[int]
|
|
66
|
+
END_FIELD_FIELD_NUMBER: _ClassVar[int]
|
|
67
|
+
literal: _expressions_pb2.Expression.Literal
|
|
68
|
+
start_field: int
|
|
69
|
+
end_field: int
|
|
70
|
+
def __init__(self, literal: _Optional[_Union[_expressions_pb2.Expression.Literal, _Mapping]] = ..., start_field: _Optional[int] = ..., end_field: _Optional[int] = ...) -> None: ...
|
|
@@ -7,8 +7,27 @@ from urllib.parse import urlparse
|
|
|
7
7
|
CLOUD_PREFIX_TO_CLOUD = {
|
|
8
8
|
"abfss": "azure",
|
|
9
9
|
"wasbs": "azure",
|
|
10
|
+
"gcs": "gcp",
|
|
11
|
+
"gs": "gcp",
|
|
10
12
|
}
|
|
11
13
|
|
|
14
|
+
SUPPORTED_COMPRESSION_PER_FORMAT = {
|
|
15
|
+
"csv": {"AUTO", "GZIP", "BZ2", "BROTLI", "ZSTD", "DEFLATE", "RAW_DEFLATE", "NONE"},
|
|
16
|
+
"json": {"AUTO", "GZIP", "BZ2", "BROTLI", "ZSTD", "DEFLATE", "RAW_DEFLATE", "NONE"},
|
|
17
|
+
"parquet": {"AUTO", "LZO", "SNAPPY", "NONE"},
|
|
18
|
+
"text": {"NONE"},
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def supported_compressions_for_format(format: str) -> set[str]:
|
|
23
|
+
return SUPPORTED_COMPRESSION_PER_FORMAT.get(format, set())
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def is_supported_compression(format: str, compression: str | None) -> bool:
|
|
27
|
+
if compression is None:
|
|
28
|
+
return True
|
|
29
|
+
return compression in supported_compressions_for_format(format)
|
|
30
|
+
|
|
12
31
|
|
|
13
32
|
def get_cloud_from_url(
|
|
14
33
|
url: str,
|
|
@@ -66,7 +85,8 @@ def is_cloud_path(path: str) -> bool:
|
|
|
66
85
|
or path.startswith("azure://")
|
|
67
86
|
or path.startswith("abfss://")
|
|
68
87
|
or path.startswith("wasbs://") # Azure
|
|
69
|
-
or path.startswith("gcs://")
|
|
88
|
+
or path.startswith("gcs://")
|
|
89
|
+
or path.startswith("gs://") # GCP
|
|
70
90
|
)
|
|
71
91
|
|
|
72
92
|
|
|
@@ -1124,10 +1124,15 @@ def map_group_map(
|
|
|
1124
1124
|
group_by_df = input_df.group_by(*snowpark_grouping_expressions)
|
|
1125
1125
|
inner_df = group_by_df._dataframe
|
|
1126
1126
|
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1127
|
+
renamed_columns = [f"snowflake_jtf_{column}" for column in input_df.columns]
|
|
1128
|
+
tfc = snowpark_fn.call_table_function(
|
|
1129
|
+
apply_udtf_temp_name, *renamed_columns
|
|
1130
|
+
).over(partition_by=snowpark_grouping_expressions)
|
|
1131
|
+
|
|
1132
|
+
result = (
|
|
1133
|
+
inner_df.to_df(renamed_columns)
|
|
1134
|
+
.join_table_function(tfc)
|
|
1135
|
+
.drop(*renamed_columns)
|
|
1131
1136
|
)
|
|
1132
1137
|
else:
|
|
1133
1138
|
(
|
|
@@ -345,7 +345,7 @@ def map_aggregate(
|
|
|
345
345
|
return new_names[0], snowpark_column
|
|
346
346
|
|
|
347
347
|
raw_groupings: list[tuple[str, TypedColumn]] = []
|
|
348
|
-
raw_aggregations: list[tuple[str, TypedColumn]] = []
|
|
348
|
+
raw_aggregations: list[tuple[str, TypedColumn, list[str]]] = []
|
|
349
349
|
|
|
350
350
|
if not is_group_by_all:
|
|
351
351
|
raw_groupings = [_map_column(exp) for exp in aggregate.grouping_expressions]
|
|
@@ -375,10 +375,21 @@ def map_aggregate(
|
|
|
375
375
|
# Note: We don't clear the map here to preserve any parent context aliases
|
|
376
376
|
from snowflake.snowpark_connect.utils.context import register_lca_alias
|
|
377
377
|
|
|
378
|
+
# If it's an unresolved attribute when its in aggregate.aggregate_expressions, we know it came from the parent map straight away
|
|
379
|
+
# in this case, we should see if the parent map has a qualifier for it and propagate that here, in case the order by references it in
|
|
380
|
+
# a qualified way later.
|
|
378
381
|
agg_count = get_sql_aggregate_function_count()
|
|
379
382
|
for exp in aggregate.aggregate_expressions:
|
|
380
383
|
col = _map_column(exp)
|
|
381
|
-
|
|
384
|
+
if exp.WhichOneof("expr_type") == "unresolved_attribute":
|
|
385
|
+
spark_name = col[0]
|
|
386
|
+
qualifiers = input_container.column_map.get_qualifier_for_spark_column(
|
|
387
|
+
spark_name
|
|
388
|
+
)
|
|
389
|
+
else:
|
|
390
|
+
qualifiers = []
|
|
391
|
+
|
|
392
|
+
raw_aggregations.append((col[0], col[1], qualifiers))
|
|
382
393
|
|
|
383
394
|
# If this is an alias, register it in the LCA map for subsequent expressions
|
|
384
395
|
if (
|
|
@@ -409,18 +420,20 @@ def map_aggregate(
|
|
|
409
420
|
spark_columns: list[str] = []
|
|
410
421
|
snowpark_columns: list[str] = []
|
|
411
422
|
snowpark_column_types: list[snowpark_types.DataType] = []
|
|
423
|
+
all_qualifiers: list[list[str]] = []
|
|
412
424
|
|
|
413
425
|
# Use grouping columns directly without aliases
|
|
414
426
|
groupings = [col.col for _, col in raw_groupings]
|
|
415
427
|
|
|
416
428
|
# Create aliases only for aggregation columns
|
|
417
429
|
aggregations = []
|
|
418
|
-
for i, (spark_name, snowpark_column) in enumerate(raw_aggregations):
|
|
430
|
+
for i, (spark_name, snowpark_column, qualifiers) in enumerate(raw_aggregations):
|
|
419
431
|
alias = make_column_names_snowpark_compatible([spark_name], plan_id, i)[0]
|
|
420
432
|
|
|
421
433
|
spark_columns.append(spark_name)
|
|
422
434
|
snowpark_columns.append(alias)
|
|
423
435
|
snowpark_column_types.append(snowpark_column.typ)
|
|
436
|
+
all_qualifiers.append(qualifiers)
|
|
424
437
|
|
|
425
438
|
aggregations.append(snowpark_column.col.alias(alias))
|
|
426
439
|
|
|
@@ -483,6 +496,7 @@ def map_aggregate(
|
|
|
483
496
|
spark_column_names=spark_columns,
|
|
484
497
|
snowpark_column_names=snowpark_columns,
|
|
485
498
|
snowpark_column_types=snowpark_column_types,
|
|
499
|
+
column_qualifiers=all_qualifiers,
|
|
486
500
|
).column_map
|
|
487
501
|
|
|
488
502
|
# Create hybrid column map that can resolve both input and aggregate contexts
|
|
@@ -494,7 +508,9 @@ def map_aggregate(
|
|
|
494
508
|
aggregate_expressions=list(aggregate.aggregate_expressions),
|
|
495
509
|
grouping_expressions=list(aggregate.grouping_expressions),
|
|
496
510
|
spark_columns=spark_columns,
|
|
497
|
-
raw_aggregations=
|
|
511
|
+
raw_aggregations=[
|
|
512
|
+
(spark_name, col) for spark_name, col, _ in raw_aggregations
|
|
513
|
+
],
|
|
498
514
|
)
|
|
499
515
|
|
|
500
516
|
# Map the HAVING condition using hybrid resolution
|
|
@@ -515,4 +531,5 @@ def map_aggregate(
|
|
|
515
531
|
snowpark_column_names=snowpark_columns,
|
|
516
532
|
snowpark_column_types=snowpark_column_types,
|
|
517
533
|
parent_column_name_map=input_df._column_map,
|
|
534
|
+
column_qualifiers=all_qualifiers,
|
|
518
535
|
)
|
|
@@ -21,6 +21,9 @@ from snowflake.snowpark_connect.relation.map_relation import (
|
|
|
21
21
|
NATURAL_JOIN_TYPE_BASE,
|
|
22
22
|
map_relation,
|
|
23
23
|
)
|
|
24
|
+
from snowflake.snowpark_connect.relation.read.metadata_utils import (
|
|
25
|
+
filter_metadata_columns,
|
|
26
|
+
)
|
|
24
27
|
from snowflake.snowpark_connect.utils.context import (
|
|
25
28
|
push_evaluating_join_condition,
|
|
26
29
|
push_sql_scope,
|
|
@@ -38,6 +41,11 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
|
|
|
38
41
|
left_container: DataFrameContainer = map_relation(rel.join.left)
|
|
39
42
|
right_container: DataFrameContainer = map_relation(rel.join.right)
|
|
40
43
|
|
|
44
|
+
# Remove any metadata columns(like metada$filename) present in the dataframes.
|
|
45
|
+
# We cannot support inputfilename for multisources as each dataframe has it's own source.
|
|
46
|
+
left_container = filter_metadata_columns(left_container)
|
|
47
|
+
right_container = filter_metadata_columns(right_container)
|
|
48
|
+
|
|
41
49
|
left_input: snowpark.DataFrame = left_container.dataframe
|
|
42
50
|
right_input: snowpark.DataFrame = right_container.dataframe
|
|
43
51
|
is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
|
|
@@ -12,7 +12,6 @@ from snowflake.snowpark_connect.constants import MAP_IN_ARROW_EVAL_TYPE
|
|
|
12
12
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
13
13
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
14
14
|
from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
|
|
15
|
-
from snowflake.snowpark_connect.utils.context import map_partitions_depth
|
|
16
15
|
from snowflake.snowpark_connect.utils.pandas_udtf_utils import (
|
|
17
16
|
create_pandas_udtf,
|
|
18
17
|
create_pandas_udtf_with_arrow,
|
|
@@ -53,18 +52,18 @@ def _call_udtf(
|
|
|
53
52
|
).cast("int"),
|
|
54
53
|
)
|
|
55
54
|
|
|
56
|
-
udtf_columns = input_df.columns + [
|
|
55
|
+
udtf_columns = [f"snowflake_jtf_{column}" for column in input_df.columns] + [
|
|
56
|
+
"_DUMMY_PARTITION_KEY"
|
|
57
|
+
]
|
|
57
58
|
|
|
58
59
|
tfc = snowpark_fn.call_table_function(udtf_name, *udtf_columns).over(
|
|
59
60
|
partition_by=[snowpark_fn.col("_DUMMY_PARTITION_KEY")]
|
|
60
61
|
)
|
|
61
62
|
|
|
62
|
-
#
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
else:
|
|
67
|
-
result_df_with_dummy = input_df_with_dummy.select(tfc)
|
|
63
|
+
# Overwrite the input_df columns to prevent name conflicts with UDTF output columns
|
|
64
|
+
result_df_with_dummy = input_df_with_dummy.to_df(udtf_columns).join_table_function(
|
|
65
|
+
tfc
|
|
66
|
+
)
|
|
68
67
|
|
|
69
68
|
output_cols = [field.name for field in return_type.fields]
|
|
70
69
|
|
|
@@ -16,7 +16,6 @@ from snowflake.snowpark_connect.utils.context import (
|
|
|
16
16
|
get_plan_id_map,
|
|
17
17
|
get_session_id,
|
|
18
18
|
not_resolving_fun_args,
|
|
19
|
-
push_map_partitions,
|
|
20
19
|
push_operation_scope,
|
|
21
20
|
set_is_aggregate_function,
|
|
22
21
|
set_plan_id_map,
|
|
@@ -185,8 +184,7 @@ def map_relation(
|
|
|
185
184
|
)
|
|
186
185
|
return cached_df
|
|
187
186
|
case "map_partitions":
|
|
188
|
-
|
|
189
|
-
result = map_map_partitions.map_map_partitions(rel)
|
|
187
|
+
result = map_map_partitions.map_map_partitions(rel)
|
|
190
188
|
case "offset":
|
|
191
189
|
result = map_row_ops.map_offset(rel)
|
|
192
190
|
case "project":
|
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
#
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
|
|
5
|
-
|
|
6
4
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
7
5
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
8
6
|
from pyspark.errors.exceptions.base import AnalysisException, IllegalArgumentException
|
|
9
7
|
|
|
10
8
|
import snowflake.snowpark_connect.relation.utils as utils
|
|
11
9
|
from snowflake import snowpark
|
|
12
|
-
from snowflake.snowpark.
|
|
10
|
+
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
|
|
11
|
+
from snowflake.snowpark.functions import col, expr as snowpark_expr, lit
|
|
13
12
|
from snowflake.snowpark.types import (
|
|
14
13
|
BooleanType,
|
|
15
14
|
ByteType,
|
|
@@ -20,8 +19,14 @@ from snowflake.snowpark.types import (
|
|
|
20
19
|
LongType,
|
|
21
20
|
NullType,
|
|
22
21
|
ShortType,
|
|
22
|
+
StructField,
|
|
23
|
+
StructType,
|
|
24
|
+
)
|
|
25
|
+
from snowflake.snowpark_connect.column_name_handler import (
|
|
26
|
+
ColumnNameMap,
|
|
27
|
+
schema_getter,
|
|
28
|
+
set_schema_getter,
|
|
23
29
|
)
|
|
24
|
-
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap, schema_getter
|
|
25
30
|
from snowflake.snowpark_connect.config import global_config
|
|
26
31
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
27
32
|
from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
|
|
@@ -318,23 +323,37 @@ def map_union(
|
|
|
318
323
|
right_column_map = right_result.column_map
|
|
319
324
|
columns_to_restore: dict[str, tuple[str, str]] = {}
|
|
320
325
|
|
|
321
|
-
|
|
326
|
+
original_right_schema = right_df.schema
|
|
327
|
+
right_renamed_fields = []
|
|
328
|
+
for field in original_right_schema.fields:
|
|
322
329
|
spark_name = (
|
|
323
|
-
right_column_map.get_spark_column_name_from_snowpark_column_name(
|
|
330
|
+
right_column_map.get_spark_column_name_from_snowpark_column_name(
|
|
331
|
+
field.name
|
|
332
|
+
)
|
|
324
333
|
)
|
|
325
|
-
right_df = right_df.withColumnRenamed(
|
|
326
|
-
columns_to_restore[spark_name.upper()] = (spark_name,
|
|
334
|
+
right_df = right_df.withColumnRenamed(field.name, spark_name)
|
|
335
|
+
columns_to_restore[spark_name.upper()] = (spark_name, field.name)
|
|
336
|
+
right_renamed_fields.append(
|
|
337
|
+
StructField(spark_name, field.datatype, field.nullable)
|
|
338
|
+
)
|
|
339
|
+
set_schema_getter(right_df, lambda: StructType(right_renamed_fields))
|
|
327
340
|
|
|
328
|
-
|
|
341
|
+
original_left_schema = left_df.schema
|
|
342
|
+
left_renamed_fields = []
|
|
343
|
+
for field in original_left_schema.fields:
|
|
329
344
|
spark_name = (
|
|
330
|
-
left_column_map.get_spark_column_name_from_snowpark_column_name(
|
|
345
|
+
left_column_map.get_spark_column_name_from_snowpark_column_name(
|
|
346
|
+
field.name
|
|
347
|
+
)
|
|
331
348
|
)
|
|
332
|
-
left_df = left_df.withColumnRenamed(
|
|
333
|
-
columns_to_restore[spark_name.upper()] = (spark_name,
|
|
349
|
+
left_df = left_df.withColumnRenamed(field.name, spark_name)
|
|
350
|
+
columns_to_restore[spark_name.upper()] = (spark_name, field.name)
|
|
351
|
+
left_renamed_fields.append(
|
|
352
|
+
StructField(spark_name, field.datatype, field.nullable)
|
|
353
|
+
)
|
|
354
|
+
set_schema_getter(left_df, lambda: StructType(left_renamed_fields))
|
|
334
355
|
|
|
335
|
-
result = left_df
|
|
336
|
-
right_df, allow_missing_columns=allow_missing_columns
|
|
337
|
-
)
|
|
356
|
+
result = _union_by_name_optimized(left_df, right_df, allow_missing_columns)
|
|
338
357
|
|
|
339
358
|
if allow_missing_columns:
|
|
340
359
|
spark_columns = []
|
|
@@ -809,3 +828,85 @@ def map_tail(
|
|
|
809
828
|
alias=input_container.alias,
|
|
810
829
|
cached_schema_getter=lambda: input_df.schema,
|
|
811
830
|
)
|
|
831
|
+
|
|
832
|
+
|
|
833
|
+
def _union_by_name_optimized(
|
|
834
|
+
left_df: snowpark.DataFrame,
|
|
835
|
+
right_df: snowpark.DataFrame,
|
|
836
|
+
allow_missing_columns: bool = False,
|
|
837
|
+
) -> snowpark.DataFrame:
|
|
838
|
+
"""
|
|
839
|
+
This implementation is an optimized version of Snowpark's Dataframe::_union_by_name_internal.
|
|
840
|
+
The only change is, that it avoids redundant schema queries that occur in the standard Snowpark,
|
|
841
|
+
by reusing already-fetched/calculated schemas.
|
|
842
|
+
"""
|
|
843
|
+
|
|
844
|
+
left_schema = left_df.schema
|
|
845
|
+
right_schema = right_df.schema
|
|
846
|
+
|
|
847
|
+
left_cols = {field.name for field in left_schema.fields}
|
|
848
|
+
right_cols = {field.name for field in right_schema.fields}
|
|
849
|
+
right_field_map = {field.name: field for field in right_schema.fields}
|
|
850
|
+
|
|
851
|
+
missing_left = right_cols - left_cols
|
|
852
|
+
missing_right = left_cols - right_cols
|
|
853
|
+
|
|
854
|
+
def add_nulls(
|
|
855
|
+
missing_cols: set[str], to_df: snowpark.DataFrame, from_df: snowpark.DataFrame
|
|
856
|
+
) -> snowpark.DataFrame:
|
|
857
|
+
dt_map = {field.name: field.datatype for field in from_df.schema.fields}
|
|
858
|
+
result = to_df.select(
|
|
859
|
+
"*",
|
|
860
|
+
*[lit(None).cast(dt_map[col]).alias(col) for col in missing_cols],
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
result_fields = []
|
|
864
|
+
for field in to_df.schema.fields:
|
|
865
|
+
result_fields.append(
|
|
866
|
+
StructField(field.name, field.datatype, field.nullable)
|
|
867
|
+
)
|
|
868
|
+
for col_name in missing_cols:
|
|
869
|
+
from_field = next(
|
|
870
|
+
field for field in from_df.schema.fields if field.name == col_name
|
|
871
|
+
)
|
|
872
|
+
result_fields.append(
|
|
873
|
+
StructField(col_name, from_field.datatype, from_field.nullable)
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
set_schema_getter(result, lambda: StructType(result_fields))
|
|
877
|
+
|
|
878
|
+
return result
|
|
879
|
+
|
|
880
|
+
if missing_left or missing_right:
|
|
881
|
+
if allow_missing_columns:
|
|
882
|
+
left = left_df
|
|
883
|
+
right = right_df
|
|
884
|
+
if missing_left:
|
|
885
|
+
left = add_nulls(missing_left, left, right)
|
|
886
|
+
if missing_right:
|
|
887
|
+
right = add_nulls(missing_right, right, left)
|
|
888
|
+
result = left._union_by_name_internal(right, is_all=True)
|
|
889
|
+
|
|
890
|
+
result_fields = []
|
|
891
|
+
for field in left_schema.fields:
|
|
892
|
+
result_fields.append(
|
|
893
|
+
StructField(field.name, field.datatype, field.nullable)
|
|
894
|
+
)
|
|
895
|
+
for col_name in missing_left:
|
|
896
|
+
right_field = right_field_map[col_name]
|
|
897
|
+
result_fields.append(
|
|
898
|
+
StructField(col_name, right_field.datatype, right_field.nullable)
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
set_schema_getter(result, lambda: StructType(result_fields))
|
|
902
|
+
return result
|
|
903
|
+
else:
|
|
904
|
+
raise SnowparkClientExceptionMessages.DF_CANNOT_RESOLVE_COLUMN_NAME_AMONG(
|
|
905
|
+
missing_left, missing_right
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
result = left_df.unionAllByName(
|
|
909
|
+
right_df, allow_missing_columns=allow_missing_columns
|
|
910
|
+
)
|
|
911
|
+
set_schema_getter(result, lambda: left_df.schema)
|
|
912
|
+
return result
|
|
@@ -15,6 +15,9 @@ from snowflake.snowpark_connect.column_name_handler import set_schema_getter
|
|
|
15
15
|
from snowflake.snowpark_connect.config import global_config
|
|
16
16
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
17
17
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
18
|
+
from snowflake.snowpark_connect.relation.read.metadata_utils import (
|
|
19
|
+
filter_metadata_columns,
|
|
20
|
+
)
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
def map_show_string(rel: relation_proto.Relation) -> pandas.DataFrame:
|
|
@@ -26,14 +29,17 @@ def map_show_string(rel: relation_proto.Relation) -> pandas.DataFrame:
|
|
|
26
29
|
Buffer object as a single cell.
|
|
27
30
|
"""
|
|
28
31
|
input_df_container: DataFrameContainer = map_relation(rel.show_string.input)
|
|
29
|
-
|
|
30
|
-
|
|
32
|
+
filtered_container = filter_metadata_columns(input_df_container)
|
|
33
|
+
display_df = filtered_container.dataframe
|
|
34
|
+
display_spark_columns = filtered_container.column_map.get_spark_columns()
|
|
35
|
+
|
|
36
|
+
input_df = _handle_datetype_columns(display_df)
|
|
31
37
|
|
|
32
38
|
show_string = input_df._show_string_spark(
|
|
33
39
|
num_rows=rel.show_string.num_rows,
|
|
34
40
|
truncate=rel.show_string.truncate,
|
|
35
41
|
vertical=rel.show_string.vertical,
|
|
36
|
-
_spark_column_names=
|
|
42
|
+
_spark_column_names=display_spark_columns,
|
|
37
43
|
_spark_session_tz=global_config.spark_sql_session_timeZone,
|
|
38
44
|
)
|
|
39
45
|
return pandas.DataFrame({"show_string": [show_string]})
|
|
@@ -44,14 +50,16 @@ def map_repr_html(rel: relation_proto.Relation) -> pandas.DataFrame:
|
|
|
44
50
|
Generate the html string representation of the input dataframe.
|
|
45
51
|
"""
|
|
46
52
|
input_df_container: DataFrameContainer = map_relation(rel.html_string.input)
|
|
47
|
-
|
|
53
|
+
|
|
54
|
+
filtered_container = filter_metadata_columns(input_df_container)
|
|
55
|
+
input_df = filtered_container.dataframe
|
|
48
56
|
|
|
49
57
|
input_panda = input_df.toPandas()
|
|
50
58
|
input_panda.rename(
|
|
51
59
|
columns={
|
|
52
60
|
analyzer_utils.unquote_if_quoted(
|
|
53
|
-
|
|
54
|
-
):
|
|
61
|
+
filtered_container.column_map.get_snowpark_columns()[i]
|
|
62
|
+
): filtered_container.column_map.get_spark_columns()[i]
|
|
55
63
|
for i in range(len(input_panda.columns))
|
|
56
64
|
},
|
|
57
65
|
inplace=True,
|
|
@@ -7,6 +7,7 @@ from collections.abc import MutableMapping, MutableSequence
|
|
|
7
7
|
from contextlib import contextmanager, suppress
|
|
8
8
|
from contextvars import ContextVar
|
|
9
9
|
from functools import reduce
|
|
10
|
+
from typing import Tuple
|
|
10
11
|
|
|
11
12
|
import jpype
|
|
12
13
|
import pandas
|
|
@@ -31,6 +32,10 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
|
31
32
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
|
32
33
|
from snowflake.snowpark._internal.utils import is_sql_select_statement, quote_name
|
|
33
34
|
from snowflake.snowpark.functions import when_matched, when_not_matched
|
|
35
|
+
from snowflake.snowpark_connect.client import (
|
|
36
|
+
SQL_PASS_THROUGH_MARKER,
|
|
37
|
+
calculate_checksum,
|
|
38
|
+
)
|
|
34
39
|
from snowflake.snowpark_connect.config import (
|
|
35
40
|
auto_uppercase_non_column_identifiers,
|
|
36
41
|
check_table_supports_operation,
|
|
@@ -397,7 +402,7 @@ def map_sql_to_pandas_df(
|
|
|
397
402
|
returns a tuple of None for SELECT queries to enable lazy evaluation
|
|
398
403
|
"""
|
|
399
404
|
|
|
400
|
-
snowpark_connect_sql_passthrough =
|
|
405
|
+
snowpark_connect_sql_passthrough, sql_string = is_valid_passthrough_sql(sql_string)
|
|
401
406
|
|
|
402
407
|
if not snowpark_connect_sql_passthrough:
|
|
403
408
|
logical_plan = sql_parser().parsePlan(sql_string)
|
|
@@ -1047,7 +1052,7 @@ def map_sql_to_pandas_df(
|
|
|
1047
1052
|
raise AnalysisException(
|
|
1048
1053
|
f"ALTER TABLE RENAME COLUMN is not supported for table '{full_table_identifier}'. "
|
|
1049
1054
|
f"This table was created as a v1 table with a data source that doesn't support column renaming. "
|
|
1050
|
-
f"To enable this operation, set 'enable_snowflake_extension_behavior' to 'true'."
|
|
1055
|
+
f"To enable this operation, set 'snowpark.connect.enable_snowflake_extension_behavior' to 'true'."
|
|
1051
1056
|
)
|
|
1052
1057
|
|
|
1053
1058
|
column_obj = logical_plan.column()
|
|
@@ -1282,6 +1287,14 @@ def map_sql_to_pandas_df(
|
|
|
1282
1287
|
return pandas.DataFrame({"": [""]}), ""
|
|
1283
1288
|
|
|
1284
1289
|
rows = session.sql(snowflake_sql).collect()
|
|
1290
|
+
case "RefreshTable":
|
|
1291
|
+
table_name_unquoted = ".".join(
|
|
1292
|
+
str(part)
|
|
1293
|
+
for part in as_java_list(logical_plan.child().multipartIdentifier())
|
|
1294
|
+
)
|
|
1295
|
+
SNOWFLAKE_CATALOG.refreshTable(table_name_unquoted)
|
|
1296
|
+
|
|
1297
|
+
return pandas.DataFrame({"": [""]}), ""
|
|
1285
1298
|
case _:
|
|
1286
1299
|
execute_logical_plan(logical_plan)
|
|
1287
1300
|
return None, None
|
|
@@ -1302,6 +1315,27 @@ def get_sql_passthrough() -> bool:
|
|
|
1302
1315
|
return get_boolean_session_config_param("snowpark.connect.sql.passthrough")
|
|
1303
1316
|
|
|
1304
1317
|
|
|
1318
|
+
def is_valid_passthrough_sql(sql_stmt: str) -> Tuple[bool, str]:
|
|
1319
|
+
"""
|
|
1320
|
+
Checks if :param sql_stmt: should be executed as SQL pass-through. SQL pass-through can be detected in 1 of 2 ways:
|
|
1321
|
+
1) Either Spark config parameter "snowpark.connect.sql.passthrough" is set (legacy mode, to be deprecated)
|
|
1322
|
+
2) If :param sql_stmt: is created through SnowflakeSession and has correct marker + checksum
|
|
1323
|
+
"""
|
|
1324
|
+
if get_sql_passthrough():
|
|
1325
|
+
# legacy style pass-through, sql_stmt should be a whole, valid SQL statement
|
|
1326
|
+
return True, sql_stmt
|
|
1327
|
+
|
|
1328
|
+
# check for new style, SnowflakeSession based SQL pass-through
|
|
1329
|
+
sql_parts = sql_stmt.split(" ", 2)
|
|
1330
|
+
if len(sql_parts) == 3:
|
|
1331
|
+
marker, checksum, sql = sql_parts
|
|
1332
|
+
if marker == SQL_PASS_THROUGH_MARKER and checksum == calculate_checksum(sql):
|
|
1333
|
+
return True, sql
|
|
1334
|
+
|
|
1335
|
+
# Not a SQL pass-through
|
|
1336
|
+
return False, sql_stmt
|
|
1337
|
+
|
|
1338
|
+
|
|
1305
1339
|
def change_default_to_public(name: str) -> str:
|
|
1306
1340
|
"""
|
|
1307
1341
|
Change the namespace to PUBLIC when given name is DEFAULT
|
|
@@ -1397,10 +1431,10 @@ def map_sql(
|
|
|
1397
1431
|
In passthough mode as True, SAS calls session.sql() and not calling Spark Parser.
|
|
1398
1432
|
This is to mitigate any issue not covered by spark logical plan to protobuf conversion.
|
|
1399
1433
|
"""
|
|
1400
|
-
snowpark_connect_sql_passthrough =
|
|
1434
|
+
snowpark_connect_sql_passthrough, sql_stmt = is_valid_passthrough_sql(rel.sql.query)
|
|
1401
1435
|
|
|
1402
1436
|
if not snowpark_connect_sql_passthrough:
|
|
1403
|
-
logical_plan = sql_parser().parseQuery(
|
|
1437
|
+
logical_plan = sql_parser().parseQuery(sql_stmt)
|
|
1404
1438
|
|
|
1405
1439
|
parsed_pos_args = parse_pos_args(logical_plan, rel.sql.pos_args)
|
|
1406
1440
|
set_sql_args(rel.sql.args, parsed_pos_args)
|
|
@@ -1408,7 +1442,7 @@ def map_sql(
|
|
|
1408
1442
|
return execute_logical_plan(logical_plan)
|
|
1409
1443
|
else:
|
|
1410
1444
|
session = snowpark.Session.get_active_session()
|
|
1411
|
-
sql_df = session.sql(
|
|
1445
|
+
sql_df = session.sql(sql_stmt)
|
|
1412
1446
|
columns = sql_df.columns
|
|
1413
1447
|
return DataFrameContainer.create_with_column_mapping(
|
|
1414
1448
|
dataframe=sql_df,
|
|
@@ -81,7 +81,7 @@ def map_approx_quantile(
|
|
|
81
81
|
input_df = input_container.dataframe
|
|
82
82
|
|
|
83
83
|
snowflake_compatible = get_boolean_session_config_param(
|
|
84
|
-
"enable_snowflake_extension_behavior"
|
|
84
|
+
"snowpark.connect.enable_snowflake_extension_behavior"
|
|
85
85
|
)
|
|
86
86
|
|
|
87
87
|
if not snowflake_compatible:
|