snowpark-connect 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (38) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/column_name_handler.py +73 -100
  3. snowflake/snowpark_connect/column_qualifier.py +47 -0
  4. snowflake/snowpark_connect/dataframe_container.py +3 -2
  5. snowflake/snowpark_connect/execute_plan/map_execution_command.py +4 -2
  6. snowflake/snowpark_connect/expression/map_expression.py +5 -4
  7. snowflake/snowpark_connect/expression/map_extension.py +12 -6
  8. snowflake/snowpark_connect/expression/map_sql_expression.py +38 -3
  9. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +5 -5
  10. snowflake/snowpark_connect/expression/map_unresolved_function.py +869 -107
  11. snowflake/snowpark_connect/expression/map_unresolved_star.py +9 -7
  12. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +4 -1
  13. snowflake/snowpark_connect/relation/map_aggregate.py +8 -5
  14. snowflake/snowpark_connect/relation/map_column_ops.py +4 -3
  15. snowflake/snowpark_connect/relation/map_extension.py +10 -9
  16. snowflake/snowpark_connect/relation/map_join.py +5 -2
  17. snowflake/snowpark_connect/relation/map_sql.py +33 -1
  18. snowflake/snowpark_connect/relation/map_subquery_alias.py +4 -1
  19. snowflake/snowpark_connect/relation/read/map_read_table.py +6 -3
  20. snowflake/snowpark_connect/relation/write/map_write.py +29 -14
  21. snowflake/snowpark_connect/server.py +1 -2
  22. snowflake/snowpark_connect/type_mapping.py +36 -3
  23. snowflake/snowpark_connect/typed_column.py +8 -6
  24. snowflake/snowpark_connect/utils/session.py +19 -3
  25. snowflake/snowpark_connect/version.py +1 -1
  26. snowflake/snowpark_decoder/dp_session.py +1 -1
  27. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/METADATA +5 -2
  28. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/RECORD +36 -37
  29. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
  30. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
  31. {snowpark_connect-0.31.0.data → snowpark_connect-0.32.0.data}/scripts/snowpark-connect +0 -0
  32. {snowpark_connect-0.31.0.data → snowpark_connect-0.32.0.data}/scripts/snowpark-session +0 -0
  33. {snowpark_connect-0.31.0.data → snowpark_connect-0.32.0.data}/scripts/snowpark-submit +0 -0
  34. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/WHEEL +0 -0
  35. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/LICENSE-binary +0 -0
  36. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/LICENSE.txt +0 -0
  37. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/licenses/NOTICE-binary +0 -0
  38. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.32.0.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,7 @@ sys.path.append(str(pathlib.Path(__file__).parent / "includes/python"))
10
10
 
11
11
  from .server import get_session # noqa: E402, F401
12
12
  from .server import start_session # noqa: E402, F401
13
+ from .utils.session import skip_session_configuration # noqa: E402, F401
13
14
 
14
15
  # Turn off catalog warning for Snowpark
15
16
  sp_logger = logging.getLogger("snowflake.snowpark")
@@ -13,12 +13,10 @@ from functools import cached_property
13
13
  from pyspark.errors.exceptions.base import AnalysisException
14
14
 
15
15
  from snowflake.snowpark import DataFrame
16
- from snowflake.snowpark._internal.analyzer.analyzer_utils import (
17
- quote_name_without_upper_casing,
18
- unquote_if_quoted,
19
- )
16
+ from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
20
17
  from snowflake.snowpark._internal.utils import quote_name
21
18
  from snowflake.snowpark.types import StructType
19
+ from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
22
20
  from snowflake.snowpark_connect.config import global_config
23
21
  from snowflake.snowpark_connect.error.error_codes import ErrorCodes
24
22
  from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
