snowpark-connect 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/column_name_handler.py +3 -93
- snowflake/snowpark_connect/config.py +99 -4
- snowflake/snowpark_connect/dataframe_container.py +0 -6
- snowflake/snowpark_connect/expression/map_expression.py +31 -1
- snowflake/snowpark_connect/expression/map_sql_expression.py +22 -18
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +22 -26
- snowflake/snowpark_connect/expression/map_unresolved_function.py +28 -10
- snowflake/snowpark_connect/expression/map_unresolved_star.py +2 -3
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/relation/map_extension.py +7 -1
- snowflake/snowpark_connect/relation/map_join.py +62 -258
- snowflake/snowpark_connect/relation/map_map_partitions.py +36 -77
- snowflake/snowpark_connect/relation/map_relation.py +8 -2
- snowflake/snowpark_connect/relation/map_show_string.py +2 -0
- snowflake/snowpark_connect/relation/map_sql.py +413 -15
- snowflake/snowpark_connect/relation/write/map_write.py +195 -114
- snowflake/snowpark_connect/resources_initializer.py +20 -5
- snowflake/snowpark_connect/server.py +20 -18
- snowflake/snowpark_connect/utils/artifacts.py +4 -5
- snowflake/snowpark_connect/utils/concurrent.py +4 -0
- snowflake/snowpark_connect/utils/context.py +41 -1
- snowflake/snowpark_connect/utils/describe_query_cache.py +57 -51
- snowflake/snowpark_connect/utils/identifiers.py +120 -0
- snowflake/snowpark_connect/utils/io_utils.py +21 -1
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +86 -2
- snowflake/snowpark_connect/utils/scala_udf_utils.py +34 -43
- snowflake/snowpark_connect/utils/session.py +16 -26
- snowflake/snowpark_connect/utils/telemetry.py +53 -0
- snowflake/snowpark_connect/utils/udf_utils.py +66 -103
- snowflake/snowpark_connect/utils/udtf_helper.py +17 -7
- snowflake/snowpark_connect/version.py +2 -3
- {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/METADATA +2 -2
- {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/RECORD +41 -42
- snowflake/snowpark_connect/hidden_column.py +0 -39
- {snowpark_connect-0.26.0.data → snowpark_connect-0.28.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.26.0.data → snowpark_connect-0.28.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.26.0.data → snowpark_connect-0.28.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/top_level.txt +0 -0
|
@@ -20,7 +20,6 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
|
20
20
|
from snowflake.snowpark._internal.utils import quote_name
|
|
21
21
|
from snowflake.snowpark.types import StructType
|
|
22
22
|
from snowflake.snowpark_connect.config import global_config
|
|
23
|
-
from snowflake.snowpark_connect.hidden_column import HiddenColumn
|
|
24
23
|
from snowflake.snowpark_connect.utils.context import get_current_operation_scope
|
|
25
24
|
from snowflake.snowpark_connect.utils.identifiers import (
|
|
26
25
|
split_fully_qualified_spark_name,
|
|
@@ -124,7 +123,6 @@ class ColumnNameMap:
|
|
|
124
123
|
] = lambda: global_config.spark_sql_caseSensitive,
|
|
125
124
|
column_metadata: dict | None = None,
|
|
126
125
|
column_qualifiers: list[list[str]] | None = None,
|
|
127
|
-
hidden_columns: set[HiddenColumn] | None = None,
|
|
128
126
|
parent_column_name_map: ColumnNameMap | None = None,
|
|
129
127
|
) -> None:
|
|
130
128
|
"""
|
|
@@ -135,7 +133,6 @@ class ColumnNameMap:
|
|
|
135
133
|
The key is the original Spark column name, and the value is the metadata.
|
|
136
134
|
example: Dict('age', {'foo': 'bar'})
|
|
137
135
|
column_qualifiers: Optional qualifiers for the columns, used to handle table aliases or DataFrame aliases.
|
|
138
|
-
hidden_columns: Optional set of HiddenColumn objects.
|
|
139
136
|
parent_column_name_map: parent ColumnNameMap
|
|
140
137
|
"""
|
|
141
138
|
self.columns: list[ColumnNames] = []
|
|
@@ -144,7 +141,6 @@ class ColumnNameMap:
|
|
|
144
141
|
self.snowpark_to_col = defaultdict(list)
|
|
145
142
|
self.is_case_sensitive = is_case_sensitive
|
|
146
143
|
self.column_metadata = column_metadata
|
|
147
|
-
self.hidden_columns = hidden_columns
|
|
148
144
|
|
|
149
145
|
# Rename chain dictionary to track column renaming history
|
|
150
146
|
self.rename_chains: dict[str, str] = {} # old_name -> new_name mapping
|
|
@@ -338,8 +334,6 @@ class ColumnNameMap:
|
|
|
338
334
|
*,
|
|
339
335
|
allow_non_exists: bool = False,
|
|
340
336
|
return_first: bool = False,
|
|
341
|
-
is_qualified: bool = False,
|
|
342
|
-
source_qualifiers: list[str] | None = None,
|
|
343
337
|
) -> str | None:
|
|
344
338
|
assert isinstance(spark_column_name, str)
|
|
345
339
|
resolved_name = (
|
|
@@ -347,37 +341,9 @@ class ColumnNameMap:
|
|
|
347
341
|
if self.rename_chains
|
|
348
342
|
else spark_column_name
|
|
349
343
|
)
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
# This is because that will return the name of the using column that's been dropped from the result
|
|
354
|
-
# dataframe. We want to fetch and resolve the hidden column to its visible using column name instead.
|
|
355
|
-
# Even if this is an unqualified reference or one to the visible column, it will resolve correctly to
|
|
356
|
-
# the visible name anyway.
|
|
357
|
-
snowpark_names = []
|
|
358
|
-
# Only check hidden columns for qualified references with source qualifiers
|
|
359
|
-
if is_qualified and source_qualifiers is not None and self.hidden_columns:
|
|
360
|
-
column_name = spark_column_name
|
|
361
|
-
|
|
362
|
-
# Check each hidden column for column name AND qualifier match
|
|
363
|
-
for hidden_col in self.hidden_columns:
|
|
364
|
-
if (
|
|
365
|
-
hidden_col.spark_name == column_name
|
|
366
|
-
and hidden_col.qualifiers == source_qualifiers
|
|
367
|
-
):
|
|
368
|
-
if not global_config.spark_sql_caseSensitive:
|
|
369
|
-
if hidden_col.spark_name.upper() == column_name.upper() and [
|
|
370
|
-
q.upper() for q in hidden_col.qualifiers
|
|
371
|
-
] == [q.upper() for q in source_qualifiers]:
|
|
372
|
-
snowpark_names.append(hidden_col.visible_snowpark_name)
|
|
373
|
-
else:
|
|
374
|
-
snowpark_names.append(hidden_col.visible_snowpark_name)
|
|
375
|
-
|
|
376
|
-
# If not found in hidden columns, proceed with normal lookup
|
|
377
|
-
if not snowpark_names:
|
|
378
|
-
snowpark_names = self.get_snowpark_column_names_from_spark_column_names(
|
|
379
|
-
[resolved_name], return_first
|
|
380
|
-
)
|
|
344
|
+
snowpark_names = self.get_snowpark_column_names_from_spark_column_names(
|
|
345
|
+
[resolved_name], return_first
|
|
346
|
+
)
|
|
381
347
|
|
|
382
348
|
snowpark_names_len = len(snowpark_names)
|
|
383
349
|
if snowpark_names_len > 1:
|
|
@@ -464,27 +430,6 @@ class ColumnNameMap:
|
|
|
464
430
|
snowpark_columns.append(c.snowpark_name)
|
|
465
431
|
qualifiers.append(c.qualifiers)
|
|
466
432
|
|
|
467
|
-
# Note: The following code is commented out because there is a bug with handling duplicate columns in
|
|
468
|
-
# qualified select *'s. This needs to be revisited once a solution for that is found.
|
|
469
|
-
# TODO: https://snowflakecomputing.atlassian.net/browse/SNOW-2265240
|
|
470
|
-
|
|
471
|
-
# # Handles fetching/resolving the hidden columns if they also match the qualifiers
|
|
472
|
-
# # This method is only ever called for qualified references, so we need to check hidden columns as well.
|
|
473
|
-
# if self.hidden_columns:
|
|
474
|
-
# for hidden_col in self.hidden_columns:
|
|
475
|
-
# col_qualifiers = (
|
|
476
|
-
# [q.upper() for q in hidden_col.qualifiers]
|
|
477
|
-
# if not self.is_case_sensitive()
|
|
478
|
-
# else hidden_col.qualifiers
|
|
479
|
-
# )
|
|
480
|
-
# if len(col_qualifiers) < len(qualifiers_input):
|
|
481
|
-
# continue
|
|
482
|
-
# if col_qualifiers[-len(qualifiers_input) :] == qualifiers_input:
|
|
483
|
-
# # This hidden column matches! Add it to the results
|
|
484
|
-
# spark_columns.append(hidden_col.spark_name)
|
|
485
|
-
# snowpark_columns.append(hidden_col.visible_snowpark_name)
|
|
486
|
-
# qualifiers.append(hidden_col.qualifiers)
|
|
487
|
-
|
|
488
433
|
return spark_columns, snowpark_columns, qualifiers
|
|
489
434
|
|
|
490
435
|
def get_snowpark_columns(self) -> list[str]:
|
|
@@ -616,35 +561,6 @@ class ColumnNameMap:
|
|
|
616
561
|
else:
|
|
617
562
|
return spark_name.upper()
|
|
618
563
|
|
|
619
|
-
def is_hidden_column_reference(
|
|
620
|
-
self, spark_column_name: str, source_qualifiers: list[str] | None = None
|
|
621
|
-
) -> bool:
|
|
622
|
-
"""
|
|
623
|
-
Check if a column reference would be resolved through hidden columns.
|
|
624
|
-
"""
|
|
625
|
-
if not self.hidden_columns or source_qualifiers is None:
|
|
626
|
-
return False
|
|
627
|
-
|
|
628
|
-
# For qualified references with source_qualifiers
|
|
629
|
-
column_name = (
|
|
630
|
-
spark_column_name # When has_plan_id=True, this is just the column name
|
|
631
|
-
)
|
|
632
|
-
|
|
633
|
-
for hidden_col in self.hidden_columns:
|
|
634
|
-
if (
|
|
635
|
-
hidden_col.spark_name == column_name
|
|
636
|
-
and hidden_col.qualifiers == source_qualifiers
|
|
637
|
-
):
|
|
638
|
-
if not global_config.spark_sql_caseSensitive:
|
|
639
|
-
if hidden_col.spark_name.upper() == column_name.upper() and [
|
|
640
|
-
q.upper() for q in hidden_col.qualifiers
|
|
641
|
-
] == [q.upper() for q in source_qualifiers]:
|
|
642
|
-
return True
|
|
643
|
-
else:
|
|
644
|
-
return True
|
|
645
|
-
|
|
646
|
-
return False
|
|
647
|
-
|
|
648
564
|
|
|
649
565
|
class JoinColumnNameMap(ColumnNameMap):
|
|
650
566
|
def __init__(
|
|
@@ -654,9 +570,6 @@ class JoinColumnNameMap(ColumnNameMap):
|
|
|
654
570
|
) -> None:
|
|
655
571
|
self.left_column_mapping: ColumnNameMap = left_colmap
|
|
656
572
|
self.right_column_mapping: ColumnNameMap = right_colmap
|
|
657
|
-
# Ensure attributes expected by base-class helpers exist to avoid AttributeError
|
|
658
|
-
# when generic code paths (e.g., hidden column checks) touch them.
|
|
659
|
-
self.hidden_columns: set[HiddenColumn] | None = None
|
|
660
573
|
|
|
661
574
|
def get_snowpark_column_name_from_spark_column_name(
|
|
662
575
|
self,
|
|
@@ -664,9 +577,6 @@ class JoinColumnNameMap(ColumnNameMap):
|
|
|
664
577
|
*,
|
|
665
578
|
allow_non_exists: bool = False,
|
|
666
579
|
return_first: bool = False,
|
|
667
|
-
# JoinColumnNameMap will never be called with using columns, so these parameters are not used.
|
|
668
|
-
is_qualified: bool = False,
|
|
669
|
-
source_qualifiers: list[str] | None = None,
|
|
670
580
|
) -> str | None:
|
|
671
581
|
snowpark_column_name_in_left = (
|
|
672
582
|
self.left_column_mapping.get_snowpark_column_name_from_spark_column_name(
|
|
@@ -8,7 +8,7 @@ import re
|
|
|
8
8
|
import sys
|
|
9
9
|
from collections import defaultdict
|
|
10
10
|
from copy import copy, deepcopy
|
|
11
|
-
from typing import Any
|
|
11
|
+
from typing import Any, Dict
|
|
12
12
|
|
|
13
13
|
import jpype
|
|
14
14
|
import pyspark.sql.connect.proto.base_pb2 as proto_base
|
|
@@ -17,6 +17,7 @@ from tzlocal import get_localzone_name
|
|
|
17
17
|
from snowflake import snowpark
|
|
18
18
|
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
19
19
|
quote_name_without_upper_casing,
|
|
20
|
+
unquote_if_quoted,
|
|
20
21
|
)
|
|
21
22
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
22
23
|
from snowflake.snowpark.types import TimestampTimeZone, TimestampType
|
|
@@ -171,9 +172,6 @@ class GlobalConfig:
|
|
|
171
172
|
"spark.app.name": lambda session, name: setattr(
|
|
172
173
|
session, "query_tag", f"Spark-Connect-App-Name={name}"
|
|
173
174
|
),
|
|
174
|
-
"snowpark.connect.udf.packages": lambda session, packages: session.add_packages(
|
|
175
|
-
*packages.strip("[] ").split(",")
|
|
176
|
-
),
|
|
177
175
|
"snowpark.connect.udf.imports": lambda session, imports: parse_imports(
|
|
178
176
|
session, imports
|
|
179
177
|
),
|
|
@@ -260,6 +258,7 @@ SESSION_CONFIG_KEY_WHITELIST = {
|
|
|
260
258
|
"spark.sql.execution.pythonUDTF.arrow.enabled",
|
|
261
259
|
"spark.sql.tvf.allowMultipleTableArguments.enabled",
|
|
262
260
|
"snowpark.connect.sql.passthrough",
|
|
261
|
+
"snowpark.connect.cte.optimization_enabled",
|
|
263
262
|
"snowpark.connect.iceberg.external_volume",
|
|
264
263
|
"snowpark.connect.sql.identifiers.auto-uppercase",
|
|
265
264
|
"snowpark.connect.udtf.compatibility_mode",
|
|
@@ -284,6 +283,7 @@ class SessionConfig:
|
|
|
284
283
|
default_session_config = {
|
|
285
284
|
"snowpark.connect.sql.identifiers.auto-uppercase": "all_except_columns",
|
|
286
285
|
"snowpark.connect.sql.passthrough": "false",
|
|
286
|
+
"snowpark.connect.cte.optimization_enabled": "true",
|
|
287
287
|
"snowpark.connect.udtf.compatibility_mode": "false",
|
|
288
288
|
"snowpark.connect.views.duplicate_column_names_handling_mode": "rename",
|
|
289
289
|
"spark.sql.execution.pythonUDTF.arrow.enabled": "false",
|
|
@@ -293,6 +293,7 @@ class SessionConfig:
|
|
|
293
293
|
|
|
294
294
|
def __init__(self) -> None:
|
|
295
295
|
self.config = deepcopy(self.default_session_config)
|
|
296
|
+
self.table_metadata: Dict[str, Dict[str, Any]] = {}
|
|
296
297
|
|
|
297
298
|
def __getitem__(self, item: str) -> str:
|
|
298
299
|
return self.get(item)
|
|
@@ -572,6 +573,12 @@ def set_snowflake_parameters(
|
|
|
572
573
|
snowpark_session.use_database(db)
|
|
573
574
|
case (prev, curr) if prev != curr:
|
|
574
575
|
snowpark_session.use_schema(prev)
|
|
576
|
+
case "snowpark.connect.cte.optimization_enabled":
|
|
577
|
+
# Set CTE optimization on the snowpark session
|
|
578
|
+
cte_enabled = str_to_bool(value)
|
|
579
|
+
snowpark_session.cte_optimization_enabled = cte_enabled
|
|
580
|
+
logger.info(f"Updated snowpark session CTE optimization: {cte_enabled}")
|
|
581
|
+
|
|
575
582
|
case _:
|
|
576
583
|
pass
|
|
577
584
|
|
|
@@ -581,6 +588,16 @@ def get_boolean_session_config_param(name: str) -> bool:
|
|
|
581
588
|
return str_to_bool(session_config[name])
|
|
582
589
|
|
|
583
590
|
|
|
591
|
+
def get_string_session_config_param(name: str) -> str:
|
|
592
|
+
session_config = sessions_config[get_session_id()]
|
|
593
|
+
return str(session_config[name])
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
def get_cte_optimization_enabled() -> bool:
|
|
597
|
+
"""Get the CTE optimization configuration setting."""
|
|
598
|
+
return get_boolean_session_config_param("snowpark.connect.cte.optimization_enabled")
|
|
599
|
+
|
|
600
|
+
|
|
584
601
|
def auto_uppercase_column_identifiers() -> bool:
|
|
585
602
|
session_config = sessions_config[get_session_id()]
|
|
586
603
|
return session_config[
|
|
@@ -616,3 +633,81 @@ def get_timestamp_type():
|
|
|
616
633
|
# shouldn't happen since `spark.sql.timestampType` is always defined, and `spark.conf.unset` sets it to default (TIMESTAMP_LTZ)
|
|
617
634
|
timestamp_type = TimestampType(TimestampTimeZone.LTZ)
|
|
618
635
|
return timestamp_type
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def record_table_metadata(
|
|
639
|
+
table_identifier: str,
|
|
640
|
+
table_type: str,
|
|
641
|
+
data_source: str,
|
|
642
|
+
supports_column_rename: bool = True,
|
|
643
|
+
) -> None:
|
|
644
|
+
"""
|
|
645
|
+
Record metadata about a table for Spark compatibility checks.
|
|
646
|
+
|
|
647
|
+
Args:
|
|
648
|
+
table_identifier: Full table identifier (catalog.database.table)
|
|
649
|
+
table_type: "v1" or "v2"
|
|
650
|
+
data_source: Source format (parquet, csv, iceberg, etc.)
|
|
651
|
+
supports_column_rename: Whether the table supports RENAME COLUMN
|
|
652
|
+
"""
|
|
653
|
+
session_id = get_session_id()
|
|
654
|
+
session_config = sessions_config[session_id]
|
|
655
|
+
|
|
656
|
+
# Normalize table identifier for consistent lookup
|
|
657
|
+
# Use the full catalog.database.table identifier to avoid conflicts
|
|
658
|
+
normalized_identifier = table_identifier.upper().strip('"')
|
|
659
|
+
|
|
660
|
+
session_config.table_metadata[normalized_identifier] = {
|
|
661
|
+
"table_type": table_type,
|
|
662
|
+
"data_source": data_source,
|
|
663
|
+
"supports_column_rename": supports_column_rename,
|
|
664
|
+
}
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def get_table_metadata(table_identifier: str) -> Dict[str, Any] | None:
|
|
668
|
+
"""
|
|
669
|
+
Get stored metadata for a table.
|
|
670
|
+
|
|
671
|
+
Args:
|
|
672
|
+
table_identifier: Full table identifier (catalog.database.table)
|
|
673
|
+
|
|
674
|
+
Returns:
|
|
675
|
+
Table metadata dict or None if not found
|
|
676
|
+
"""
|
|
677
|
+
session_id = get_session_id()
|
|
678
|
+
session_config = sessions_config[session_id]
|
|
679
|
+
|
|
680
|
+
normalized_identifier = unquote_if_quoted(table_identifier).upper()
|
|
681
|
+
|
|
682
|
+
return session_config.table_metadata.get(normalized_identifier)
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
def check_table_supports_operation(table_identifier: str, operation: str) -> bool:
|
|
686
|
+
"""
|
|
687
|
+
Check if a table supports a given operation based on metadata and config.
|
|
688
|
+
|
|
689
|
+
Args:
|
|
690
|
+
table_identifier: Full table identifier (catalog.database.table)
|
|
691
|
+
operation: Operation to check (e.g., "rename_column")
|
|
692
|
+
|
|
693
|
+
Returns:
|
|
694
|
+
True if operation is supported, False if should be blocked
|
|
695
|
+
"""
|
|
696
|
+
table_metadata = get_table_metadata(table_identifier)
|
|
697
|
+
|
|
698
|
+
if not table_metadata:
|
|
699
|
+
return True
|
|
700
|
+
|
|
701
|
+
session_id = get_session_id()
|
|
702
|
+
session_config = sessions_config[session_id]
|
|
703
|
+
enable_extensions = str_to_bool(
|
|
704
|
+
session_config.get("enable_snowflake_extension_behavior", "false")
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
if enable_extensions:
|
|
708
|
+
return True
|
|
709
|
+
|
|
710
|
+
if operation == "rename_column":
|
|
711
|
+
return table_metadata.get("supports_column_rename", True)
|
|
712
|
+
|
|
713
|
+
return True
|
|
@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Callable
|
|
|
8
8
|
|
|
9
9
|
from snowflake import snowpark
|
|
10
10
|
from snowflake.snowpark.types import StructField, StructType
|
|
11
|
-
from snowflake.snowpark_connect.hidden_column import HiddenColumn
|
|
12
11
|
|
|
13
12
|
if TYPE_CHECKING:
|
|
14
13
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
@@ -61,7 +60,6 @@ class DataFrameContainer:
|
|
|
61
60
|
column_metadata: dict | None = None,
|
|
62
61
|
column_qualifiers: list[list[str]] | None = None,
|
|
63
62
|
parent_column_name_map: ColumnNameMap | None = None,
|
|
64
|
-
hidden_columns: set[HiddenColumn] | None = None,
|
|
65
63
|
table_name: str | None = None,
|
|
66
64
|
alias: str | None = None,
|
|
67
65
|
cached_schema_getter: Callable[[], StructType] | None = None,
|
|
@@ -78,7 +76,6 @@ class DataFrameContainer:
|
|
|
78
76
|
column_metadata: Optional metadata dictionary
|
|
79
77
|
column_qualifiers: Optional column qualifiers
|
|
80
78
|
parent_column_name_map: Optional parent column name map
|
|
81
|
-
hidden_columns: Optional list of hidden column names
|
|
82
79
|
table_name: Optional table name
|
|
83
80
|
alias: Optional alias
|
|
84
81
|
cached_schema_getter: Optional function to get cached schema
|
|
@@ -101,7 +98,6 @@ class DataFrameContainer:
|
|
|
101
98
|
column_metadata,
|
|
102
99
|
column_qualifiers,
|
|
103
100
|
parent_column_name_map,
|
|
104
|
-
hidden_columns,
|
|
105
101
|
)
|
|
106
102
|
|
|
107
103
|
# Determine the schema getter to use
|
|
@@ -226,7 +222,6 @@ class DataFrameContainer:
|
|
|
226
222
|
column_metadata: dict | None = None,
|
|
227
223
|
column_qualifiers: list[list[str]] | None = None,
|
|
228
224
|
parent_column_name_map: ColumnNameMap | None = None,
|
|
229
|
-
hidden_columns: set[HiddenColumn] | None = None,
|
|
230
225
|
) -> ColumnNameMap:
|
|
231
226
|
"""Create a ColumnNameMap with the provided configuration."""
|
|
232
227
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
@@ -237,7 +232,6 @@ class DataFrameContainer:
|
|
|
237
232
|
column_metadata=column_metadata,
|
|
238
233
|
column_qualifiers=column_qualifiers,
|
|
239
234
|
parent_column_name_map=parent_column_name_map,
|
|
240
|
-
hidden_columns=hidden_columns,
|
|
241
235
|
)
|
|
242
236
|
|
|
243
237
|
@staticmethod
|
|
@@ -6,6 +6,7 @@ import datetime
|
|
|
6
6
|
from collections import defaultdict
|
|
7
7
|
|
|
8
8
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
9
|
+
from pyspark.errors.exceptions.connect import AnalysisException
|
|
9
10
|
|
|
10
11
|
import snowflake.snowpark.functions as snowpark_fn
|
|
11
12
|
from snowflake import snowpark
|
|
@@ -34,8 +35,10 @@ from snowflake.snowpark_connect.type_mapping import (
|
|
|
34
35
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
35
36
|
from snowflake.snowpark_connect.utils.context import (
|
|
36
37
|
gen_sql_plan_id,
|
|
38
|
+
get_current_lambda_params,
|
|
37
39
|
is_function_argument_being_resolved,
|
|
38
40
|
is_lambda_being_resolved,
|
|
41
|
+
not_resolving_fun_args,
|
|
39
42
|
)
|
|
40
43
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
41
44
|
SnowparkConnectNotImplementedError,
|
|
@@ -134,7 +137,10 @@ def map_expression(
|
|
|
134
137
|
case "expression_string":
|
|
135
138
|
return map_sql_expr(exp, column_mapping, typer)
|
|
136
139
|
case "extension":
|
|
137
|
-
|
|
140
|
+
# Extensions can be passed as function args, and we need to reset the context here.
|
|
141
|
+
# Matters only for resolving alias expressions in the extensions rel.
|
|
142
|
+
with not_resolving_fun_args():
|
|
143
|
+
return map_extension.map_extension(exp, column_mapping, typer)
|
|
138
144
|
case "lambda_function":
|
|
139
145
|
lambda_name, lambda_body = map_single_column_expression(
|
|
140
146
|
exp.lambda_function.function, column_mapping, typer
|
|
@@ -271,6 +277,30 @@ def map_expression(
|
|
|
271
277
|
case "unresolved_function":
|
|
272
278
|
return map_func.map_unresolved_function(exp, column_mapping, typer)
|
|
273
279
|
case "unresolved_named_lambda_variable":
|
|
280
|
+
# Validate that this lambda variable is in scope
|
|
281
|
+
var_name = exp.unresolved_named_lambda_variable.name_parts[0]
|
|
282
|
+
current_params = get_current_lambda_params()
|
|
283
|
+
|
|
284
|
+
if current_params and var_name not in current_params:
|
|
285
|
+
outer_col_name = (
|
|
286
|
+
column_mapping.get_snowpark_column_name_from_spark_column_name(
|
|
287
|
+
var_name, allow_non_exists=True
|
|
288
|
+
)
|
|
289
|
+
)
|
|
290
|
+
if outer_col_name:
|
|
291
|
+
col = snowpark_fn.col(outer_col_name)
|
|
292
|
+
return ["namedlambdavariable()"], TypedColumn(
|
|
293
|
+
col, lambda: typer.type(col)
|
|
294
|
+
)
|
|
295
|
+
else:
|
|
296
|
+
raise AnalysisException(
|
|
297
|
+
f"Cannot resolve variable '{var_name}' within lambda function. "
|
|
298
|
+
f"Lambda functions can access their own parameters and parent dataframe columns. "
|
|
299
|
+
f"Current lambda parameters: {current_params}. "
|
|
300
|
+
f"If '{var_name}' is an outer scope lambda variable from a nested lambda, "
|
|
301
|
+
f"that is an unsupported feature in Snowflake SQL."
|
|
302
|
+
)
|
|
303
|
+
|
|
274
304
|
col = snowpark_fn.Column(
|
|
275
305
|
UnresolvedAttribute(exp.unresolved_named_lambda_variable.name_parts[0])
|
|
276
306
|
)
|
|
@@ -11,9 +11,10 @@ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
|
11
11
|
import pyspark.sql.connect.proto.types_pb2 as types_proto
|
|
12
12
|
from google.protobuf.any_pb2 import Any
|
|
13
13
|
from pyspark.errors.exceptions.base import AnalysisException
|
|
14
|
+
from pyspark.sql.connect import functions as pyspark_functions
|
|
14
15
|
|
|
15
16
|
import snowflake.snowpark_connect.proto.snowflake_expression_ext_pb2 as snowflake_proto
|
|
16
|
-
from snowflake import
|
|
17
|
+
from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
|
|
17
18
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
18
19
|
from snowflake.snowpark_connect.config import global_config
|
|
19
20
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
@@ -89,6 +90,11 @@ def as_scala_seq(input):
|
|
|
89
90
|
)
|
|
90
91
|
|
|
91
92
|
|
|
93
|
+
@cache
|
|
94
|
+
def _scala_some():
|
|
95
|
+
return jpype.JClass("scala.Some")
|
|
96
|
+
|
|
97
|
+
|
|
92
98
|
def map_sql_expr(
|
|
93
99
|
exp: expressions_proto.Expression,
|
|
94
100
|
column_mapping: ColumnNameMap,
|
|
@@ -223,9 +229,6 @@ def apply_filter_clause(
|
|
|
223
229
|
|
|
224
230
|
|
|
225
231
|
def map_logical_plan_expression(exp: jpype.JObject) -> expressions_proto.Expression:
|
|
226
|
-
from snowflake.snowpark_connect.expression.map_expression import (
|
|
227
|
-
map_single_column_expression,
|
|
228
|
-
)
|
|
229
232
|
from snowflake.snowpark_connect.relation.map_sql import map_logical_plan_relation
|
|
230
233
|
|
|
231
234
|
class_name = str(exp.getClass().getSimpleName())
|
|
@@ -308,22 +311,23 @@ def map_logical_plan_expression(exp: jpype.JObject) -> expressions_proto.Express
|
|
|
308
311
|
)
|
|
309
312
|
proto = expressions_proto.Expression(extension=any_proto)
|
|
310
313
|
case "ExpressionWithUnresolvedIdentifier":
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
session = snowpark.Session.get_active_session()
|
|
314
|
-
m = ColumnNameMap([], [], None)
|
|
315
|
-
expr = map_single_column_expression(
|
|
316
|
-
identifierExpr, m, ExpressionTyper.dummy_typer(session)
|
|
314
|
+
from snowflake.snowpark_connect.relation.map_sql import (
|
|
315
|
+
get_relation_identifier_name,
|
|
317
316
|
)
|
|
318
|
-
value = session.range(1).select(expr[1].col).collect()[0][0]
|
|
319
317
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
)
|
|
325
|
-
|
|
326
|
-
|
|
318
|
+
value = unquote_if_quoted(get_relation_identifier_name(exp))
|
|
319
|
+
if getattr(pyspark_functions, value.lower(), None) is not None:
|
|
320
|
+
unresolved_function = exp.exprBuilder().apply(
|
|
321
|
+
_scala_some()(value).toList()
|
|
322
|
+
)
|
|
323
|
+
proto = map_logical_plan_expression(unresolved_function)
|
|
324
|
+
else:
|
|
325
|
+
proto = expressions_proto.Expression(
|
|
326
|
+
unresolved_attribute=expressions_proto.Expression.UnresolvedAttribute(
|
|
327
|
+
unparsed_identifier=str(value),
|
|
328
|
+
plan_id=None,
|
|
329
|
+
),
|
|
330
|
+
)
|
|
327
331
|
case "InSubquery":
|
|
328
332
|
rel_proto = map_logical_plan_relation(exp.query().plan())
|
|
329
333
|
any_proto = Any()
|
|
@@ -22,6 +22,7 @@ from snowflake.snowpark_connect.utils.context import (
|
|
|
22
22
|
get_is_evaluating_sql,
|
|
23
23
|
get_outer_dataframes,
|
|
24
24
|
get_plan_id_map,
|
|
25
|
+
is_lambda_being_resolved,
|
|
25
26
|
resolve_lca_alias,
|
|
26
27
|
)
|
|
27
28
|
from snowflake.snowpark_connect.utils.identifiers import (
|
|
@@ -162,7 +163,6 @@ def map_unresolved_attribute(
|
|
|
162
163
|
attr_name = ".".join(name_parts)
|
|
163
164
|
|
|
164
165
|
has_plan_id = exp.unresolved_attribute.HasField("plan_id")
|
|
165
|
-
source_qualifiers = None
|
|
166
166
|
|
|
167
167
|
if has_plan_id:
|
|
168
168
|
plan_id = exp.unresolved_attribute.plan_id
|
|
@@ -171,27 +171,13 @@ def map_unresolved_attribute(
|
|
|
171
171
|
assert (
|
|
172
172
|
target_df is not None
|
|
173
173
|
), f"resolving an attribute of a unresolved dataframe {plan_id}"
|
|
174
|
-
|
|
175
|
-
# Get the qualifiers for this column from the target DataFrame
|
|
176
|
-
source_qualifiers = (
|
|
177
|
-
target_df_container.column_map.get_qualifier_for_spark_column(
|
|
178
|
-
name_parts[-1]
|
|
179
|
-
)
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
if hasattr(column_mapping, "hidden_columns"):
|
|
183
|
-
hidden = column_mapping.hidden_columns
|
|
184
|
-
else:
|
|
185
|
-
hidden = None
|
|
186
|
-
|
|
187
174
|
column_mapping = target_df_container.column_map
|
|
188
|
-
column_mapping.hidden_columns = hidden
|
|
189
175
|
typer = ExpressionTyper(target_df)
|
|
190
176
|
|
|
191
|
-
def get_col(snowpark_name
|
|
177
|
+
def get_col(snowpark_name):
|
|
192
178
|
return (
|
|
193
179
|
snowpark_fn.col(snowpark_name)
|
|
194
|
-
if not has_plan_id
|
|
180
|
+
if not has_plan_id
|
|
195
181
|
else target_df.col(snowpark_name)
|
|
196
182
|
)
|
|
197
183
|
|
|
@@ -276,17 +262,10 @@ def map_unresolved_attribute(
|
|
|
276
262
|
quoted_attr_name = name_parts[0]
|
|
277
263
|
|
|
278
264
|
snowpark_name = column_mapping.get_snowpark_column_name_from_spark_column_name(
|
|
279
|
-
quoted_attr_name,
|
|
280
|
-
allow_non_exists=True,
|
|
281
|
-
is_qualified=has_plan_id,
|
|
282
|
-
source_qualifiers=source_qualifiers if has_plan_id else None,
|
|
265
|
+
quoted_attr_name, allow_non_exists=True
|
|
283
266
|
)
|
|
284
|
-
|
|
285
267
|
if snowpark_name is not None:
|
|
286
|
-
|
|
287
|
-
quoted_attr_name, source_qualifiers
|
|
288
|
-
)
|
|
289
|
-
col = get_col(snowpark_name, is_hidden)
|
|
268
|
+
col = get_col(snowpark_name)
|
|
290
269
|
qualifiers = column_mapping.get_qualifier_for_spark_column(quoted_attr_name)
|
|
291
270
|
else:
|
|
292
271
|
# this means it has to be a struct column with a field name
|
|
@@ -356,6 +335,23 @@ def map_unresolved_attribute(
|
|
|
356
335
|
return (unqualified_name, typed_col)
|
|
357
336
|
|
|
358
337
|
if snowpark_name is None:
|
|
338
|
+
# Check if we're inside a lambda and trying to reference an outer column
|
|
339
|
+
# This catches direct column references (not lambda variables)
|
|
340
|
+
if is_lambda_being_resolved() and column_mapping:
|
|
341
|
+
# Check if this column exists in the outer scope (not lambda params)
|
|
342
|
+
outer_col_name = (
|
|
343
|
+
column_mapping.get_snowpark_column_name_from_spark_column_name(
|
|
344
|
+
attr_name, allow_non_exists=True
|
|
345
|
+
)
|
|
346
|
+
)
|
|
347
|
+
if outer_col_name:
|
|
348
|
+
# This is an outer scope column being referenced inside a lambda
|
|
349
|
+
raise AnalysisException(
|
|
350
|
+
f"Reference to non-lambda variable '{attr_name}' within lambda function. "
|
|
351
|
+
f"Lambda functions can only access their own parameters. "
|
|
352
|
+
f"If '{attr_name}' is a table column, it must be passed as an explicit parameter to the enclosing function."
|
|
353
|
+
)
|
|
354
|
+
|
|
359
355
|
if has_plan_id:
|
|
360
356
|
raise AnalysisException(
|
|
361
357
|
f'[RESOLVED_REFERENCE_COLUMN_NOT_FOUND] The column "{attr_name}" does not exist in the target dataframe.'
|
|
@@ -711,6 +711,9 @@ def map_unresolved_function(
|
|
|
711
711
|
"-",
|
|
712
712
|
)
|
|
713
713
|
match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
|
|
714
|
+
case (DateType(), NullType()) | (NullType(), DateType()):
|
|
715
|
+
result_type = LongType()
|
|
716
|
+
result_exp = snowpark_fn.lit(None).cast(result_type)
|
|
714
717
|
case (NullType(), _) | (_, NullType()):
|
|
715
718
|
result_type = _get_add_sub_result_type(
|
|
716
719
|
snowpark_typed_args[0].typ,
|
|
@@ -724,7 +727,10 @@ def map_unresolved_function(
|
|
|
724
727
|
result_type = LongType()
|
|
725
728
|
result_exp = snowpark_args[0] - snowpark_args[1]
|
|
726
729
|
case (DateType(), StringType()):
|
|
727
|
-
if
|
|
730
|
+
if (
|
|
731
|
+
hasattr(snowpark_typed_args[1].col._expr1, "pretty_name")
|
|
732
|
+
and "INTERVAL" == snowpark_typed_args[1].col._expr1.pretty_name
|
|
733
|
+
):
|
|
728
734
|
result_type = TimestampType()
|
|
729
735
|
result_exp = snowpark_args[0] - snowpark_args[1]
|
|
730
736
|
else:
|
|
@@ -2421,7 +2427,7 @@ def map_unresolved_function(
|
|
|
2421
2427
|
"try_to_date",
|
|
2422
2428
|
snowpark_fn.cast(
|
|
2423
2429
|
truncated_date,
|
|
2424
|
-
TimestampType(
|
|
2430
|
+
TimestampType(),
|
|
2425
2431
|
),
|
|
2426
2432
|
snowpark_args[1],
|
|
2427
2433
|
)
|
|
@@ -2613,9 +2619,18 @@ def map_unresolved_function(
|
|
|
2613
2619
|
result_type = input_type.element_type
|
|
2614
2620
|
result_exp = fn(snowpark_args[0])
|
|
2615
2621
|
case _:
|
|
2616
|
-
|
|
2617
|
-
|
|
2618
|
-
|
|
2622
|
+
# Check if the type has map-like attributes before accessing them
|
|
2623
|
+
if hasattr(input_type, "key_type") and hasattr(
|
|
2624
|
+
input_type, "value_type"
|
|
2625
|
+
):
|
|
2626
|
+
spark_col_names = ["key", "value"]
|
|
2627
|
+
result_exp = fn(snowpark_args[0])
|
|
2628
|
+
result_type = [input_type.key_type, input_type.value_type]
|
|
2629
|
+
else:
|
|
2630
|
+
# Throw proper error for types without key_type/value_type attributes
|
|
2631
|
+
raise AnalysisException(
|
|
2632
|
+
f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{function_name}({snowpark_arg_names[0]})" due to data type mismatch: Parameter 1 requires the ("ARRAY" or "MAP") type, however "{snowpark_arg_names[0]}" has the type "{str(input_type)}".'
|
|
2633
|
+
)
|
|
2619
2634
|
case "expm1":
|
|
2620
2635
|
spark_function_name = f"EXPM1({snowpark_arg_names[0]})"
|
|
2621
2636
|
result_exp = snowpark_fn.exp(*snowpark_args) - 1
|
|
@@ -8725,7 +8740,7 @@ def _resolve_function_with_lambda(
|
|
|
8725
8740
|
artificial_df = Session.get_active_session().create_dataframe([], schema)
|
|
8726
8741
|
set_schema_getter(artificial_df, lambda: schema)
|
|
8727
8742
|
|
|
8728
|
-
with resolving_lambda_function():
|
|
8743
|
+
with resolving_lambda_function(names):
|
|
8729
8744
|
return map_expression(
|
|
8730
8745
|
(
|
|
8731
8746
|
lambda_exp.lambda_function.function
|
|
@@ -9911,7 +9926,10 @@ def _get_spark_function_name(
|
|
|
9911
9926
|
return f"({date_param_name1} {operation_op} {date_param_name2})"
|
|
9912
9927
|
case (StringType(), DateType()):
|
|
9913
9928
|
date_param_name2 = _get_literal_param_name(exp, 1, snowpark_arg_names[1])
|
|
9914
|
-
if
|
|
9929
|
+
if (
|
|
9930
|
+
hasattr(col1.col._expr1, "pretty_name")
|
|
9931
|
+
and "INTERVAL" == col1.col._expr1.pretty_name
|
|
9932
|
+
):
|
|
9915
9933
|
return f"{date_param_name2} {operation_op} {snowpark_arg_names[0]}"
|
|
9916
9934
|
elif global_config.spark_sql_ansi_enabled and function_name == "+":
|
|
9917
9935
|
return f"{operation_func}(cast({date_param_name2} as date), cast({snowpark_arg_names[0]} as double))"
|
|
@@ -9919,9 +9937,9 @@ def _get_spark_function_name(
|
|
|
9919
9937
|
return f"({snowpark_arg_names[0]} {operation_op} {date_param_name2})"
|
|
9920
9938
|
case (DateType(), StringType()):
|
|
9921
9939
|
date_param_name1 = _get_literal_param_name(exp, 0, snowpark_arg_names[0])
|
|
9922
|
-
if (
|
|
9923
|
-
|
|
9924
|
-
|
|
9940
|
+
if global_config.spark_sql_ansi_enabled or (
|
|
9941
|
+
hasattr(col2.col._expr1, "pretty_name")
|
|
9942
|
+
and "INTERVAL" == col2.col._expr1.pretty_name
|
|
9925
9943
|
):
|
|
9926
9944
|
return f"{date_param_name1} {operation_op} {snowpark_arg_names[1]}"
|
|
9927
9945
|
else:
|