snowpark-connect 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +19 -14
- snowflake/snowpark_connect/error/error_utils.py +32 -0
- snowflake/snowpark_connect/error/exceptions.py +4 -0
- snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
- snowflake/snowpark_connect/expression/literal.py +9 -12
- snowflake/snowpark_connect/expression/map_cast.py +20 -4
- snowflake/snowpark_connect/expression/map_expression.py +8 -1
- snowflake/snowpark_connect/expression/map_udf.py +4 -4
- snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
- snowflake/snowpark_connect/expression/map_unresolved_function.py +269 -134
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
- snowflake/snowpark_connect/relation/map_aggregate.py +154 -18
- snowflake/snowpark_connect/relation/map_column_ops.py +59 -8
- snowflake/snowpark_connect/relation/map_extension.py +58 -24
- snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +3 -1
- snowflake/snowpark_connect/relation/map_row_ops.py +30 -1
- snowflake/snowpark_connect/relation/map_sql.py +40 -196
- snowflake/snowpark_connect/relation/map_udtf.py +4 -4
- snowflake/snowpark_connect/relation/read/map_read.py +2 -1
- snowflake/snowpark_connect/relation/read/map_read_json.py +12 -1
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
- snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
- snowflake/snowpark_connect/relation/read/utils.py +7 -6
- snowflake/snowpark_connect/relation/utils.py +170 -1
- snowflake/snowpark_connect/relation/write/map_write.py +306 -87
- snowflake/snowpark_connect/server.py +34 -5
- snowflake/snowpark_connect/type_mapping.py +6 -2
- snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
- snowflake/snowpark_connect/utils/env_utils.py +55 -0
- snowflake/snowpark_connect/utils/session.py +21 -4
- snowflake/snowpark_connect/utils/telemetry.py +213 -61
- snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/__init__.py +0 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
- snowflake/snowpark_decoder/dp_session.py +111 -0
- snowflake/snowpark_decoder/spark_decoder.py +76 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/METADATA +2 -2
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/RECORD +55 -44
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/top_level.txt +1 -0
- spark/__init__.py +0 -0
- spark/connect/__init__.py +0 -0
- spark/connect/envelope_pb2.py +31 -0
- spark/connect/envelope_pb2.pyi +46 -0
- snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/NOTICE-binary +0 -0
|
@@ -374,23 +374,31 @@ def map_aggregate(
|
|
|
374
374
|
snowpark_columns: list[str] = []
|
|
375
375
|
snowpark_column_types: list[snowpark_types.DataType] = []
|
|
376
376
|
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
377
|
+
# Use grouping columns directly without aliases
|
|
378
|
+
groupings = [col.col for _, col in raw_groupings]
|
|
379
|
+
|
|
380
|
+
# Create aliases only for aggregation columns
|
|
381
|
+
aggregations = []
|
|
382
|
+
for i, (spark_name, snowpark_column) in enumerate(raw_aggregations):
|
|
383
|
+
alias = make_column_names_snowpark_compatible([spark_name], plan_id, i)[0]
|
|
381
384
|
|
|
382
385
|
spark_columns.append(spark_name)
|
|
383
386
|
snowpark_columns.append(alias)
|
|
384
387
|
snowpark_column_types.append(snowpark_column.typ)
|
|
385
388
|
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
groupings = [_add_column(name, col) for name, col in raw_groupings]
|
|
389
|
-
aggregations = [_add_column(name, col) for name, col in raw_aggregations]
|
|
389
|
+
aggregations.append(snowpark_column.col.alias(alias))
|
|
390
390
|
|
|
391
391
|
match aggregate.group_type:
|
|
392
392
|
case snowflake_proto.Aggregate.GROUP_TYPE_GROUPBY:
|
|
393
|
-
|
|
393
|
+
if groupings:
|
|
394
|
+
# Normal GROUP BY with explicit grouping columns
|
|
395
|
+
result = input_df.group_by(groupings)
|
|
396
|
+
else:
|
|
397
|
+
# No explicit GROUP BY - this is an aggregate over the entire table
|
|
398
|
+
# Use a dummy constant that will be excluded from the final result
|
|
399
|
+
result = input_df.with_column(
|
|
400
|
+
"__dummy_group__", snowpark_fn.lit(1)
|
|
401
|
+
).group_by("__dummy_group__")
|
|
394
402
|
case snowflake_proto.Aggregate.GROUP_TYPE_ROLLUP:
|
|
395
403
|
result = input_df.rollup(groupings)
|
|
396
404
|
case snowflake_proto.Aggregate.GROUP_TYPE_CUBE:
|
|
@@ -410,28 +418,54 @@ def map_aggregate(
|
|
|
410
418
|
f"Unsupported GROUP BY type: {other}"
|
|
411
419
|
)
|
|
412
420
|
|
|
413
|
-
result = result.agg(*aggregations)
|
|
421
|
+
result = result.agg(*aggregations, exclude_grouping_columns=True)
|
|
422
|
+
|
|
423
|
+
# If we added a dummy grouping column, make sure it's excluded
|
|
424
|
+
if not groupings and "__dummy_group__" in result.columns:
|
|
425
|
+
result = result.drop("__dummy_group__")
|
|
426
|
+
|
|
427
|
+
# Apply HAVING condition if present
|
|
428
|
+
if aggregate.HasField("having_condition"):
|
|
429
|
+
from snowflake.snowpark_connect.expression.hybrid_column_map import (
|
|
430
|
+
create_hybrid_column_map_for_having,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# Create aggregated DataFrame column map
|
|
434
|
+
aggregated_column_map = DataFrameContainer.create_with_column_mapping(
|
|
435
|
+
dataframe=result,
|
|
436
|
+
spark_column_names=spark_columns,
|
|
437
|
+
snowpark_column_names=snowpark_columns,
|
|
438
|
+
snowpark_column_types=snowpark_column_types,
|
|
439
|
+
).column_map
|
|
440
|
+
|
|
441
|
+
# Create hybrid column map that can resolve both input and aggregate contexts
|
|
442
|
+
hybrid_map = create_hybrid_column_map_for_having(
|
|
443
|
+
input_df=input_df,
|
|
444
|
+
input_column_map=input_container.column_map,
|
|
445
|
+
aggregated_df=result,
|
|
446
|
+
aggregated_column_map=aggregated_column_map,
|
|
447
|
+
aggregate_expressions=list(aggregate.aggregate_expressions),
|
|
448
|
+
grouping_expressions=list(aggregate.grouping_expressions),
|
|
449
|
+
spark_columns=spark_columns,
|
|
450
|
+
raw_aggregations=raw_aggregations,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# Map the HAVING condition using hybrid resolution
|
|
454
|
+
_, having_column = hybrid_map.resolve_expression(aggregate.having_condition)
|
|
455
|
+
|
|
456
|
+
# Apply the HAVING filter
|
|
457
|
+
result = result.filter(having_column.col)
|
|
414
458
|
|
|
415
459
|
if aggregate.group_type == snowflake_proto.Aggregate.GROUP_TYPE_GROUPING_SETS:
|
|
416
460
|
# Immediately drop extra columns. Unlike other GROUP BY operations,
|
|
417
461
|
# grouping sets don't allow ORDER BY with columns that aren't in the aggregate list.
|
|
418
|
-
result = result.select(result.columns[-len(
|
|
462
|
+
result = result.select(result.columns[-len(aggregations) :])
|
|
419
463
|
|
|
420
|
-
#
|
|
421
|
-
|
|
464
|
+
# Return only aggregation columns in the column map
|
|
465
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
422
466
|
dataframe=result,
|
|
423
467
|
spark_column_names=spark_columns,
|
|
424
468
|
snowpark_column_names=snowpark_columns,
|
|
425
469
|
snowpark_column_types=snowpark_column_types,
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
# Drop the groupings.
|
|
429
|
-
grouping_count = len(groupings)
|
|
430
|
-
|
|
431
|
-
return DataFrameContainer.create_with_column_mapping(
|
|
432
|
-
result.drop(snowpark_columns[:grouping_count]),
|
|
433
|
-
spark_columns[grouping_count:],
|
|
434
|
-
snowpark_columns[grouping_count:],
|
|
435
|
-
snowpark_column_types[grouping_count:],
|
|
436
|
-
parent_column_name_map=result_container.column_map,
|
|
470
|
+
parent_column_name_map=input_df._column_map,
|
|
437
471
|
)
|
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
import json
|
|
6
6
|
import re
|
|
7
|
+
from json import JSONDecodeError
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import pyarrow as pa
|
|
@@ -19,6 +20,7 @@ from snowflake.snowpark_connect.column_name_handler import (
|
|
|
19
20
|
)
|
|
20
21
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
21
22
|
from snowflake.snowpark_connect.type_mapping import (
|
|
23
|
+
get_python_sql_utils_class,
|
|
22
24
|
map_json_schema_to_snowpark,
|
|
23
25
|
map_pyarrow_to_snowpark_types,
|
|
24
26
|
map_simple_types,
|
|
@@ -34,7 +36,12 @@ def parse_local_relation_schema_string(rel: relation_proto.Relation):
|
|
|
34
36
|
# schema_str can be a dict, or just a type string, e.g. INTEGER.
|
|
35
37
|
schema_str = rel.local_relation.schema
|
|
36
38
|
assert schema_str
|
|
37
|
-
|
|
39
|
+
try:
|
|
40
|
+
schema_dict = json.loads(schema_str)
|
|
41
|
+
except JSONDecodeError:
|
|
42
|
+
# Legacy scala clients sends unparsed struct type strings like "struct<id:bigint,a:int,b:double>"
|
|
43
|
+
spark_datatype = get_python_sql_utils_class().parseDataType(schema_str)
|
|
44
|
+
schema_dict = json.loads(spark_datatype.json())
|
|
38
45
|
|
|
39
46
|
column_metadata = {}
|
|
40
47
|
if isinstance(schema_dict, dict):
|
|
@@ -14,7 +14,6 @@ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
|
14
14
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
15
15
|
from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
|
|
16
16
|
from snowflake.snowpark_connect.utils.pandas_udtf_utils import create_pandas_udtf
|
|
17
|
-
from snowflake.snowpark_connect.utils.session import get_python_udxf_import_files
|
|
18
17
|
from snowflake.snowpark_connect.utils.udf_helper import (
|
|
19
18
|
SnowparkUDF,
|
|
20
19
|
process_udf_in_sproc,
|
|
@@ -28,6 +27,9 @@ from snowflake.snowpark_connect.utils.udtf_helper import (
|
|
|
28
27
|
create_pandas_udtf_in_sproc,
|
|
29
28
|
require_creating_udtf_in_sproc,
|
|
30
29
|
)
|
|
30
|
+
from snowflake.snowpark_connect.utils.udxf_import_utils import (
|
|
31
|
+
get_python_udxf_import_files,
|
|
32
|
+
)
|
|
31
33
|
|
|
32
34
|
|
|
33
35
|
def map_map_partitions(
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
#
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
|
+
from copy import copy
|
|
4
5
|
|
|
5
6
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
6
7
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
@@ -8,6 +9,7 @@ from pyspark.errors.exceptions.base import AnalysisException, IllegalArgumentExc
|
|
|
8
9
|
|
|
9
10
|
import snowflake.snowpark_connect.relation.utils as utils
|
|
10
11
|
from snowflake import snowpark
|
|
12
|
+
from snowflake.snowpark._internal.analyzer.binary_expression import And
|
|
11
13
|
from snowflake.snowpark.functions import col, expr as snowpark_expr
|
|
12
14
|
from snowflake.snowpark.types import (
|
|
13
15
|
BooleanType,
|
|
@@ -29,6 +31,7 @@ from snowflake.snowpark_connect.expression.map_expression import (
|
|
|
29
31
|
)
|
|
30
32
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
31
33
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
34
|
+
from snowflake.snowpark_connect.relation.utils import can_filter_be_flattened
|
|
32
35
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
33
36
|
SnowparkConnectNotImplementedError,
|
|
34
37
|
)
|
|
@@ -551,7 +554,33 @@ def map_filter(
|
|
|
551
554
|
_, condition = map_single_column_expression(
|
|
552
555
|
rel.filter.condition, input_container.column_map, typer
|
|
553
556
|
)
|
|
554
|
-
|
|
557
|
+
|
|
558
|
+
select_statement = getattr(input_df, "_select_statement", None)
|
|
559
|
+
condition_exp = condition.col._expression
|
|
560
|
+
if (
|
|
561
|
+
can_filter_be_flattened(select_statement, condition_exp)
|
|
562
|
+
and input_df._ops_after_agg is None
|
|
563
|
+
):
|
|
564
|
+
new = copy(select_statement)
|
|
565
|
+
new.from_ = select_statement.from_.to_subqueryable()
|
|
566
|
+
new.pre_actions = new.from_.pre_actions
|
|
567
|
+
new.post_actions = new.from_.post_actions
|
|
568
|
+
new.column_states = select_statement.column_states
|
|
569
|
+
new.where = (
|
|
570
|
+
And(select_statement.where, condition_exp)
|
|
571
|
+
if select_statement.where is not None
|
|
572
|
+
else condition_exp
|
|
573
|
+
)
|
|
574
|
+
new._merge_projection_complexity_with_subquery = False
|
|
575
|
+
new.df_ast_ids = (
|
|
576
|
+
select_statement.df_ast_ids.copy()
|
|
577
|
+
if select_statement.df_ast_ids is not None
|
|
578
|
+
else None
|
|
579
|
+
)
|
|
580
|
+
new.attributes = select_statement.attributes
|
|
581
|
+
result = input_df._with_plan(new)
|
|
582
|
+
else:
|
|
583
|
+
result = input_df.filter(condition.col)
|
|
555
584
|
|
|
556
585
|
return DataFrameContainer(
|
|
557
586
|
result,
|
|
@@ -14,7 +14,10 @@ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
|
14
14
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
15
15
|
import sqlglot
|
|
16
16
|
from google.protobuf.any_pb2 import Any
|
|
17
|
-
from pyspark.errors.exceptions.base import
|
|
17
|
+
from pyspark.errors.exceptions.base import (
|
|
18
|
+
AnalysisException,
|
|
19
|
+
UnsupportedOperationException,
|
|
20
|
+
)
|
|
18
21
|
from sqlglot.expressions import ColumnDef, DataType, FileFormatProperty, Identifier
|
|
19
22
|
|
|
20
23
|
import snowflake.snowpark.functions as snowpark_fn
|
|
@@ -27,7 +30,6 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
|
27
30
|
)
|
|
28
31
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
|
29
32
|
from snowflake.snowpark._internal.utils import is_sql_select_statement, quote_name
|
|
30
|
-
from snowflake.snowpark.functions import when_matched, when_not_matched
|
|
31
33
|
from snowflake.snowpark_connect.config import (
|
|
32
34
|
auto_uppercase_non_column_identifiers,
|
|
33
35
|
get_boolean_session_config_param,
|
|
@@ -56,16 +58,15 @@ from snowflake.snowpark_connect.utils.context import (
|
|
|
56
58
|
get_session_id,
|
|
57
59
|
push_evaluating_sql_scope,
|
|
58
60
|
push_sql_scope,
|
|
59
|
-
set_plan_id_map,
|
|
60
61
|
set_sql_args,
|
|
61
62
|
set_sql_plan_name,
|
|
62
63
|
)
|
|
63
64
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
64
65
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
65
66
|
SnowparkConnectNotImplementedError,
|
|
67
|
+
telemetry,
|
|
66
68
|
)
|
|
67
69
|
|
|
68
|
-
from .. import column_name_handler
|
|
69
70
|
from ..expression.map_sql_expression import (
|
|
70
71
|
_window_specs,
|
|
71
72
|
as_java_list,
|
|
@@ -76,6 +77,9 @@ from ..expression.map_sql_expression import (
|
|
|
76
77
|
from ..utils.identifiers import spark_to_sf_single_id
|
|
77
78
|
|
|
78
79
|
_ctes = ContextVar[dict[str, relation_proto.Relation]]("_ctes", default={})
|
|
80
|
+
_having_condition = ContextVar[expressions_proto.Expression | None](
|
|
81
|
+
"_having_condition", default=None
|
|
82
|
+
)
|
|
79
83
|
|
|
80
84
|
|
|
81
85
|
def _is_sql_select_statement_helper(sql_string: str) -> bool:
|
|
@@ -163,6 +167,7 @@ def parse_pos_args(
|
|
|
163
167
|
|
|
164
168
|
def execute_logical_plan(logical_plan) -> DataFrameContainer:
|
|
165
169
|
proto = map_logical_plan_relation(logical_plan)
|
|
170
|
+
telemetry.report_parsed_sql_plan(proto)
|
|
166
171
|
with push_evaluating_sql_scope():
|
|
167
172
|
return map_relation(proto)
|
|
168
173
|
|
|
@@ -712,197 +717,22 @@ def map_sql_to_pandas_df(
|
|
|
712
717
|
f"INSERT {overwrite_str} INTO {name} {cols_str} {final_query}",
|
|
713
718
|
).collect()
|
|
714
719
|
case "MergeIntoTable":
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
if action.condition().isDefined():
|
|
719
|
-
(_, condition_typed_col,) = map_single_column_expression(
|
|
720
|
-
map_logical_plan_expression(
|
|
721
|
-
matched_action.condition().get()
|
|
722
|
-
),
|
|
723
|
-
column_mapping,
|
|
724
|
-
typer,
|
|
725
|
-
)
|
|
726
|
-
condition = condition_typed_col.col
|
|
727
|
-
return condition
|
|
728
|
-
|
|
729
|
-
def _get_assignments_from_action(
|
|
730
|
-
action,
|
|
731
|
-
column_mapping_source,
|
|
732
|
-
column_mapping_target,
|
|
733
|
-
typer_source,
|
|
734
|
-
typer_target,
|
|
735
|
-
):
|
|
736
|
-
assignments = dict()
|
|
737
|
-
if (
|
|
738
|
-
action.getClass().getSimpleName() == "InsertAction"
|
|
739
|
-
or action.getClass().getSimpleName() == "UpdateAction"
|
|
740
|
-
):
|
|
741
|
-
incoming_assignments = as_java_list(action.assignments())
|
|
742
|
-
for assignment in incoming_assignments:
|
|
743
|
-
(key_name, _) = map_single_column_expression(
|
|
744
|
-
map_logical_plan_expression(assignment.key()),
|
|
745
|
-
column_mapping=column_mapping_target,
|
|
746
|
-
typer=typer_source,
|
|
747
|
-
)
|
|
748
|
-
|
|
749
|
-
(_, val_typ_col) = map_single_column_expression(
|
|
750
|
-
map_logical_plan_expression(assignment.value()),
|
|
751
|
-
column_mapping=column_mapping_source,
|
|
752
|
-
typer=typer_target,
|
|
753
|
-
)
|
|
754
|
-
|
|
755
|
-
assignments[key_name] = val_typ_col.col
|
|
756
|
-
elif (
|
|
757
|
-
action.getClass().getSimpleName() == "InsertStarAction"
|
|
758
|
-
or action.getClass().getSimpleName() == "UpdateStarAction"
|
|
759
|
-
):
|
|
760
|
-
if len(column_mapping_source.columns) != len(
|
|
761
|
-
column_mapping_target.columns
|
|
762
|
-
):
|
|
763
|
-
raise ValueError(
|
|
764
|
-
"source and target must have the same number of columns for InsertStarAction or UpdateStarAction"
|
|
765
|
-
)
|
|
766
|
-
for i, col in enumerate(column_mapping_target.columns):
|
|
767
|
-
if assignments.get(col.snowpark_name) is not None:
|
|
768
|
-
raise SnowparkConnectNotImplementedError(
|
|
769
|
-
"UpdateStarAction or InsertStarAction is not supported with duplicate columns."
|
|
770
|
-
)
|
|
771
|
-
assignments[col.snowpark_name] = snowpark_fn.col(
|
|
772
|
-
column_mapping_source.columns[i].snowpark_name
|
|
773
|
-
)
|
|
774
|
-
return assignments
|
|
775
|
-
|
|
776
|
-
source_df_container = map_relation(
|
|
777
|
-
map_logical_plan_relation(logical_plan.sourceTable())
|
|
778
|
-
)
|
|
779
|
-
source_df = source_df_container.dataframe
|
|
780
|
-
plan_id = gen_sql_plan_id()
|
|
781
|
-
target_df_container = map_relation(
|
|
782
|
-
map_logical_plan_relation(logical_plan.targetTable(), plan_id)
|
|
783
|
-
)
|
|
784
|
-
target_df = target_df_container.dataframe
|
|
785
|
-
|
|
786
|
-
for col in target_df_container.column_map.columns:
|
|
787
|
-
target_df = target_df.with_column_renamed(
|
|
788
|
-
col.snowpark_name,
|
|
789
|
-
spark_to_sf_single_id(col.spark_name, is_column=True),
|
|
790
|
-
)
|
|
791
|
-
target_df_container = DataFrameContainer.create_with_column_mapping(
|
|
792
|
-
dataframe=target_df,
|
|
793
|
-
spark_column_names=target_df.columns,
|
|
794
|
-
snowpark_column_names=target_df.columns,
|
|
795
|
-
)
|
|
796
|
-
|
|
797
|
-
set_plan_id_map(plan_id, target_df_container)
|
|
798
|
-
|
|
799
|
-
joined_df_before_condition: snowpark.DataFrame = source_df.join(
|
|
800
|
-
target_df
|
|
801
|
-
)
|
|
802
|
-
|
|
803
|
-
column_mapping_for_conditions = column_name_handler.JoinColumnNameMap(
|
|
804
|
-
source_df_container.column_map,
|
|
805
|
-
target_df_container.column_map,
|
|
806
|
-
)
|
|
807
|
-
typer_for_expressions = ExpressionTyper(joined_df_before_condition)
|
|
808
|
-
|
|
809
|
-
(_, merge_condition_typed_col,) = map_single_column_expression(
|
|
810
|
-
map_logical_plan_expression(logical_plan.mergeCondition()),
|
|
811
|
-
column_mapping=column_mapping_for_conditions,
|
|
812
|
-
typer=typer_for_expressions,
|
|
813
|
-
)
|
|
814
|
-
|
|
815
|
-
clauses = []
|
|
816
|
-
|
|
817
|
-
for matched_action in as_java_list(logical_plan.matchedActions()):
|
|
818
|
-
condition = _get_condition_from_action(
|
|
819
|
-
matched_action,
|
|
820
|
-
column_mapping_for_conditions,
|
|
821
|
-
typer_for_expressions,
|
|
822
|
-
)
|
|
823
|
-
if matched_action.getClass().getSimpleName() == "DeleteAction":
|
|
824
|
-
clauses.append(when_matched(condition).delete())
|
|
825
|
-
elif (
|
|
826
|
-
matched_action.getClass().getSimpleName() == "UpdateAction"
|
|
827
|
-
or matched_action.getClass().getSimpleName()
|
|
828
|
-
== "UpdateStarAction"
|
|
829
|
-
):
|
|
830
|
-
assignments = _get_assignments_from_action(
|
|
831
|
-
matched_action,
|
|
832
|
-
source_df_container.column_map,
|
|
833
|
-
target_df_container.column_map,
|
|
834
|
-
ExpressionTyper(source_df),
|
|
835
|
-
ExpressionTyper(target_df),
|
|
836
|
-
)
|
|
837
|
-
clauses.append(when_matched(condition).update(assignments))
|
|
838
|
-
|
|
839
|
-
for not_matched_action in as_java_list(
|
|
840
|
-
logical_plan.notMatchedActions()
|
|
841
|
-
):
|
|
842
|
-
condition = _get_condition_from_action(
|
|
843
|
-
not_matched_action,
|
|
844
|
-
column_mapping_for_conditions,
|
|
845
|
-
typer_for_expressions,
|
|
846
|
-
)
|
|
847
|
-
if (
|
|
848
|
-
not_matched_action.getClass().getSimpleName() == "InsertAction"
|
|
849
|
-
or not_matched_action.getClass().getSimpleName()
|
|
850
|
-
== "InsertStarAction"
|
|
851
|
-
):
|
|
852
|
-
assignments = _get_assignments_from_action(
|
|
853
|
-
not_matched_action,
|
|
854
|
-
source_df_container.column_map,
|
|
855
|
-
target_df_container.column_map,
|
|
856
|
-
ExpressionTyper(source_df),
|
|
857
|
-
ExpressionTyper(target_df),
|
|
858
|
-
)
|
|
859
|
-
clauses.append(when_not_matched(condition).insert(assignments))
|
|
860
|
-
|
|
861
|
-
if not as_java_list(logical_plan.notMatchedBySourceActions()).isEmpty():
|
|
862
|
-
raise SnowparkConnectNotImplementedError(
|
|
863
|
-
"Snowflake does not support 'not matched by source' actions in MERGE statements."
|
|
864
|
-
)
|
|
865
|
-
|
|
866
|
-
if (
|
|
867
|
-
logical_plan.targetTable().getClass().getSimpleName()
|
|
868
|
-
== "UnresolvedRelation"
|
|
869
|
-
):
|
|
870
|
-
target_table_name = _spark_to_snowflake(
|
|
871
|
-
logical_plan.targetTable().multipartIdentifier()
|
|
872
|
-
)
|
|
873
|
-
else:
|
|
874
|
-
target_table_name = _spark_to_snowflake(
|
|
875
|
-
logical_plan.targetTable().child().multipartIdentifier()
|
|
876
|
-
)
|
|
877
|
-
session.table(target_table_name).merge(
|
|
878
|
-
source_df, merge_condition_typed_col.col, clauses
|
|
720
|
+
raise UnsupportedOperationException(
|
|
721
|
+
"[UNSUPPORTED_SQL_EXTENSION] The MERGE INTO command failed.\n"
|
|
722
|
+
+ "Reason: This command is a platform-specific SQL extension and is not part of the standard Apache Spark specification that this interface uses."
|
|
879
723
|
)
|
|
880
724
|
case "DeleteFromTable":
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
df = df_container.dataframe
|
|
885
|
-
for col in df_container.column_map.columns:
|
|
886
|
-
df = df.with_column_renamed(
|
|
887
|
-
col.snowpark_name,
|
|
888
|
-
spark_to_sf_single_id(col.spark_name, is_column=True),
|
|
889
|
-
)
|
|
890
|
-
df_container = column_name_handler.create_with_column_mapping(
|
|
891
|
-
dataframe=df,
|
|
892
|
-
spark_column_names=df.columns,
|
|
893
|
-
snowpark_column_names=df.columns,
|
|
725
|
+
raise UnsupportedOperationException(
|
|
726
|
+
"[UNSUPPORTED_SQL_EXTENSION] The DELETE FROM command failed.\n"
|
|
727
|
+
+ "Reason: This command is a platform-specific SQL extension and is not part of the standard Apache Spark specification that this interface uses."
|
|
894
728
|
)
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
map_logical_plan_expression(logical_plan.condition()),
|
|
902
|
-
df_container.column_map,
|
|
903
|
-
ExpressionTyper(df),
|
|
729
|
+
case "UpdateTable":
|
|
730
|
+
# Databricks/Delta-specific extension not supported by SAS.
|
|
731
|
+
# Provide an actionable, clear error.
|
|
732
|
+
raise UnsupportedOperationException(
|
|
733
|
+
"[UNSUPPORTED_SQL_EXTENSION] The UPDATE TABLE command failed.\n"
|
|
734
|
+
+ "Reason: This command is a platform-specific SQL extension and is not part of the standard Apache Spark specification that this interface uses."
|
|
904
735
|
)
|
|
905
|
-
session.table(name).delete(condition_typed_col.col)
|
|
906
736
|
case "RenameColumn":
|
|
907
737
|
table_name = get_relation_identifier_name(logical_plan.table(), True)
|
|
908
738
|
column_obj = logical_plan.column()
|
|
@@ -1319,6 +1149,7 @@ def map_logical_plan_relation(
|
|
|
1319
1149
|
grouping_expressions=grouping_expressions,
|
|
1320
1150
|
aggregate_expressions=aggregate_expressions,
|
|
1321
1151
|
grouping_sets=grouping_sets,
|
|
1152
|
+
having_condition=_having_condition.get(),
|
|
1322
1153
|
)
|
|
1323
1154
|
)
|
|
1324
1155
|
)
|
|
@@ -1562,12 +1393,25 @@ def map_logical_plan_relation(
|
|
|
1562
1393
|
)
|
|
1563
1394
|
)
|
|
1564
1395
|
case "UnresolvedHaving":
|
|
1565
|
-
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
|
|
1396
|
+
# Store the having condition in context and process the child aggregate
|
|
1397
|
+
child_relation = rel.child()
|
|
1398
|
+
if str(child_relation.getClass().getSimpleName()) != "Aggregate":
|
|
1399
|
+
raise SnowparkConnectNotImplementedError(
|
|
1400
|
+
"UnresolvedHaving can only be applied to Aggregate relations"
|
|
1569
1401
|
)
|
|
1570
|
-
|
|
1402
|
+
|
|
1403
|
+
# Store having condition in a context variable for the Aggregate case to pick up
|
|
1404
|
+
having_condition = map_logical_plan_expression(rel.havingCondition())
|
|
1405
|
+
|
|
1406
|
+
# Store in thread-local context (similar to how _ctes works)
|
|
1407
|
+
token = _having_condition.set(having_condition)
|
|
1408
|
+
|
|
1409
|
+
try:
|
|
1410
|
+
# Recursively call map_logical_plan_relation on the child Aggregate
|
|
1411
|
+
# The Aggregate case will pick up the having condition from context
|
|
1412
|
+
proto = map_logical_plan_relation(child_relation, plan_id)
|
|
1413
|
+
finally:
|
|
1414
|
+
_having_condition.reset(token)
|
|
1571
1415
|
case "UnresolvedHint":
|
|
1572
1416
|
proto = relation_proto.Relation(
|
|
1573
1417
|
hint=relation_proto.Hint(
|
|
@@ -31,10 +31,7 @@ from snowflake.snowpark_connect.type_mapping import (
|
|
|
31
31
|
proto_to_snowpark_type,
|
|
32
32
|
)
|
|
33
33
|
from snowflake.snowpark_connect.utils.context import push_udtf_context
|
|
34
|
-
from snowflake.snowpark_connect.utils.session import
|
|
35
|
-
get_or_create_snowpark_session,
|
|
36
|
-
get_python_udxf_import_files,
|
|
37
|
-
)
|
|
34
|
+
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
38
35
|
from snowflake.snowpark_connect.utils.udtf_helper import (
|
|
39
36
|
SnowparkUDTF,
|
|
40
37
|
create_udtf_in_sproc,
|
|
@@ -42,6 +39,9 @@ from snowflake.snowpark_connect.utils.udtf_helper import (
|
|
|
42
39
|
udtf_check,
|
|
43
40
|
)
|
|
44
41
|
from snowflake.snowpark_connect.utils.udtf_utils import create_udtf
|
|
42
|
+
from snowflake.snowpark_connect.utils.udxf_import_utils import (
|
|
43
|
+
get_python_udxf_import_files,
|
|
44
|
+
)
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
def build_expected_types_from_parsed(
|
|
@@ -95,7 +95,8 @@ def map_read(
|
|
|
95
95
|
if len(rel.read.data_source.paths) > 0:
|
|
96
96
|
# Normalize paths to ensure consistent behavior
|
|
97
97
|
clean_source_paths = [
|
|
98
|
-
str(Path(path))
|
|
98
|
+
path.rstrip("/") if is_cloud_path(path) else str(Path(path))
|
|
99
|
+
for path in rel.read.data_source.paths
|
|
99
100
|
]
|
|
100
101
|
|
|
101
102
|
result = _read_file(
|
|
@@ -6,6 +6,7 @@ import copy
|
|
|
6
6
|
import json
|
|
7
7
|
import typing
|
|
8
8
|
from contextlib import suppress
|
|
9
|
+
from datetime import datetime
|
|
9
10
|
|
|
10
11
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
11
12
|
|
|
@@ -21,6 +22,7 @@ from snowflake.snowpark.types import (
|
|
|
21
22
|
StringType,
|
|
22
23
|
StructField,
|
|
23
24
|
StructType,
|
|
25
|
+
TimestampType,
|
|
24
26
|
)
|
|
25
27
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
26
28
|
from snowflake.snowpark_connect.relation.read.map_read import JsonReaderConfig
|
|
@@ -204,6 +206,8 @@ def merge_row_schema(
|
|
|
204
206
|
next_level_content = row[col_name]
|
|
205
207
|
if next_level_content is not None:
|
|
206
208
|
with suppress(json.JSONDecodeError):
|
|
209
|
+
if isinstance(next_level_content, datetime):
|
|
210
|
+
next_level_content = str(next_level_content)
|
|
207
211
|
next_level_content = json.loads(next_level_content)
|
|
208
212
|
if isinstance(next_level_content, dict):
|
|
209
213
|
sf.datatype = merge_json_schema(
|
|
@@ -235,6 +239,9 @@ def merge_row_schema(
|
|
|
235
239
|
sf.datatype.element_type,
|
|
236
240
|
dropFieldIfAllNull,
|
|
237
241
|
)
|
|
242
|
+
elif isinstance(sf.datatype, TimestampType):
|
|
243
|
+
sf.datatype = StringType()
|
|
244
|
+
columns_with_valid_contents.add(col_name)
|
|
238
245
|
elif row[col_name] is not None:
|
|
239
246
|
columns_with_valid_contents.add(col_name)
|
|
240
247
|
|
|
@@ -265,7 +272,7 @@ def construct_dataframe_by_schema(
|
|
|
265
272
|
rows: typing.Iterator[Row],
|
|
266
273
|
session: snowpark.Session,
|
|
267
274
|
snowpark_options: dict,
|
|
268
|
-
batch_size: int =
|
|
275
|
+
batch_size: int = 1000,
|
|
269
276
|
) -> snowpark.DataFrame:
|
|
270
277
|
result = None
|
|
271
278
|
|
|
@@ -280,6 +287,8 @@ def construct_dataframe_by_schema(
|
|
|
280
287
|
session,
|
|
281
288
|
)
|
|
282
289
|
|
|
290
|
+
current_data = []
|
|
291
|
+
|
|
283
292
|
if len(current_data) > 0:
|
|
284
293
|
result = union_data_into_df(
|
|
285
294
|
result,
|
|
@@ -288,6 +297,8 @@ def construct_dataframe_by_schema(
|
|
|
288
297
|
session,
|
|
289
298
|
)
|
|
290
299
|
|
|
300
|
+
current_data = []
|
|
301
|
+
|
|
291
302
|
if result is None:
|
|
292
303
|
raise ValueError("Dataframe cannot be empty")
|
|
293
304
|
return result
|
|
@@ -54,10 +54,17 @@ def map_read_parquet(
|
|
|
54
54
|
if len(paths) == 1:
|
|
55
55
|
df = _read_parquet_with_partitions(session, reader, paths[0])
|
|
56
56
|
else:
|
|
57
|
+
is_merge_schema = options.config.get("mergeschema")
|
|
57
58
|
df = _read_parquet_with_partitions(session, reader, paths[0])
|
|
59
|
+
schema_cols = df.columns
|
|
58
60
|
for p in paths[1:]:
|
|
59
61
|
reader._user_schema = None
|
|
60
|
-
df = df.
|
|
62
|
+
df = df.union_all_by_name(
|
|
63
|
+
_read_parquet_with_partitions(session, reader, p),
|
|
64
|
+
allow_missing_columns=True,
|
|
65
|
+
)
|
|
66
|
+
if not is_merge_schema:
|
|
67
|
+
df = df.select(*schema_cols)
|
|
61
68
|
|
|
62
69
|
renamed_df, snowpark_column_names = rename_columns_as_snowflake_standard(
|
|
63
70
|
df, rel.common.plan_id
|
|
@@ -346,6 +346,7 @@ class JsonReaderConfig(ReaderWriterConfig):
|
|
|
346
346
|
"compression",
|
|
347
347
|
# "ignoreNullFields",
|
|
348
348
|
"rowsToInferSchema",
|
|
349
|
+
# "inferTimestamp",
|
|
349
350
|
},
|
|
350
351
|
boolean_config_list=[
|
|
351
352
|
"multiLine",
|
|
@@ -397,3 +398,12 @@ class ParquetReaderConfig(ReaderWriterConfig):
|
|
|
397
398
|
),
|
|
398
399
|
options,
|
|
399
400
|
)
|
|
401
|
+
|
|
402
|
+
def convert_to_snowpark_args(self) -> dict[str, Any]:
|
|
403
|
+
snowpark_args = super().convert_to_snowpark_args()
|
|
404
|
+
|
|
405
|
+
# Should be determined by spark.sql.parquet.binaryAsString, but currently Snowpark Connect only supports
|
|
406
|
+
# the default value (false). TODO: Add support for spark.sql.parquet.binaryAsString equal to "true".
|
|
407
|
+
snowpark_args["BINARY_AS_TEXT"] = False
|
|
408
|
+
|
|
409
|
+
return snowpark_args
|