@@ -97,31 +95,15 @@ def make_column_names_snowpark_compatible(
97
95
  class ColumnNames:
98
96
  spark_name: str
99
97
  snowpark_name: str
100
- qualifiers: list[str]
98
+ qualifiers: set[ColumnQualifier]
101
99
  catalog_info: str | None = None # Catalog from fully qualified name
102
100
  database_info: str | None = None # Database from fully qualified name
103
101
 
104
-
105
- def get_list_of_spark_names_for_column(column_names: ColumnNames) -> list[str]:
106
- """
107
- Returns a list of Spark names for a given ColumnNames object.
108
- This is useful when a single Spark name maps to multiple names due to table alias.
109
-
110
- For example, if the column name is 'id' and the qualifiers are ['db', 'table'],
111
- then the possible Spark names are:
112
- ['id', 'db.table.id', 'table.id']
113
- """
114
- spark_name = column_names.spark_name
115
- qualifiers = column_names.qualifiers
116
-
117
- qualifier_suffixes_list = [
118
- ".".join(quote_name_without_upper_casing(x) for x in qualifiers[i:])
119
- for i in range(len(qualifiers))
120
- ]
121
- return [spark_name] + [
122
- f"{qualifier_suffix}.{spark_name}"
123
- for qualifier_suffix in qualifier_suffixes_list
124
- ]
102
+ def all_spark_names_including_qualified_names(self):
103
+ all_names = [self.spark_name]
104
+ for qualifier in self.qualifiers:
105
+ all_names.extend(qualifier.all_qualified_names(self.spark_name))
106
+ return all_names
125
107
 
126
108
 
127
109
  class ColumnNameMap:
@@ -133,13 +115,13 @@ class ColumnNameMap:
133
115
  [], bool
134
116
  ] = lambda: global_config.spark_sql_caseSensitive,
135
117
  column_metadata: dict | None = None,
136
- column_qualifiers: list[list[str]] | None = None,
118
+ column_qualifiers: list[set[ColumnQualifier]] = None,
137
119
  parent_column_name_map: ColumnNameMap | None = None,
138
120
  ) -> None:
139
121
  """
140
122
  spark_column_names: Original spark column names
141
123
  snowpark_column_names: Snowpark column names
142
- column_metadata: This field is used to store metadata related to columns. Since Snowparks Struct type does not support metadata,
124
+ column_metadata: This field is used to store metadata related to columns. Since Snowpark's Struct type does not support metadata,
143
125
  we use this attribute to store any metadata related to the columns.
144
126
  The key is the original Spark column name, and the value is the metadata.
145
127
  example: Dict('age', {'foo': 'bar'})
@@ -186,21 +168,18 @@ class ColumnNameMap:
186
168
  c = ColumnNames(
187
169
  spark_name=spark_name,
188
170
  snowpark_name=snowpark_column_names[i],
189
- qualifiers=column_qualifiers[i] if column_qualifiers else [],
171
+ qualifiers=column_qualifiers[i]
172
+ if column_qualifiers and column_qualifiers[i]
173
+ else {ColumnQualifier.no_qualifier()},
190
174
  catalog_info=catalog_info,
191
175
  database_info=database_info,
192
176
  )
193
177
  self.columns.append(c)
194
178
 
195
- # we want to store all the spark names including qualifiers (these are generated from table alias or dataframe alias)
196
- spark_names_including_qualifier = get_list_of_spark_names_for_column(c)
197
-
198
- for spark_name_including_qualifier in spark_names_including_qualifier:
179
+ for spark_name in c.all_spark_names_including_qualified_names():
199
180
  # the same spark name can map to multiple snowpark names
200
- self.spark_to_col[spark_name_including_qualifier].append(c)
201
- self.uppercase_spark_to_col[
202
- spark_name_including_qualifier.upper()
203
- ].append(c)
181
+ self.spark_to_col[spark_name].append(c)
182
+ self.uppercase_spark_to_col[spark_name.upper()].append(c)
204
183
 
205
184
  # the same snowpark name can map to multiple spark column
206
185
  # e.g. df.select(date_format('dt', 'yyy'), date_format('dt', 'yyyy')) ->
@@ -405,7 +384,7 @@ class ColumnNameMap:
405
384
  if spark_column_name not in self.spark_to_col:
406
385
  return False
407
386
 
408
- columns = self.spark_to_col[spark_column_name]
387
+ columns: list[ColumnNames] = self.spark_to_col[spark_column_name]
409
388
 
410
389
  # If we don't have multiple columns, there's no ambiguity to resolve
411
390
  if len(columns) <= 1:
@@ -416,30 +395,7 @@ class ColumnNameMap:
416
395
  first_column = columns[0]
417
396
 
418
397
  for column in columns[1:]:
419
- # Check snowpark_type attribute
420
- # If one has the attribute but the other doesn't, they're different
421
- if hasattr(first_column, "snowpark_type") != hasattr(
422
- column, "snowpark_type"
423
- ):
424
- return False
425
- # If both have the attribute and values differ, they're different expressions
426
- if (
427
- hasattr(first_column, "snowpark_type")
428
- and hasattr(column, "snowpark_type")
429
- and first_column.snowpark_type != column.snowpark_type
430
- ):
431
- return False
432
-
433
- # Check qualifiers attribute
434
- # If one has the attribute but the other doesn't, they're different
435
- if hasattr(first_column, "qualifiers") != hasattr(column, "qualifiers"):
436
- return False
437
- # If both have the attribute and values differ, they might be from different contexts
438
- if (
439
- hasattr(first_column, "qualifiers")
440
- and hasattr(column, "qualifiers")
441
- and first_column.qualifiers != column.qualifiers
442
- ):
398
+ if first_column.qualifiers != column.qualifiers:
443
399
  return False
444
400
 
445
401
  # Additional safety check: ensure all snowpark names are actually in our mapping
@@ -500,32 +456,30 @@ class ColumnNameMap:
500
456
  return [c.spark_name for c in self.columns]
501
457
 
502
458
  def get_spark_and_snowpark_columns_with_qualifier_for_qualifier(
503
- self, qualifiers_input: list[str]
504
- ) -> tuple[list[str], list[str], list[list[str]]]:
459
+ self, target_qualifier: ColumnQualifier
460
+ ) -> tuple[list[str], list[str], list[set[ColumnQualifier]]]:
505
461
  """
