snowpark-connect 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (42) hide show
  1. snowflake/snowpark_connect/column_name_handler.py +3 -93
  2. snowflake/snowpark_connect/config.py +99 -4
  3. snowflake/snowpark_connect/dataframe_container.py +0 -6
  4. snowflake/snowpark_connect/expression/map_expression.py +31 -1
  5. snowflake/snowpark_connect/expression/map_sql_expression.py +22 -18
  6. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +22 -26
  7. snowflake/snowpark_connect/expression/map_unresolved_function.py +28 -10
  8. snowflake/snowpark_connect/expression/map_unresolved_star.py +2 -3
  9. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  10. snowflake/snowpark_connect/relation/map_extension.py +7 -1
  11. snowflake/snowpark_connect/relation/map_join.py +62 -258
  12. snowflake/snowpark_connect/relation/map_map_partitions.py +36 -77
  13. snowflake/snowpark_connect/relation/map_relation.py +8 -2
  14. snowflake/snowpark_connect/relation/map_show_string.py +2 -0
  15. snowflake/snowpark_connect/relation/map_sql.py +413 -15
  16. snowflake/snowpark_connect/relation/write/map_write.py +195 -114
  17. snowflake/snowpark_connect/resources_initializer.py +20 -5
  18. snowflake/snowpark_connect/server.py +20 -18
  19. snowflake/snowpark_connect/utils/artifacts.py +4 -5
  20. snowflake/snowpark_connect/utils/concurrent.py +4 -0
  21. snowflake/snowpark_connect/utils/context.py +41 -1
  22. snowflake/snowpark_connect/utils/describe_query_cache.py +57 -51
  23. snowflake/snowpark_connect/utils/identifiers.py +120 -0
  24. snowflake/snowpark_connect/utils/io_utils.py +21 -1
  25. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +86 -2
  26. snowflake/snowpark_connect/utils/scala_udf_utils.py +34 -43
  27. snowflake/snowpark_connect/utils/session.py +16 -26
  28. snowflake/snowpark_connect/utils/telemetry.py +53 -0
  29. snowflake/snowpark_connect/utils/udf_utils.py +66 -103
  30. snowflake/snowpark_connect/utils/udtf_helper.py +17 -7
  31. snowflake/snowpark_connect/version.py +2 -3
  32. {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/METADATA +2 -2
  33. {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/RECORD +41 -42
  34. snowflake/snowpark_connect/hidden_column.py +0 -39
  35. {snowpark_connect-0.26.0.data → snowpark_connect-0.28.0.data}/scripts/snowpark-connect +0 -0
  36. {snowpark_connect-0.26.0.data → snowpark_connect-0.28.0.data}/scripts/snowpark-session +0 -0
  37. {snowpark_connect-0.26.0.data → snowpark_connect-0.28.0.data}/scripts/snowpark-submit +0 -0
  38. {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/WHEEL +0 -0
  39. {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/licenses/LICENSE-binary +0 -0
  40. {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/licenses/LICENSE.txt +0 -0
  41. {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/licenses/NOTICE-binary +0 -0
  42. {snowpark_connect-0.26.0.dist-info → snowpark_connect-0.28.0.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,6 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
20
20
  from snowflake.snowpark._internal.utils import quote_name
21
21
  from snowflake.snowpark.types import StructType
22
22
  from snowflake.snowpark_connect.config import global_config
23
- from snowflake.snowpark_connect.hidden_column import HiddenColumn
24
23
  from snowflake.snowpark_connect.utils.context import get_current_operation_scope
25
24
  from snowflake.snowpark_connect.utils.identifiers import (
26
25
  split_fully_qualified_spark_name,
@@ -124,7 +123,6 @@ class ColumnNameMap:
124
123
  ] = lambda: global_config.spark_sql_caseSensitive,
125
124
  column_metadata: dict | None = None,
126
125
  column_qualifiers: list[list[str]] | None = None,
127
- hidden_columns: set[HiddenColumn] | None = None,
128
126
  parent_column_name_map: ColumnNameMap | None = None,
129
127
  ) -> None:
130
128
  """
@@ -135,7 +133,6 @@ class ColumnNameMap:
135
133
  The key is the original Spark column name, and the value is the metadata.
136
134
  example: Dict('age', {'foo': 'bar'})
137
135
  column_qualifiers: Optional qualifiers for the columns, used to handle table aliases or DataFrame aliases.
138
- hidden_columns: Optional set of HiddenColumn objects.
139
136
  parent_column_name_map: parent ColumnNameMap
140
137
  """
141
138
  self.columns: list[ColumnNames] = []
@@ -144,7 +141,6 @@ class ColumnNameMap:
144
141
  self.snowpark_to_col = defaultdict(list)
145
142
  self.is_case_sensitive = is_case_sensitive
146
143
  self.column_metadata = column_metadata
147
- self.hidden_columns = hidden_columns
148
144
 
149
145
  # Rename chain dictionary to track column renaming history
150
146
  self.rename_chains: dict[str, str] = {} # old_name -> new_name mapping
@@ -338,8 +334,6 @@ class ColumnNameMap:
338
334
  *,
339
335
  allow_non_exists: bool = False,
340
336
  return_first: bool = False,
341
- is_qualified: bool = False,
342
- source_qualifiers: list[str] | None = None,
343
337
  ) -> str | None:
344
338
  assert isinstance(spark_column_name, str)
345
339
  resolved_name = (
@@ -347,37 +341,9 @@ class ColumnNameMap:
347
341
  if self.rename_chains
348
342
  else spark_column_name
349
343
  )
350
-
351
- # We need to check hidden columns first. We want to avoid the code path
352
- # within get_snowpark_column_names_from_spark_column_names that checks the parent ColumnNameMap.
353
- # This is because that will return the name of the using column that's been dropped from the result
354
- # dataframe. We want to fetch and resolve the hidden column to its visible using column name instead.
355
- # Even if this is an unqualified reference or one to the visible column, it will resolve correctly to
356
- # the visible name anyway.
357
- snowpark_names = []
358
- # Only check hidden columns for qualified references with source qualifiers
359
- if is_qualified and source_qualifiers is not None and self.hidden_columns:
360
- column_name = spark_column_name
361
-
362
- # Check each hidden column for column name AND qualifier match
363
- for hidden_col in self.hidden_columns:
364
- if (
365
- hidden_col.spark_name == column_name
366
- and hidden_col.qualifiers == source_qualifiers
367
- ):
368
- if not global_config.spark_sql_caseSensitive:
369
- if hidden_col.spark_name.upper() == column_name.upper() and [
370
- q.upper() for q in hidden_col.qualifiers
371
- ] == [q.upper() for q in source_qualifiers]:
372
- snowpark_names.append(hidden_col.visible_snowpark_name)
373
- else:
374
- snowpark_names.append(hidden_col.visible_snowpark_name)
375
-
376
- # If not found in hidden columns, proceed with normal lookup
377
- if not snowpark_names:
378
- snowpark_names = self.get_snowpark_column_names_from_spark_column_names(
379
- [resolved_name], return_first
380
- )
344
+ snowpark_names = self.get_snowpark_column_names_from_spark_column_names(
345
+ [resolved_name], return_first
346
+ )
381
347
 
382
348
  snowpark_names_len = len(snowpark_names)
383
349
  if snowpark_names_len > 1:
@@ -464,27 +430,6 @@ class ColumnNameMap:
464
430
  snowpark_columns.append(c.snowpark_name)
465
431
  qualifiers.append(c.qualifiers)
466
432
 
467
- # Note: The following code is commented out because there is a bug with handling duplicate columns in
468
- # qualified select *'s. This needs to be revisited once a solution for that is found.
469
- # TODO: https://snowflakecomputing.atlassian.net/browse/SNOW-2265240
470
-
471
- # # Handles fetching/resolving the hidden columns if they also match the qualifiers
472
- # # This method is only ever called for qualified references, so we need to check hidden columns as well.
473
- # if self.hidden_columns:
474
- # for hidden_col in self.hidden_columns:
475
- # col_qualifiers = (
476
- # [q.upper() for q in hidden_col.qualifiers]
477
- # if not self.is_case_sensitive()
478
- # else hidden_col.qualifiers
479
- # )
480
- # if len(col_qualifiers) < len(qualifiers_input):
481
- # continue
482
- # if col_qualifiers[-len(qualifiers_input) :] == qualifiers_input:
483
- # # This hidden column matches! Add it to the results
484
- # spark_columns.append(hidden_col.spark_name)
485
- # snowpark_columns.append(hidden_col.visible_snowpark_name)
486
- # qualifiers.append(hidden_col.qualifiers)
487
-
488
433
  return spark_columns, snowpark_columns, qualifiers
489
434
 
490
435
  def get_snowpark_columns(self) -> list[str]:
@@ -616,35 +561,6 @@ class ColumnNameMap:
616
561
  else:
617
562
  return spark_name.upper()
618
563
 
619
- def is_hidden_column_reference(
620
- self, spark_column_name: str, source_qualifiers: list[str] | None = None
621
- ) -> bool:
622
- """
623
- Check if a column reference would be resolved through hidden columns.
624
- """
625
- if not self.hidden_columns or source_qualifiers is None:
626
- return False
627
-
628
- # For qualified references with source_qualifiers
629
- column_name = (
630
- spark_column_name # When has_plan_id=True, this is just the column name
631
- )
632
-
633
- for hidden_col in self.hidden_columns:
634
- if (
635
- hidden_col.spark_name == column_name
636
- and hidden_col.qualifiers == source_qualifiers
637
- ):
638
- if not global_config.spark_sql_caseSensitive:
639
- if hidden_col.spark_name.upper() == column_name.upper() and [
640
- q.upper() for q in hidden_col.qualifiers
641
- ] == [q.upper() for q in source_qualifiers]:
642
- return True
643
- else:
644
- return True
645
-
646
- return False
647
-
648
564
 
649
565
  class JoinColumnNameMap(ColumnNameMap):
650
566
  def __init__(
@@ -654,9 +570,6 @@ class JoinColumnNameMap(ColumnNameMap):
654
570
  ) -> None:
655
571
  self.left_column_mapping: ColumnNameMap = left_colmap
656
572
  self.right_column_mapping: ColumnNameMap = right_colmap
657
- # Ensure attributes expected by base-class helpers exist to avoid AttributeError
658
- # when generic code paths (e.g., hidden column checks) touch them.
659
- self.hidden_columns: set[HiddenColumn] | None = None
660
573
 
661
574
  def get_snowpark_column_name_from_spark_column_name(
662
575
  self,
@@ -664,9 +577,6 @@ class JoinColumnNameMap(ColumnNameMap):
664
577
  *,
665
578
  allow_non_exists: bool = False,
666
579
  return_first: bool = False,
667
- # JoinColumnNameMap will never be called with using columns, so these parameters are not used.
668
- is_qualified: bool = False,
669
- source_qualifiers: list[str] | None = None,
670
580
  ) -> str | None:
671
581
  snowpark_column_name_in_left = (
672
582
  self.left_column_mapping.get_snowpark_column_name_from_spark_column_name(
@@ -8,7 +8,7 @@ import re
8
8
  import sys
9
9
  from collections import defaultdict
10
10
  from copy import copy, deepcopy
11
- from typing import Any
11
+ from typing import Any, Dict
12
12
 
13
13
  import jpype
14
14
  import pyspark.sql.connect.proto.base_pb2 as proto_base
@@ -17,6 +17,7 @@ from tzlocal import get_localzone_name
17
17
  from snowflake import snowpark
18
18
  from snowflake.snowpark._internal.analyzer.analyzer_utils import (
19
19
  quote_name_without_upper_casing,
20
+ unquote_if_quoted,
20
21
  )
21
22
  from snowflake.snowpark.exceptions import SnowparkSQLException
22
23
  from snowflake.snowpark.types import TimestampTimeZone, TimestampType
@@ -171,9 +172,6 @@ class GlobalConfig:
171
172
  "spark.app.name": lambda session, name: setattr(
172
173
  session, "query_tag", f"Spark-Connect-App-Name={name}"
173
174
  ),
174
- "snowpark.connect.udf.packages": lambda session, packages: session.add_packages(
175
- *packages.strip("[] ").split(",")
176
- ),
177
175
  "snowpark.connect.udf.imports": lambda session, imports: parse_imports(
178
176
  session, imports
179
177
  ),
@@ -260,6 +258,7 @@ SESSION_CONFIG_KEY_WHITELIST = {
260
258
  "spark.sql.execution.pythonUDTF.arrow.enabled",
261
259
  "spark.sql.tvf.allowMultipleTableArguments.enabled",
262
260
  "snowpark.connect.sql.passthrough",
261
+ "snowpark.connect.cte.optimization_enabled",
263
262
  "snowpark.connect.iceberg.external_volume",
264
263
  "snowpark.connect.sql.identifiers.auto-uppercase",
265
264
  "snowpark.connect.udtf.compatibility_mode",
@@ -284,6 +283,7 @@ class SessionConfig:
284
283
  default_session_config = {
285
284
  "snowpark.connect.sql.identifiers.auto-uppercase": "all_except_columns",
286
285
  "snowpark.connect.sql.passthrough": "false",
286
+ "snowpark.connect.cte.optimization_enabled": "true",
287
287
  "snowpark.connect.udtf.compatibility_mode": "false",
288
288
  "snowpark.connect.views.duplicate_column_names_handling_mode": "rename",
289
289
  "spark.sql.execution.pythonUDTF.arrow.enabled": "false",
@@ -293,6 +293,7 @@ class SessionConfig:
293
293
 
294
294
  def __init__(self) -> None:
295
295
  self.config = deepcopy(self.default_session_config)
296
+ self.table_metadata: Dict[str, Dict[str, Any]] = {}
296
297
 
297
298
  def __getitem__(self, item: str) -> str:
298
299
  return self.get(item)
@@ -572,6 +573,12 @@ def set_snowflake_parameters(
572
573
  snowpark_session.use_database(db)
573
574
  case (prev, curr) if prev != curr:
574
575
  snowpark_session.use_schema(prev)
576
+ case "snowpark.connect.cte.optimization_enabled":
577
+ # Set CTE optimization on the snowpark session
578
+ cte_enabled = str_to_bool(value)
579
+ snowpark_session.cte_optimization_enabled = cte_enabled
580
+ logger.info(f"Updated snowpark session CTE optimization: {cte_enabled}")
581
+
575
582
  case _:
576
583
  pass
577
584
 
@@ -581,6 +588,16 @@ def get_boolean_session_config_param(name: str) -> bool:
581
588
  return str_to_bool(session_config[name])
582
589
 
583
590
 
591
+ def get_string_session_config_param(name: str) -> str:
592
+ session_config = sessions_config[get_session_id()]
593
+ return str(session_config[name])
594
+
595
+
596
+ def get_cte_optimization_enabled() -> bool:
597
+ """Get the CTE optimization configuration setting."""
598
+ return get_boolean_session_config_param("snowpark.connect.cte.optimization_enabled")
599
+
600
+
584
601
  def auto_uppercase_column_identifiers() -> bool:
585
602
  session_config = sessions_config[get_session_id()]
586
603
  return session_config[
@@ -616,3 +633,81 @@ def get_timestamp_type():
616
633
  # shouldn't happen since `spark.sql.timestampType` is always defined, and `spark.conf.unset` sets it to default (TIMESTAMP_LTZ)
617
634
  timestamp_type = TimestampType(TimestampTimeZone.LTZ)
618
635
  return timestamp_type
636
+
637
+
638
+ def record_table_metadata(
639
+ table_identifier: str,
640
+ table_type: str,
641
+ data_source: str,
642
+ supports_column_rename: bool = True,
643
+ ) -> None:
644
+ """
645
+ Record metadata about a table for Spark compatibility checks.
646
+
647
+ Args:
648
+ table_identifier: Full table identifier (catalog.database.table)
649
+ table_type: "v1" or "v2"
650
+ data_source: Source format (parquet, csv, iceberg, etc.)
651
+ supports_column_rename: Whether the table supports RENAME COLUMN
652
+ """
653
+ session_id = get_session_id()
654
+ session_config = sessions_config[session_id]
655
+
656
+ # Normalize table identifier for consistent lookup
657
+ # Use the full catalog.database.table identifier to avoid conflicts
658
+ normalized_identifier = table_identifier.upper().strip('"')
659
+
660
+ session_config.table_metadata[normalized_identifier] = {
661
+ "table_type": table_type,
662
+ "data_source": data_source,
663
+ "supports_column_rename": supports_column_rename,
664
+ }
665
+
666
+
667
+ def get_table_metadata(table_identifier: str) -> Dict[str, Any] | None:
668
+ """
669
+ Get stored metadata for a table.
670
+
671
+ Args:
672
+ table_identifier: Full table identifier (catalog.database.table)
673
+
674
+ Returns:
675
+ Table metadata dict or None if not found
676
+ """
677
+ session_id = get_session_id()
678
+ session_config = sessions_config[session_id]
679
+
680
+ normalized_identifier = unquote_if_quoted(table_identifier).upper()
681
+
682
+ return session_config.table_metadata.get(normalized_identifier)
683
+
684
+
685
+ def check_table_supports_operation(table_identifier: str, operation: str) -> bool:
686
+ """
687
+ Check if a table supports a given operation based on metadata and config.
688
+
689
+ Args:
690
+ table_identifier: Full table identifier (catalog.database.table)
691
+ operation: Operation to check (e.g., "rename_column")
692
+
693
+ Returns:
694
+ True if operation is supported, False if should be blocked
695
+ """
696
+ table_metadata = get_table_metadata(table_identifier)
697
+
698
+ if not table_metadata:
699
+ return True
700
+
701
+ session_id = get_session_id()
702
+ session_config = sessions_config[session_id]
703
+ enable_extensions = str_to_bool(
704
+ session_config.get("enable_snowflake_extension_behavior", "false")
705
+ )
706
+
707
+ if enable_extensions:
708
+ return True
709
+
710
+ if operation == "rename_column":
711
+ return table_metadata.get("supports_column_rename", True)
712
+
713
+ return True
@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Callable
8
8
 
9
9
  from snowflake import snowpark
10
10
  from snowflake.snowpark.types import StructField, StructType
11
- from snowflake.snowpark_connect.hidden_column import HiddenColumn
12
11
 
13
12
  if TYPE_CHECKING:
14
13
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
@@ -61,7 +60,6 @@ class DataFrameContainer:
61
60
  column_metadata: dict | None = None,
62
61
  column_qualifiers: list[list[str]] | None = None,
63
62
  parent_column_name_map: ColumnNameMap | None = None,
64
- hidden_columns: set[HiddenColumn] | None = None,
65
63
  table_name: str | None = None,
66
64
  alias: str | None = None,
67
65
  cached_schema_getter: Callable[[], StructType] | None = None,
@@ -78,7 +76,6 @@ class DataFrameContainer:
78
76
  column_metadata: Optional metadata dictionary
79
77
  column_qualifiers: Optional column qualifiers
80
78
  parent_column_name_map: Optional parent column name map
81
- hidden_columns: Optional list of hidden column names
82
79
  table_name: Optional table name
83
80
  alias: Optional alias
84
81
  cached_schema_getter: Optional function to get cached schema
@@ -101,7 +98,6 @@ class DataFrameContainer:
101
98
  column_metadata,
102
99
  column_qualifiers,
103
100
  parent_column_name_map,
104
- hidden_columns,
105
101
  )
106
102
 
107
103
  # Determine the schema getter to use
@@ -226,7 +222,6 @@ class DataFrameContainer:
226
222
  column_metadata: dict | None = None,
227
223
  column_qualifiers: list[list[str]] | None = None,
228
224
  parent_column_name_map: ColumnNameMap | None = None,
229
- hidden_columns: set[HiddenColumn] | None = None,
230
225
  ) -> ColumnNameMap:
231
226
  """Create a ColumnNameMap with the provided configuration."""
232
227
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
@@ -237,7 +232,6 @@ class DataFrameContainer:
237
232
  column_metadata=column_metadata,
238
233
  column_qualifiers=column_qualifiers,
239
234
  parent_column_name_map=parent_column_name_map,
240
- hidden_columns=hidden_columns,
241
235
  )
242
236
 
243
237
  @staticmethod
@@ -6,6 +6,7 @@ import datetime
6
6
  from collections import defaultdict
7
7
 
8
8
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
9
+ from pyspark.errors.exceptions.connect import AnalysisException
9
10
 
10
11
  import snowflake.snowpark.functions as snowpark_fn
11
12
  from snowflake import snowpark
@@ -34,8 +35,10 @@ from snowflake.snowpark_connect.type_mapping import (
34
35
  from snowflake.snowpark_connect.typed_column import TypedColumn
35
36
  from snowflake.snowpark_connect.utils.context import (
36
37
  gen_sql_plan_id,
38
+ get_current_lambda_params,
37
39
  is_function_argument_being_resolved,
38
40
  is_lambda_being_resolved,
41
+ not_resolving_fun_args,
39
42
  )
40
43
  from snowflake.snowpark_connect.utils.telemetry import (
41
44
  SnowparkConnectNotImplementedError,
@@ -134,7 +137,10 @@ def map_expression(
134
137
  case "expression_string":
135
138
  return map_sql_expr(exp, column_mapping, typer)
136
139
  case "extension":
137
- return map_extension.map_extension(exp, column_mapping, typer)
140
+ # Extensions can be passed as function args, and we need to reset the context here.
141
+ # Matters only for resolving alias expressions in the extensions rel.
142
+ with not_resolving_fun_args():
143
+ return map_extension.map_extension(exp, column_mapping, typer)
138
144
  case "lambda_function":
139
145
  lambda_name, lambda_body = map_single_column_expression(
140
146
  exp.lambda_function.function, column_mapping, typer
@@ -271,6 +277,30 @@ def map_expression(
271
277
  case "unresolved_function":
272
278
  return map_func.map_unresolved_function(exp, column_mapping, typer)
273
279
  case "unresolved_named_lambda_variable":
280
+ # Validate that this lambda variable is in scope
281
+ var_name = exp.unresolved_named_lambda_variable.name_parts[0]
282
+ current_params = get_current_lambda_params()
283
+
284
+ if current_params and var_name not in current_params:
285
+ outer_col_name = (
286
+ column_mapping.get_snowpark_column_name_from_spark_column_name(
287
+ var_name, allow_non_exists=True
288
+ )
289
+ )
290
+ if outer_col_name:
291
+ col = snowpark_fn.col(outer_col_name)
292
+ return ["namedlambdavariable()"], TypedColumn(
293
+ col, lambda: typer.type(col)
294
+ )
295
+ else:
296
+ raise AnalysisException(
297
+ f"Cannot resolve variable '{var_name}' within lambda function. "
298
+ f"Lambda functions can access their own parameters and parent dataframe columns. "
299
+ f"Current lambda parameters: {current_params}. "
300
+ f"If '{var_name}' is an outer scope lambda variable from a nested lambda, "
301
+ f"that is an unsupported feature in Snowflake SQL."
302
+ )
303
+
274
304
  col = snowpark_fn.Column(
275
305
  UnresolvedAttribute(exp.unresolved_named_lambda_variable.name_parts[0])
276
306
  )
@@ -11,9 +11,10 @@ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
11
11
  import pyspark.sql.connect.proto.types_pb2 as types_proto
12
12
  from google.protobuf.any_pb2 import Any
13
13
  from pyspark.errors.exceptions.base import AnalysisException
14
+ from pyspark.sql.connect import functions as pyspark_functions
14
15
 
15
16
  import snowflake.snowpark_connect.proto.snowflake_expression_ext_pb2 as snowflake_proto
16
- from snowflake import snowpark
17
+ from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
17
18
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
18
19
  from snowflake.snowpark_connect.config import global_config
19
20
  from snowflake.snowpark_connect.typed_column import TypedColumn
@@ -89,6 +90,11 @@ def as_scala_seq(input):
89
90
  )
90
91
 
91
92
 
93
+ @cache
94
+ def _scala_some():
95
+ return jpype.JClass("scala.Some")
96
+
97
+
92
98
  def map_sql_expr(
93
99
  exp: expressions_proto.Expression,
94
100
  column_mapping: ColumnNameMap,
@@ -223,9 +229,6 @@ def apply_filter_clause(
223
229
 
224
230
 
225
231
  def map_logical_plan_expression(exp: jpype.JObject) -> expressions_proto.Expression:
226
- from snowflake.snowpark_connect.expression.map_expression import (
227
- map_single_column_expression,
228
- )
229
232
  from snowflake.snowpark_connect.relation.map_sql import map_logical_plan_relation
230
233
 
231
234
  class_name = str(exp.getClass().getSimpleName())
@@ -308,22 +311,23 @@ def map_logical_plan_expression(exp: jpype.JObject) -> expressions_proto.Express
308
311
  )
309
312
  proto = expressions_proto.Expression(extension=any_proto)
310
313
  case "ExpressionWithUnresolvedIdentifier":
311
- plan_id = None
312
- identifierExpr = map_logical_plan_expression(exp.identifierExpr())
313
- session = snowpark.Session.get_active_session()
314
- m = ColumnNameMap([], [], None)
315
- expr = map_single_column_expression(
316
- identifierExpr, m, ExpressionTyper.dummy_typer(session)
314
+ from snowflake.snowpark_connect.relation.map_sql import (
315
+ get_relation_identifier_name,
317
316
  )
318
- value = session.range(1).select(expr[1].col).collect()[0][0]
319
317
 
320
- proto = expressions_proto.Expression(
321
- unresolved_attribute=expressions_proto.Expression.UnresolvedAttribute(
322
- unparsed_identifier=str(value),
323
- plan_id=plan_id,
324
- ),
325
- )
326
- # TODO: support identifier referencing unresolved function
318
+ value = unquote_if_quoted(get_relation_identifier_name(exp))
319
+ if getattr(pyspark_functions, value.lower(), None) is not None:
320
+ unresolved_function = exp.exprBuilder().apply(
321
+ _scala_some()(value).toList()
322
+ )
323
+ proto = map_logical_plan_expression(unresolved_function)
324
+ else:
325
+ proto = expressions_proto.Expression(
326
+ unresolved_attribute=expressions_proto.Expression.UnresolvedAttribute(
327
+ unparsed_identifier=str(value),
328
+ plan_id=None,
329
+ ),
330
+ )
327
331
  case "InSubquery":
328
332
  rel_proto = map_logical_plan_relation(exp.query().plan())
329
333
  any_proto = Any()
@@ -22,6 +22,7 @@ from snowflake.snowpark_connect.utils.context import (
22
22
  get_is_evaluating_sql,
23
23
  get_outer_dataframes,
24
24
  get_plan_id_map,
25
+ is_lambda_being_resolved,
25
26
  resolve_lca_alias,
26
27
  )
27
28
  from snowflake.snowpark_connect.utils.identifiers import (
@@ -162,7 +163,6 @@ def map_unresolved_attribute(
162
163
  attr_name = ".".join(name_parts)
163
164
 
164
165
  has_plan_id = exp.unresolved_attribute.HasField("plan_id")
165
- source_qualifiers = None
166
166
 
167
167
  if has_plan_id:
168
168
  plan_id = exp.unresolved_attribute.plan_id
@@ -171,27 +171,13 @@ def map_unresolved_attribute(
171
171
  assert (
172
172
  target_df is not None
173
173
  ), f"resolving an attribute of a unresolved dataframe {plan_id}"
174
-
175
- # Get the qualifiers for this column from the target DataFrame
176
- source_qualifiers = (
177
- target_df_container.column_map.get_qualifier_for_spark_column(
178
- name_parts[-1]
179
- )
180
- )
181
-
182
- if hasattr(column_mapping, "hidden_columns"):
183
- hidden = column_mapping.hidden_columns
184
- else:
185
- hidden = None
186
-
187
174
  column_mapping = target_df_container.column_map
188
- column_mapping.hidden_columns = hidden
189
175
  typer = ExpressionTyper(target_df)
190
176
 
191
- def get_col(snowpark_name, has_hidden=False):
177
+ def get_col(snowpark_name):
192
178
  return (
193
179
  snowpark_fn.col(snowpark_name)
194
- if not has_plan_id or has_hidden
180
+ if not has_plan_id
195
181
  else target_df.col(snowpark_name)
196
182
  )
197
183
 
@@ -276,17 +262,10 @@ def map_unresolved_attribute(
276
262
  quoted_attr_name = name_parts[0]
277
263
 
278
264
  snowpark_name = column_mapping.get_snowpark_column_name_from_spark_column_name(
279
- quoted_attr_name,
280
- allow_non_exists=True,
281
- is_qualified=has_plan_id,
282
- source_qualifiers=source_qualifiers if has_plan_id else None,
265
+ quoted_attr_name, allow_non_exists=True
283
266
  )
284
-
285
267
  if snowpark_name is not None:
286
- is_hidden = column_mapping.is_hidden_column_reference(
287
- quoted_attr_name, source_qualifiers
288
- )
289
- col = get_col(snowpark_name, is_hidden)
268
+ col = get_col(snowpark_name)
290
269
  qualifiers = column_mapping.get_qualifier_for_spark_column(quoted_attr_name)
291
270
  else:
292
271
  # this means it has to be a struct column with a field name
@@ -356,6 +335,23 @@ def map_unresolved_attribute(
356
335
  return (unqualified_name, typed_col)
357
336
 
358
337
  if snowpark_name is None:
338
+ # Check if we're inside a lambda and trying to reference an outer column
339
+ # This catches direct column references (not lambda variables)
340
+ if is_lambda_being_resolved() and column_mapping:
341
+ # Check if this column exists in the outer scope (not lambda params)
342
+ outer_col_name = (
343
+ column_mapping.get_snowpark_column_name_from_spark_column_name(
344
+ attr_name, allow_non_exists=True
345
+ )
346
+ )
347
+ if outer_col_name:
348
+ # This is an outer scope column being referenced inside a lambda
349
+ raise AnalysisException(
350
+ f"Reference to non-lambda variable '{attr_name}' within lambda function. "
351
+ f"Lambda functions can only access their own parameters. "
352
+ f"If '{attr_name}' is a table column, it must be passed as an explicit parameter to the enclosing function."
353
+ )
354
+
359
355
  if has_plan_id:
360
356
  raise AnalysisException(
361
357
  f'[RESOLVED_REFERENCE_COLUMN_NOT_FOUND] The column "{attr_name}" does not exist in the target dataframe.'
@@ -711,6 +711,9 @@ def map_unresolved_function(
711
711
  "-",
712
712
  )
713
713
  match (snowpark_typed_args[0].typ, snowpark_typed_args[1].typ):
714
+ case (DateType(), NullType()) | (NullType(), DateType()):
715
+ result_type = LongType()
716
+ result_exp = snowpark_fn.lit(None).cast(result_type)
714
717
  case (NullType(), _) | (_, NullType()):
715
718
  result_type = _get_add_sub_result_type(
716
719
  snowpark_typed_args[0].typ,
@@ -724,7 +727,10 @@ def map_unresolved_function(
724
727
  result_type = LongType()
725
728
  result_exp = snowpark_args[0] - snowpark_args[1]
726
729
  case (DateType(), StringType()):
727
- if "INTERVAL" == snowpark_typed_args[1].col._expr1.pretty_name:
730
+ if (
731
+ hasattr(snowpark_typed_args[1].col._expr1, "pretty_name")
732
+ and "INTERVAL" == snowpark_typed_args[1].col._expr1.pretty_name
733
+ ):
728
734
  result_type = TimestampType()
729
735
  result_exp = snowpark_args[0] - snowpark_args[1]
730
736
  else:
@@ -2421,7 +2427,7 @@ def map_unresolved_function(
2421
2427
  "try_to_date",
2422
2428
  snowpark_fn.cast(
2423
2429
  truncated_date,
2424
- TimestampType(snowpark.types.TimestampTimeZone.NTZ),
2430
+ TimestampType(),
2425
2431
  ),
2426
2432
  snowpark_args[1],
2427
2433
  )
@@ -2613,9 +2619,18 @@ def map_unresolved_function(
2613
2619
  result_type = input_type.element_type
2614
2620
  result_exp = fn(snowpark_args[0])
2615
2621
  case _:
2616
- spark_col_names = ["key", "value"]
2617
- result_exp = fn(snowpark_args[0])
2618
- result_type = [input_type.key_type, input_type.value_type]
2622
+ # Check if the type has map-like attributes before accessing them
2623
+ if hasattr(input_type, "key_type") and hasattr(
2624
+ input_type, "value_type"
2625
+ ):
2626
+ spark_col_names = ["key", "value"]
2627
+ result_exp = fn(snowpark_args[0])
2628
+ result_type = [input_type.key_type, input_type.value_type]
2629
+ else:
2630
+ # Throw proper error for types without key_type/value_type attributes
2631
+ raise AnalysisException(
2632
+ f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{function_name}({snowpark_arg_names[0]})" due to data type mismatch: Parameter 1 requires the ("ARRAY" or "MAP") type, however "{snowpark_arg_names[0]}" has the type "{str(input_type)}".'
2633
+ )
2619
2634
  case "expm1":
2620
2635
  spark_function_name = f"EXPM1({snowpark_arg_names[0]})"
2621
2636
  result_exp = snowpark_fn.exp(*snowpark_args) - 1
@@ -8725,7 +8740,7 @@ def _resolve_function_with_lambda(
8725
8740
  artificial_df = Session.get_active_session().create_dataframe([], schema)
8726
8741
  set_schema_getter(artificial_df, lambda: schema)
8727
8742
 
8728
- with resolving_lambda_function():
8743
+ with resolving_lambda_function(names):
8729
8744
  return map_expression(
8730
8745
  (
8731
8746
  lambda_exp.lambda_function.function
@@ -9911,7 +9926,10 @@ def _get_spark_function_name(
9911
9926
  return f"({date_param_name1} {operation_op} {date_param_name2})"
9912
9927
  case (StringType(), DateType()):
9913
9928
  date_param_name2 = _get_literal_param_name(exp, 1, snowpark_arg_names[1])
9914
- if "INTERVAL" == col1.col._expr1.pretty_name:
9929
+ if (
9930
+ hasattr(col1.col._expr1, "pretty_name")
9931
+ and "INTERVAL" == col1.col._expr1.pretty_name
9932
+ ):
9915
9933
  return f"{date_param_name2} {operation_op} {snowpark_arg_names[0]}"
9916
9934
  elif global_config.spark_sql_ansi_enabled and function_name == "+":
9917
9935
  return f"{operation_func}(cast({date_param_name2} as date), cast({snowpark_arg_names[0]} as double))"
@@ -9919,9 +9937,9 @@ def _get_spark_function_name(
9919
9937
  return f"({snowpark_arg_names[0]} {operation_op} {date_param_name2})"
9920
9938
  case (DateType(), StringType()):
9921
9939
  date_param_name1 = _get_literal_param_name(exp, 0, snowpark_arg_names[0])
9922
- if (
9923
- global_config.spark_sql_ansi_enabled
9924
- or "INTERVAL" == col2.col._expr1.pretty_name
9940
+ if global_config.spark_sql_ansi_enabled or (
9941
+ hasattr(col2.col._expr1, "pretty_name")
9942
+ and "INTERVAL" == col2.col._expr1.pretty_name
9925
9943
  ):
9926
9944
  return f"{date_param_name1} {operation_op} {snowpark_arg_names[1]}"
9927
9945
  else: