snowpark-connect 0.32.0__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/column_name_handler.py +92 -27
- snowflake/snowpark_connect/column_qualifier.py +0 -4
- snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
- snowflake/snowpark_connect/expression/map_sql_expression.py +12 -4
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +58 -21
- snowflake/snowpark_connect/expression/map_unresolved_function.py +62 -27
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
- snowflake/snowpark_connect/relation/map_aggregate.py +2 -4
- snowflake/snowpark_connect/relation/map_column_ops.py +5 -0
- snowflake/snowpark_connect/relation/map_join.py +218 -146
- snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
- snowflake/snowpark_connect/relation/map_sql.py +102 -16
- snowflake/snowpark_connect/relation/read/map_read_json.py +87 -2
- snowflake/snowpark_connect/relation/utils.py +46 -0
- snowflake/snowpark_connect/relation/write/map_write.py +186 -275
- snowflake/snowpark_connect/resources_initializer.py +25 -13
- snowflake/snowpark_connect/server.py +9 -24
- snowflake/snowpark_connect/type_mapping.py +2 -0
- snowflake/snowpark_connect/typed_column.py +2 -2
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +8 -1
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/METADATA +3 -1
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/RECORD +35 -93
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/top_level.txt +0 -0
|
@@ -27,6 +27,7 @@ from snowflake.snowpark_connect.utils.context import (
|
|
|
27
27
|
from snowflake.snowpark_connect.utils.identifiers import (
|
|
28
28
|
split_fully_qualified_spark_name,
|
|
29
29
|
)
|
|
30
|
+
from snowflake.snowpark_connect.utils.sequence import next_unique_num
|
|
30
31
|
|
|
31
32
|
ALREADY_QUOTED = re.compile('^(".+")$', re.DOTALL)
|
|
32
33
|
|
|
@@ -46,6 +47,7 @@ def set_schema_getter(df: DataFrame, get_schema: Callable[[], StructType]) -> No
|
|
|
46
47
|
df.__class__ = PatchedDataFrame
|
|
47
48
|
|
|
48
49
|
|
|
50
|
+
# TODO replace plan_id-offset with single unique value
|
|
49
51
|
def make_column_names_snowpark_compatible(
|
|
50
52
|
names: list[str], plan_id: int, offset: int = 0
|
|
51
53
|
) -> list[str]:
|
|
@@ -91,6 +93,14 @@ def make_column_names_snowpark_compatible(
|
|
|
91
93
|
]
|
|
92
94
|
|
|
93
95
|
|
|
96
|
+
def make_unique_snowpark_name(spark_name: str) -> str:
|
|
97
|
+
"""
|
|
98
|
+
Returns a snowpark column name that's guaranteed to be unique in this session,
|
|
99
|
+
by appending "#<unique number>" to the given spark name.
|
|
100
|
+
"""
|
|
101
|
+
return quote_name(f"{spark_name}#{next_unique_num()}")
|
|
102
|
+
|
|
103
|
+
|
|
94
104
|
@dataclass(frozen=True)
|
|
95
105
|
class ColumnNames:
|
|
96
106
|
spark_name: str
|
|
@@ -137,7 +147,7 @@ class ColumnNameMap:
|
|
|
137
147
|
|
|
138
148
|
# Rename chain dictionary to track column renaming history
|
|
139
149
|
self.rename_chains: dict[str, str] = {} # old_name -> new_name mapping
|
|
140
|
-
self.current_columns: set[str] = set() #
|
|
150
|
+
self.current_columns: set[str] = set() # current column names
|
|
141
151
|
|
|
142
152
|
# Parent ColumnNameMap classes
|
|
143
153
|
self._parent_column_name_map = parent_column_name_map
|
|
@@ -170,7 +180,7 @@ class ColumnNameMap:
|
|
|
170
180
|
snowpark_name=snowpark_column_names[i],
|
|
171
181
|
qualifiers=column_qualifiers[i]
|
|
172
182
|
if column_qualifiers and column_qualifiers[i]
|
|
173
|
-
else
|
|
183
|
+
else set(),
|
|
174
184
|
catalog_info=catalog_info,
|
|
175
185
|
database_info=database_info,
|
|
176
186
|
)
|
|
@@ -511,21 +521,6 @@ class ColumnNameMap:
|
|
|
511
521
|
if self._quote_if_unquoted(c.snowpark_name) not in cols_to_drop
|
|
512
522
|
]
|
|
513
523
|
|
|
514
|
-
def get_qualifier_for_spark_column(
|
|
515
|
-
self,
|
|
516
|
-
spark_column_name: str,
|
|
517
|
-
) -> ColumnQualifier:
|
|
518
|
-
"""
|
|
519
|
-
Backward compatibility: returns the first qualifier for the given Spark column name.
|
|
520
|
-
Throws if more than one qualifier exists.
|
|
521
|
-
"""
|
|
522
|
-
qualifiers = self.get_qualifiers_for_spark_column(spark_column_name)
|
|
523
|
-
if len(qualifiers) > 1:
|
|
524
|
-
raise ValueError(
|
|
525
|
-
"Shouldn't happen. Multiple qualifiers found; expected only one."
|
|
526
|
-
)
|
|
527
|
-
return next(iter(qualifiers))
|
|
528
|
-
|
|
529
524
|
def get_qualifiers_for_spark_column(
|
|
530
525
|
self,
|
|
531
526
|
spark_column_name: str,
|
|
@@ -544,7 +539,7 @@ class ColumnNameMap:
|
|
|
544
539
|
col = mapping.get(name)
|
|
545
540
|
|
|
546
541
|
if col is None or len(col) == 0:
|
|
547
|
-
return
|
|
542
|
+
return set()
|
|
548
543
|
|
|
549
544
|
return col[0].qualifiers
|
|
550
545
|
|
|
@@ -605,7 +600,7 @@ class ColumnNameMap:
|
|
|
605
600
|
removed_index.add(index)
|
|
606
601
|
spark_columns.append(new_spark_columns[index])
|
|
607
602
|
snowpark_columns.append(new_snowpark_columns[index])
|
|
608
|
-
qualifiers.append(
|
|
603
|
+
qualifiers.append(set())
|
|
609
604
|
else:
|
|
610
605
|
spark_columns.append(c.spark_name)
|
|
611
606
|
snowpark_columns.append(c.snowpark_name)
|
|
@@ -615,7 +610,7 @@ class ColumnNameMap:
|
|
|
615
610
|
if i not in removed_index:
|
|
616
611
|
spark_columns.append(new_spark_columns[i])
|
|
617
612
|
snowpark_columns.append(new_snowpark_columns[i])
|
|
618
|
-
qualifiers.append(
|
|
613
|
+
qualifiers.append(set())
|
|
619
614
|
|
|
620
615
|
return spark_columns, snowpark_columns, qualifiers
|
|
621
616
|
|
|
@@ -625,6 +620,67 @@ class ColumnNameMap:
|
|
|
625
620
|
else:
|
|
626
621
|
return spark_name.upper()
|
|
627
622
|
|
|
623
|
+
def get_columns_after_join(
|
|
624
|
+
self, other: ColumnNameMap, join_columns: list[str]
|
|
625
|
+
) -> list[ColumnNames]:
|
|
626
|
+
"""
|
|
627
|
+
Returns a list of columns (names and qualifiers) after a using_columns join with the given column map
|
|
628
|
+
"""
|
|
629
|
+
|
|
630
|
+
join_column_names = {self._normalized_spark_name(c) for c in join_columns}
|
|
631
|
+
other_join_columns: dict[str, ColumnNames] = {}
|
|
632
|
+
other_remaining_columns: list[ColumnNames] = []
|
|
633
|
+
for oc in other.columns:
|
|
634
|
+
col_name = self._normalized_spark_name(oc.spark_name)
|
|
635
|
+
# only take the first matching column
|
|
636
|
+
if col_name in join_column_names and col_name not in other_join_columns:
|
|
637
|
+
other_join_columns[col_name] = oc
|
|
638
|
+
else:
|
|
639
|
+
other_remaining_columns.append(oc)
|
|
640
|
+
|
|
641
|
+
joined_columns: list[ColumnNames] = []
|
|
642
|
+
visited: set[str] = set()
|
|
643
|
+
# add local columns first, we're in the left side of the join
|
|
644
|
+
for c in self.columns:
|
|
645
|
+
col_name = self._normalized_spark_name(c.spark_name)
|
|
646
|
+
if col_name in join_column_names and col_name not in visited:
|
|
647
|
+
visited.add(col_name)
|
|
648
|
+
qualifiers = c.qualifiers | other_join_columns[col_name].qualifiers
|
|
649
|
+
joined_columns.append(
|
|
650
|
+
ColumnNames(c.spark_name, c.snowpark_name, qualifiers)
|
|
651
|
+
)
|
|
652
|
+
else:
|
|
653
|
+
joined_columns.append(c)
|
|
654
|
+
|
|
655
|
+
# add other columns, excluding join columns
|
|
656
|
+
return joined_columns + other_remaining_columns
|
|
657
|
+
|
|
658
|
+
def get_column_indexes(self, spark_names: list[str]) -> list[int]:
|
|
659
|
+
"""
|
|
660
|
+
Returns the first positions of the given spark_names in this column mapping.
|
|
661
|
+
Used to reorder columns after a using_columns join.
|
|
662
|
+
"""
|
|
663
|
+
# mapping from normalized spark name ot the first index of the column in the mapping
|
|
664
|
+
column_indexes = {}
|
|
665
|
+
|
|
666
|
+
for i, c in enumerate(self.columns):
|
|
667
|
+
col_name = self._normalized_spark_name(c.spark_name)
|
|
668
|
+
if col_name not in column_indexes:
|
|
669
|
+
column_indexes[col_name] = i
|
|
670
|
+
|
|
671
|
+
# return indexes for given columns
|
|
672
|
+
return [column_indexes[self._normalized_spark_name(c)] for c in spark_names]
|
|
673
|
+
|
|
674
|
+
def get_conflicting_snowpark_columns(self, other: ColumnNameMap) -> set[str]:
|
|
675
|
+
conflicting_columns = set()
|
|
676
|
+
snowpark_names = {c.snowpark_name for c in self.columns}
|
|
677
|
+
|
|
678
|
+
for c in other.columns:
|
|
679
|
+
if c.snowpark_name in snowpark_names:
|
|
680
|
+
conflicting_columns.add(c.snowpark_name)
|
|
681
|
+
|
|
682
|
+
return conflicting_columns
|
|
683
|
+
|
|
628
684
|
|
|
629
685
|
class JoinColumnNameMap(ColumnNameMap):
|
|
630
686
|
def __init__(
|
|
@@ -784,19 +840,28 @@ class JoinColumnNameMap(ColumnNameMap):
|
|
|
784
840
|
def get_qualifiers_for_spark_column(
|
|
785
841
|
self, spark_column_name: str
|
|
786
842
|
) -> set[ColumnQualifier]:
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
def get_qualifier_for_spark_column(self, spark_column_name: str) -> ColumnQualifier:
|
|
790
|
-
qualifier_left = self.left_column_mapping.get_qualifier_for_spark_column(
|
|
843
|
+
qualifiers_left = self.left_column_mapping.get_qualifiers_for_spark_column(
|
|
791
844
|
spark_column_name
|
|
792
845
|
)
|
|
793
|
-
|
|
846
|
+
qualifiers_right = self.right_column_mapping.get_qualifiers_for_spark_column(
|
|
794
847
|
spark_column_name
|
|
795
848
|
)
|
|
796
849
|
|
|
797
|
-
if (
|
|
850
|
+
if (len(qualifiers_left) > 0) and (len(qualifiers_right) > 0):
|
|
798
851
|
exception = AnalysisException(f"Ambiguous column name {spark_column_name}")
|
|
799
852
|
attach_custom_error_code(exception, ErrorCodes.AMBIGUOUS_COLUMN_NAME)
|
|
800
853
|
raise exception
|
|
801
854
|
|
|
802
|
-
return
|
|
855
|
+
return qualifiers_right if len(qualifiers_left) == 0 else qualifiers_left
|
|
856
|
+
|
|
857
|
+
def get_columns_after_join(
|
|
858
|
+
self, other: ColumnNameMap, join_columns: list[str]
|
|
859
|
+
) -> list[ColumnNames]:
|
|
860
|
+
exception = NotImplementedError("Method not implemented!")
|
|
861
|
+
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
862
|
+
raise exception
|
|
863
|
+
|
|
864
|
+
def get_column_indexes(self, spark_names: list[str]) -> list[int]:
|
|
865
|
+
exception = NotImplementedError("Method not implemented!")
|
|
866
|
+
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
867
|
+
raise exception
|
|
@@ -23,10 +23,6 @@ class ColumnQualifier:
|
|
|
23
23
|
def is_empty(self) -> bool:
|
|
24
24
|
return len(self.parts) == 0
|
|
25
25
|
|
|
26
|
-
@classmethod
|
|
27
|
-
def no_qualifier(cls) -> ColumnQualifier:
|
|
28
|
-
return cls(())
|
|
29
|
-
|
|
30
26
|
def all_qualified_names(self, name: str) -> list[str]:
|
|
31
27
|
qualifier_parts = self.parts
|
|
32
28
|
qualifier_prefixes = [
|
|
@@ -148,14 +148,15 @@ class HybridColumnMap:
|
|
|
148
148
|
exp, self.aggregated_column_map, self.aggregated_typer
|
|
149
149
|
)
|
|
150
150
|
|
|
151
|
-
# For other expression types, try aggregated context first (likely references to computed values)
|
|
152
151
|
try:
|
|
152
|
+
# 1. Evaluate the expression using the input grouping columns. i.e input_df.
|
|
153
|
+
# If not found, use the aggregate alias.
|
|
154
|
+
return map_expression(exp, self.input_column_map, self.input_typer)
|
|
155
|
+
except Exception:
|
|
156
|
+
# Fall back to input context
|
|
153
157
|
return map_expression(
|
|
154
158
|
exp, self.aggregated_column_map, self.aggregated_typer
|
|
155
159
|
)
|
|
156
|
-
except Exception:
|
|
157
|
-
# Fall back to input context
|
|
158
|
-
return map_expression(exp, self.input_column_map, self.input_typer)
|
|
159
160
|
|
|
160
161
|
|
|
161
162
|
def create_hybrid_column_map_for_having(
|
|
@@ -418,13 +418,21 @@ def map_logical_plan_expression(exp: jpype.JObject) -> expressions_proto.Express
|
|
|
418
418
|
)
|
|
419
419
|
)
|
|
420
420
|
case "Like" | "ILike" | "RLike":
|
|
421
|
+
arguments = [
|
|
422
|
+
map_logical_plan_expression(e)
|
|
423
|
+
for e in list(as_java_list(exp.children()))
|
|
424
|
+
]
|
|
425
|
+
# exp.escapeChar() returns a JPype JChar - convert to string and create a literal
|
|
426
|
+
if getattr(exp, "escapeChar", None) is not None:
|
|
427
|
+
escape_char_str = str(exp.escapeChar())
|
|
428
|
+
escape_literal = expressions_proto.Expression(
|
|
429
|
+
literal=expressions_proto.Expression.Literal(string=escape_char_str)
|
|
430
|
+
)
|
|
431
|
+
arguments.append(escape_literal)
|
|
421
432
|
proto = expressions_proto.Expression(
|
|
422
433
|
unresolved_function=expressions_proto.Expression.UnresolvedFunction(
|
|
423
434
|
function_name=class_name.lower(),
|
|
424
|
-
arguments=
|
|
425
|
-
map_logical_plan_expression(e)
|
|
426
|
-
for e in list(as_java_list(exp.children()))
|
|
427
|
-
],
|
|
435
|
+
arguments=arguments,
|
|
428
436
|
)
|
|
429
437
|
)
|
|
430
438
|
case "LikeAny" | "NotLikeAny" | "LikeAll" | "NotLikeAll":
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
#
|
|
4
4
|
|
|
5
5
|
import re
|
|
6
|
+
from typing import Any
|
|
6
7
|
|
|
7
8
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
8
9
|
from pyspark.errors.exceptions.connect import AnalysisException
|
|
@@ -275,12 +276,33 @@ def map_unresolved_attribute(
|
|
|
275
276
|
else:
|
|
276
277
|
quoted_attr_name = name_parts[0]
|
|
277
278
|
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
279
|
+
# Helper function to try finding a column in current and outer scopes
|
|
280
|
+
def try_resolve_column(column_name: str) -> tuple[str | None, Any]:
|
|
281
|
+
# Try current scope
|
|
282
|
+
snowpark_name = column_mapping.get_snowpark_column_name_from_spark_column_name(
|
|
283
|
+
column_name, allow_non_exists=True
|
|
284
|
+
)
|
|
285
|
+
if snowpark_name is not None:
|
|
286
|
+
return snowpark_name, column_mapping
|
|
287
|
+
|
|
288
|
+
# Try outer scopes
|
|
289
|
+
for outer_df in get_outer_dataframes():
|
|
290
|
+
snowpark_name = (
|
|
291
|
+
outer_df.column_map.get_snowpark_column_name_from_spark_column_name(
|
|
292
|
+
column_name, allow_non_exists=True
|
|
293
|
+
)
|
|
294
|
+
)
|
|
295
|
+
if snowpark_name is not None:
|
|
296
|
+
return snowpark_name, outer_df.column_map
|
|
297
|
+
|
|
298
|
+
return None, None
|
|
299
|
+
|
|
300
|
+
# Try to resolve the full qualified name first
|
|
301
|
+
snowpark_name, found_column_map = try_resolve_column(quoted_attr_name)
|
|
302
|
+
|
|
281
303
|
if snowpark_name is not None:
|
|
282
304
|
col = get_col(snowpark_name)
|
|
283
|
-
qualifiers =
|
|
305
|
+
qualifiers = found_column_map.get_qualifiers_for_spark_column(quoted_attr_name)
|
|
284
306
|
else:
|
|
285
307
|
# this means it has to be a struct column with a field name
|
|
286
308
|
snowpark_name: str | None = None
|
|
@@ -295,28 +317,43 @@ def map_unresolved_attribute(
|
|
|
295
317
|
# For qualified names like "table.column.field", we need to find the column part
|
|
296
318
|
for i in range(len(name_parts)):
|
|
297
319
|
candidate_column = name_parts[i]
|
|
298
|
-
snowpark_name = (
|
|
299
|
-
|
|
300
|
-
candidate_column, allow_non_exists=True
|
|
301
|
-
)
|
|
302
|
-
)
|
|
320
|
+
snowpark_name, found_column_map = try_resolve_column(candidate_column)
|
|
321
|
+
|
|
303
322
|
if snowpark_name is not None:
|
|
304
323
|
column_part_index = i
|
|
305
324
|
break
|
|
306
325
|
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
326
|
+
# Validate qualifier scope: if we found a column but skipped prefix parts,
|
|
327
|
+
# those prefix parts could be valid qualifiers for the column
|
|
328
|
+
# We have prefix parts like 'nt1' in 'nt1.k' that were skipped
|
|
329
|
+
maybe_qualified = column_part_index > 0
|
|
330
|
+
if (
|
|
331
|
+
snowpark_name is not None
|
|
332
|
+
and maybe_qualified
|
|
333
|
+
and found_column_map is not None
|
|
334
|
+
):
|
|
335
|
+
prefix_parts = name_parts[:column_part_index]
|
|
336
|
+
found_col_qualifiers = found_column_map.get_qualifiers_for_spark_column(
|
|
337
|
+
candidate_column
|
|
338
|
+
)
|
|
317
339
|
|
|
318
|
-
if
|
|
319
|
-
|
|
340
|
+
# Check if any qualifier matches the prefix
|
|
341
|
+
has_matching_qualifier = False
|
|
342
|
+
for qual in found_col_qualifiers:
|
|
343
|
+
if not qual.is_empty and len(qual.parts) >= len(prefix_parts):
|
|
344
|
+
if qual.parts[-len(prefix_parts) :] == tuple(prefix_parts):
|
|
345
|
+
has_matching_qualifier = True
|
|
346
|
+
break
|
|
347
|
+
|
|
348
|
+
# If no matching qualifier, it's a scope violation
|
|
349
|
+
if not has_matching_qualifier:
|
|
350
|
+
# The prefix is not a valid qualifier for this column - scope violation!
|
|
351
|
+
exception = AnalysisException(
|
|
352
|
+
f'[UNRESOLVED_COLUMN] Column "{attr_name}" cannot be resolved. '
|
|
353
|
+
f'The table or alias "{".".join(prefix_parts)}" is not in scope or does not exist.'
|
|
354
|
+
)
|
|
355
|
+
attach_custom_error_code(exception, ErrorCodes.COLUMN_NOT_FOUND)
|
|
356
|
+
raise exception
|
|
320
357
|
|
|
321
358
|
if snowpark_name is None:
|
|
322
359
|
# Attempt LCA fallback.
|
|
@@ -24,7 +24,6 @@ from typing import List, Optional
|
|
|
24
24
|
from urllib.parse import quote, unquote
|
|
25
25
|
|
|
26
26
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
27
|
-
import pyspark.sql.functions as pyspark_functions
|
|
28
27
|
from google.protobuf.message import Message
|
|
29
28
|
from pyspark.errors.exceptions.base import (
|
|
30
29
|
AnalysisException,
|
|
@@ -101,6 +100,7 @@ from snowflake.snowpark_connect.expression.map_unresolved_star import (
|
|
|
101
100
|
)
|
|
102
101
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
103
102
|
from snowflake.snowpark_connect.relation.catalogs.utils import CURRENT_CATALOG_NAME
|
|
103
|
+
from snowflake.snowpark_connect.relation.utils import is_aggregate_function
|
|
104
104
|
from snowflake.snowpark_connect.type_mapping import (
|
|
105
105
|
map_json_schema_to_snowpark,
|
|
106
106
|
map_pyspark_types_to_snowpark_types,
|
|
@@ -400,9 +400,8 @@ def map_unresolved_function(
|
|
|
400
400
|
result_type: Optional[DataType | List[DateType]] = None
|
|
401
401
|
qualifier_parts: List[str] = []
|
|
402
402
|
|
|
403
|
-
|
|
404
|
-
if
|
|
405
|
-
# Used by the GROUP BY ALL implementation. Far from ideal, but it seems to work...
|
|
403
|
+
# Check if this is an aggregate function (used by GROUP BY ALL implementation)
|
|
404
|
+
if is_aggregate_function(function_name):
|
|
406
405
|
add_sql_aggregate_function()
|
|
407
406
|
|
|
408
407
|
def _type_with_typer(col: Column) -> TypedColumn:
|
|
@@ -912,15 +911,28 @@ def map_unresolved_function(
|
|
|
912
911
|
):
|
|
913
912
|
# String + YearMonthInterval: Spark tries to cast string to double first, throws error if it fails
|
|
914
913
|
result_type = StringType()
|
|
914
|
+
raise_error = _raise_error_helper(StringType(), AnalysisException)
|
|
915
915
|
if isinstance(snowpark_typed_args[0].typ, StringType):
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
916
|
+
# Try to cast string to double, if it fails (returns null), raise exception
|
|
917
|
+
cast_result = snowpark_fn.try_cast(snowpark_args[0], "double")
|
|
918
|
+
result_exp = snowpark_fn.when(
|
|
919
|
+
cast_result.is_null(),
|
|
920
|
+
raise_error(
|
|
921
|
+
snowpark_fn.lit(
|
|
922
|
+
f'The value \'{snowpark_args[0]}\' of the type {snowpark_typed_args[0].typ} cannot be cast to "DOUBLE" because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.'
|
|
923
|
+
)
|
|
924
|
+
),
|
|
925
|
+
).otherwise(cast_result + snowpark_args[1])
|
|
920
926
|
else:
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
927
|
+
cast_result = snowpark_fn.try_cast(snowpark_args[1], "double")
|
|
928
|
+
result_exp = snowpark_fn.when(
|
|
929
|
+
cast_result.is_null(),
|
|
930
|
+
raise_error(
|
|
931
|
+
snowpark_fn.lit(
|
|
932
|
+
f'The value \'{snowpark_args[0]}\' of the type {snowpark_typed_args[0].typ} cannot be cast to "DOUBLE" because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.'
|
|
933
|
+
)
|
|
934
|
+
),
|
|
935
|
+
).otherwise(snowpark_args[0] + cast_result)
|
|
924
936
|
case (StringType(), t) | (t, StringType()) if isinstance(
|
|
925
937
|
t, DayTimeIntervalType
|
|
926
938
|
):
|
|
@@ -6184,15 +6196,19 @@ def map_unresolved_function(
|
|
|
6184
6196
|
or isinstance(arg_type, DayTimeIntervalType)
|
|
6185
6197
|
else DoubleType()
|
|
6186
6198
|
)
|
|
6187
|
-
|
|
6188
|
-
|
|
6189
|
-
|
|
6190
|
-
|
|
6191
|
-
|
|
6192
|
-
|
|
6193
|
-
|
|
6194
|
-
|
|
6195
|
-
|
|
6199
|
+
case "pow" | "power":
|
|
6200
|
+
spark_function_name = f"{function_name if function_name == 'pow' else function_name.upper()}({snowpark_arg_names[0]}, {snowpark_arg_names[1]})"
|
|
6201
|
+
if not spark_sql_ansi_enabled:
|
|
6202
|
+
snowpark_args = _validate_numeric_args(
|
|
6203
|
+
function_name, snowpark_typed_args, snowpark_args
|
|
6204
|
+
)
|
|
6205
|
+
result_exp = snowpark_fn.when(
|
|
6206
|
+
snowpark_fn.equal_nan(snowpark_fn.cast(snowpark_args[0], FloatType()))
|
|
6207
|
+
| snowpark_fn.equal_nan(
|
|
6208
|
+
snowpark_fn.cast(snowpark_args[1], FloatType())
|
|
6209
|
+
),
|
|
6210
|
+
NAN,
|
|
6211
|
+
).otherwise(snowpark_fn.pow(snowpark_args[0], snowpark_args[1]))
|
|
6196
6212
|
result_type = DoubleType()
|
|
6197
6213
|
case "product":
|
|
6198
6214
|
col = snowpark_args[0]
|
|
@@ -9458,15 +9474,21 @@ def map_unresolved_function(
|
|
|
9458
9474
|
result_exp = snowpark_fn.year(snowpark_fn.to_date(snowpark_args[0]))
|
|
9459
9475
|
result_type = LongType()
|
|
9460
9476
|
case binary_method if binary_method in ("to_binary", "try_to_binary"):
|
|
9461
|
-
binary_format = "hex"
|
|
9477
|
+
binary_format = snowpark_fn.lit("hex")
|
|
9478
|
+
arg_str = snowpark_fn.cast(snowpark_args[0], StringType())
|
|
9462
9479
|
if len(snowpark_args) > 1:
|
|
9463
9480
|
binary_format = snowpark_args[1]
|
|
9464
9481
|
result_exp = snowpark_fn.when(
|
|
9465
9482
|
snowpark_args[0].isNull(), snowpark_fn.lit(None)
|
|
9466
9483
|
).otherwise(
|
|
9467
9484
|
snowpark_fn.function(binary_method)(
|
|
9468
|
-
snowpark_fn.
|
|
9469
|
-
|
|
9485
|
+
snowpark_fn.when(
|
|
9486
|
+
(snowpark_fn.length(arg_str) % 2 == 1)
|
|
9487
|
+
& (snowpark_fn.lower(binary_format) == snowpark_fn.lit("hex")),
|
|
9488
|
+
snowpark_fn.concat(snowpark_fn.lit("0"), arg_str),
|
|
9489
|
+
).otherwise(arg_str),
|
|
9490
|
+
binary_format,
|
|
9491
|
+
)
|
|
9470
9492
|
)
|
|
9471
9493
|
result_type = BinaryType()
|
|
9472
9494
|
case udtf_name if udtf_name.lower() in session._udtfs:
|
|
@@ -10705,12 +10727,18 @@ def _try_sum_helper(
|
|
|
10705
10727
|
return snowpark_fn.lit(None), new_type
|
|
10706
10728
|
else:
|
|
10707
10729
|
non_null_rows = snowpark_fn.count(col_name)
|
|
10708
|
-
|
|
10730
|
+
# Use _divnull to handle case when non_null_rows is 0
|
|
10731
|
+
return _divnull(aggregate_sum, non_null_rows), new_type
|
|
10709
10732
|
else:
|
|
10710
10733
|
new_type = DecimalType(
|
|
10711
10734
|
precision=min(38, arg_type.precision + 10), scale=arg_type.scale
|
|
10712
10735
|
)
|
|
10713
|
-
|
|
10736
|
+
# Return NULL when there are no non-null values (i.e., all values are NULL); this is handled using case/when to check for non-null values for both SUM and the sum component of AVG calculations.
|
|
10737
|
+
non_null_rows = snowpark_fn.count(col_name)
|
|
10738
|
+
result = snowpark_fn.when(
|
|
10739
|
+
non_null_rows == 0, snowpark_fn.lit(None)
|
|
10740
|
+
).otherwise(aggregate_sum)
|
|
10741
|
+
return result, new_type
|
|
10714
10742
|
|
|
10715
10743
|
case _:
|
|
10716
10744
|
# If the input column is floating point (double and float are synonymous in Snowflake per
|
|
@@ -10728,9 +10756,16 @@ def _try_sum_helper(
|
|
|
10728
10756
|
return snowpark_fn.lit(None), DoubleType()
|
|
10729
10757
|
else:
|
|
10730
10758
|
non_null_rows = snowpark_fn.count(col_name)
|
|
10731
|
-
|
|
10759
|
+
# Use _divnull to handle case when non_null_rows is 0
|
|
10760
|
+
return _divnull(aggregate_sum, non_null_rows), DoubleType()
|
|
10732
10761
|
else:
|
|
10733
|
-
|
|
10762
|
+
# When all values are NULL, SUM should return NULL (not 0)
|
|
10763
|
+
# Use case/when to return NULL when there are no non-null values (i.e., all values are NULL)
|
|
10764
|
+
non_null_rows = snowpark_fn.count(col_name)
|
|
10765
|
+
result = snowpark_fn.when(
|
|
10766
|
+
non_null_rows == 0, snowpark_fn.lit(None)
|
|
10767
|
+
).otherwise(aggregate_sum)
|
|
10768
|
+
return result, DoubleType()
|
|
10734
10769
|
|
|
10735
10770
|
|
|
10736
10771
|
def _get_type_precision(typ: DataType) -> tuple[int, int]:
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
3
|
+
# contributor license agreements. See the NOTICE file distributed with
|
|
4
|
+
# this work for additional information regarding copyright ownership.
|
|
5
|
+
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
6
|
+
# (the "License"); you may not use this file except in compliance with
|
|
7
|
+
# the License. You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|