snowpark-connect 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (56) hide show
  1. snowflake/snowpark_connect/config.py +19 -14
  2. snowflake/snowpark_connect/error/error_utils.py +32 -0
  3. snowflake/snowpark_connect/error/exceptions.py +4 -0
  4. snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
  5. snowflake/snowpark_connect/expression/literal.py +9 -12
  6. snowflake/snowpark_connect/expression/map_cast.py +20 -4
  7. snowflake/snowpark_connect/expression/map_expression.py +8 -1
  8. snowflake/snowpark_connect/expression/map_udf.py +4 -4
  9. snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
  10. snowflake/snowpark_connect/expression/map_unresolved_function.py +269 -134
  11. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
  12. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
  13. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
  14. snowflake/snowpark_connect/relation/map_aggregate.py +154 -18
  15. snowflake/snowpark_connect/relation/map_column_ops.py +59 -8
  16. snowflake/snowpark_connect/relation/map_extension.py +58 -24
  17. snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
  18. snowflake/snowpark_connect/relation/map_map_partitions.py +3 -1
  19. snowflake/snowpark_connect/relation/map_row_ops.py +30 -1
  20. snowflake/snowpark_connect/relation/map_sql.py +40 -196
  21. snowflake/snowpark_connect/relation/map_udtf.py +4 -4
  22. snowflake/snowpark_connect/relation/read/map_read.py +2 -1
  23. snowflake/snowpark_connect/relation/read/map_read_json.py +12 -1
  24. snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
  25. snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
  26. snowflake/snowpark_connect/relation/read/utils.py +7 -6
  27. snowflake/snowpark_connect/relation/utils.py +170 -1
  28. snowflake/snowpark_connect/relation/write/map_write.py +306 -87
  29. snowflake/snowpark_connect/server.py +34 -5
  30. snowflake/snowpark_connect/type_mapping.py +6 -2
  31. snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
  32. snowflake/snowpark_connect/utils/env_utils.py +55 -0
  33. snowflake/snowpark_connect/utils/session.py +21 -4
  34. snowflake/snowpark_connect/utils/telemetry.py +213 -61
  35. snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
  36. snowflake/snowpark_connect/version.py +1 -1
  37. snowflake/snowpark_decoder/__init__.py +0 -0
  38. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
  39. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
  40. snowflake/snowpark_decoder/dp_session.py +111 -0
  41. snowflake/snowpark_decoder/spark_decoder.py +76 -0
  42. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/METADATA +2 -2
  43. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/RECORD +55 -44
  44. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/top_level.txt +1 -0
  45. spark/__init__.py +0 -0
  46. spark/connect/__init__.py +0 -0
  47. spark/connect/envelope_pb2.py +31 -0
  48. spark/connect/envelope_pb2.pyi +46 -0
  49. snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
  50. {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-connect +0 -0
  51. {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-session +0 -0
  52. {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-submit +0 -0
  53. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/WHEEL +0 -0
  54. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE-binary +0 -0
  55. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE.txt +0 -0
  56. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/NOTICE-binary +0 -0
@@ -4,10 +4,8 @@
4
4
 
5
5
  # Proto source for reference:
6
6
  # https://github.com/apache/spark/blob/branch-3.5/connector/connect/common/src/main/protobuf/spark/connect/base.proto#L420
7
- import os
8
7
  import re
9
8
  import sys
10
- import time
11
9
  from collections import defaultdict
12
10
  from copy import copy, deepcopy
13
11
  from typing import Any
@@ -168,6 +166,9 @@ class GlobalConfig:
168
166
  "snowpark.connect.udf.packages": lambda session, packages: session.add_packages(
169
167
  *packages.strip("[] ").split(",")
170
168
  ),
169
+ "snowpark.connect.udf.imports": lambda session, imports: parse_imports(
170
+ session, imports
171
+ ),
171
172
  }
172
173
 
173
174
  float_config_list = []
@@ -332,7 +333,7 @@ def route_config_proto(
332
333
  match op_type:
333
334
  case "set":
334
335
  logger.info("SET")
335
-
336
+ telemetry.report_config_set(config.operation.set.pairs)
336
337
  for pair in config.operation.set.pairs:
337
338
  # Check if the value field is present, not present when invalid fields are set in conf.
338
339
  if not pair.HasField("value"):
@@ -342,7 +343,6 @@ def route_config_proto(
342
343
  f"Cannot set config '{pair.key}' to None"
343
344
  )
344
345
 
345
- telemetry.report_config_set(pair.key, pair.value)
346
346
  set_config_param(
347
347
  config.session_id, pair.key, pair.value, snowpark_session
348
348
  )
@@ -350,14 +350,15 @@ def route_config_proto(
350
350
  return proto_base.ConfigResponse(session_id=config.session_id)
351
351
  case "unset":
352
352
  logger.info("UNSET")
353
+ telemetry.report_config_unset(config.operation.unset.keys)
353
354
  for key in config.operation.unset.keys:
354
- telemetry.report_config_unset(key)
355
355
  unset_config_param(config.session_id, key, snowpark_session)
356
356
 
357
357
  return proto_base.ConfigResponse(session_id=config.session_id)
358
358
  case "get":
359
359
  logger.info("GET")
360
360
  res = proto_base.ConfigResponse(session_id=config.session_id)
361
+ telemetry.report_config_get(config.operation.get.keys)
361
362
  for key in config.operation.get.keys:
362
363
  pair = res.pairs.add()
363
364
  pair.key = key
@@ -367,6 +368,9 @@ def route_config_proto(
367
368
  return res
368
369
  case "get_with_default":
369
370
  logger.info("GET_WITH_DEFAULT")
371
+ telemetry.report_config_get(
372
+ [pair.key for pair in config.operation.get_with_default.pairs]
373
+ )
370
374
  result_pairs = [
371
375
  proto_base.KeyValue(
372
376
  key=pair.key,
@@ -383,6 +387,7 @@ def route_config_proto(
383
387
  case "get_option":
384
388
  logger.info("GET_OPTION")
385
389
  res = proto_base.ConfigResponse(session_id=config.session_id)
390
+ telemetry.report_config_get(config.operation.get_option.keys)
386
391
  for key in config.operation.get_option.keys:
387
392
  pair = res.pairs.add()
388
393
  pair.key = key
@@ -411,6 +416,7 @@ def route_config_proto(
411
416
  case "is_modifiable":
412
417
  logger.info("IS_MODIFIABLE")
413
418
  res = proto_base.ConfigResponse(session_id=config.session_id)
419
+ telemetry.report_config_get(config.operation.is_modifiable.keys)
414
420
  for key in config.operation.is_modifiable.keys:
415
421
  pair = res.pairs.add()
416
422
  pair.key = key
@@ -525,17 +531,8 @@ def set_snowflake_parameters(
525
531
  snowpark_session.sql(
526
532
  f"ALTER SESSION SET TIMEZONE = '{value}'"
527
533
  ).collect()
528
- set_jvm_timezone(value)
529
- if hasattr(time, "tzset"):
530
- os.environ["TZ"] = value
531
- time.tzset()
532
534
  else:
533
535
  snowpark_session.sql("ALTER SESSION UNSET TIMEZONE").collect()
534
- reset_jvm_timezone_to_system_default()
535
- if hasattr(time, "tzset") and False:
536
- if "TZ" in os.environ:
537
- del os.environ["TZ"]
538
- time.tzset()
539
536
  case "spark.sql.globalTempDatabase":
540
537
  if not value:
541
538
  value = global_config.default_static_global_config.get(key)
@@ -588,3 +585,11 @@ def auto_uppercase_non_column_identifiers() -> bool:
588
585
  return session_config[
589
586
  "snowpark.connect.sql.identifiers.auto-uppercase"
590
587
  ].lower() in ("all", "all_except_columns")
588
+
589
+
590
+ def parse_imports(session: snowpark.Session, imports: str | None) -> None:
591
+ if not imports:
592
+ return
593
+
594
+ for udf_import in imports.strip("[] ").split(","):
595
+ session.add_import(udf_import)
@@ -28,7 +28,9 @@ from pyspark.errors.exceptions.base import (
28
28
  PySparkException,
29
29
  PythonException,
30
30
  SparkRuntimeException,
31
+ UnsupportedOperationException,
31
32
  )
33
+ from pyspark.errors.exceptions.connect import SparkConnectGrpcException
32
34
  from snowflake.core.exceptions import NotFoundError
33
35
 
34
36
  from snowflake.connector.errors import ProgrammingError
@@ -49,7 +51,9 @@ SPARK_PYTHON_TO_JAVA_EXCEPTION = {
49
51
  ArrayIndexOutOfBoundsException: "java.lang.ArrayIndexOutOfBoundsException",
50
52
  NumberFormatException: "java.lang.NumberFormatException",
51
53
  SparkRuntimeException: "org.apache.spark.SparkRuntimeException",
54
+ SparkConnectGrpcException: "pyspark.errors.exceptions.connect.SparkConnectGrpcException",
52
55
  PythonException: "org.apache.spark.api.python.PythonException",
56
+ UnsupportedOperationException: "java.lang.UnsupportedOperationException",
53
57
  }
54
58
 
55
59
  WINDOW_FUNCTION_ANALYSIS_EXCEPTION_SQL_ERROR_CODE = {1005, 2303}
@@ -68,6 +72,13 @@ init_multi_args_exception_pattern = (
68
72
  terminate_multi_args_exception_pattern = (
69
73
  r"terminate\(\) missing \d+ required positional argument"
70
74
  )
75
+ snowpark_connect_exception_pattern = re.compile(
76
+ r"\[snowpark-connect-exception(?::(\w+))?\]\s*(.+?)'\s*is not recognized"
77
+ )
78
+ invalid_bit_pattern = re.compile(
79
+ r"Invalid bit position: \d+ exceeds the bit (?:upper|lower) limit",
80
+ re.IGNORECASE,
81
+ )
71
82
 
72
83
 
73
84
  def contains_udtf_select(sql_string):
@@ -100,6 +111,22 @@ def _get_converted_known_sql_or_custom_exception(
100
111
  return SparkRuntimeException(
101
112
  message="Unexpected value for start in function slice: SQL array indices start at 1."
102
113
  )
114
+ invalid_bit = invalid_bit_pattern.search(msg)
115
+ if invalid_bit:
116
+ return IllegalArgumentException(message=invalid_bit.group(0))
117
+ match = snowpark_connect_exception_pattern.search(
118
+ ex.message if hasattr(ex, "message") else str(ex)
119
+ )
120
+ if match:
121
+ class_name = match.group(1)
122
+ message = match.group(2)
123
+ exception_class = (
124
+ globals().get(class_name, SparkConnectGrpcException)
125
+ if class_name
126
+ else SparkConnectGrpcException
127
+ )
128
+ return exception_class(message=message)
129
+
103
130
  if "select with no columns" in msg and contains_udtf_select(query):
104
131
  # We try our best to detect if the SQL string contains a UDTF call and the output schema is empty.
105
132
  return PythonException(message=f"[UDTF_RETURN_SCHEMA_MISMATCH] {ex.message}")
@@ -131,6 +158,11 @@ def _get_converted_known_sql_or_custom_exception(
131
158
  message=f"[UDTF_EXEC_ERROR] User defined table function encountered an error in the terminate method: {ex.message}"
132
159
  )
133
160
 
161
+ if "failed to split string, provided pattern:" in msg:
162
+ return IllegalArgumentException(
163
+ message=f"Failed to split string using provided pattern. {ex.message}"
164
+ )
165
+
134
166
  if "100357" in msg and "wrong tuple size for returned value" in msg:
135
167
  return PythonException(
136
168
  message=f"[UDTF_RETURN_SCHEMA_MISMATCH] The number of columns in the result does not match the specified schema. {ex.message}"
@@ -22,3 +22,7 @@ class MissingSchema(SnowparkConnectException):
22
22
  super().__init__(
23
23
  "No default schema found in session",
24
24
  )
25
+
26
+
27
+ class MaxRetryExceeded(SnowparkConnectException):
28
+ ...
@@ -0,0 +1,192 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ """
6
+ Hybrid column mapping for HAVING clause resolution.
7
+
8
+ This module provides a special column mapping that can resolve expressions
9
+ in the context of both the input DataFrame (for base columns) and the
10
+ aggregated DataFrame (for aggregate expressions and aliases).
11
+ """
12
+
13
+ from typing import Dict, List
14
+
15
+ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
16
+
17
+ from snowflake import snowpark
18
+ from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
19
+ from snowflake.snowpark_connect.expression.typer import ExpressionTyper
20
+ from snowflake.snowpark_connect.typed_column import TypedColumn
21
+
22
+
23
+ class HybridColumnMap:
24
+ """
25
+ A column mapping that can resolve expressions in both input and aggregated contexts.
26
+
27
+ This is specifically designed for HAVING clause resolution where expressions may reference:
28
+ 1. Base columns from the input DataFrame (to build new aggregates)
29
+ 2. Existing aggregate expressions and their aliases
30
+ 3. Grouping columns
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ input_column_map: ColumnNameMap,
36
+ input_typer: ExpressionTyper,
37
+ aggregated_column_map: ColumnNameMap,
38
+ aggregated_typer: ExpressionTyper,
39
+ aggregate_expressions: List[expressions_proto.Expression],
40
+ grouping_expressions: List[expressions_proto.Expression],
41
+ aggregate_aliases: Dict[str, expressions_proto.Expression],
42
+ ) -> None:
43
+ self.input_column_map = input_column_map
44
+ self.input_typer = input_typer
45
+ self.aggregated_column_map = aggregated_column_map
46
+ self.aggregated_typer = aggregated_typer
47
+ self.aggregate_expressions = aggregate_expressions
48
+ self.grouping_expressions = grouping_expressions
49
+ self.aggregate_aliases = aggregate_aliases
50
+
51
+ def is_aggregate_function(self, exp: expressions_proto.Expression) -> bool:
52
+ """Check if an expression is an aggregate function."""
53
+ if exp.WhichOneof("expr_type") == "unresolved_function":
54
+ func_name = exp.unresolved_function.function_name.lower()
55
+ # Common aggregate functions - expand this list as needed
56
+ aggregate_functions = {
57
+ "avg",
58
+ "average",
59
+ "sum",
60
+ "count",
61
+ "min",
62
+ "max",
63
+ "stddev",
64
+ "stddev_pop",
65
+ "stddev_samp",
66
+ "variance",
67
+ "var_pop",
68
+ "var_samp",
69
+ "collect_list",
70
+ "collect_set",
71
+ "first",
72
+ "last",
73
+ "any_value",
74
+ "bool_and",
75
+ "bool_or",
76
+ "corr",
77
+ "covar_pop",
78
+ "covar_samp",
79
+ "kurtosis",
80
+ "skewness",
81
+ "percentile_cont",
82
+ "percentile_disc",
83
+ "approx_count_distinct",
84
+ }
85
+ return func_name in aggregate_functions
86
+ return False
87
+
88
+ def is_grouping_column(self, column_name: str) -> bool:
89
+ """Check if a column name refers to a grouping column."""
90
+ for group_exp in self.grouping_expressions:
91
+ if (
92
+ group_exp.WhichOneof("expr_type") == "unresolved_attribute"
93
+ and group_exp.unresolved_attribute.unparsed_identifier == column_name
94
+ ):
95
+ return True
96
+ return False
97
+
98
+ def resolve_expression(
99
+ self, exp: expressions_proto.Expression
100
+ ) -> tuple[list[str], TypedColumn]:
101
+ """
102
+ Resolve an expression in the hybrid context.
103
+
104
+ Strategy:
105
+ 1. If it's an aggregate function -> create new aggregate using input context
106
+ 2. If it's an alias to existing aggregate -> use aggregated context
107
+ 3. If it's a grouping column -> try aggregated context first, fall back to input context
108
+ (handles exclude_grouping_columns=True case)
109
+ 4. Otherwise -> try input context first, then aggregated context
110
+ """
111
+ from snowflake.snowpark_connect.expression.map_expression import map_expression
112
+
113
+ expr_type = exp.WhichOneof("expr_type")
114
+
115
+ # Handle aggregate functions - need to evaluate against input DataFrame
116
+ if self.is_aggregate_function(exp):
117
+ return map_expression(exp, self.input_column_map, self.input_typer)
118
+
119
+ # Handle column references
120
+ if expr_type == "unresolved_attribute":
121
+ column_name = exp.unresolved_attribute.unparsed_identifier
122
+
123
+ # Check if it's an alias to an existing aggregate expression
124
+ if column_name in self.aggregate_aliases:
125
+ # Use the aggregated context to get the alias
126
+ return map_expression(
127
+ exp, self.aggregated_column_map, self.aggregated_typer
128
+ )
129
+
130
+ # Check if it's a grouping column
131
+ if self.is_grouping_column(column_name):
132
+ # Try aggregated context first (for cases where grouping columns are included)
133
+ try:
134
+ return map_expression(
135
+ exp, self.aggregated_column_map, self.aggregated_typer
136
+ )
137
+ except Exception:
138
+ # Fall back to input context if grouping columns were excluded
139
+ # This handles the exclude_grouping_columns=True case
140
+ return map_expression(exp, self.input_column_map, self.input_typer)
141
+
142
+ # Try input context first (for base columns used in new aggregates)
143
+ try:
144
+ return map_expression(exp, self.input_column_map, self.input_typer)
145
+ except Exception:
146
+ # Fall back to aggregated context
147
+ return map_expression(
148
+ exp, self.aggregated_column_map, self.aggregated_typer
149
+ )
150
+
151
+ # For other expression types, try aggregated context first (likely references to computed values)
152
+ try:
153
+ return map_expression(
154
+ exp, self.aggregated_column_map, self.aggregated_typer
155
+ )
156
+ except Exception:
157
+ # Fall back to input context
158
+ return map_expression(exp, self.input_column_map, self.input_typer)
159
+
160
+
161
+ def create_hybrid_column_map_for_having(
162
+ input_df: snowpark.DataFrame,
163
+ input_column_map: ColumnNameMap,
164
+ aggregated_df: snowpark.DataFrame,
165
+ aggregated_column_map: ColumnNameMap,
166
+ aggregate_expressions: List[expressions_proto.Expression],
167
+ grouping_expressions: List[expressions_proto.Expression],
168
+ spark_columns: List[str],
169
+ raw_aggregations: List[tuple[str, TypedColumn]],
170
+ ) -> HybridColumnMap:
171
+ """
172
+ Create a HybridColumnMap instance for HAVING clause resolution.
173
+ """
174
+ # Create typers for both contexts
175
+ input_typer = ExpressionTyper(input_df)
176
+ aggregated_typer = ExpressionTyper(aggregated_df)
177
+
178
+ # Build alias mapping from spark column names to aggregate expressions
179
+ aggregate_aliases = {}
180
+ for i, (spark_name, _) in enumerate(raw_aggregations):
181
+ if i < len(aggregate_expressions):
182
+ aggregate_aliases[spark_name] = aggregate_expressions[i]
183
+
184
+ return HybridColumnMap(
185
+ input_column_map=input_column_map,
186
+ input_typer=input_typer,
187
+ aggregated_column_map=aggregated_column_map,
188
+ aggregated_typer=aggregated_typer,
189
+ aggregate_expressions=aggregate_expressions,
190
+ grouping_expressions=grouping_expressions,
191
+ aggregate_aliases=aggregate_aliases,
192
+ )
@@ -10,6 +10,7 @@ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
10
10
  from tzlocal import get_localzone
11
11
 
12
12
  from snowflake.snowpark_connect.config import global_config
13
+ from snowflake.snowpark_connect.utils.context import get_is_evaluating_sql
13
14
  from snowflake.snowpark_connect.utils.telemetry import (
14
15
  SnowparkConnectNotImplementedError,
15
16
  )
@@ -47,17 +48,7 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
47
48
  ).date()
48
49
  return date, f"DATE '{date}'"
49
50
  case "timestamp" | "timestamp_ntz" as t:
50
- # Note - Clients need to ensure local_timezone is the same as spark_sql_session_timeZone config.
51
- # No need to apply timezone for lit datetime, because we set the TIMEZONE parameter in snowpark session,
52
- # the snowflake backend would convert the lit datetime correctly. However, for returned column name, the
53
- # timezone needs to be added. Pyspark has a weird behavior that datetime.datetime always gets converted
54
- # to local timezone before printing according to spark_sql_session_timeZone setting. Haven't found
55
- # official doc about it, but this behavior is based on my testings.
56
- tz = (
57
- ZoneInfo(global_config.spark_sql_session_timeZone)
58
- if hasattr(global_config, "spark_sql_session_timeZone")
59
- else get_localzone()
60
- )
51
+ local_tz = get_localzone()
61
52
  if t == "timestamp":
62
53
  microseconds = literal.timestamp
63
54
  else:
@@ -66,11 +57,17 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
66
57
  microseconds // 1_000_000
67
58
  ) + datetime.timedelta(microseconds=microseconds % 1_000_000)
68
59
  tz_dt = datetime.datetime.fromtimestamp(
69
- microseconds // 1_000_000, tz=tz
60
+ microseconds // 1_000_000, tz=local_tz
70
61
  ) + datetime.timedelta(microseconds=microseconds % 1_000_000)
71
62
  if t == "timestamp_ntz":
72
63
  lit_dt = lit_dt.astimezone(datetime.timezone.utc)
73
64
  tz_dt = tz_dt.astimezone(datetime.timezone.utc)
65
+ elif not get_is_evaluating_sql():
66
+ config_tz = global_config.spark_sql_session_timeZone
67
+ config_tz = ZoneInfo(config_tz) if config_tz else local_tz
68
+ tz_dt = tz_dt.astimezone(config_tz)
69
+ lit_dt = lit_dt.astimezone(local_tz)
70
+
74
71
  return lit_dt, f"{t.upper()} '{tz_dt.strftime('%Y-%m-%d %H:%M:%S')}'"
75
72
  case "day_time_interval":
76
73
  # TODO(SNOW-1920942): Snowflake SQL is missing an "interval" type.
@@ -127,10 +127,11 @@ def map_cast(
127
127
  from_type = StringType()
128
128
  if isinstance(to_type, StringType):
129
129
  to_type = StringType()
130
+
131
+ # todo - verify if that's correct SNOW-2248680
130
132
  if isinstance(from_type, TimestampType):
131
133
  from_type = TimestampType()
132
- if isinstance(to_type, TimestampType):
133
- to_type = TimestampType()
134
+
134
135
  match (from_type, to_type):
135
136
  case (_, _) if (from_type == to_type):
136
137
  result_exp = col
@@ -185,6 +186,17 @@ def map_cast(
185
186
  case (DateType(), TimestampType()):
186
187
  result_exp = snowpark_fn.to_timestamp(col)
187
188
  result_exp = result_exp.cast(TimestampType(TimestampTimeZone.NTZ))
189
+ case (TimestampType() as f, TimestampType() as t) if f.tzinfo == t.tzinfo:
190
+ result_exp = col
191
+ case (
192
+ TimestampType(),
193
+ TimestampType() as t,
194
+ ) if t.tzinfo == TimestampTimeZone.NTZ:
195
+ zone = global_config.spark_sql_session_timeZone
196
+ result_exp = snowpark_fn.convert_timezone(snowpark_fn.lit(zone), col).cast(
197
+ TimestampType(TimestampTimeZone.NTZ)
198
+ )
199
+ # todo: verify if more support for LTZ and TZ is needed - SNOW-2248680
188
200
  case (TimestampType(), TimestampType()):
189
201
  result_exp = col
190
202
  case (_, TimestampType()) if isinstance(from_type, _NumericType):
@@ -259,8 +271,12 @@ def map_cast(
259
271
  case (_, _) if isinstance(from_type, _FractionalType) and isinstance(
260
272
  to_type, _IntegralType
261
273
  ):
262
- result_exp = snowpark_fn.when(col < 0, snowpark_fn.ceil(col)).otherwise(
263
- snowpark_fn.floor(col)
274
+ result_exp = (
275
+ snowpark_fn.when(
276
+ col == snowpark_fn.lit(float("nan")), snowpark_fn.lit(0)
277
+ )
278
+ .when(col < 0, snowpark_fn.ceil(col))
279
+ .otherwise(snowpark_fn.floor(col))
264
280
  )
265
281
  result_exp = result_exp.cast(to_type)
266
282
  case (StringType(), _) if (isinstance(to_type, _IntegralType)):
@@ -11,6 +11,7 @@ import snowflake.snowpark.functions as snowpark_fn
11
11
  from snowflake import snowpark
12
12
  from snowflake.snowpark import Session
13
13
  from snowflake.snowpark._internal.analyzer.expression import UnresolvedAttribute
14
+ from snowflake.snowpark.types import TimestampTimeZone, TimestampType
14
15
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
15
16
  from snowflake.snowpark_connect.expression import (
16
17
  map_extension,
@@ -190,9 +191,15 @@ def map_expression(
190
191
  return [lit_name], TypedColumn(
191
192
  snowpark_fn.lit(lit_value, return_type), lambda: [return_type]
192
193
  )
194
+ result_exp = snowpark_fn.lit(lit_value)
195
+
196
+ if lit_type_str == "timestamp_ntz" and isinstance(
197
+ lit_value, datetime.datetime
198
+ ):
199
+ result_exp = result_exp.cast(TimestampType(TimestampTimeZone.NTZ))
193
200
 
194
201
  return [lit_name], TypedColumn(
195
- snowpark_fn.lit(lit_value), lambda: [map_simple_types(lit_type_str)]
202
+ result_exp, lambda: [map_simple_types(lit_type_str)]
196
203
  )
197
204
  case "sort_order":
198
205
  child_name, child_column = map_single_column_expression(
@@ -13,10 +13,7 @@ from snowflake.snowpark_connect.config import global_config
13
13
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
14
14
  from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
15
15
  from snowflake.snowpark_connect.typed_column import TypedColumn
16
- from snowflake.snowpark_connect.utils.session import (
17
- get_or_create_snowpark_session,
18
- get_python_udxf_import_files,
19
- )
16
+ from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
20
17
  from snowflake.snowpark_connect.utils.udf_helper import (
21
18
  SnowparkUDF,
22
19
  gen_input_types,
@@ -28,6 +25,9 @@ from snowflake.snowpark_connect.utils.udf_helper import (
28
25
  from snowflake.snowpark_connect.utils.udf_utils import (
29
26
  ProcessCommonInlineUserDefinedFunction,
30
27
  )
28
+ from snowflake.snowpark_connect.utils.udxf_import_utils import (
29
+ get_python_udxf_import_files,
30
+ )
31
31
 
32
32
 
33
33
  def process_udf_return_type(
@@ -5,6 +5,7 @@
5
5
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
6
6
 
7
7
  import snowflake.snowpark.functions as snowpark_fn
8
+ from snowflake.snowpark._internal.analyzer.expression import Literal
8
9
  from snowflake.snowpark.types import ArrayType, MapType, StructType, _IntegralType
9
10
  from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
10
11
  from snowflake.snowpark_connect.config import global_config
@@ -57,7 +58,8 @@ def map_unresolved_extract_value(
57
58
  extract_fn = snowpark_fn.get_ignore_case
58
59
  # Set index to a dummy value before we use it later in the ansi mode check.
59
60
  index = snowpark_fn.lit(1)
60
- if _check_if_array_type(extract_typed_column, child_typed_column):
61
+ is_array = _check_if_array_type(extract_typed_column, child_typed_column)
62
+ if is_array:
61
63
  # Set all non-valid array indices to NULL.
62
64
  # This is done because both conditions of a CASE WHEN statement are executed regardless of if the condition is true or not.
63
65
  # Getting a negative index in Snowflake throws an error; thus, we convert all non-valid array indices to NULL before getting the index.
@@ -74,12 +76,37 @@ def map_unresolved_extract_value(
74
76
 
75
77
  spark_sql_ansi_enabled = global_config.spark_sql_ansi_enabled
76
78
 
77
- if spark_sql_ansi_enabled and _check_if_array_type(
78
- extract_typed_column, child_typed_column
79
- ):
79
+ if spark_sql_ansi_enabled and is_array:
80
80
  result_exp = snowpark_fn.when(
81
81
  index.isNull(),
82
82
  child_typed_column.col.getItem("[snowpark_connect::INVALID_ARRAY_INDEX]"),
83
83
  ).otherwise(result_exp)
84
84
 
85
- return spark_function_name, TypedColumn(result_exp, lambda: typer.type(result_exp))
85
+ def _get_extracted_value_type():
86
+ if is_array:
87
+ return [child_typed_column.typ.element_type]
88
+ elif isinstance(child_typed_column.typ, MapType):
89
+ return [child_typed_column.typ.value_type]
90
+ elif (
91
+ isinstance(child_typed_column.typ, StructType)
92
+ and isinstance(extract_typed_column.col._expr1, Literal)
93
+ and isinstance(extract_typed_column.col._expr1.value, str)
94
+ ):
95
+ struct = dict(
96
+ {
97
+ (
98
+ f.name
99
+ if global_config.spark_sql_caseSensitive
100
+ else f.name.lower(),
101
+ f.datatype,
102
+ )
103
+ for f in child_typed_column.typ.fields
104
+ }
105
+ )
106
+ key = extract_typed_column.col._expr1.value
107
+ key = key if global_config.spark_sql_caseSensitive else key.lower()
108
+
109
+ return [struct[key]] if key in struct else typer.type(result_exp)
110
+ return typer.type(result_exp)
111
+
112
+ return spark_function_name, TypedColumn(result_exp, _get_extracted_value_type)