snowpark-connect 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/column_name_handler.py +73 -100
- snowflake/snowpark_connect/column_qualifier.py +47 -0
- snowflake/snowpark_connect/dataframe_container.py +3 -2
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +4 -2
- snowflake/snowpark_connect/expression/map_expression.py +5 -4
- snowflake/snowpark_connect/expression/map_extension.py +12 -6
- snowflake/snowpark_connect/expression/map_sql_expression.py +38 -3
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +5 -5
- snowflake/snowpark_connect/expression/map_unresolved_function.py +869 -107
- snowflake/snowpark_connect/expression/map_unresolved_star.py +9 -7
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +4 -1
- snowflake/snowpark_connect/relation/map_aggregate.py +8 -5
- snowflake/snowpark_connect/relation/map_column_ops.py +4 -3
- snowflake/snowpark_connect/relation/map_extension.py +10 -9
- snowflake/snowpark_connect/relation/map_join.py +5 -2
- snowflake/snowpark_connect/relation/map_sql.py +33 -1
- snowflake/snowpark_connect/relation/map_subquery_alias.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_table.py +6 -3
- snowflake/snowpark_connect/relation/write/map_write.py +29 -14
- snowflake/snowpark_connect/server.py +1 -2
- snowflake/snowpark_connect/type_mapping.py +36 -3
- snowflake/snowpark_connect/typed_column.py +8 -6
- snowflake/snowpark_connect/utils/session.py +19 -3
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +1 -1
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/METADATA +5 -2
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/RECORD +36 -37
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
- {snowpark_connect-0.31.0.data → snowpark_connect-0.32.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.31.0.data → snowpark_connect-0.32.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.31.0.data → snowpark_connect-0.32.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/top_level.txt +0 -0
|
@@ -10,6 +10,7 @@ sys.path.append(str(pathlib.Path(__file__).parent / "includes/python"))
|
|
|
10
10
|
|
|
11
11
|
from .server import get_session # noqa: E402, F401
|
|
12
12
|
from .server import start_session # noqa: E402, F401
|
|
13
|
+
from .utils.session import skip_session_configuration # noqa: E402, F401
|
|
13
14
|
|
|
14
15
|
# Turn off catalog warning for Snowpark
|
|
15
16
|
sp_logger = logging.getLogger("snowflake.snowpark")
|
|
@@ -13,12 +13,10 @@ from functools import cached_property
|
|
|
13
13
|
from pyspark.errors.exceptions.base import AnalysisException
|
|
14
14
|
|
|
15
15
|
from snowflake.snowpark import DataFrame
|
|
16
|
-
from snowflake.snowpark._internal.analyzer.analyzer_utils import
|
|
17
|
-
quote_name_without_upper_casing,
|
|
18
|
-
unquote_if_quoted,
|
|
19
|
-
)
|
|
16
|
+
from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
|
|
20
17
|
from snowflake.snowpark._internal.utils import quote_name
|
|
21
18
|
from snowflake.snowpark.types import StructType
|
|
19
|
+
from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
|
|
22
20
|
from snowflake.snowpark_connect.config import global_config
|
|
23
21
|
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
24
22
|
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
@@ -97,31 +95,15 @@ def make_column_names_snowpark_compatible(
|
|
|
97
95
|
class ColumnNames:
|
|
98
96
|
spark_name: str
|
|
99
97
|
snowpark_name: str
|
|
100
|
-
qualifiers:
|
|
98
|
+
qualifiers: set[ColumnQualifier]
|
|
101
99
|
catalog_info: str | None = None # Catalog from fully qualified name
|
|
102
100
|
database_info: str | None = None # Database from fully qualified name
|
|
103
101
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
For example, if the column name is 'id' and the qualifiers are ['db', 'table'],
|
|
111
|
-
then the possible Spark names are:
|
|
112
|
-
['id', 'db.table.id', 'table.id']
|
|
113
|
-
"""
|
|
114
|
-
spark_name = column_names.spark_name
|
|
115
|
-
qualifiers = column_names.qualifiers
|
|
116
|
-
|
|
117
|
-
qualifier_suffixes_list = [
|
|
118
|
-
".".join(quote_name_without_upper_casing(x) for x in qualifiers[i:])
|
|
119
|
-
for i in range(len(qualifiers))
|
|
120
|
-
]
|
|
121
|
-
return [spark_name] + [
|
|
122
|
-
f"{qualifier_suffix}.{spark_name}"
|
|
123
|
-
for qualifier_suffix in qualifier_suffixes_list
|
|
124
|
-
]
|
|
102
|
+
def all_spark_names_including_qualified_names(self):
|
|
103
|
+
all_names = [self.spark_name]
|
|
104
|
+
for qualifier in self.qualifiers:
|
|
105
|
+
all_names.extend(qualifier.all_qualified_names(self.spark_name))
|
|
106
|
+
return all_names
|
|
125
107
|
|
|
126
108
|
|
|
127
109
|
class ColumnNameMap:
|
|
@@ -133,13 +115,13 @@ class ColumnNameMap:
|
|
|
133
115
|
[], bool
|
|
134
116
|
] = lambda: global_config.spark_sql_caseSensitive,
|
|
135
117
|
column_metadata: dict | None = None,
|
|
136
|
-
column_qualifiers: list[
|
|
118
|
+
column_qualifiers: list[set[ColumnQualifier]] = None,
|
|
137
119
|
parent_column_name_map: ColumnNameMap | None = None,
|
|
138
120
|
) -> None:
|
|
139
121
|
"""
|
|
140
122
|
spark_column_names: Original spark column names
|
|
141
123
|
snowpark_column_names: Snowpark column names
|
|
142
|
-
column_metadata: This field is used to store metadata related to columns. Since Snowpark
|
|
124
|
+
column_metadata: This field is used to store metadata related to columns. Since Snowpark's Struct type does not support metadata,
|
|
143
125
|
we use this attribute to store any metadata related to the columns.
|
|
144
126
|
The key is the original Spark column name, and the value is the metadata.
|
|
145
127
|
example: Dict('age', {'foo': 'bar'})
|
|
@@ -186,21 +168,18 @@ class ColumnNameMap:
|
|
|
186
168
|
c = ColumnNames(
|
|
187
169
|
spark_name=spark_name,
|
|
188
170
|
snowpark_name=snowpark_column_names[i],
|
|
189
|
-
qualifiers=column_qualifiers[i]
|
|
171
|
+
qualifiers=column_qualifiers[i]
|
|
172
|
+
if column_qualifiers and column_qualifiers[i]
|
|
173
|
+
else {ColumnQualifier.no_qualifier()},
|
|
190
174
|
catalog_info=catalog_info,
|
|
191
175
|
database_info=database_info,
|
|
192
176
|
)
|
|
193
177
|
self.columns.append(c)
|
|
194
178
|
|
|
195
|
-
|
|
196
|
-
spark_names_including_qualifier = get_list_of_spark_names_for_column(c)
|
|
197
|
-
|
|
198
|
-
for spark_name_including_qualifier in spark_names_including_qualifier:
|
|
179
|
+
for spark_name in c.all_spark_names_including_qualified_names():
|
|
199
180
|
# the same spark name can map to multiple snowpark names
|
|
200
|
-
self.spark_to_col[
|
|
201
|
-
self.uppercase_spark_to_col[
|
|
202
|
-
spark_name_including_qualifier.upper()
|
|
203
|
-
].append(c)
|
|
181
|
+
self.spark_to_col[spark_name].append(c)
|
|
182
|
+
self.uppercase_spark_to_col[spark_name.upper()].append(c)
|
|
204
183
|
|
|
205
184
|
# the same snowpark name can map to multiple spark column
|
|
206
185
|
# e.g. df.select(date_format('dt', 'yyy'), date_format('dt', 'yyyy')) ->
|
|
@@ -405,7 +384,7 @@ class ColumnNameMap:
|
|
|
405
384
|
if spark_column_name not in self.spark_to_col:
|
|
406
385
|
return False
|
|
407
386
|
|
|
408
|
-
columns = self.spark_to_col[spark_column_name]
|
|
387
|
+
columns: list[ColumnNames] = self.spark_to_col[spark_column_name]
|
|
409
388
|
|
|
410
389
|
# If we don't have multiple columns, there's no ambiguity to resolve
|
|
411
390
|
if len(columns) <= 1:
|
|
@@ -416,30 +395,7 @@ class ColumnNameMap:
|
|
|
416
395
|
first_column = columns[0]
|
|
417
396
|
|
|
418
397
|
for column in columns[1:]:
|
|
419
|
-
|
|
420
|
-
# If one has the attribute but the other doesn't, they're different
|
|
421
|
-
if hasattr(first_column, "snowpark_type") != hasattr(
|
|
422
|
-
column, "snowpark_type"
|
|
423
|
-
):
|
|
424
|
-
return False
|
|
425
|
-
# If both have the attribute and values differ, they're different expressions
|
|
426
|
-
if (
|
|
427
|
-
hasattr(first_column, "snowpark_type")
|
|
428
|
-
and hasattr(column, "snowpark_type")
|
|
429
|
-
and first_column.snowpark_type != column.snowpark_type
|
|
430
|
-
):
|
|
431
|
-
return False
|
|
432
|
-
|
|
433
|
-
# Check qualifiers attribute
|
|
434
|
-
# If one has the attribute but the other doesn't, they're different
|
|
435
|
-
if hasattr(first_column, "qualifiers") != hasattr(column, "qualifiers"):
|
|
436
|
-
return False
|
|
437
|
-
# If both have the attribute and values differ, they might be from different contexts
|
|
438
|
-
if (
|
|
439
|
-
hasattr(first_column, "qualifiers")
|
|
440
|
-
and hasattr(column, "qualifiers")
|
|
441
|
-
and first_column.qualifiers != column.qualifiers
|
|
442
|
-
):
|
|
398
|
+
if first_column.qualifiers != column.qualifiers:
|
|
443
399
|
return False
|
|
444
400
|
|
|
445
401
|
# Additional safety check: ensure all snowpark names are actually in our mapping
|
|
@@ -500,32 +456,30 @@ class ColumnNameMap:
|
|
|
500
456
|
return [c.spark_name for c in self.columns]
|
|
501
457
|
|
|
502
458
|
def get_spark_and_snowpark_columns_with_qualifier_for_qualifier(
|
|
503
|
-
self,
|
|
504
|
-
) -> tuple[list[str], list[str], list[
|
|
459
|
+
self, target_qualifier: ColumnQualifier
|
|
460
|
+
) -> tuple[list[str], list[str], list[set[ColumnQualifier]]]:
|
|
505
461
|
"""
|
|
506
|
-
Returns the Spark and Snowpark column names along with their qualifiers for the specified
|
|
507
|
-
If a column does not have a qualifier, it will be None.
|
|
462
|
+
Returns the Spark and Snowpark column names along with their qualifiers for the specified qualifier.
|
|
508
463
|
"""
|
|
509
|
-
spark_columns = []
|
|
510
|
-
snowpark_columns = []
|
|
511
|
-
qualifiers = []
|
|
464
|
+
spark_columns: list[str] = []
|
|
465
|
+
snowpark_columns: list[str] = []
|
|
466
|
+
qualifiers: list[set[ColumnQualifier]] = []
|
|
512
467
|
|
|
468
|
+
normalized_qualifier = target_qualifier
|
|
513
469
|
if not self.is_case_sensitive():
|
|
514
|
-
|
|
470
|
+
normalized_qualifier = target_qualifier.to_upper()
|
|
515
471
|
|
|
516
|
-
for
|
|
517
|
-
|
|
518
|
-
|
|
472
|
+
for column in self.columns:
|
|
473
|
+
# Normalize all qualifiers for comparison
|
|
474
|
+
column_qualifiers: set[ColumnQualifier] = (
|
|
475
|
+
{q.to_upper() for q in iter(column.qualifiers)}
|
|
519
476
|
if not self.is_case_sensitive()
|
|
520
|
-
else
|
|
477
|
+
else column.qualifiers
|
|
521
478
|
)
|
|
522
|
-
if
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
spark_columns.append(c.spark_name)
|
|
527
|
-
snowpark_columns.append(c.snowpark_name)
|
|
528
|
-
qualifiers.append(c.qualifiers)
|
|
479
|
+
if any([q.matches(normalized_qualifier) for q in column_qualifiers]):
|
|
480
|
+
spark_columns.append(column.spark_name)
|
|
481
|
+
snowpark_columns.append(column.snowpark_name)
|
|
482
|
+
qualifiers.append(column.qualifiers)
|
|
529
483
|
|
|
530
484
|
return spark_columns, snowpark_columns, qualifiers
|
|
531
485
|
|
|
@@ -539,19 +493,17 @@ class ColumnNameMap:
|
|
|
539
493
|
if self._quote_if_unquoted(c) not in cols_to_drop
|
|
540
494
|
]
|
|
541
495
|
|
|
542
|
-
def get_qualifiers(self) -> list[
|
|
496
|
+
def get_qualifiers(self) -> list[set[ColumnQualifier]]:
|
|
543
497
|
"""
|
|
544
498
|
Returns the qualifiers for the columns.
|
|
545
|
-
If a column does not have a qualifier, it will be None.
|
|
546
499
|
"""
|
|
547
500
|
return [c.qualifiers for c in self.columns]
|
|
548
501
|
|
|
549
502
|
def get_qualifiers_for_columns_after_drop(
|
|
550
503
|
self, cols_to_drop: list[str]
|
|
551
|
-
) -> list[
|
|
504
|
+
) -> list[set[ColumnQualifier]]:
|
|
552
505
|
"""
|
|
553
506
|
Returns the qualifiers for the columns after dropping the specified columns.
|
|
554
|
-
If a column is dropped, its qualifier will be None.
|
|
555
507
|
"""
|
|
556
508
|
return [
|
|
557
509
|
c.qualifiers
|
|
@@ -562,10 +514,25 @@ class ColumnNameMap:
|
|
|
562
514
|
def get_qualifier_for_spark_column(
|
|
563
515
|
self,
|
|
564
516
|
spark_column_name: str,
|
|
565
|
-
) ->
|
|
517
|
+
) -> ColumnQualifier:
|
|
518
|
+
"""
|
|
519
|
+
Backward compatibility: returns the first qualifier for the given Spark column name.
|
|
520
|
+
Throws if more than one qualifier exists.
|
|
521
|
+
"""
|
|
522
|
+
qualifiers = self.get_qualifiers_for_spark_column(spark_column_name)
|
|
523
|
+
if len(qualifiers) > 1:
|
|
524
|
+
raise ValueError(
|
|
525
|
+
"Shouldn't happen. Multiple qualifiers found; expected only one."
|
|
526
|
+
)
|
|
527
|
+
return next(iter(qualifiers))
|
|
528
|
+
|
|
529
|
+
def get_qualifiers_for_spark_column(
|
|
530
|
+
self,
|
|
531
|
+
spark_column_name: str,
|
|
532
|
+
) -> set[ColumnQualifier]:
|
|
566
533
|
"""
|
|
567
534
|
Returns the qualifier for the specified Spark column name.
|
|
568
|
-
If the column does not exist, returns
|
|
535
|
+
If the column does not exist, returns empty ColumnQualifier.
|
|
569
536
|
"""
|
|
570
537
|
if not self.is_case_sensitive():
|
|
571
538
|
name = spark_column_name.upper()
|
|
@@ -577,7 +544,7 @@ class ColumnNameMap:
|
|
|
577
544
|
col = mapping.get(name)
|
|
578
545
|
|
|
579
546
|
if col is None or len(col) == 0:
|
|
580
|
-
return
|
|
547
|
+
return {ColumnQualifier.no_qualifier()}
|
|
581
548
|
|
|
582
549
|
return col[0].qualifiers
|
|
583
550
|
|
|
@@ -609,7 +576,7 @@ class ColumnNameMap:
|
|
|
609
576
|
|
|
610
577
|
def with_columns(
|
|
611
578
|
self, new_spark_columns: list[str], new_snowpark_columns: list[str]
|
|
612
|
-
) -> tuple[list[str], list[str], list[
|
|
579
|
+
) -> tuple[list[str], list[str], list[set[ColumnQualifier]]]:
|
|
613
580
|
"""
|
|
614
581
|
Returns an ordered list of spark and snowpark column names after adding the new columns through a withColumns call.
|
|
615
582
|
All replaced columns retain their ordering in the dataframe. The new columns are added to the end of the list.
|
|
@@ -638,7 +605,7 @@ class ColumnNameMap:
|
|
|
638
605
|
removed_index.add(index)
|
|
639
606
|
spark_columns.append(new_spark_columns[index])
|
|
640
607
|
snowpark_columns.append(new_snowpark_columns[index])
|
|
641
|
-
qualifiers.append(
|
|
608
|
+
qualifiers.append({ColumnQualifier.no_qualifier()})
|
|
642
609
|
else:
|
|
643
610
|
spark_columns.append(c.spark_name)
|
|
644
611
|
snowpark_columns.append(c.snowpark_name)
|
|
@@ -648,7 +615,7 @@ class ColumnNameMap:
|
|
|
648
615
|
if i not in removed_index:
|
|
649
616
|
spark_columns.append(new_spark_columns[i])
|
|
650
617
|
snowpark_columns.append(new_snowpark_columns[i])
|
|
651
|
-
qualifiers.append(
|
|
618
|
+
qualifiers.append({ColumnQualifier.no_qualifier()})
|
|
652
619
|
|
|
653
620
|
return spark_columns, snowpark_columns, qualifiers
|
|
654
621
|
|
|
@@ -745,7 +712,9 @@ class JoinColumnNameMap(ColumnNameMap):
|
|
|
745
712
|
raise exception
|
|
746
713
|
|
|
747
714
|
def get_spark_column_name_from_snowpark_column_name(
|
|
748
|
-
self,
|
|
715
|
+
self,
|
|
716
|
+
snowpark_column_name: str,
|
|
717
|
+
allow_non_exists: bool = False,
|
|
749
718
|
) -> str:
|
|
750
719
|
exception = NotImplementedError("Method not implemented!")
|
|
751
720
|
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
@@ -788,32 +757,36 @@ class JoinColumnNameMap(ColumnNameMap):
|
|
|
788
757
|
|
|
789
758
|
def with_columns(
|
|
790
759
|
self, new_spark_columns: list[str], new_snowpark_columns: list[str]
|
|
791
|
-
) -> tuple[list[str], list[str], list[
|
|
760
|
+
) -> tuple[list[str], list[str], list[set[ColumnQualifier]]]:
|
|
792
761
|
exception = NotImplementedError("Method not implemented!")
|
|
793
762
|
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
794
763
|
raise exception
|
|
795
764
|
|
|
796
|
-
def get_qualifiers(self) -> list[
|
|
765
|
+
def get_qualifiers(self) -> list[set[ColumnQualifier]]:
|
|
797
766
|
exception = NotImplementedError("Method not implemented!")
|
|
798
767
|
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
799
768
|
raise exception
|
|
800
769
|
|
|
801
770
|
def get_qualifiers_for_columns_after_drop(
|
|
802
771
|
self, cols_to_drop: list[str]
|
|
803
|
-
) -> list[
|
|
772
|
+
) -> list[set[ColumnQualifier]]:
|
|
804
773
|
exception = NotImplementedError("Method not implemented!")
|
|
805
774
|
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
806
775
|
raise exception
|
|
807
776
|
|
|
808
777
|
def get_spark_and_snowpark_columns_with_qualifier_for_qualifier(
|
|
809
|
-
self,
|
|
810
|
-
) -> tuple[list[str], list[str], list[
|
|
778
|
+
self, target_qualifier: list[str]
|
|
779
|
+
) -> tuple[list[str], list[str], list[set[ColumnQualifier]]]:
|
|
811
780
|
exception = NotImplementedError("Method not implemented!")
|
|
812
781
|
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
813
782
|
raise exception
|
|
814
783
|
|
|
815
|
-
def
|
|
784
|
+
def get_qualifiers_for_spark_column(
|
|
785
|
+
self, spark_column_name: str
|
|
786
|
+
) -> set[ColumnQualifier]:
|
|
787
|
+
return {self.get_qualifier_for_spark_column(spark_column_name)}
|
|
816
788
|
|
|
789
|
+
def get_qualifier_for_spark_column(self, spark_column_name: str) -> ColumnQualifier:
|
|
817
790
|
qualifier_left = self.left_column_mapping.get_qualifier_for_spark_column(
|
|
818
791
|
spark_column_name
|
|
819
792
|
)
|
|
@@ -821,9 +794,9 @@ class JoinColumnNameMap(ColumnNameMap):
|
|
|
821
794
|
spark_column_name
|
|
822
795
|
)
|
|
823
796
|
|
|
824
|
-
if (
|
|
797
|
+
if (not qualifier_left.is_empty) and (not qualifier_right.is_empty):
|
|
825
798
|
exception = AnalysisException(f"Ambiguous column name {spark_column_name}")
|
|
826
799
|
attach_custom_error_code(exception, ErrorCodes.AMBIGUOUS_COLUMN_NAME)
|
|
827
800
|
raise exception
|
|
828
801
|
|
|
829
|
-
return qualifier_right if
|
|
802
|
+
return qualifier_right if qualifier_left.is_empty else qualifier_left
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
10
|
+
quote_name_without_upper_casing,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class ColumnQualifier:
|
|
16
|
+
parts: tuple[str, ...]
|
|
17
|
+
|
|
18
|
+
def __post_init__(self) -> None:
|
|
19
|
+
if not all(isinstance(x, str) for x in self.parts):
|
|
20
|
+
raise TypeError("ColumnQualifier.parts must be strings")
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def is_empty(self) -> bool:
|
|
24
|
+
return len(self.parts) == 0
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def no_qualifier(cls) -> ColumnQualifier:
|
|
28
|
+
return cls(())
|
|
29
|
+
|
|
30
|
+
def all_qualified_names(self, name: str) -> list[str]:
|
|
31
|
+
qualifier_parts = self.parts
|
|
32
|
+
qualifier_prefixes = [
|
|
33
|
+
".".join(quote_name_without_upper_casing(x) for x in qualifier_parts[i:])
|
|
34
|
+
for i in range(len(qualifier_parts))
|
|
35
|
+
]
|
|
36
|
+
return [f"{prefix}.{name}" for prefix in qualifier_prefixes]
|
|
37
|
+
|
|
38
|
+
def to_upper(self):
|
|
39
|
+
return ColumnQualifier(tuple(part.upper() for part in self.parts))
|
|
40
|
+
|
|
41
|
+
def matches(self, target: ColumnQualifier) -> bool:
|
|
42
|
+
if self.is_empty or target.is_empty:
|
|
43
|
+
return False
|
|
44
|
+
# If the column has fewer qualifiers than the target, it cannot match
|
|
45
|
+
if len(self.parts) < len(target.parts):
|
|
46
|
+
return False
|
|
47
|
+
return self.parts[-len(target.parts) :] == target.parts
|
|
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Callable
|
|
|
8
8
|
|
|
9
9
|
from snowflake import snowpark
|
|
10
10
|
from snowflake.snowpark.types import StructField, StructType
|
|
11
|
+
from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
|
|
11
12
|
|
|
12
13
|
if TYPE_CHECKING:
|
|
13
14
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
@@ -58,7 +59,7 @@ class DataFrameContainer:
|
|
|
58
59
|
snowpark_column_names: list[str],
|
|
59
60
|
snowpark_column_types: list | None = None,
|
|
60
61
|
column_metadata: dict | None = None,
|
|
61
|
-
column_qualifiers: list[
|
|
62
|
+
column_qualifiers: list[set[ColumnQualifier]] | None = None,
|
|
62
63
|
parent_column_name_map: ColumnNameMap | None = None,
|
|
63
64
|
table_name: str | None = None,
|
|
64
65
|
alias: str | None = None,
|
|
@@ -220,7 +221,7 @@ class DataFrameContainer:
|
|
|
220
221
|
spark_column_names: list[str],
|
|
221
222
|
snowpark_column_names: list[str],
|
|
222
223
|
column_metadata: dict | None = None,
|
|
223
|
-
column_qualifiers: list[
|
|
224
|
+
column_qualifiers: list[set[ColumnQualifier]] | None = None,
|
|
224
225
|
parent_column_name_map: ColumnNameMap | None = None,
|
|
225
226
|
) -> ColumnNameMap:
|
|
226
227
|
"""Create a ColumnNameMap with the provided configuration."""
|
|
@@ -54,9 +54,11 @@ def _create_column_rename_map(
|
|
|
54
54
|
new_column_name = (
|
|
55
55
|
f"{new_column_name}_DEDUP_{column_counts[normalized_name] - 1}"
|
|
56
56
|
)
|
|
57
|
-
renamed_cols.append(ColumnNames(new_column_name, col.snowpark_name,
|
|
57
|
+
renamed_cols.append(ColumnNames(new_column_name, col.snowpark_name, set()))
|
|
58
58
|
else:
|
|
59
|
-
not_renamed_cols.append(
|
|
59
|
+
not_renamed_cols.append(
|
|
60
|
+
ColumnNames(new_column_name, col.snowpark_name, set())
|
|
61
|
+
)
|
|
60
62
|
|
|
61
63
|
if len(renamed_cols) == 0:
|
|
62
64
|
return {
|
|
@@ -230,18 +230,18 @@ def map_expression(
|
|
|
230
230
|
| exp.sort_order.SORT_DIRECTION_ASCENDING
|
|
231
231
|
):
|
|
232
232
|
if exp.sort_order.null_ordering == exp.sort_order.SORT_NULLS_LAST:
|
|
233
|
-
|
|
233
|
+
col = snowpark_fn.asc_nulls_last(child_column.col)
|
|
234
234
|
else:
|
|
235
235
|
# If nulls are not specified or null_ordering is FIRST in the sort order, Spark defaults to nulls
|
|
236
236
|
# first in the case of ascending sort order.
|
|
237
|
-
|
|
237
|
+
col = snowpark_fn.asc_nulls_first(child_column.col)
|
|
238
238
|
case exp.sort_order.SORT_DIRECTION_DESCENDING:
|
|
239
239
|
if exp.sort_order.null_ordering == exp.sort_order.SORT_NULLS_FIRST:
|
|
240
|
-
|
|
240
|
+
col = snowpark_fn.desc_nulls_first(child_column.col)
|
|
241
241
|
else:
|
|
242
242
|
# If nulls are not specified or null_ordering is LAST in the sort order, Spark defaults to nulls
|
|
243
243
|
# last in the case of descending sort order.
|
|
244
|
-
|
|
244
|
+
col = snowpark_fn.desc_nulls_last(child_column.col)
|
|
245
245
|
case _:
|
|
246
246
|
exception = ValueError(
|
|
247
247
|
f"Invalid sort direction {exp.sort_order.direction}"
|
|
@@ -250,6 +250,7 @@ def map_expression(
|
|
|
250
250
|
exception, ErrorCodes.INVALID_FUNCTION_ARGUMENT
|
|
251
251
|
)
|
|
252
252
|
raise exception
|
|
253
|
+
return [child_name], TypedColumn(col, lambda: typer.type(col))
|
|
253
254
|
case "unresolved_attribute":
|
|
254
255
|
col_name, col = map_att.map_unresolved_attribute(exp, column_mapping, typer)
|
|
255
256
|
# Check if this is a multi-column regex expansion
|
|
@@ -277,6 +277,9 @@ def _format_day_time_interval(
|
|
|
277
277
|
if is_negative:
|
|
278
278
|
days = -days
|
|
279
279
|
|
|
280
|
+
# Calculate days string representation (handle -0 case)
|
|
281
|
+
days_str = "-0" if (is_negative and days == 0) else str(days)
|
|
282
|
+
|
|
280
283
|
# Format based on the specific start/end field context
|
|
281
284
|
if (
|
|
282
285
|
start_field == DayTimeIntervalType.DAY and end_field == DayTimeIntervalType.DAY
|
|
@@ -344,7 +347,10 @@ def _format_day_time_interval(
|
|
|
344
347
|
start_field == DayTimeIntervalType.HOUR
|
|
345
348
|
and end_field == DayTimeIntervalType.MINUTE
|
|
346
349
|
): # HOUR TO MINUTE
|
|
347
|
-
|
|
350
|
+
if is_negative:
|
|
351
|
+
str_value = f"INTERVAL '-{_TWO_DIGIT_FORMAT.format(hours)}:{_TWO_DIGIT_FORMAT.format(minutes)}' HOUR TO MINUTE"
|
|
352
|
+
else:
|
|
353
|
+
str_value = f"INTERVAL '{_TWO_DIGIT_FORMAT.format(hours)}:{_TWO_DIGIT_FORMAT.format(minutes)}' HOUR TO MINUTE"
|
|
348
354
|
elif (
|
|
349
355
|
start_field == DayTimeIntervalType.HOUR
|
|
350
356
|
and end_field == DayTimeIntervalType.SECOND
|
|
@@ -368,21 +374,21 @@ def _format_day_time_interval(
|
|
|
368
374
|
and end_field == DayTimeIntervalType.SECOND
|
|
369
375
|
): # DAY TO SECOND
|
|
370
376
|
if seconds == int(seconds):
|
|
371
|
-
str_value = f"INTERVAL '{
|
|
377
|
+
str_value = f"INTERVAL '{days_str} {_format_time_component(hours)}:{_format_time_component(minutes)}:{_format_time_component(int(seconds))}' DAY TO SECOND"
|
|
372
378
|
else:
|
|
373
379
|
seconds_str = _format_seconds_precise(seconds)
|
|
374
|
-
str_value = f"INTERVAL '{
|
|
380
|
+
str_value = f"INTERVAL '{days_str} {_format_time_component(hours)}:{_format_time_component(minutes)}:{seconds_str}' DAY TO SECOND"
|
|
375
381
|
else:
|
|
376
382
|
# Fallback - use smart formatting like the original literal.py logic
|
|
377
|
-
if days
|
|
383
|
+
if days >= 0:
|
|
378
384
|
if hours == 0 and minutes == 0 and seconds == 0:
|
|
379
385
|
str_value = f"INTERVAL '{int(days)}' DAY"
|
|
380
386
|
else:
|
|
381
387
|
if seconds == int(seconds):
|
|
382
|
-
str_value = f"INTERVAL '{
|
|
388
|
+
str_value = f"INTERVAL '{days_str} {_format_time_component(hours)}:{_format_time_component(minutes)}:{_format_time_component(int(seconds))}' DAY TO SECOND"
|
|
383
389
|
else:
|
|
384
390
|
seconds_str = _format_seconds_precise(seconds)
|
|
385
|
-
str_value = f"INTERVAL '{
|
|
391
|
+
str_value = f"INTERVAL '{days_str} {_format_time_component(hours)}:{_format_time_component(minutes)}:{seconds_str}' DAY TO SECOND"
|
|
386
392
|
elif hours > 0:
|
|
387
393
|
if minutes == 0 and seconds == 0:
|
|
388
394
|
str_value = f"INTERVAL '{_format_time_component(hours)}' HOUR"
|
|
@@ -260,12 +260,47 @@ def map_logical_plan_expression(exp: jpype.JObject) -> expressions_proto.Express
|
|
|
260
260
|
class_name = str(exp.getClass().getSimpleName())
|
|
261
261
|
match class_name:
|
|
262
262
|
case "AggregateExpression":
|
|
263
|
-
|
|
263
|
+
aggregate_func = as_java_list(exp.children())[0]
|
|
264
|
+
func_name = aggregate_func.nodeName()
|
|
264
265
|
args = [
|
|
265
266
|
map_logical_plan_expression(e)
|
|
266
|
-
for e in list(as_java_list(
|
|
267
|
+
for e in list(as_java_list(aggregate_func.children()))
|
|
267
268
|
]
|
|
268
|
-
|
|
269
|
+
|
|
270
|
+
# Special handling for percentile_cont and percentile_disc
|
|
271
|
+
# These functions have a 'reverse' property that indicates sort order
|
|
272
|
+
# Pass it as a 3rd argument (sort_order expression) without modifying children
|
|
273
|
+
if func_name.lower() in ("percentile_cont", "percentiledisc"):
|
|
274
|
+
# percentile_cont/disc should always have exactly 2 children: unresolved attribute and percentile value
|
|
275
|
+
if len(args) != 2:
|
|
276
|
+
exception = AssertionError(
|
|
277
|
+
f"{func_name} expected 2 args but got {len(args)}"
|
|
278
|
+
)
|
|
279
|
+
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
280
|
+
raise exception
|
|
281
|
+
|
|
282
|
+
reverse = bool(aggregate_func.reverse())
|
|
283
|
+
|
|
284
|
+
direction = (
|
|
285
|
+
expressions_proto.Expression.SortOrder.SORT_DIRECTION_DESCENDING
|
|
286
|
+
if reverse
|
|
287
|
+
else expressions_proto.Expression.SortOrder.SORT_DIRECTION_ASCENDING
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
sort_order_expr = expressions_proto.Expression(
|
|
291
|
+
sort_order=expressions_proto.Expression.SortOrder(
|
|
292
|
+
child=args[0],
|
|
293
|
+
direction=direction,
|
|
294
|
+
)
|
|
295
|
+
)
|
|
296
|
+
args.append(sort_order_expr)
|
|
297
|
+
proto = apply_filter_clause(func_name, [args[0]], exp)
|
|
298
|
+
# second arg is a literal value and it doesn't make sense to apply filter on it.
|
|
299
|
+
# also skips filtering on sort_order.
|
|
300
|
+
proto.unresolved_function.arguments.append(args[1])
|
|
301
|
+
proto.unresolved_function.arguments.append(sort_order_expr)
|
|
302
|
+
else:
|
|
303
|
+
proto = apply_filter_clause(func_name, args, exp)
|
|
269
304
|
case "Alias":
|
|
270
305
|
proto = expressions_proto.Expression(
|
|
271
306
|
alias=expressions_proto.Expression.Alias(
|
|
@@ -247,7 +247,7 @@ def map_unresolved_attribute(
|
|
|
247
247
|
)
|
|
248
248
|
)
|
|
249
249
|
col = get_col(snowpark_name)
|
|
250
|
-
qualifiers = column_mapping.
|
|
250
|
+
qualifiers = column_mapping.get_qualifiers_for_spark_column(quoted_col_name)
|
|
251
251
|
typed_col = TypedColumn(col, lambda: typer.type(col))
|
|
252
252
|
typed_col.set_qualifiers(qualifiers)
|
|
253
253
|
# Store matched columns info for later use
|
|
@@ -262,7 +262,7 @@ def map_unresolved_attribute(
|
|
|
262
262
|
)
|
|
263
263
|
)
|
|
264
264
|
col = get_col(snowpark_name)
|
|
265
|
-
qualifiers = column_mapping.
|
|
265
|
+
qualifiers = column_mapping.get_qualifiers_for_spark_column(quoted_col_name)
|
|
266
266
|
typed_col = TypedColumn(col, lambda: typer.type(col))
|
|
267
267
|
typed_col.set_qualifiers(qualifiers)
|
|
268
268
|
return (matched_columns[0], typed_col)
|
|
@@ -280,7 +280,7 @@ def map_unresolved_attribute(
|
|
|
280
280
|
)
|
|
281
281
|
if snowpark_name is not None:
|
|
282
282
|
col = get_col(snowpark_name)
|
|
283
|
-
qualifiers = column_mapping.
|
|
283
|
+
qualifiers = column_mapping.get_qualifiers_for_spark_column(quoted_attr_name)
|
|
284
284
|
else:
|
|
285
285
|
# this means it has to be a struct column with a field name
|
|
286
286
|
snowpark_name: str | None = None
|
|
@@ -338,7 +338,7 @@ def map_unresolved_attribute(
|
|
|
338
338
|
)
|
|
339
339
|
if snowpark_name is not None:
|
|
340
340
|
col = get_col(snowpark_name)
|
|
341
|
-
qualifiers = column_mapping.
|
|
341
|
+
qualifiers = column_mapping.get_qualifiers_for_spark_column(
|
|
342
342
|
unqualified_name
|
|
343
343
|
)
|
|
344
344
|
typed_col = TypedColumn(col, lambda: typer.type(col))
|
|
@@ -405,7 +405,7 @@ def map_unresolved_attribute(
|
|
|
405
405
|
for field_name in path:
|
|
406
406
|
col = col.getItem(field_name)
|
|
407
407
|
|
|
408
|
-
qualifiers =
|
|
408
|
+
qualifiers = set()
|
|
409
409
|
|
|
410
410
|
typed_col = TypedColumn(col, lambda: typer.type(col))
|
|
411
411
|
typed_col.set_qualifiers(qualifiers)
|