506
- Returns the Spark and Snowpark column names along with their qualifiers for the specified qualifiers.
507
- If a column does not have a qualifier, it will be None.
462
+ Returns the Spark and Snowpark column names along with their qualifiers for the specified qualifier.
508
463
  """
509
- spark_columns = []
510
- snowpark_columns = []
511
- qualifiers = []
464
+ spark_columns: list[str] = []
465
+ snowpark_columns: list[str] = []
466
+ qualifiers: list[set[ColumnQualifier]] = []
512
467
 
468
+ normalized_qualifier = target_qualifier
513
469
  if not self.is_case_sensitive():
514
- qualifiers_input = [q.upper() for q in qualifiers_input]
470
+ normalized_qualifier = target_qualifier.to_upper()
515
471
 
516
- for c in self.columns:
517
- col_qualifiers = (
518
- [q.upper() for q in c.qualifiers]
472
+ for column in self.columns:
473
+ # Normalize all qualifiers for comparison
474
+ column_qualifiers: set[ColumnQualifier] = (
475
+ {q.to_upper() for q in iter(column.qualifiers)}
519
476
  if not self.is_case_sensitive()
520
- else c.qualifiers
477
+ else column.qualifiers
521
478
  )
522
- if len(col_qualifiers) < len(qualifiers_input):
523
- # If the column has fewer qualifiers than the input, it cannot match
524
- continue
525
- if col_qualifiers[-len(qualifiers_input) :] == qualifiers_input:
526
- spark_columns.append(c.spark_name)
527
- snowpark_columns.append(c.snowpark_name)
528
- qualifiers.append(c.qualifiers)
479
+ if any([q.matches(normalized_qualifier) for q in column_qualifiers]):
480
+ spark_columns.append(column.spark_name)
481
+ snowpark_columns.append(column.snowpark_name)
482
+ qualifiers.append(column.qualifiers)
529
483
 
530
484
  return spark_columns, snowpark_columns, qualifiers
531
485
 
@@ -539,19 +493,17 @@ class ColumnNameMap:
539
493
  if self._quote_if_unquoted(c) not in cols_to_drop
540
494
  ]
541
495
 
542
- def get_qualifiers(self) -> list[list[str]]:
496
+ def get_qualifiers(self) -> list[set[ColumnQualifier]]:
543
497
  """
544
498
  Returns the qualifiers for the columns.
545
- If a column does not have a qualifier, it will be None.
546
499
  """
547
500
  return [c.qualifiers for c in self.columns]
548
501
 
549
502
  def get_qualifiers_for_columns_after_drop(
550
503
  self, cols_to_drop: list[str]
551
- ) -> list[list[str]]:
504
+ ) -> list[set[ColumnQualifier]]:
552
505
  """
553
506
  Returns the qualifiers for the columns after dropping the specified columns.
554
- If a column is dropped, its qualifier will be None.
555
507
  """
556
508
  return [
557
509
  c.qualifiers
@@ -562,10 +514,25 @@ class ColumnNameMap:
562
514
  def get_qualifier_for_spark_column(
563
515
  self,
564
516
  spark_column_name: str,
565
- ) -> list[str]:
517
+ ) -> ColumnQualifier:
518
+ """
519
+ Backward compatibility: returns the first qualifier for the given Spark column name.
520
+ Throws if more than one qualifier exists.
521
+ """
522
+ qualifiers = self.get_qualifiers_for_spark_column(spark_column_name)
523
+ if len(qualifiers) > 1:
524
+ raise ValueError(
525
+ "Shouldn't happen. Multiple qualifiers found; expected only one."
526
+ )
527
+ return next(iter(qualifiers))
528
+
529
+ def get_qualifiers_for_spark_column(
530
+ self,
531
+ spark_column_name: str,
532
+ ) -> set[ColumnQualifier]:
566
533
  """
567
534
  Returns the qualifier for the specified Spark column name.
568
- If the column does not exist, returns None.
535
+ If the column does not exist, returns empty ColumnQualifier.
569
536
  """
570
537
  if not self.is_case_sensitive():
571
538
  name = spark_column_name.upper()
@@ -577,7 +544,7 @@ class ColumnNameMap:
577
544
  col = mapping.get(name)
578
545
 
579
546
  if col is None or len(col) == 0:
580
- return []
547
+ return {ColumnQualifier.no_qualifier()}
581
548
 
582
549
  return col[0].qualifiers
583
550
 
@@ -609,7 +576,7 @@ class ColumnNameMap:
609
576
 
610
577
  def with_columns(
611
578
  self, new_spark_columns: list[str], new_snowpark_columns: list[str]
612
- ) -> tuple[list[str], list[str], list[list[str]]]:
579
+ ) -> tuple[list[str], list[str], list[set[ColumnQualifier]]]:
613
580
  """
614
581
  Returns an ordered list of spark and snowpark column names after adding the new columns through a withColumns call.
615
582
  All replaced columns retain their ordering in the dataframe. The new columns are added to the end of the list.
@@ -638,7 +605,7 @@ class ColumnNameMap:
638
605
  removed_index.add(index)
639
606
  spark_columns.append(new_spark_columns[index])
640
607
  snowpark_columns.append(new_snowpark_columns[index])
641
- qualifiers.append([])
608
+ qualifiers.append({ColumnQualifier.no_qualifier()})
642
609
  else:
643
610
  spark_columns.append(c.spark_name)
644
611
  snowpark_columns.append(c.snowpark_name)
@@ -648,7 +615,7 @@ class ColumnNameMap:
648
615
  if i not in removed_index:
649
616
  spark_columns.append(new_spark_columns[i])
650
617
  snowpark_columns.append(new_snowpark_columns[i])
651
- qualifiers.append([])
618
+ qualifiers.append({ColumnQualifier.no_qualifier()})
652
619
 
653
620
  return spark_columns, snowpark_columns, qualifiers
654
621
 
@@ -745,7 +712,9 @@ class JoinColumnNameMap(ColumnNameMap):
745
712
  raise exception
746
713
 
747
714
  def get_spark_column_name_from_snowpark_column_name(
748
- self, snowpark_column_name: str
715
+ self,
716
+ snowpark_column_name: str,
717
+ allow_non_exists: bool = False,
749
718
  ) -> str:
750
719
  exception = NotImplementedError("Method not implemented!")
751
720
  attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
@@ -788,32 +757,36 @@ class JoinColumnNameMap(ColumnNameMap):
788
757
 
789
758
  def with_columns(
790
759
  self, new_spark_columns: list[str], new_snowpark_columns: list[str]
791
- ) -> tuple[list[str], list[str], list[list[str]]]:
760
+ ) -> tuple[list[str], list[str], list[set[ColumnQualifier]]]:
792
761
  exception = NotImplementedError("Method not implemented!")
793
762
  attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
794
763
  raise exception
795
764
 
796
- def get_qualifiers(self) -> list[list[str]]:
765
+ def get_qualifiers(self) -> list[set[ColumnQualifier]]:
797
766
  exception = NotImplementedError("Method not implemented!")
798
767
  attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
799
768
  raise exception
800
769
 
801
770
  def get_qualifiers_for_columns_after_drop(
802
771
  self, cols_to_drop: list[str]
803
- ) -> list[list[str]]:
772
+ ) -> list[set[ColumnQualifier]]:
804
773
  exception = NotImplementedError("Method not implemented!")
805
774
  attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
806
775
  raise exception
807
776
 
808
777
  def get_spark_and_snowpark_columns_with_qualifier_for_qualifier(
809
- self, qualifiers_input: list[str]
810
- ) -> tuple[list[str], list[str], list[list[str]]]:
778
+ self, target_qualifier: list[str]
779
+ ) -> tuple[list[str], list[str], list[set[ColumnQualifier]]]:
811
780
  exception = NotImplementedError("Method not implemented!")
812
781
  attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
813
782
  raise exception
814
783
 
815
- def get_qualifier_for_spark_column(self, spark_column_name: str) -> list[str]:
784
+ def get_qualifiers_for_spark_column(
785
+ self, spark_column_name: str
786
+ ) -> set[ColumnQualifier]:
787
+ return {self.get_qualifier_for_spark_column(spark_column_name)}
816
788
 
789
+ def get_qualifier_for_spark_column(self, spark_column_name: str) -> ColumnQualifier:
817
790
  qualifier_left = self.left_column_mapping.get_qualifier_for_spark_column(
818
791
  spark_column_name
819
792
  )
@@ -821,9 +794,9 @@ class JoinColumnNameMap(ColumnNameMap):
821
794
  spark_column_name
822
795
  )
823
796
 
824
- if (len(qualifier_left) > 0) and (len(qualifier_right) > 0):
797
+ if (not qualifier_left.is_empty) and (not qualifier_right.is_empty):
825
798
  exception = AnalysisException(f"Ambiguous column name {spark_column_name}")
826
799
  attach_custom_error_code(exception, ErrorCodes.AMBIGUOUS_COLUMN_NAME)
827
800
  raise exception
828
801
 
829
- return qualifier_right if len(qualifier_left) == 0 else qualifier_left
802
+ return qualifier_right if qualifier_left.is_empty else qualifier_left
@@ -0,0 +1,47 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+
9
+ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
10
+ quote_name_without_upper_casing,
11
+ )
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class ColumnQualifier:
16
+ parts: tuple[str, ...]
17
+
18
+ def __post_init__(self) -> None:
19
+ if not all(isinstance(x, str) for x in self.parts):
20
+ raise TypeError("ColumnQualifier.parts must be strings")
21
+
22
+ @property
23
+ def is_empty(self) -> bool:
24
+ return len(self.parts) == 0
25
+
26
+ @classmethod
27
+ def no_qualifier(cls) -> ColumnQualifier:
28
+ return cls(())
29
+
30
+ def all_qualified_names(self, name: str) -> list[str]:
31
+ qualifier_parts = self.parts
32
+ qualifier_prefixes = [
33
+ ".".join(quote_name_without_upper_casing(x) for x in qualifier_parts[i:])
34
+ for i in range(len(qualifier_parts))
35
+ ]
36
+ return [f"{prefix}.{name}" for prefix in qualifier_prefixes]
37
+
38
+ def to_upper(self):
39
+ return ColumnQualifier(tuple(part.upper() for part in self.parts))
40
+
41
+ def matches(self, target: ColumnQualifier) -> bool:
42
+ if self.is_empty or target.is_empty:
43
+ return False
44
+ # If the column has fewer qualifiers than the target, it cannot match
45
+ if len(self.parts) < len(target.parts):
46
+ return False
47
+ return self.parts[-len(target.parts) :] == target.parts
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Callable
8
8
 
9
9
  from snowflake import snowpark
10
10
  from snowflake.snowpark.types import StructField, StructType
