snowpark-connect 0.22.1__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (46) hide show
  1. snowflake/snowpark_connect/config.py +0 -11
  2. snowflake/snowpark_connect/error/error_utils.py +7 -0
  3. snowflake/snowpark_connect/error/exceptions.py +4 -0
  4. snowflake/snowpark_connect/expression/function_defaults.py +207 -0
  5. snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
  6. snowflake/snowpark_connect/expression/literal.py +14 -12
  7. snowflake/snowpark_connect/expression/map_cast.py +20 -4
  8. snowflake/snowpark_connect/expression/map_expression.py +18 -2
  9. snowflake/snowpark_connect/expression/map_extension.py +12 -2
  10. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
  11. snowflake/snowpark_connect/expression/map_unresolved_function.py +69 -10
  12. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
  13. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
  14. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
  15. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
  16. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
  17. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
  18. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
  19. snowflake/snowpark_connect/relation/map_aggregate.py +57 -5
  20. snowflake/snowpark_connect/relation/map_column_ops.py +6 -5
  21. snowflake/snowpark_connect/relation/map_extension.py +65 -31
  22. snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
  23. snowflake/snowpark_connect/relation/map_row_ops.py +2 -0
  24. snowflake/snowpark_connect/relation/map_sql.py +22 -5
  25. snowflake/snowpark_connect/relation/read/map_read.py +2 -1
  26. snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
  27. snowflake/snowpark_connect/relation/read/reader_config.py +9 -0
  28. snowflake/snowpark_connect/relation/write/map_write.py +243 -68
  29. snowflake/snowpark_connect/server.py +25 -5
  30. snowflake/snowpark_connect/type_mapping.py +2 -2
  31. snowflake/snowpark_connect/utils/env_utils.py +55 -0
  32. snowflake/snowpark_connect/utils/session.py +21 -0
  33. snowflake/snowpark_connect/version.py +1 -1
  34. snowflake/snowpark_decoder/spark_decoder.py +1 -1
  35. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/METADATA +2 -2
  36. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/RECORD +44 -39
  37. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
  38. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
  39. {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-connect +0 -0
  40. {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-session +0 -0
  41. {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-submit +0 -0
  42. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/WHEEL +0 -0
  43. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE-binary +0 -0
  44. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE.txt +0 -0
  45. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/NOTICE-binary +0 -0
  46. {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/top_level.txt +0 -0
@@ -4,10 +4,8 @@
4
4
 
5
5
  # Proto source for reference:
6
6
  # https://github.com/apache/spark/blob/branch-3.5/connector/connect/common/src/main/protobuf/spark/connect/base.proto#L420
7
- import os
8
7
  import re
9
8
  import sys
10
- import time
11
9
  from collections import defaultdict
12
10
  from copy import copy, deepcopy
13
11
  from typing import Any
@@ -533,17 +531,8 @@ def set_snowflake_parameters(
533
531
  snowpark_session.sql(
534
532
  f"ALTER SESSION SET TIMEZONE = '{value}'"
535
533
  ).collect()
536
- set_jvm_timezone(value)
537
- if hasattr(time, "tzset"):
538
- os.environ["TZ"] = value
539
- time.tzset()
540
534
  else:
541
535
  snowpark_session.sql("ALTER SESSION UNSET TIMEZONE").collect()
542
- reset_jvm_timezone_to_system_default()
543
- if hasattr(time, "tzset") and False:
544
- if "TZ" in os.environ:
545
- del os.environ["TZ"]
546
- time.tzset()
547
536
  case "spark.sql.globalTempDatabase":
548
537
  if not value:
549
538
  value = global_config.default_static_global_config.get(key)
@@ -75,6 +75,10 @@ terminate_multi_args_exception_pattern = (
75
75
  snowpark_connect_exception_pattern = re.compile(
76
76
  r"\[snowpark-connect-exception(?::(\w+))?\]\s*(.+?)'\s*is not recognized"
77
77
  )
78
+ invalid_bit_pattern = re.compile(
79
+ r"Invalid bit position: \d+ exceeds the bit (?:upper|lower) limit",
80
+ re.IGNORECASE,
81
+ )
78
82
 
79
83
 
80
84
  def contains_udtf_select(sql_string):
@@ -107,6 +111,9 @@ def _get_converted_known_sql_or_custom_exception(
107
111
  return SparkRuntimeException(
108
112
  message="Unexpected value for start in function slice: SQL array indices start at 1."
109
113
  )
114
+ invalid_bit = invalid_bit_pattern.search(msg)
115
+ if invalid_bit:
116
+ return IllegalArgumentException(message=invalid_bit.group(0))
110
117
  match = snowpark_connect_exception_pattern.search(
111
118
  ex.message if hasattr(ex, "message") else str(ex)
112
119
  )
@@ -22,3 +22,7 @@ class MissingSchema(SnowparkConnectException):
22
22
  super().__init__(
23
23
  "No default schema found in session",
24
24
  )
25
+
26
+
27
+ class MaxRetryExceeded(SnowparkConnectException):
28
+ ...
@@ -0,0 +1,207 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import pyspark.sql.connect.proto.expressions_pb2 as expressions_pb2
8
+ import pyspark.sql.connect.proto.types_pb2 as types_pb2
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class DefaultParameter:
13
+ """Represents a single default parameter for a function."""
14
+
15
+ name: str
16
+ value: Any
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class FunctionDefaults:
21
+ """Represents default parameter configuration for a function."""
22
+
23
+ total_args: int
24
+ defaults: list[DefaultParameter]
25
+
26
+
27
+ # FUNCTION_DEFAULTS dictionary to hold operation name with default values.
28
+ # This is required as non pyspark clients such as scala or sql won't send all the parameters.
29
+ # We use this dict to inject the missing parameters before processing the unresolved function.
30
+ FUNCTION_DEFAULTS: dict[str, FunctionDefaults] = {
31
+ "aes_decrypt": FunctionDefaults(
32
+ total_args=5,
33
+ defaults=[
34
+ DefaultParameter("mode", "GCM"), # Spark SQL default: GCM
35
+ DefaultParameter("padding", "NONE"), # Spark SQL default: NONE for GCM mode
36
+ DefaultParameter("aad", ""), # Spark SQL default: empty string
37
+ ],
38
+ ),
39
+ "aes_encrypt": FunctionDefaults(
40
+ total_args=6,
41
+ defaults=[
42
+ DefaultParameter("mode", "GCM"), # Spark SQL default: GCM
43
+ DefaultParameter("padding", "NONE"), # Spark SQL default: NONE for GCM mode
44
+ DefaultParameter(
45
+ "iv", ""
46
+ ), # Spark SQL default: empty string (random generated if not provided)
47
+ DefaultParameter("aad", ""), # Spark SQL default: empty string
48
+ ],
49
+ ),
50
+ "approx_percentile": FunctionDefaults(
51
+ total_args=3,
52
+ defaults=[DefaultParameter("accuracy", 10000)],
53
+ ),
54
+ "bround": FunctionDefaults(
55
+ total_args=2,
56
+ defaults=[DefaultParameter("scale", 0)],
57
+ ),
58
+ "first": FunctionDefaults(
59
+ total_args=2,
60
+ defaults=[DefaultParameter("ignorenulls", False)],
61
+ ),
62
+ "lag": FunctionDefaults(
63
+ total_args=2,
64
+ defaults=[
65
+ DefaultParameter("offset", 1),
66
+ ],
67
+ ),
68
+ "last": FunctionDefaults(
69
+ total_args=2,
70
+ defaults=[DefaultParameter("ignorenulls", False)],
71
+ ),
72
+ "lead": FunctionDefaults(
73
+ total_args=3,
74
+ defaults=[DefaultParameter("offset", 1), DefaultParameter("default", None)],
75
+ ),
76
+ "locate": FunctionDefaults(
77
+ total_args=3,
78
+ defaults=[DefaultParameter("pos", 1)],
79
+ ),
80
+ "months_between": FunctionDefaults(
81
+ total_args=3,
82
+ defaults=[DefaultParameter("roundOff", True)],
83
+ ),
84
+ "nth_value": FunctionDefaults(
85
+ total_args=3,
86
+ defaults=[DefaultParameter("ignoreNulls", False)],
87
+ ),
88
+ "overlay": FunctionDefaults(
89
+ total_args=4,
90
+ defaults=[DefaultParameter("len", -1)],
91
+ ),
92
+ "percentile": FunctionDefaults(
93
+ total_args=3,
94
+ defaults=[DefaultParameter("frequency", 1)],
95
+ ),
96
+ "percentile_approx": FunctionDefaults(
97
+ total_args=3,
98
+ defaults=[DefaultParameter("accuracy", 10000)],
99
+ ),
100
+ "round": FunctionDefaults(
101
+ total_args=2,
102
+ defaults=[DefaultParameter("scale", 0)],
103
+ ),
104
+ "sentences": FunctionDefaults(
105
+ total_args=3,
106
+ defaults=[
107
+ DefaultParameter("language", ""),
108
+ DefaultParameter("country", ""),
109
+ ],
110
+ ),
111
+ "sort_array": FunctionDefaults(
112
+ total_args=2,
113
+ defaults=[DefaultParameter("asc", True)],
114
+ ),
115
+ "split": FunctionDefaults(
116
+ total_args=3,
117
+ defaults=[DefaultParameter("limit", -1)],
118
+ ),
119
+ "str_to_map": FunctionDefaults(
120
+ total_args=3,
121
+ defaults=[
122
+ DefaultParameter(
123
+ "pairDelim", ","
124
+ ), # Spark SQL default: comma for splitting pairs
125
+ DefaultParameter(
126
+ "keyValueDelim", ":"
127
+ ), # Spark SQL default: colon for splitting key/value
128
+ ],
129
+ ),
130
+ "try_aes_decrypt": FunctionDefaults(
131
+ total_args=5,
132
+ defaults=[
133
+ DefaultParameter("mode", "GCM"), # Spark SQL default: GCM
134
+ DefaultParameter("padding", "NONE"), # Spark SQL default: NONE for GCM mode
135
+ DefaultParameter("aad", ""), # Spark SQL default: empty string
136
+ ],
137
+ ),
138
+ }
139
+
140
+
141
+ def _create_literal_expression(value: Any) -> expressions_pb2.Expression:
142
+ """Create a literal expression for the given value."""
143
+ expr = expressions_pb2.Expression()
144
+ if isinstance(value, bool):
145
+ expr.literal.boolean = value
146
+ elif isinstance(value, int):
147
+ expr.literal.integer = value
148
+ elif isinstance(value, str):
149
+ expr.literal.string = value
150
+ elif isinstance(value, float):
151
+ expr.literal.double = value
152
+ elif value is None:
153
+ null_type = types_pb2.DataType()
154
+ null_type.null.SetInParent()
155
+ expr.literal.null.CopyFrom(null_type)
156
+ else:
157
+ raise ValueError(f"Unsupported literal type: {value}")
158
+
159
+ return expr
160
+
161
+
162
+ def inject_function_defaults(
163
+ unresolved_function: expressions_pb2.Expression.UnresolvedFunction,
164
+ ) -> bool:
165
+ """
166
+ Inject missing default parameters into an UnresolvedFunction protobuf.
167
+
168
+ Args:
169
+ unresolved_function: The protobuf UnresolvedFunction to modify
170
+
171
+ Returns:
172
+ bool: True if any defaults were injected, False otherwise
173
+ """
174
+ function_name = unresolved_function.function_name.lower()
175
+
176
+ if function_name not in FUNCTION_DEFAULTS:
177
+ return False
178
+
179
+ func_config = FUNCTION_DEFAULTS[function_name]
180
+ current_arg_count = len(unresolved_function.arguments)
181
+ total_args = func_config.total_args
182
+ defaults = func_config.defaults
183
+
184
+ if not defaults or current_arg_count >= total_args:
185
+ return False
186
+
187
+ # Calculate how many defaults to append
188
+ missing_arg_count = total_args - current_arg_count
189
+
190
+ # Check if any required params are missing.
191
+ if missing_arg_count > len(defaults):
192
+ raise ValueError(
193
+ f"Function '{function_name}' is missing required arguments. "
194
+ f"Expected {total_args} args, got {current_arg_count}, "
195
+ f"but only {len(defaults)} defaults are defined."
196
+ )
197
+
198
+ defaults_to_append = defaults[-missing_arg_count:]
199
+ injected = False
200
+
201
+ # Simply append the needed default values
202
+ for default_param in defaults_to_append:
203
+ default_expr = _create_literal_expression(default_param.value)
204
+ unresolved_function.arguments.append(default_expr)
205
+ injected = True
206
+
207
+ return injected
@@ -0,0 +1,192 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ """
6
+ Hybrid column mapping for HAVING clause resolution.
7
+
8
+ This module provides a special column mapping that can resolve expressions
9
+ in the context of both the input DataFrame (for base columns) and the
10
+ aggregated DataFrame (for aggregate expressions and aliases).
11
+ """
12
+
13
+ from typing import Dict, List
14
+
15
+ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
16
+
17
+ from snowflake import snowpark
18
+ from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
19
+ from snowflake.snowpark_connect.expression.typer import ExpressionTyper
20
+ from snowflake.snowpark_connect.typed_column import TypedColumn
21
+
22
+
23
+ class HybridColumnMap:
24
+ """
25
+ A column mapping that can resolve expressions in both input and aggregated contexts.
26
+
27
+ This is specifically designed for HAVING clause resolution where expressions may reference:
28
+ 1. Base columns from the input DataFrame (to build new aggregates)
29
+ 2. Existing aggregate expressions and their aliases
30
+ 3. Grouping columns
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ input_column_map: ColumnNameMap,
36
+ input_typer: ExpressionTyper,
37
+ aggregated_column_map: ColumnNameMap,
38
+ aggregated_typer: ExpressionTyper,
39
+ aggregate_expressions: List[expressions_proto.Expression],
40
+ grouping_expressions: List[expressions_proto.Expression],
41
+ aggregate_aliases: Dict[str, expressions_proto.Expression],
42
+ ) -> None:
43
+ self.input_column_map = input_column_map
44
+ self.input_typer = input_typer
45
+ self.aggregated_column_map = aggregated_column_map
46
+ self.aggregated_typer = aggregated_typer
47
+ self.aggregate_expressions = aggregate_expressions
48
+ self.grouping_expressions = grouping_expressions
49
+ self.aggregate_aliases = aggregate_aliases
50
+
51
+ def is_aggregate_function(self, exp: expressions_proto.Expression) -> bool:
52
+ """Check if an expression is an aggregate function."""
53
+ if exp.WhichOneof("expr_type") == "unresolved_function":
54
+ func_name = exp.unresolved_function.function_name.lower()
55
+ # Common aggregate functions - expand this list as needed
56
+ aggregate_functions = {
57
+ "avg",
58
+ "average",
59
+ "sum",
60
+ "count",
61
+ "min",
62
+ "max",
63
+ "stddev",
64
+ "stddev_pop",
65
+ "stddev_samp",
66
+ "variance",
67
+ "var_pop",
68
+ "var_samp",
69
+ "collect_list",
70
+ "collect_set",
71
+ "first",
72
+ "last",
73
+ "any_value",
74
+ "bool_and",
75
+ "bool_or",
76
+ "corr",
77
+ "covar_pop",
78
+ "covar_samp",
79
+ "kurtosis",
80
+ "skewness",
81
+ "percentile_cont",
82
+ "percentile_disc",
83
+ "approx_count_distinct",
84
+ }
85
+ return func_name in aggregate_functions
86
+ return False
87
+
88
+ def is_grouping_column(self, column_name: str) -> bool:
89
+ """Check if a column name refers to a grouping column."""
90
+ for group_exp in self.grouping_expressions:
91
+ if (
92
+ group_exp.WhichOneof("expr_type") == "unresolved_attribute"
93
+ and group_exp.unresolved_attribute.unparsed_identifier == column_name
94
+ ):
95
+ return True
96
+ return False
97
+
98
+ def resolve_expression(
99
+ self, exp: expressions_proto.Expression
100
+ ) -> tuple[list[str], TypedColumn]:
101
+ """
102
+ Resolve an expression in the hybrid context.
103
+
104
+ Strategy:
105
+ 1. If it's an aggregate function -> create new aggregate using input context
106
+ 2. If it's an alias to existing aggregate -> use aggregated context
107
+ 3. If it's a grouping column -> try aggregated context first, fall back to input context
108
+ (handles exclude_grouping_columns=True case)
109
+ 4. Otherwise -> try input context first, then aggregated context
110
+ """
111
+ from snowflake.snowpark_connect.expression.map_expression import map_expression
112
+
113
+ expr_type = exp.WhichOneof("expr_type")
114
+
115
+ # Handle aggregate functions - need to evaluate against input DataFrame
116
+ if self.is_aggregate_function(exp):
117
+ return map_expression(exp, self.input_column_map, self.input_typer)
118
+
119
+ # Handle column references
120
+ if expr_type == "unresolved_attribute":
121
+ column_name = exp.unresolved_attribute.unparsed_identifier
122
+
123
+ # Check if it's an alias to an existing aggregate expression
124
+ if column_name in self.aggregate_aliases:
125
+ # Use the aggregated context to get the alias
126
+ return map_expression(
127
+ exp, self.aggregated_column_map, self.aggregated_typer
128
+ )
129
+
130
+ # Check if it's a grouping column
131
+ if self.is_grouping_column(column_name):
132
+ # Try aggregated context first (for cases where grouping columns are included)
133
+ try:
134
+ return map_expression(
135
+ exp, self.aggregated_column_map, self.aggregated_typer
136
+ )
137
+ except Exception:
138
+ # Fall back to input context if grouping columns were excluded
139
+ # This handles the exclude_grouping_columns=True case
140
+ return map_expression(exp, self.input_column_map, self.input_typer)
141
+
142
+ # Try input context first (for base columns used in new aggregates)
143
+ try:
144
+ return map_expression(exp, self.input_column_map, self.input_typer)
145
+ except Exception:
146
+ # Fall back to aggregated context
147
+ return map_expression(
148
+ exp, self.aggregated_column_map, self.aggregated_typer
149
+ )
150
+
151
+ # For other expression types, try aggregated context first (likely references to computed values)
152
+ try:
153
+ return map_expression(
154
+ exp, self.aggregated_column_map, self.aggregated_typer
155
+ )
156
+ except Exception:
157
+ # Fall back to input context
158
+ return map_expression(exp, self.input_column_map, self.input_typer)
159
+
160
+
161
+ def create_hybrid_column_map_for_having(
162
+ input_df: snowpark.DataFrame,
163
+ input_column_map: ColumnNameMap,
164
+ aggregated_df: snowpark.DataFrame,
165
+ aggregated_column_map: ColumnNameMap,
166
+ aggregate_expressions: List[expressions_proto.Expression],
167
+ grouping_expressions: List[expressions_proto.Expression],
168
+ spark_columns: List[str],
169
+ raw_aggregations: List[tuple[str, TypedColumn]],
170
+ ) -> HybridColumnMap:
171
+ """
172
+ Create a HybridColumnMap instance for HAVING clause resolution.
173
+ """
174
+ # Create typers for both contexts
175
+ input_typer = ExpressionTyper(input_df)
176
+ aggregated_typer = ExpressionTyper(aggregated_df)
177
+
178
+ # Build alias mapping from spark column names to aggregate expressions
179
+ aggregate_aliases = {}
180
+ for i, (spark_name, _) in enumerate(raw_aggregations):
181
+ if i < len(aggregate_expressions):
182
+ aggregate_aliases[spark_name] = aggregate_expressions[i]
183
+
184
+ return HybridColumnMap(
185
+ input_column_map=input_column_map,
186
+ input_typer=input_typer,
187
+ aggregated_column_map=aggregated_column_map,
188
+ aggregated_typer=aggregated_typer,
189
+ aggregate_expressions=aggregate_expressions,
190
+ grouping_expressions=grouping_expressions,
191
+ aggregate_aliases=aggregate_aliases,
192
+ )
@@ -10,6 +10,7 @@ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
10
10
  from tzlocal import get_localzone
11
11
 
12
12
  from snowflake.snowpark_connect.config import global_config
13
+ from snowflake.snowpark_connect.utils.context import get_is_evaluating_sql
13
14
  from snowflake.snowpark_connect.utils.telemetry import (
14
15
  SnowparkConnectNotImplementedError,
15
16
  )
@@ -47,17 +48,7 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
47
48
  ).date()
48
49
  return date, f"DATE '{date}'"
49
50
  case "timestamp" | "timestamp_ntz" as t:
50
- # Note - Clients need to ensure local_timezone is the same as spark_sql_session_timeZone config.
51
- # No need to apply timezone for lit datetime, because we set the TIMEZONE parameter in snowpark session,
52
- # the snowflake backend would convert the lit datetime correctly. However, for returned column name, the
53
- # timezone needs to be added. Pyspark has a weird behavior that datetime.datetime always gets converted
54
- # to local timezone before printing according to spark_sql_session_timeZone setting. Haven't found
55
- # official doc about it, but this behavior is based on my testings.
56
- tz = (
57
- ZoneInfo(global_config.spark_sql_session_timeZone)
58
- if hasattr(global_config, "spark_sql_session_timeZone")
59
- else get_localzone()
60
- )
51
+ local_tz = get_localzone()
61
52
  if t == "timestamp":
62
53
  microseconds = literal.timestamp
63
54
  else:
@@ -66,11 +57,17 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
66
57
  microseconds // 1_000_000
67
58
  ) + datetime.timedelta(microseconds=microseconds % 1_000_000)
68
59
  tz_dt = datetime.datetime.fromtimestamp(
69
- microseconds // 1_000_000, tz=tz
60
+ microseconds // 1_000_000, tz=local_tz
70
61
  ) + datetime.timedelta(microseconds=microseconds % 1_000_000)
71
62
  if t == "timestamp_ntz":
72
63
  lit_dt = lit_dt.astimezone(datetime.timezone.utc)
73
64
  tz_dt = tz_dt.astimezone(datetime.timezone.utc)
65
+ elif not get_is_evaluating_sql():
66
+ config_tz = global_config.spark_sql_session_timeZone
67
+ config_tz = ZoneInfo(config_tz) if config_tz else local_tz
68
+ tz_dt = tz_dt.astimezone(config_tz)
69
+ lit_dt = lit_dt.astimezone(local_tz)
70
+
74
71
  return lit_dt, f"{t.upper()} '{tz_dt.strftime('%Y-%m-%d %H:%M:%S')}'"
75
72
  case "day_time_interval":
76
73
  # TODO(SNOW-1920942): Snowflake SQL is missing an "interval" type.
@@ -84,6 +81,11 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
84
81
  case "decimal":
85
82
  # literal.decimal.precision & scale are ignored, as decimal.Decimal doesn't accept them
86
83
  return decimal.Decimal(literal.decimal.value), literal.decimal.value
84
+ case "array":
85
+ array_values, element_names = zip(
86
+ *(get_literal_field_and_name(e) for e in literal.array.elements)
87
+ )
88
+ return array_values, f"ARRAY({', '.join(element_names)})"
87
89
  case "null" | None:
88
90
  return None, "NULL"
89
91
  case other:
@@ -127,10 +127,11 @@ def map_cast(
127
127
  from_type = StringType()
128
128
  if isinstance(to_type, StringType):
129
129
  to_type = StringType()
130
+
131
+ # todo - verify if that's correct SNOW-2248680
130
132
  if isinstance(from_type, TimestampType):
131
133
  from_type = TimestampType()
132
- if isinstance(to_type, TimestampType):
133
- to_type = TimestampType()
134
+
134
135
  match (from_type, to_type):
135
136
  case (_, _) if (from_type == to_type):
136
137
  result_exp = col
@@ -185,6 +186,17 @@ def map_cast(
185
186
  case (DateType(), TimestampType()):
186
187
  result_exp = snowpark_fn.to_timestamp(col)
187
188
  result_exp = result_exp.cast(TimestampType(TimestampTimeZone.NTZ))
189
+ case (TimestampType() as f, TimestampType() as t) if f.tzinfo == t.tzinfo:
190
+ result_exp = col
191
+ case (
192
+ TimestampType(),
193
+ TimestampType() as t,
194
+ ) if t.tzinfo == TimestampTimeZone.NTZ:
195
+ zone = global_config.spark_sql_session_timeZone
196
+ result_exp = snowpark_fn.convert_timezone(snowpark_fn.lit(zone), col).cast(
197
+ TimestampType(TimestampTimeZone.NTZ)
198
+ )
199
+ # todo: verify if more support for LTZ and TZ is needed - SNOW-2248680
188
200
  case (TimestampType(), TimestampType()):
189
201
  result_exp = col
190
202
  case (_, TimestampType()) if isinstance(from_type, _NumericType):
@@ -259,8 +271,12 @@ def map_cast(
259
271
  case (_, _) if isinstance(from_type, _FractionalType) and isinstance(
260
272
  to_type, _IntegralType
261
273
  ):
262
- result_exp = snowpark_fn.when(col < 0, snowpark_fn.ceil(col)).otherwise(
263
- snowpark_fn.floor(col)
274
+ result_exp = (
275
+ snowpark_fn.when(
276
+ col == snowpark_fn.lit(float("nan")), snowpark_fn.lit(0)
277
+ )
278
+ .when(col < 0, snowpark_fn.ceil(col))
279
+ .otherwise(snowpark_fn.floor(col))
264
280
  )
265
281
  result_exp = result_exp.cast(to_type)
266
282
  case (StringType(), _) if (isinstance(to_type, _IntegralType)):
@@ -11,6 +11,7 @@ import snowflake.snowpark.functions as snowpark_fn
11
11
  from snowflake import snowpark
12
12
  from snowflake.snowpark import Session
13
13
  from snowflake.snowpark._internal.analyzer.expression import UnresolvedAttribute
14
+ from snowflake.snowpark.types import TimestampTimeZone, TimestampType
14
15
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
15
16
  from snowflake.snowpark_connect.expression import (
16
17
  map_extension,
@@ -26,7 +27,10 @@ from snowflake.snowpark_connect.expression.literal import get_literal_field_and_
26
27
  from snowflake.snowpark_connect.expression.map_cast import map_cast
27
28
  from snowflake.snowpark_connect.expression.map_sql_expression import map_sql_expr
28
29
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
29
- from snowflake.snowpark_connect.type_mapping import map_simple_types
30
+ from snowflake.snowpark_connect.type_mapping import (
31
+ map_simple_types,
32
+ proto_to_snowpark_type,
33
+ )
30
34
  from snowflake.snowpark_connect.typed_column import TypedColumn
31
35
  from snowflake.snowpark_connect.utils.context import (
32
36
  gen_sql_plan_id,
@@ -165,6 +169,12 @@ def map_expression(
165
169
  lambda: [map_simple_types(lit_type_str)],
166
170
  )
167
171
 
172
+ if lit_type_str == "array":
173
+ result_exp = snowpark_fn.lit(lit_value)
174
+ element_types = proto_to_snowpark_type(exp.literal.array.element_type)
175
+ array_type = snowpark.types.ArrayType(element_types)
176
+ return [lit_name], TypedColumn(result_exp, lambda: [array_type])
177
+
168
178
  # Decimal needs further processing to get the precision and scale properly.
169
179
  if lit_type_str == "decimal":
170
180
  # Precision and scale are optional in the proto.
@@ -190,9 +200,15 @@ def map_expression(
190
200
  return [lit_name], TypedColumn(
191
201
  snowpark_fn.lit(lit_value, return_type), lambda: [return_type]
192
202
  )
203
+ result_exp = snowpark_fn.lit(lit_value)
204
+
205
+ if lit_type_str == "timestamp_ntz" and isinstance(
206
+ lit_value, datetime.datetime
207
+ ):
208
+ result_exp = result_exp.cast(TimestampType(TimestampTimeZone.NTZ))
193
209
 
194
210
  return [lit_name], TypedColumn(
195
- snowpark_fn.lit(lit_value), lambda: [map_simple_types(lit_type_str)]
211
+ result_exp, lambda: [map_simple_types(lit_type_str)]
196
212
  )
197
213
  case "sort_order":
198
214
  child_name, child_column = map_single_column_expression(
@@ -10,7 +10,10 @@ from snowflake.snowpark.types import BooleanType
10
10
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
11
11
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
12
12
  from snowflake.snowpark_connect.typed_column import TypedColumn
13
- from snowflake.snowpark_connect.utils.context import push_evaluating_sql_scope
13
+ from snowflake.snowpark_connect.utils.context import (
14
+ push_evaluating_sql_scope,
15
+ push_outer_dataframe,
16
+ )
14
17
  from snowflake.snowpark_connect.utils.telemetry import (
15
18
  SnowparkConnectNotImplementedError,
16
19
  )
@@ -52,12 +55,19 @@ def map_extension(
52
55
  return [name], typed_col
53
56
 
54
57
  case "subquery_expression":
58
+ from snowflake.snowpark_connect.dataframe_container import (
59
+ DataFrameContainer,
60
+ )
55
61
  from snowflake.snowpark_connect.expression.map_expression import (
56
62
  map_expression,
57
63
  )
58
64
  from snowflake.snowpark_connect.relation.map_relation import map_relation
59
65
 
60
- with push_evaluating_sql_scope():
66
+ current_outer_df = DataFrameContainer(
67
+ dataframe=typer.df, column_map=column_mapping
68
+ )
69
+
70
+ with push_evaluating_sql_scope(), push_outer_dataframe(current_outer_df):
61
71
  df_container = map_relation(extension.subquery_expression.input)
62
72
  df = df_container.dataframe
63
73