snowpark-connect 0.27.0__py3-none-any.whl → 0.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/column_name_handler.py +3 -93
- snowflake/snowpark_connect/config.py +99 -1
- snowflake/snowpark_connect/dataframe_container.py +0 -6
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +31 -68
- snowflake/snowpark_connect/expression/map_expression.py +22 -7
- snowflake/snowpark_connect/expression/map_sql_expression.py +22 -18
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +4 -26
- snowflake/snowpark_connect/expression/map_unresolved_function.py +12 -3
- snowflake/snowpark_connect/expression/map_unresolved_star.py +2 -3
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +207 -20
- snowflake/snowpark_connect/relation/map_extension.py +14 -10
- snowflake/snowpark_connect/relation/map_join.py +62 -258
- snowflake/snowpark_connect/relation/map_relation.py +5 -1
- snowflake/snowpark_connect/relation/map_sql.py +464 -68
- snowflake/snowpark_connect/relation/read/map_read_table.py +58 -0
- snowflake/snowpark_connect/relation/write/map_write.py +228 -120
- snowflake/snowpark_connect/resources_initializer.py +20 -5
- snowflake/snowpark_connect/server.py +16 -17
- snowflake/snowpark_connect/utils/concurrent.py +4 -0
- snowflake/snowpark_connect/utils/context.py +21 -0
- snowflake/snowpark_connect/utils/describe_query_cache.py +57 -51
- snowflake/snowpark_connect/utils/identifiers.py +128 -2
- snowflake/snowpark_connect/utils/io_utils.py +21 -1
- snowflake/snowpark_connect/utils/scala_udf_utils.py +34 -43
- snowflake/snowpark_connect/utils/session.py +16 -26
- snowflake/snowpark_connect/utils/telemetry.py +53 -0
- snowflake/snowpark_connect/utils/temporary_view_cache.py +61 -0
- snowflake/snowpark_connect/utils/udf_utils.py +9 -8
- snowflake/snowpark_connect/utils/udtf_utils.py +3 -2
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/METADATA +2 -2
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/RECORD +41 -41
- snowflake/snowpark_connect/hidden_column.py +0 -39
- {snowpark_connect-0.27.0.data → snowpark_connect-0.28.1.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-0.28.1.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-0.28.1.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
import re
|
|
6
6
|
from collections.abc import MutableMapping, MutableSequence
|
|
7
|
-
from contextlib import contextmanager
|
|
7
|
+
from contextlib import contextmanager, suppress
|
|
8
8
|
from contextvars import ContextVar
|
|
9
9
|
from functools import reduce
|
|
10
10
|
|
|
@@ -30,10 +30,13 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
|
30
30
|
)
|
|
31
31
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
|
32
32
|
from snowflake.snowpark._internal.utils import is_sql_select_statement, quote_name
|
|
33
|
+
from snowflake.snowpark.functions import when_matched, when_not_matched
|
|
33
34
|
from snowflake.snowpark_connect.config import (
|
|
34
35
|
auto_uppercase_non_column_identifiers,
|
|
36
|
+
check_table_supports_operation,
|
|
35
37
|
get_boolean_session_config_param,
|
|
36
38
|
global_config,
|
|
39
|
+
record_table_metadata,
|
|
37
40
|
set_config_param,
|
|
38
41
|
unset_config_param,
|
|
39
42
|
)
|
|
@@ -58,7 +61,9 @@ from snowflake.snowpark_connect.utils.context import (
|
|
|
58
61
|
get_session_id,
|
|
59
62
|
get_sql_plan,
|
|
60
63
|
push_evaluating_sql_scope,
|
|
64
|
+
push_processed_view,
|
|
61
65
|
push_sql_scope,
|
|
66
|
+
set_plan_id_map,
|
|
62
67
|
set_sql_args,
|
|
63
68
|
set_sql_plan_name,
|
|
64
69
|
)
|
|
@@ -68,6 +73,7 @@ from snowflake.snowpark_connect.utils.telemetry import (
|
|
|
68
73
|
telemetry,
|
|
69
74
|
)
|
|
70
75
|
|
|
76
|
+
from .. import column_name_handler
|
|
71
77
|
from ..expression.map_sql_expression import (
|
|
72
78
|
_window_specs,
|
|
73
79
|
as_java_list,
|
|
@@ -75,7 +81,16 @@ from ..expression.map_sql_expression import (
|
|
|
75
81
|
map_logical_plan_expression,
|
|
76
82
|
sql_parser,
|
|
77
83
|
)
|
|
78
|
-
from ..utils.identifiers import
|
|
84
|
+
from ..utils.identifiers import (
|
|
85
|
+
spark_to_sf_single_id,
|
|
86
|
+
spark_to_sf_single_id_with_unquoting,
|
|
87
|
+
)
|
|
88
|
+
from ..utils.temporary_view_cache import (
|
|
89
|
+
get_temp_view,
|
|
90
|
+
register_temp_view,
|
|
91
|
+
unregister_temp_view,
|
|
92
|
+
)
|
|
93
|
+
from .catalogs import SNOWFLAKE_CATALOG
|
|
79
94
|
|
|
80
95
|
_ctes = ContextVar[dict[str, relation_proto.Relation]]("_ctes", default={})
|
|
81
96
|
_cte_definitions = ContextVar[dict[str, any]]("_cte_definitions", default={})
|
|
@@ -203,6 +218,9 @@ def _rename_columns(
|
|
|
203
218
|
def _create_table_as_select(logical_plan, mode: str) -> None:
|
|
204
219
|
# TODO: for as select create tables we'd map multi layer identifier here
|
|
205
220
|
name = get_relation_identifier_name(logical_plan.name())
|
|
221
|
+
full_table_identifier = get_relation_identifier_name(
|
|
222
|
+
logical_plan.name(), is_multi_part=True
|
|
223
|
+
)
|
|
206
224
|
comment = logical_plan.tableSpec().comment()
|
|
207
225
|
|
|
208
226
|
container = execute_logical_plan(logical_plan.query())
|
|
@@ -223,6 +241,15 @@ def _create_table_as_select(logical_plan, mode: str) -> None:
|
|
|
223
241
|
mode=mode,
|
|
224
242
|
)
|
|
225
243
|
|
|
244
|
+
# Record table metadata for CREATE TABLE AS SELECT
|
|
245
|
+
# These are typically considered v2 tables and support RENAME COLUMN
|
|
246
|
+
record_table_metadata(
|
|
247
|
+
table_identifier=full_table_identifier,
|
|
248
|
+
table_type="v2",
|
|
249
|
+
data_source="default",
|
|
250
|
+
supports_column_rename=True,
|
|
251
|
+
)
|
|
252
|
+
|
|
226
253
|
|
|
227
254
|
def _spark_field_to_sql(field: jpype.JObject, is_column: bool) -> str:
|
|
228
255
|
# Column names will be uppercased according to "snowpark.connect.sql.identifiers.auto-uppercase",
|
|
@@ -300,6 +327,65 @@ def _remove_column_data_type(node):
|
|
|
300
327
|
return node
|
|
301
328
|
|
|
302
329
|
|
|
330
|
+
def _get_condition_from_action(action, column_mapping, typer):
|
|
331
|
+
condition = None
|
|
332
|
+
if action.condition().isDefined():
|
|
333
|
+
(_, condition_typed_col,) = map_single_column_expression(
|
|
334
|
+
map_logical_plan_expression(action.condition().get()),
|
|
335
|
+
column_mapping,
|
|
336
|
+
typer,
|
|
337
|
+
)
|
|
338
|
+
condition = condition_typed_col.col
|
|
339
|
+
return condition
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def _get_assignments_from_action(
|
|
343
|
+
action,
|
|
344
|
+
column_mapping_source,
|
|
345
|
+
column_mapping_target,
|
|
346
|
+
typer_source,
|
|
347
|
+
typer_target,
|
|
348
|
+
):
|
|
349
|
+
assignments = dict()
|
|
350
|
+
if (
|
|
351
|
+
action.getClass().getSimpleName() == "InsertAction"
|
|
352
|
+
or action.getClass().getSimpleName() == "UpdateAction"
|
|
353
|
+
):
|
|
354
|
+
incoming_assignments = as_java_list(action.assignments())
|
|
355
|
+
for assignment in incoming_assignments:
|
|
356
|
+
(_, key_typ_col) = map_single_column_expression(
|
|
357
|
+
map_logical_plan_expression(assignment.key()),
|
|
358
|
+
column_mapping=column_mapping_target,
|
|
359
|
+
typer=typer_target,
|
|
360
|
+
)
|
|
361
|
+
key_name = typer_target.df.select(key_typ_col.col).columns[0]
|
|
362
|
+
|
|
363
|
+
(_, val_typ_col) = map_single_column_expression(
|
|
364
|
+
map_logical_plan_expression(assignment.value()),
|
|
365
|
+
column_mapping=column_mapping_source,
|
|
366
|
+
typer=typer_source,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
assignments[key_name] = val_typ_col.col
|
|
370
|
+
elif (
|
|
371
|
+
action.getClass().getSimpleName() == "InsertStarAction"
|
|
372
|
+
or action.getClass().getSimpleName() == "UpdateStarAction"
|
|
373
|
+
):
|
|
374
|
+
if len(column_mapping_source.columns) != len(column_mapping_target.columns):
|
|
375
|
+
raise ValueError(
|
|
376
|
+
"source and target must have the same number of columns for InsertStarAction or UpdateStarAction"
|
|
377
|
+
)
|
|
378
|
+
for i, col in enumerate(column_mapping_target.columns):
|
|
379
|
+
if assignments.get(col.snowpark_name) is not None:
|
|
380
|
+
raise SnowparkConnectNotImplementedError(
|
|
381
|
+
"UpdateStarAction or InsertStarAction is not supported with duplicate columns."
|
|
382
|
+
)
|
|
383
|
+
assignments[col.snowpark_name] = snowpark_fn.col(
|
|
384
|
+
column_mapping_source.columns[i].snowpark_name
|
|
385
|
+
)
|
|
386
|
+
return assignments
|
|
387
|
+
|
|
388
|
+
|
|
303
389
|
def map_sql_to_pandas_df(
|
|
304
390
|
sql_string: str,
|
|
305
391
|
named_args: MutableMapping[str, expressions_proto.Expression.Literal],
|
|
@@ -327,6 +413,7 @@ def map_sql_to_pandas_df(
|
|
|
327
413
|
) == "UnresolvedHint":
|
|
328
414
|
logical_plan = logical_plan.child()
|
|
329
415
|
|
|
416
|
+
# TODO: Add support for temporary views for SQL cases such as ShowViews, ShowColumns ect. (Currently the cases are not compatible with Spark, returning raw Snowflake rows)
|
|
330
417
|
match class_name:
|
|
331
418
|
case "AddColumns":
|
|
332
419
|
# Handle ALTER TABLE ... ADD COLUMNS (col_name data_type) -> ADD COLUMN col_name data_type
|
|
@@ -421,6 +508,9 @@ def map_sql_to_pandas_df(
|
|
|
421
508
|
)
|
|
422
509
|
|
|
423
510
|
name = get_relation_identifier_name(logical_plan.name())
|
|
511
|
+
full_table_identifier = get_relation_identifier_name(
|
|
512
|
+
logical_plan.name(), is_multi_part=True
|
|
513
|
+
)
|
|
424
514
|
columns = ", ".join(
|
|
425
515
|
_spark_field_to_sql(f, True)
|
|
426
516
|
for f in logical_plan.tableSchema().fields()
|
|
@@ -431,10 +521,48 @@ def map_sql_to_pandas_df(
|
|
|
431
521
|
if comment_opt.isDefined()
|
|
432
522
|
else ""
|
|
433
523
|
)
|
|
524
|
+
|
|
525
|
+
# Extract data source for metadata tracking
|
|
526
|
+
data_source = "default"
|
|
527
|
+
|
|
528
|
+
with suppress(Exception):
|
|
529
|
+
# Get data source from tableSpec.provider() (for USING clause)
|
|
530
|
+
if hasattr(logical_plan, "tableSpec"):
|
|
531
|
+
table_spec = logical_plan.tableSpec()
|
|
532
|
+
if hasattr(table_spec, "provider"):
|
|
533
|
+
provider_opt = table_spec.provider()
|
|
534
|
+
if provider_opt.isDefined():
|
|
535
|
+
data_source = str(provider_opt.get()).lower()
|
|
536
|
+
else:
|
|
537
|
+
# Fall back to checking properties for FORMAT
|
|
538
|
+
table_properties = table_spec.properties()
|
|
539
|
+
if not table_properties.isEmpty():
|
|
540
|
+
for prop in table_properties.get():
|
|
541
|
+
if str(prop.key()) == "FORMAT":
|
|
542
|
+
data_source = str(prop.value()).lower()
|
|
543
|
+
break
|
|
544
|
+
|
|
434
545
|
# NOTE: We are intentionally ignoring any FORMAT=... parameters here.
|
|
435
546
|
session.sql(
|
|
436
547
|
f"CREATE {replace_table} TABLE {if_not_exists}{name} ({columns}) {comment}"
|
|
437
548
|
).collect()
|
|
549
|
+
|
|
550
|
+
# Record table metadata for Spark compatibility
|
|
551
|
+
# Tables created with explicit schema are considered v1 tables
|
|
552
|
+
# v1 tables with certain data sources don't support RENAME COLUMN in OSS Spark
|
|
553
|
+
supports_rename = data_source not in (
|
|
554
|
+
"parquet",
|
|
555
|
+
"csv",
|
|
556
|
+
"json",
|
|
557
|
+
"orc",
|
|
558
|
+
"avro",
|
|
559
|
+
)
|
|
560
|
+
record_table_metadata(
|
|
561
|
+
table_identifier=full_table_identifier,
|
|
562
|
+
table_type="v1",
|
|
563
|
+
data_source=data_source,
|
|
564
|
+
supports_column_rename=supports_rename,
|
|
565
|
+
)
|
|
438
566
|
case "CreateTableAsSelect":
|
|
439
567
|
mode = "ignore" if logical_plan.ignoreIfExists() else "errorifexists"
|
|
440
568
|
_create_table_as_select(logical_plan, mode=mode)
|
|
@@ -460,6 +588,23 @@ def map_sql_to_pandas_df(
|
|
|
460
588
|
)
|
|
461
589
|
snowflake_sql = parsed_sql.sql(dialect="snowflake")
|
|
462
590
|
session.sql(f"{snowflake_sql}{empty_select}").collect()
|
|
591
|
+
spark_view_name = next(
|
|
592
|
+
sqlglot.parse_one(sql_string, dialect="spark").find_all(
|
|
593
|
+
sqlglot.exp.Table
|
|
594
|
+
)
|
|
595
|
+
).name
|
|
596
|
+
snowflake_view_name = spark_to_sf_single_id_with_unquoting(
|
|
597
|
+
spark_view_name
|
|
598
|
+
)
|
|
599
|
+
temp_view = get_temp_view(snowflake_view_name)
|
|
600
|
+
if temp_view is not None and not logical_plan.replace():
|
|
601
|
+
raise AnalysisException(
|
|
602
|
+
f"[TEMP_TABLE_OR_VIEW_ALREADY_EXISTS] Cannot create the temporary view `{spark_view_name}` because it already exists."
|
|
603
|
+
)
|
|
604
|
+
else:
|
|
605
|
+
unregister_temp_view(
|
|
606
|
+
spark_to_sf_single_id_with_unquoting(spark_view_name)
|
|
607
|
+
)
|
|
463
608
|
case "CreateView":
|
|
464
609
|
current_schema = session.connection.schema
|
|
465
610
|
if (
|
|
@@ -496,54 +641,63 @@ def map_sql_to_pandas_df(
|
|
|
496
641
|
else None,
|
|
497
642
|
)
|
|
498
643
|
case "CreateViewCommand":
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
644
|
+
with push_processed_view(logical_plan.name().identifier()):
|
|
645
|
+
df_container = execute_logical_plan(logical_plan.plan())
|
|
646
|
+
df = df_container.dataframe
|
|
647
|
+
user_specified_spark_column_names = [
|
|
648
|
+
str(col._1())
|
|
649
|
+
for col in as_java_list(logical_plan.userSpecifiedColumns())
|
|
650
|
+
]
|
|
651
|
+
df_container = DataFrameContainer.create_with_column_mapping(
|
|
652
|
+
dataframe=df,
|
|
653
|
+
spark_column_names=user_specified_spark_column_names
|
|
654
|
+
if user_specified_spark_column_names
|
|
655
|
+
else df_container.column_map.get_spark_columns(),
|
|
656
|
+
snowpark_column_names=df_container.column_map.get_snowpark_columns(),
|
|
657
|
+
parent_column_name_map=df_container.column_map,
|
|
507
658
|
)
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
logical_plan.viewType(),
|
|
514
|
-
jpype.JClass(
|
|
515
|
-
"org.apache.spark.sql.catalyst.analysis.GlobalTempView$"
|
|
516
|
-
),
|
|
517
|
-
):
|
|
518
|
-
name = f"{global_config.spark_sql_globalTempDatabase}.{name}"
|
|
519
|
-
comment = logical_plan.comment()
|
|
520
|
-
maybe_comment = (
|
|
521
|
-
_escape_sql_comment(str(comment.get()))
|
|
522
|
-
if comment.isDefined()
|
|
523
|
-
else None
|
|
524
|
-
)
|
|
525
|
-
|
|
526
|
-
df = _rename_columns(
|
|
527
|
-
df, logical_plan.userSpecifiedColumns(), df_container.column_map
|
|
528
|
-
)
|
|
529
|
-
|
|
530
|
-
if logical_plan.replace():
|
|
531
|
-
df.create_or_replace_temp_view(
|
|
532
|
-
name,
|
|
533
|
-
comment=maybe_comment,
|
|
659
|
+
is_global = isinstance(
|
|
660
|
+
logical_plan.viewType(),
|
|
661
|
+
jpype.JClass(
|
|
662
|
+
"org.apache.spark.sql.catalyst.analysis.GlobalTempView$"
|
|
663
|
+
),
|
|
534
664
|
)
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
665
|
+
if is_global:
|
|
666
|
+
view_name = [
|
|
667
|
+
global_config.spark_sql_globalTempDatabase,
|
|
668
|
+
logical_plan.name().quotedString(),
|
|
669
|
+
]
|
|
670
|
+
else:
|
|
671
|
+
view_name = [logical_plan.name().quotedString()]
|
|
672
|
+
view_name = [
|
|
673
|
+
spark_to_sf_single_id_with_unquoting(part) for part in view_name
|
|
674
|
+
]
|
|
675
|
+
joined_view_name = ".".join(view_name)
|
|
676
|
+
|
|
677
|
+
register_temp_view(
|
|
678
|
+
joined_view_name,
|
|
679
|
+
df_container,
|
|
680
|
+
logical_plan.replace(),
|
|
681
|
+
)
|
|
682
|
+
tmp_views = _get_current_temp_objects()
|
|
683
|
+
tmp_views.add(
|
|
684
|
+
(
|
|
685
|
+
CURRENT_CATALOG_NAME,
|
|
686
|
+
session.connection.schema,
|
|
687
|
+
str(logical_plan.name().identifier()),
|
|
688
|
+
)
|
|
539
689
|
)
|
|
540
690
|
case "DescribeColumn":
|
|
541
|
-
name =
|
|
691
|
+
name = get_relation_identifier_name_without_uppercasing(
|
|
692
|
+
logical_plan.column()
|
|
693
|
+
)
|
|
694
|
+
if get_temp_view(name):
|
|
695
|
+
return SNOWFLAKE_CATALOG.listColumns(unquote_if_quoted(name)), ""
|
|
542
696
|
# todo double check if this is correct
|
|
697
|
+
name = get_relation_identifier_name(logical_plan.column())
|
|
543
698
|
rows = session.sql(f"DESCRIBE TABLE {name}").collect()
|
|
544
699
|
case "DescribeNamespace":
|
|
545
700
|
name = get_relation_identifier_name(logical_plan.namespace(), True)
|
|
546
|
-
name = change_default_to_public(name)
|
|
547
701
|
rows = session.sql(f"DESCRIBE SCHEMA {name}").collect()
|
|
548
702
|
if not rows:
|
|
549
703
|
rows = None
|
|
@@ -615,9 +769,13 @@ def map_sql_to_pandas_df(
|
|
|
615
769
|
if_exists = "IF EXISTS " if logical_plan.ifExists() else ""
|
|
616
770
|
session.sql(f"DROP TABLE {if_exists}{name}").collect()
|
|
617
771
|
case "DropView":
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
772
|
+
temporary_view_name = get_relation_identifier_name_without_uppercasing(
|
|
773
|
+
logical_plan.child()
|
|
774
|
+
)
|
|
775
|
+
if not unregister_temp_view(temporary_view_name):
|
|
776
|
+
name = get_relation_identifier_name(logical_plan.child())
|
|
777
|
+
if_exists = "IF EXISTS " if logical_plan.ifExists() else ""
|
|
778
|
+
session.sql(f"DROP VIEW {if_exists}{name}").collect()
|
|
621
779
|
case "ExplainCommand":
|
|
622
780
|
inner_plan = logical_plan.logicalPlan()
|
|
623
781
|
logical_plan_name = inner_plan.nodeName()
|
|
@@ -729,15 +887,147 @@ def map_sql_to_pandas_df(
|
|
|
729
887
|
f"INSERT {overwrite_str} INTO {name} {cols_str} {final_query}",
|
|
730
888
|
).collect()
|
|
731
889
|
case "MergeIntoTable":
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
890
|
+
source_df_container = map_relation(
|
|
891
|
+
map_logical_plan_relation(logical_plan.sourceTable())
|
|
892
|
+
)
|
|
893
|
+
source_df = source_df_container.dataframe
|
|
894
|
+
plan_id = gen_sql_plan_id()
|
|
895
|
+
target_df_container = map_relation(
|
|
896
|
+
map_logical_plan_relation(logical_plan.targetTable(), plan_id)
|
|
897
|
+
)
|
|
898
|
+
target_df = target_df_container.dataframe
|
|
899
|
+
|
|
900
|
+
if (
|
|
901
|
+
logical_plan.targetTable().getClass().getSimpleName()
|
|
902
|
+
== "UnresolvedRelation"
|
|
903
|
+
):
|
|
904
|
+
target_table_name = _spark_to_snowflake(
|
|
905
|
+
logical_plan.targetTable().multipartIdentifier()
|
|
906
|
+
)
|
|
907
|
+
else:
|
|
908
|
+
target_table_name = _spark_to_snowflake(
|
|
909
|
+
logical_plan.targetTable().child().multipartIdentifier()
|
|
910
|
+
)
|
|
911
|
+
|
|
912
|
+
target_table = session.table(target_table_name)
|
|
913
|
+
target_table_columns = target_table.columns
|
|
914
|
+
target_df_spark_names = []
|
|
915
|
+
for target_table_col, target_df_col in zip(
|
|
916
|
+
target_table_columns, target_df_container.column_map.columns
|
|
917
|
+
):
|
|
918
|
+
target_df = target_df.with_column_renamed(
|
|
919
|
+
target_df_col.snowpark_name,
|
|
920
|
+
target_table_col,
|
|
921
|
+
)
|
|
922
|
+
target_df_spark_names.append(target_df_col.spark_name)
|
|
923
|
+
target_df_container = DataFrameContainer.create_with_column_mapping(
|
|
924
|
+
dataframe=target_df,
|
|
925
|
+
spark_column_names=target_df_spark_names,
|
|
926
|
+
snowpark_column_names=target_table_columns,
|
|
735
927
|
)
|
|
928
|
+
|
|
929
|
+
set_plan_id_map(plan_id, target_df_container)
|
|
930
|
+
|
|
931
|
+
joined_df_before_condition: snowpark.DataFrame = source_df.join(
|
|
932
|
+
target_df
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
column_mapping_for_conditions = column_name_handler.JoinColumnNameMap(
|
|
936
|
+
source_df_container.column_map,
|
|
937
|
+
target_df_container.column_map,
|
|
938
|
+
)
|
|
939
|
+
typer_for_expressions = ExpressionTyper(joined_df_before_condition)
|
|
940
|
+
|
|
941
|
+
(_, merge_condition_typed_col,) = map_single_column_expression(
|
|
942
|
+
map_logical_plan_expression(logical_plan.mergeCondition()),
|
|
943
|
+
column_mapping=column_mapping_for_conditions,
|
|
944
|
+
typer=typer_for_expressions,
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
clauses = []
|
|
948
|
+
|
|
949
|
+
for matched_action in as_java_list(logical_plan.matchedActions()):
|
|
950
|
+
condition = _get_condition_from_action(
|
|
951
|
+
matched_action,
|
|
952
|
+
column_mapping_for_conditions,
|
|
953
|
+
typer_for_expressions,
|
|
954
|
+
)
|
|
955
|
+
if matched_action.getClass().getSimpleName() == "DeleteAction":
|
|
956
|
+
clauses.append(when_matched(condition).delete())
|
|
957
|
+
elif (
|
|
958
|
+
matched_action.getClass().getSimpleName() == "UpdateAction"
|
|
959
|
+
or matched_action.getClass().getSimpleName()
|
|
960
|
+
== "UpdateStarAction"
|
|
961
|
+
):
|
|
962
|
+
assignments = _get_assignments_from_action(
|
|
963
|
+
matched_action,
|
|
964
|
+
source_df_container.column_map,
|
|
965
|
+
target_df_container.column_map,
|
|
966
|
+
ExpressionTyper(source_df),
|
|
967
|
+
ExpressionTyper(target_df),
|
|
968
|
+
)
|
|
969
|
+
clauses.append(when_matched(condition).update(assignments))
|
|
970
|
+
|
|
971
|
+
for not_matched_action in as_java_list(
|
|
972
|
+
logical_plan.notMatchedActions()
|
|
973
|
+
):
|
|
974
|
+
condition = _get_condition_from_action(
|
|
975
|
+
not_matched_action,
|
|
976
|
+
column_mapping_for_conditions,
|
|
977
|
+
typer_for_expressions,
|
|
978
|
+
)
|
|
979
|
+
if (
|
|
980
|
+
not_matched_action.getClass().getSimpleName() == "InsertAction"
|
|
981
|
+
or not_matched_action.getClass().getSimpleName()
|
|
982
|
+
== "InsertStarAction"
|
|
983
|
+
):
|
|
984
|
+
assignments = _get_assignments_from_action(
|
|
985
|
+
not_matched_action,
|
|
986
|
+
source_df_container.column_map,
|
|
987
|
+
target_df_container.column_map,
|
|
988
|
+
ExpressionTyper(source_df),
|
|
989
|
+
ExpressionTyper(target_df),
|
|
990
|
+
)
|
|
991
|
+
clauses.append(when_not_matched(condition).insert(assignments))
|
|
992
|
+
|
|
993
|
+
if not as_java_list(logical_plan.notMatchedBySourceActions()).isEmpty():
|
|
994
|
+
raise SnowparkConnectNotImplementedError(
|
|
995
|
+
"Snowflake does not support 'not matched by source' actions in MERGE statements."
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
target_table.merge(source_df, merge_condition_typed_col.col, clauses)
|
|
736
999
|
case "DeleteFromTable":
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
+ "Reason: This command is a platform-specific SQL extension and is not part of the standard Apache Spark specification that this interface uses."
|
|
1000
|
+
df_container = map_relation(
|
|
1001
|
+
map_logical_plan_relation(logical_plan.table())
|
|
740
1002
|
)
|
|
1003
|
+
name = get_relation_identifier_name(logical_plan.table(), True)
|
|
1004
|
+
table = session.table(name)
|
|
1005
|
+
table_columns = table.columns
|
|
1006
|
+
df = df_container.dataframe
|
|
1007
|
+
spark_names = []
|
|
1008
|
+
for table_col, df_col in zip(
|
|
1009
|
+
table_columns, df_container.column_map.columns
|
|
1010
|
+
):
|
|
1011
|
+
df = df.with_column_renamed(
|
|
1012
|
+
df_col.snowpark_name,
|
|
1013
|
+
table_col,
|
|
1014
|
+
)
|
|
1015
|
+
spark_names.append(df_col.spark_name)
|
|
1016
|
+
df_container = DataFrameContainer.create_with_column_mapping(
|
|
1017
|
+
dataframe=df,
|
|
1018
|
+
spark_column_names=spark_names,
|
|
1019
|
+
snowpark_column_names=table_columns,
|
|
1020
|
+
)
|
|
1021
|
+
df = df_container.dataframe
|
|
1022
|
+
(
|
|
1023
|
+
condition_column_name,
|
|
1024
|
+
condition_typed_col,
|
|
1025
|
+
) = map_single_column_expression(
|
|
1026
|
+
map_logical_plan_expression(logical_plan.condition()),
|
|
1027
|
+
df_container.column_map,
|
|
1028
|
+
ExpressionTyper(df),
|
|
1029
|
+
)
|
|
1030
|
+
table.delete(condition_typed_col.col)
|
|
741
1031
|
case "UpdateTable":
|
|
742
1032
|
# Databricks/Delta-specific extension not supported by SAS.
|
|
743
1033
|
# Provide an actionable, clear error.
|
|
@@ -746,7 +1036,20 @@ def map_sql_to_pandas_df(
|
|
|
746
1036
|
+ "Reason: This command is a platform-specific SQL extension and is not part of the standard Apache Spark specification that this interface uses."
|
|
747
1037
|
)
|
|
748
1038
|
case "RenameColumn":
|
|
749
|
-
|
|
1039
|
+
full_table_identifier = get_relation_identifier_name(
|
|
1040
|
+
logical_plan.table(), True
|
|
1041
|
+
)
|
|
1042
|
+
|
|
1043
|
+
# Check Spark compatibility for RENAME COLUMN operation
|
|
1044
|
+
if not check_table_supports_operation(
|
|
1045
|
+
full_table_identifier, "rename_column"
|
|
1046
|
+
):
|
|
1047
|
+
raise AnalysisException(
|
|
1048
|
+
f"ALTER TABLE RENAME COLUMN is not supported for table '{full_table_identifier}'. "
|
|
1049
|
+
f"This table was created as a v1 table with a data source that doesn't support column renaming. "
|
|
1050
|
+
f"To enable this operation, set 'enable_snowflake_extension_behavior' to 'true'."
|
|
1051
|
+
)
|
|
1052
|
+
|
|
750
1053
|
column_obj = logical_plan.column()
|
|
751
1054
|
old_column_name = ".".join(
|
|
752
1055
|
spark_to_sf_single_id(str(part), is_column=True)
|
|
@@ -756,7 +1059,7 @@ def map_sql_to_pandas_df(
|
|
|
756
1059
|
case_insensitive_name = next(
|
|
757
1060
|
(
|
|
758
1061
|
f.name
|
|
759
|
-
for f in session.table(
|
|
1062
|
+
for f in session.table(full_table_identifier).schema.fields
|
|
760
1063
|
if f.name.lower() == old_column_name.lower()
|
|
761
1064
|
),
|
|
762
1065
|
None,
|
|
@@ -768,7 +1071,7 @@ def map_sql_to_pandas_df(
|
|
|
768
1071
|
)
|
|
769
1072
|
|
|
770
1073
|
# Pass through to Snowflake
|
|
771
|
-
snowflake_sql = f"ALTER TABLE {
|
|
1074
|
+
snowflake_sql = f"ALTER TABLE {full_table_identifier} RENAME COLUMN {old_column_name} TO {new_column_name}"
|
|
772
1075
|
session.sql(snowflake_sql).collect()
|
|
773
1076
|
case "RenameTable":
|
|
774
1077
|
name = get_relation_identifier_name(logical_plan.child(), True)
|
|
@@ -795,7 +1098,6 @@ def map_sql_to_pandas_df(
|
|
|
795
1098
|
case "SetCatalogAndNamespace":
|
|
796
1099
|
# TODO: add catalog setting here
|
|
797
1100
|
name = get_relation_identifier_name(logical_plan.child(), True)
|
|
798
|
-
name = change_default_to_public(name)
|
|
799
1101
|
session.sql(f"USE SCHEMA {name}").collect()
|
|
800
1102
|
case "SetCommand":
|
|
801
1103
|
kv_result_tuple = logical_plan.kv().get()
|
|
@@ -804,7 +1106,6 @@ def map_sql_to_pandas_df(
|
|
|
804
1106
|
set_config_param(get_session_id(), key, val, session)
|
|
805
1107
|
case "SetNamespaceCommand":
|
|
806
1108
|
name = _spark_to_snowflake(logical_plan.namespace())
|
|
807
|
-
name = change_default_to_public(name)
|
|
808
1109
|
session.sql(f"USE SCHEMA {name}").collect()
|
|
809
1110
|
case "SetNamespaceLocation" | "SetNamespaceProperties":
|
|
810
1111
|
raise SnowparkConnectNotImplementedError(
|
|
@@ -1015,6 +1316,76 @@ def change_default_to_public(name: str) -> str:
|
|
|
1015
1316
|
return name
|
|
1016
1317
|
|
|
1017
1318
|
|
|
1319
|
+
def _preprocess_identifier_calls(sql_query: str) -> str:
|
|
1320
|
+
"""
|
|
1321
|
+
Pre-process SQL query to resolve IDENTIFIER() calls before Spark parsing.
|
|
1322
|
+
|
|
1323
|
+
Transforms: IDENTIFIER('abs')(c2) -> abs(c2)
|
|
1324
|
+
Transforms: IDENTIFIER('COAL' || 'ESCE')(NULL, 1) -> COALESCE(NULL, 1)
|
|
1325
|
+
|
|
1326
|
+
This preserves all function arguments in their original positions, eliminating
|
|
1327
|
+
the need to reconstruct them at the expression level.
|
|
1328
|
+
"""
|
|
1329
|
+
import re
|
|
1330
|
+
|
|
1331
|
+
# Pattern to match IDENTIFIER(...) followed by optional function call arguments
|
|
1332
|
+
# This captures both the identifier expression and any trailing arguments
|
|
1333
|
+
# Note: We need to be careful about whitespace preservation
|
|
1334
|
+
identifier_pattern = r"IDENTIFIER\s*\(\s*([^)]+)\s*\)(\s*)(\([^)]*\))?"
|
|
1335
|
+
|
|
1336
|
+
def resolve_identifier_match(match):
|
|
1337
|
+
identifier_expr_str = match.group(1).strip()
|
|
1338
|
+
whitespace = match.group(2) if match.group(2) else ""
|
|
1339
|
+
function_args = match.group(3) if match.group(3) else ""
|
|
1340
|
+
|
|
1341
|
+
try:
|
|
1342
|
+
# Handle string concatenation FIRST: IDENTIFIER('COAL' || 'ESCE')
|
|
1343
|
+
# (Must check this before simple strings since it also starts/ends with quotes)
|
|
1344
|
+
if "||" in identifier_expr_str:
|
|
1345
|
+
# Parse basic string concatenation with proper quote handling
|
|
1346
|
+
parts = []
|
|
1347
|
+
split_parts = identifier_expr_str.split("||")
|
|
1348
|
+
for part in split_parts:
|
|
1349
|
+
part = part.strip()
|
|
1350
|
+
if part.startswith("'") and part.endswith("'"):
|
|
1351
|
+
unquoted = part[1:-1] # Remove quotes from each part
|
|
1352
|
+
parts.append(unquoted)
|
|
1353
|
+
else:
|
|
1354
|
+
# Non-string parts - return original for safety
|
|
1355
|
+
return match.group(0)
|
|
1356
|
+
resolved_name = "".join(parts) # Concatenate the unquoted parts
|
|
1357
|
+
|
|
1358
|
+
# Handle simple string literals: IDENTIFIER('abs')
|
|
1359
|
+
elif identifier_expr_str.startswith("'") and identifier_expr_str.endswith(
|
|
1360
|
+
"'"
|
|
1361
|
+
):
|
|
1362
|
+
resolved_name = identifier_expr_str[1:-1] # Remove quotes
|
|
1363
|
+
|
|
1364
|
+
else:
|
|
1365
|
+
# Complex expressions not supported yet - return original
|
|
1366
|
+
return match.group(0)
|
|
1367
|
+
|
|
1368
|
+
# Return resolved function call with preserved arguments and whitespace
|
|
1369
|
+
if function_args:
|
|
1370
|
+
# Function call case: IDENTIFIER('abs')(c1) -> abs(c1)
|
|
1371
|
+
result = f"{resolved_name}{function_args}"
|
|
1372
|
+
else:
|
|
1373
|
+
# Column reference case: IDENTIFIER('c1') FROM -> c1 FROM (preserve whitespace)
|
|
1374
|
+
result = f"{resolved_name}{whitespace}"
|
|
1375
|
+
return result
|
|
1376
|
+
|
|
1377
|
+
except Exception:
|
|
1378
|
+
# Return original to avoid breaking the query
|
|
1379
|
+
return match.group(0)
|
|
1380
|
+
|
|
1381
|
+
# Apply the transformation
|
|
1382
|
+
processed_query = re.sub(
|
|
1383
|
+
identifier_pattern, resolve_identifier_match, sql_query, flags=re.IGNORECASE
|
|
1384
|
+
)
|
|
1385
|
+
|
|
1386
|
+
return processed_query
|
|
1387
|
+
|
|
1388
|
+
|
|
1018
1389
|
def map_sql(
|
|
1019
1390
|
rel: relation_proto.Relation,
|
|
1020
1391
|
) -> DataFrameContainer:
|
|
@@ -1844,21 +2215,46 @@ def map_logical_plan_relation(
|
|
|
1844
2215
|
return proto
|
|
1845
2216
|
|
|
1846
2217
|
|
|
1847
|
-
def
|
|
1848
|
-
|
|
1849
|
-
|
|
1850
|
-
|
|
1851
|
-
|
|
1852
|
-
|
|
1853
|
-
|
|
1854
|
-
|
|
1855
|
-
|
|
1856
|
-
|
|
1857
|
-
|
|
2218
|
+
def _get_relation_identifier(name_obj) -> str:
|
|
2219
|
+
# IDENTIFIER(<table_name>), or IDENTIFIER(<method name>)
|
|
2220
|
+
expr_proto = map_logical_plan_expression(name_obj.identifierExpr())
|
|
2221
|
+
session = snowpark.Session.get_active_session()
|
|
2222
|
+
m = ColumnNameMap([], [], None)
|
|
2223
|
+
expr = map_single_column_expression(
|
|
2224
|
+
expr_proto, m, ExpressionTyper.dummy_typer(session)
|
|
2225
|
+
)
|
|
2226
|
+
return spark_to_sf_single_id(session.range(1).select(expr[1].col).collect()[0][0])
|
|
2227
|
+
|
|
2228
|
+
|
|
2229
|
+
def get_relation_identifier_name_without_uppercasing(name_obj) -> str:
|
|
2230
|
+
if name_obj.getClass().getSimpleName() in (
|
|
2231
|
+
"PlanWithUnresolvedIdentifier",
|
|
2232
|
+
"ExpressionWithUnresolvedIdentifier",
|
|
2233
|
+
):
|
|
2234
|
+
return _get_relation_identifier(name_obj)
|
|
2235
|
+
else:
|
|
2236
|
+
name = ".".join(
|
|
2237
|
+
quote_name_without_upper_casing(str(part))
|
|
2238
|
+
for part in as_java_list(name_obj.nameParts())
|
|
1858
2239
|
)
|
|
2240
|
+
|
|
2241
|
+
return name
|
|
2242
|
+
|
|
2243
|
+
|
|
2244
|
+
def get_relation_identifier_name(name_obj, is_multi_part: bool = False) -> str:
|
|
2245
|
+
if name_obj.getClass().getSimpleName() in (
|
|
2246
|
+
"PlanWithUnresolvedIdentifier",
|
|
2247
|
+
"ExpressionWithUnresolvedIdentifier",
|
|
2248
|
+
):
|
|
2249
|
+
return _get_relation_identifier(name_obj)
|
|
1859
2250
|
else:
|
|
1860
2251
|
if is_multi_part:
|
|
1861
|
-
|
|
2252
|
+
try:
|
|
2253
|
+
# Try multipartIdentifier first for full catalog.database.table
|
|
2254
|
+
name = _spark_to_snowflake(name_obj.multipartIdentifier())
|
|
2255
|
+
except AttributeError:
|
|
2256
|
+
# Fallback to nameParts if multipartIdentifier not available
|
|
2257
|
+
name = _spark_to_snowflake(name_obj.nameParts())
|
|
1862
2258
|
else:
|
|
1863
2259
|
name = _spark_to_snowflake(name_obj.nameParts())
|
|
1864
2260
|
|