11
+ from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
@@ -58,7 +59,7 @@ class DataFrameContainer:
58
59
  snowpark_column_names: list[str],
59
60
  snowpark_column_types: list | None = None,
60
61
  column_metadata: dict | None = None,
61
- column_qualifiers: list[list[str]] | None = None,
62
+ column_qualifiers: list[set[ColumnQualifier]] | None = None,
62
63
  parent_column_name_map: ColumnNameMap | None = None,
63
64
  table_name: str | None = None,
64
65
  alias: str | None = None,
@@ -220,7 +221,7 @@ class DataFrameContainer:
220
221
  spark_column_names: list[str],
221
222
  snowpark_column_names: list[str],
222
223
  column_metadata: dict | None = None,
223
- column_qualifiers: list[list[str]] | None = None,
224
+ column_qualifiers: list[set[ColumnQualifier]] | None = None,
224
225
  parent_column_name_map: ColumnNameMap | None = None,
225
226
  ) -> ColumnNameMap:
226
227
  """Create a ColumnNameMap with the provided configuration."""
@@ -54,9 +54,11 @@ def _create_column_rename_map(
54
54
  new_column_name = (
55
55
  f"{new_column_name}_DEDUP_{column_counts[normalized_name] - 1}"
56
56
  )
57
- renamed_cols.append(ColumnNames(new_column_name, col.snowpark_name, []))
57
+ renamed_cols.append(ColumnNames(new_column_name, col.snowpark_name, set()))
58
58
  else:
59
- not_renamed_cols.append(ColumnNames(new_column_name, col.snowpark_name, []))
59
+ not_renamed_cols.append(
60
+ ColumnNames(new_column_name, col.snowpark_name, set())
61
+ )
60
62
 
61
63
  if len(renamed_cols) == 0:
62
64
  return {
@@ -230,18 +230,18 @@ def map_expression(
230
230
  | exp.sort_order.SORT_DIRECTION_ASCENDING
231
231
  ):
232
232
  if exp.sort_order.null_ordering == exp.sort_order.SORT_NULLS_LAST:
233
- return [child_name], snowpark_fn.asc_nulls_last(child_column)
233
+ col = snowpark_fn.asc_nulls_last(child_column.col)
234
234
  else:
235
235
  # If nulls are not specified or null_ordering is FIRST in the sort order, Spark defaults to nulls
236
236
  # first in the case of ascending sort order.
237
- return [child_name], snowpark_fn.asc_nulls_first(child_column)
237
+ col = snowpark_fn.asc_nulls_first(child_column.col)
238
238
  case exp.sort_order.SORT_DIRECTION_DESCENDING:
239
239
  if exp.sort_order.null_ordering == exp.sort_order.SORT_NULLS_FIRST:
240
- return [child_name], snowpark_fn.desc_nulls_first(child_column)
240
+ col = snowpark_fn.desc_nulls_first(child_column.col)
241
241
  else:
242
242
  # If nulls are not specified or null_ordering is LAST in the sort order, Spark defaults to nulls
243
243
  # last in the case of descending sort order.
244
- return [child_name], snowpark_fn.desc_nulls_last(child_column)
244
+ col = snowpark_fn.desc_nulls_last(child_column.col)
245
245
  case _:
246
246
  exception = ValueError(
247
247
  f"Invalid sort direction {exp.sort_order.direction}"
@@ -250,6 +250,7 @@ def map_expression(
250
250
  exception, ErrorCodes.INVALID_FUNCTION_ARGUMENT
251
251
  )
252
252
  raise exception
253
+ return [child_name], TypedColumn(col, lambda: typer.type(col))
253
254
  case "unresolved_attribute":
254
255
  col_name, col = map_att.map_unresolved_attribute(exp, column_mapping, typer)
255
256
  # Check if this is a multi-column regex expansion
@@ -277,6 +277,9 @@ def _format_day_time_interval(
277
277
  if is_negative:
278
278
  days = -days
279
279
 
280
+ # Calculate days string representation (handle -0 case)
281
+ days_str = "-0" if (is_negative and days == 0) else str(days)
282
+
280
283
  # Format based on the specific start/end field context
281
284
  if (
282
285
  start_field == DayTimeIntervalType.DAY and end_field == DayTimeIntervalType.DAY
@@ -344,7 +347,10 @@ def _format_day_time_interval(
344
347
  start_field == DayTimeIntervalType.HOUR
345
348
  and end_field == DayTimeIntervalType.MINUTE
346
349
  ): # HOUR TO MINUTE
347
- str_value = f"INTERVAL '{_TWO_DIGIT_FORMAT.format(hours)}:{_TWO_DIGIT_FORMAT.format(minutes)}' HOUR TO MINUTE"
350
+ if is_negative:
351
+ str_value = f"INTERVAL '-{_TWO_DIGIT_FORMAT.format(hours)}:{_TWO_DIGIT_FORMAT.format(minutes)}' HOUR TO MINUTE"
352
+ else:
353
+ str_value = f"INTERVAL '{_TWO_DIGIT_FORMAT.format(hours)}:{_TWO_DIGIT_FORMAT.format(minutes)}' HOUR TO MINUTE"
348
354
  elif (
349
355
  start_field == DayTimeIntervalType.HOUR
350
356
  and end_field == DayTimeIntervalType.SECOND
@@ -368,21 +374,21 @@ def _format_day_time_interval(
368
374
  and end_field == DayTimeIntervalType.SECOND
369
375
  ): # DAY TO SECOND
370
376
  if seconds == int(seconds):
371
- str_value = f"INTERVAL '{days} {_format_time_component(hours)}:{_format_time_component(minutes)}:{_format_time_component(int(seconds))}' DAY TO SECOND"
377
+ str_value = f"INTERVAL '{days_str} {_format_time_component(hours)}:{_format_time_component(minutes)}:{_format_time_component(int(seconds))}' DAY TO SECOND"
372
378
  else:
373
379
  seconds_str = _format_seconds_precise(seconds)
374
- str_value = f"INTERVAL '{days} {_format_time_component(hours)}:{_format_time_component(minutes)}:{seconds_str}' DAY TO SECOND"
380
+ str_value = f"INTERVAL '{days_str} {_format_time_component(hours)}:{_format_time_component(minutes)}:{seconds_str}' DAY TO SECOND"
375
381
  else:
376
382
  # Fallback - use smart formatting like the original literal.py logic
377
- if days > 0:
383
+ if days >= 0:
378
384
  if hours == 0 and minutes == 0 and seconds == 0:
379
385
  str_value = f"INTERVAL '{int(days)}' DAY"
380
386
  else:
381
387
  if seconds == int(seconds):
382
- str_value = f"INTERVAL '{days} {_format_time_component(hours)}:{_format_time_component(minutes)}:{_format_time_component(int(seconds))}' DAY TO SECOND"
388
+ str_value = f"INTERVAL '{days_str} {_format_time_component(hours)}:{_format_time_component(minutes)}:{_format_time_component(int(seconds))}' DAY TO SECOND"
383
389
  else:
384
390
  seconds_str = _format_seconds_precise(seconds)
385
- str_value = f"INTERVAL '{days} {_format_time_component(hours)}:{_format_time_component(minutes)}:{seconds_str}' DAY TO SECOND"
391
+ str_value = f"INTERVAL '{days_str} {_format_time_component(hours)}:{_format_time_component(minutes)}:{seconds_str}' DAY TO SECOND"
386
392
  elif hours > 0:
387
393
  if minutes == 0 and seconds == 0:
388
394
  str_value = f"INTERVAL '{_format_time_component(hours)}' HOUR"
@@ -260,12 +260,47 @@ def map_logical_plan_expression(exp: jpype.JObject) -> expressions_proto.Express
260
260
  class_name = str(exp.getClass().getSimpleName())
261
261
  match class_name:
262
262
  case "AggregateExpression":
263
- func_name = as_java_list(exp.children())[0].nodeName()
263
+ aggregate_func = as_java_list(exp.children())[0]
264
+ func_name = aggregate_func.nodeName()
264
265
  args = [
265
266
  map_logical_plan_expression(e)
266
- for e in list(as_java_list(as_java_list(exp.children())[0].children()))
267
+ for e in list(as_java_list(aggregate_func.children()))
267
268
  ]
268
- proto = apply_filter_clause(func_name, args, exp)
269
+
270
+ # Special handling for percentile_cont and percentile_disc
271
+ # These functions have a 'reverse' property that indicates sort order
272
+ # Pass it as a 3rd argument (sort_order expression) without modifying children
273
+ if func_name.lower() in ("percentile_cont", "percentiledisc"):
274
+ # percentile_cont/disc should always have exactly 2 children: unresolved attribute and percentile value
275
+ if len(args) != 2:
276
+ exception = AssertionError(
277
+ f"{func_name} expected 2 args but got {len(args)}"
278
+ )
279
+ attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
280
+ raise exception
281
+
282
+ reverse = bool(aggregate_func.reverse())
283
+
284
+ direction = (
285
+ expressions_proto.Expression.SortOrder.SORT_DIRECTION_DESCENDING
286
+ if reverse
287
+ else expressions_proto.Expression.SortOrder.SORT_DIRECTION_ASCENDING
288
+ )
289
+
290
+ sort_order_expr = expressions_proto.Expression(
291
+ sort_order=expressions_proto.Expression.SortOrder(
292
+ child=args[0],
293
+ direction=direction,
294
+ )
295
+ )
296
+ args.append(sort_order_expr)
297
+ proto = apply_filter_clause(func_name, [args[0]], exp)
298
+ # second arg is a literal value and it doesn't make sense to apply filter on it.
299
+ # also skips filtering on sort_order.
300
+ proto.unresolved_function.arguments.append(args[1])
301
+ proto.unresolved_function.arguments.append(sort_order_expr)
302
+ else:
303
+ proto = apply_filter_clause(func_name, args, exp)
269
304
  case "Alias":
270
305
  proto = expressions_proto.Expression(
271
306
  alias=expressions_proto.Expression.Alias(
@@ -247,7 +247,7 @@ def map_unresolved_attribute(
247
247
  )
248
248
  )
249
249
  col = get_col(snowpark_name)
250
- qualifiers = column_mapping.get_qualifier_for_spark_column(quoted_col_name)
250
+ qualifiers = column_mapping.get_qualifiers_for_spark_column(quoted_col_name)
251
251
  typed_col = TypedColumn(col, lambda: typer.type(col))
252
252
  typed_col.set_qualifiers(qualifiers)
253
253
  # Store matched columns info for later use
@@ -262,7 +262,7 @@ def map_unresolved_attribute(
262
262
  )
263
263
  )
264
264
  col = get_col(snowpark_name)
265
- qualifiers = column_mapping.get_qualifier_for_spark_column(quoted_col_name)
265
+ qualifiers = column_mapping.get_qualifiers_for_spark_column(quoted_col_name)
266
266
  typed_col = TypedColumn(col, lambda: typer.type(col))
267
267
  typed_col.set_qualifiers(qualifiers)
268
268
  return (matched_columns[0], typed_col)
@@ -280,7 +280,7 @@ def map_unresolved_attribute(
280
280
  )
281
281
  if snowpark_name is not None:
282
282
  col = get_col(snowpark_name)
283
- qualifiers = column_mapping.get_qualifier_for_spark_column(quoted_attr_name)
283
+ qualifiers = column_mapping.get_qualifiers_for_spark_column(quoted_attr_name)
284
284
  else:
285
285
  # this means it has to be a struct column with a field name
286
286
  snowpark_name: str | None = None
@@ -338,7 +338,7 @@ def map_unresolved_attribute(
338
338
  )
339
339
  if snowpark_name is not None:
340
340
  col = get_col(snowpark_name)
341
- qualifiers = column_mapping.get_qualifier_for_spark_column(
341
+ qualifiers = column_mapping.get_qualifiers_for_spark_column(
342
342
  unqualified_name
343
343
  )
344
344
  typed_col = TypedColumn(col, lambda: typer.type(col))
@@ -405,7 +405,7 @@ def map_unresolved_attribute(
405
405
  for field_name in path:
406
406
  col = col.getItem(field_name)
407
407
 
408
- qualifiers = []
408
+ qualifiers = set()
409
409
 
410
410
  typed_col = TypedColumn(col, lambda: typer.type(col))
411
411
  typed_col.set_qualifiers(qualifiers)