snowpark-connect 0.22.1__py3-none-any.whl → 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +0 -11
- snowflake/snowpark_connect/error/error_utils.py +7 -0
- snowflake/snowpark_connect/error/exceptions.py +4 -0
- snowflake/snowpark_connect/expression/function_defaults.py +207 -0
- snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
- snowflake/snowpark_connect/expression/literal.py +14 -12
- snowflake/snowpark_connect/expression/map_cast.py +20 -4
- snowflake/snowpark_connect/expression/map_expression.py +18 -2
- snowflake/snowpark_connect/expression/map_extension.py +12 -2
- snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
- snowflake/snowpark_connect/expression/map_unresolved_function.py +69 -10
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
- snowflake/snowpark_connect/relation/map_aggregate.py +57 -5
- snowflake/snowpark_connect/relation/map_column_ops.py +6 -5
- snowflake/snowpark_connect/relation/map_extension.py +65 -31
- snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
- snowflake/snowpark_connect/relation/map_row_ops.py +2 -0
- snowflake/snowpark_connect/relation/map_sql.py +22 -5
- snowflake/snowpark_connect/relation/read/map_read.py +2 -1
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
- snowflake/snowpark_connect/relation/read/reader_config.py +9 -0
- snowflake/snowpark_connect/relation/write/map_write.py +243 -68
- snowflake/snowpark_connect/server.py +25 -5
- snowflake/snowpark_connect/type_mapping.py +2 -2
- snowflake/snowpark_connect/utils/env_utils.py +55 -0
- snowflake/snowpark_connect/utils/session.py +21 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/spark_decoder.py +1 -1
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/METADATA +2 -2
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/RECORD +44 -39
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
- {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.22.1.data → snowpark_connect-0.24.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.24.0.dist-info}/top_level.txt +0 -0
|
@@ -4,10 +4,8 @@
|
|
|
4
4
|
|
|
5
5
|
# Proto source for reference:
|
|
6
6
|
# https://github.com/apache/spark/blob/branch-3.5/connector/connect/common/src/main/protobuf/spark/connect/base.proto#L420
|
|
7
|
-
import os
|
|
8
7
|
import re
|
|
9
8
|
import sys
|
|
10
|
-
import time
|
|
11
9
|
from collections import defaultdict
|
|
12
10
|
from copy import copy, deepcopy
|
|
13
11
|
from typing import Any
|
|
@@ -533,17 +531,8 @@ def set_snowflake_parameters(
|
|
|
533
531
|
snowpark_session.sql(
|
|
534
532
|
f"ALTER SESSION SET TIMEZONE = '{value}'"
|
|
535
533
|
).collect()
|
|
536
|
-
set_jvm_timezone(value)
|
|
537
|
-
if hasattr(time, "tzset"):
|
|
538
|
-
os.environ["TZ"] = value
|
|
539
|
-
time.tzset()
|
|
540
534
|
else:
|
|
541
535
|
snowpark_session.sql("ALTER SESSION UNSET TIMEZONE").collect()
|
|
542
|
-
reset_jvm_timezone_to_system_default()
|
|
543
|
-
if hasattr(time, "tzset") and False:
|
|
544
|
-
if "TZ" in os.environ:
|
|
545
|
-
del os.environ["TZ"]
|
|
546
|
-
time.tzset()
|
|
547
536
|
case "spark.sql.globalTempDatabase":
|
|
548
537
|
if not value:
|
|
549
538
|
value = global_config.default_static_global_config.get(key)
|
|
@@ -75,6 +75,10 @@ terminate_multi_args_exception_pattern = (
|
|
|
75
75
|
snowpark_connect_exception_pattern = re.compile(
|
|
76
76
|
r"\[snowpark-connect-exception(?::(\w+))?\]\s*(.+?)'\s*is not recognized"
|
|
77
77
|
)
|
|
78
|
+
invalid_bit_pattern = re.compile(
|
|
79
|
+
r"Invalid bit position: \d+ exceeds the bit (?:upper|lower) limit",
|
|
80
|
+
re.IGNORECASE,
|
|
81
|
+
)
|
|
78
82
|
|
|
79
83
|
|
|
80
84
|
def contains_udtf_select(sql_string):
|
|
@@ -107,6 +111,9 @@ def _get_converted_known_sql_or_custom_exception(
|
|
|
107
111
|
return SparkRuntimeException(
|
|
108
112
|
message="Unexpected value for start in function slice: SQL array indices start at 1."
|
|
109
113
|
)
|
|
114
|
+
invalid_bit = invalid_bit_pattern.search(msg)
|
|
115
|
+
if invalid_bit:
|
|
116
|
+
return IllegalArgumentException(message=invalid_bit.group(0))
|
|
110
117
|
match = snowpark_connect_exception_pattern.search(
|
|
111
118
|
ex.message if hasattr(ex, "message") else str(ex)
|
|
112
119
|
)
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import pyspark.sql.connect.proto.expressions_pb2 as expressions_pb2
|
|
8
|
+
import pyspark.sql.connect.proto.types_pb2 as types_pb2
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class DefaultParameter:
|
|
13
|
+
"""Represents a single default parameter for a function."""
|
|
14
|
+
|
|
15
|
+
name: str
|
|
16
|
+
value: Any
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class FunctionDefaults:
|
|
21
|
+
"""Represents default parameter configuration for a function."""
|
|
22
|
+
|
|
23
|
+
total_args: int
|
|
24
|
+
defaults: list[DefaultParameter]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# FUNCTION_DEFAULTS dictionary to hold operation name with default values.
|
|
28
|
+
# This is required as non pyspark clients such as scala or sql won't send all the parameters.
|
|
29
|
+
# We use this dict to inject the missing parameters before processing the unresolved function.
|
|
30
|
+
FUNCTION_DEFAULTS: dict[str, FunctionDefaults] = {
|
|
31
|
+
"aes_decrypt": FunctionDefaults(
|
|
32
|
+
total_args=5,
|
|
33
|
+
defaults=[
|
|
34
|
+
DefaultParameter("mode", "GCM"), # Spark SQL default: GCM
|
|
35
|
+
DefaultParameter("padding", "NONE"), # Spark SQL default: NONE for GCM mode
|
|
36
|
+
DefaultParameter("aad", ""), # Spark SQL default: empty string
|
|
37
|
+
],
|
|
38
|
+
),
|
|
39
|
+
"aes_encrypt": FunctionDefaults(
|
|
40
|
+
total_args=6,
|
|
41
|
+
defaults=[
|
|
42
|
+
DefaultParameter("mode", "GCM"), # Spark SQL default: GCM
|
|
43
|
+
DefaultParameter("padding", "NONE"), # Spark SQL default: NONE for GCM mode
|
|
44
|
+
DefaultParameter(
|
|
45
|
+
"iv", ""
|
|
46
|
+
), # Spark SQL default: empty string (random generated if not provided)
|
|
47
|
+
DefaultParameter("aad", ""), # Spark SQL default: empty string
|
|
48
|
+
],
|
|
49
|
+
),
|
|
50
|
+
"approx_percentile": FunctionDefaults(
|
|
51
|
+
total_args=3,
|
|
52
|
+
defaults=[DefaultParameter("accuracy", 10000)],
|
|
53
|
+
),
|
|
54
|
+
"bround": FunctionDefaults(
|
|
55
|
+
total_args=2,
|
|
56
|
+
defaults=[DefaultParameter("scale", 0)],
|
|
57
|
+
),
|
|
58
|
+
"first": FunctionDefaults(
|
|
59
|
+
total_args=2,
|
|
60
|
+
defaults=[DefaultParameter("ignorenulls", False)],
|
|
61
|
+
),
|
|
62
|
+
"lag": FunctionDefaults(
|
|
63
|
+
total_args=2,
|
|
64
|
+
defaults=[
|
|
65
|
+
DefaultParameter("offset", 1),
|
|
66
|
+
],
|
|
67
|
+
),
|
|
68
|
+
"last": FunctionDefaults(
|
|
69
|
+
total_args=2,
|
|
70
|
+
defaults=[DefaultParameter("ignorenulls", False)],
|
|
71
|
+
),
|
|
72
|
+
"lead": FunctionDefaults(
|
|
73
|
+
total_args=3,
|
|
74
|
+
defaults=[DefaultParameter("offset", 1), DefaultParameter("default", None)],
|
|
75
|
+
),
|
|
76
|
+
"locate": FunctionDefaults(
|
|
77
|
+
total_args=3,
|
|
78
|
+
defaults=[DefaultParameter("pos", 1)],
|
|
79
|
+
),
|
|
80
|
+
"months_between": FunctionDefaults(
|
|
81
|
+
total_args=3,
|
|
82
|
+
defaults=[DefaultParameter("roundOff", True)],
|
|
83
|
+
),
|
|
84
|
+
"nth_value": FunctionDefaults(
|
|
85
|
+
total_args=3,
|
|
86
|
+
defaults=[DefaultParameter("ignoreNulls", False)],
|
|
87
|
+
),
|
|
88
|
+
"overlay": FunctionDefaults(
|
|
89
|
+
total_args=4,
|
|
90
|
+
defaults=[DefaultParameter("len", -1)],
|
|
91
|
+
),
|
|
92
|
+
"percentile": FunctionDefaults(
|
|
93
|
+
total_args=3,
|
|
94
|
+
defaults=[DefaultParameter("frequency", 1)],
|
|
95
|
+
),
|
|
96
|
+
"percentile_approx": FunctionDefaults(
|
|
97
|
+
total_args=3,
|
|
98
|
+
defaults=[DefaultParameter("accuracy", 10000)],
|
|
99
|
+
),
|
|
100
|
+
"round": FunctionDefaults(
|
|
101
|
+
total_args=2,
|
|
102
|
+
defaults=[DefaultParameter("scale", 0)],
|
|
103
|
+
),
|
|
104
|
+
"sentences": FunctionDefaults(
|
|
105
|
+
total_args=3,
|
|
106
|
+
defaults=[
|
|
107
|
+
DefaultParameter("language", ""),
|
|
108
|
+
DefaultParameter("country", ""),
|
|
109
|
+
],
|
|
110
|
+
),
|
|
111
|
+
"sort_array": FunctionDefaults(
|
|
112
|
+
total_args=2,
|
|
113
|
+
defaults=[DefaultParameter("asc", True)],
|
|
114
|
+
),
|
|
115
|
+
"split": FunctionDefaults(
|
|
116
|
+
total_args=3,
|
|
117
|
+
defaults=[DefaultParameter("limit", -1)],
|
|
118
|
+
),
|
|
119
|
+
"str_to_map": FunctionDefaults(
|
|
120
|
+
total_args=3,
|
|
121
|
+
defaults=[
|
|
122
|
+
DefaultParameter(
|
|
123
|
+
"pairDelim", ","
|
|
124
|
+
), # Spark SQL default: comma for splitting pairs
|
|
125
|
+
DefaultParameter(
|
|
126
|
+
"keyValueDelim", ":"
|
|
127
|
+
), # Spark SQL default: colon for splitting key/value
|
|
128
|
+
],
|
|
129
|
+
),
|
|
130
|
+
"try_aes_decrypt": FunctionDefaults(
|
|
131
|
+
total_args=5,
|
|
132
|
+
defaults=[
|
|
133
|
+
DefaultParameter("mode", "GCM"), # Spark SQL default: GCM
|
|
134
|
+
DefaultParameter("padding", "NONE"), # Spark SQL default: NONE for GCM mode
|
|
135
|
+
DefaultParameter("aad", ""), # Spark SQL default: empty string
|
|
136
|
+
],
|
|
137
|
+
),
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _create_literal_expression(value: Any) -> expressions_pb2.Expression:
|
|
142
|
+
"""Create a literal expression for the given value."""
|
|
143
|
+
expr = expressions_pb2.Expression()
|
|
144
|
+
if isinstance(value, bool):
|
|
145
|
+
expr.literal.boolean = value
|
|
146
|
+
elif isinstance(value, int):
|
|
147
|
+
expr.literal.integer = value
|
|
148
|
+
elif isinstance(value, str):
|
|
149
|
+
expr.literal.string = value
|
|
150
|
+
elif isinstance(value, float):
|
|
151
|
+
expr.literal.double = value
|
|
152
|
+
elif value is None:
|
|
153
|
+
null_type = types_pb2.DataType()
|
|
154
|
+
null_type.null.SetInParent()
|
|
155
|
+
expr.literal.null.CopyFrom(null_type)
|
|
156
|
+
else:
|
|
157
|
+
raise ValueError(f"Unsupported literal type: {value}")
|
|
158
|
+
|
|
159
|
+
return expr
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def inject_function_defaults(
|
|
163
|
+
unresolved_function: expressions_pb2.Expression.UnresolvedFunction,
|
|
164
|
+
) -> bool:
|
|
165
|
+
"""
|
|
166
|
+
Inject missing default parameters into an UnresolvedFunction protobuf.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
unresolved_function: The protobuf UnresolvedFunction to modify
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
bool: True if any defaults were injected, False otherwise
|
|
173
|
+
"""
|
|
174
|
+
function_name = unresolved_function.function_name.lower()
|
|
175
|
+
|
|
176
|
+
if function_name not in FUNCTION_DEFAULTS:
|
|
177
|
+
return False
|
|
178
|
+
|
|
179
|
+
func_config = FUNCTION_DEFAULTS[function_name]
|
|
180
|
+
current_arg_count = len(unresolved_function.arguments)
|
|
181
|
+
total_args = func_config.total_args
|
|
182
|
+
defaults = func_config.defaults
|
|
183
|
+
|
|
184
|
+
if not defaults or current_arg_count >= total_args:
|
|
185
|
+
return False
|
|
186
|
+
|
|
187
|
+
# Calculate how many defaults to append
|
|
188
|
+
missing_arg_count = total_args - current_arg_count
|
|
189
|
+
|
|
190
|
+
# Check if any required params are missing.
|
|
191
|
+
if missing_arg_count > len(defaults):
|
|
192
|
+
raise ValueError(
|
|
193
|
+
f"Function '{function_name}' is missing required arguments. "
|
|
194
|
+
f"Expected {total_args} args, got {current_arg_count}, "
|
|
195
|
+
f"but only {len(defaults)} defaults are defined."
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
defaults_to_append = defaults[-missing_arg_count:]
|
|
199
|
+
injected = False
|
|
200
|
+
|
|
201
|
+
# Simply append the needed default values
|
|
202
|
+
for default_param in defaults_to_append:
|
|
203
|
+
default_expr = _create_literal_expression(default_param.value)
|
|
204
|
+
unresolved_function.arguments.append(default_expr)
|
|
205
|
+
injected = True
|
|
206
|
+
|
|
207
|
+
return injected
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Hybrid column mapping for HAVING clause resolution.
|
|
7
|
+
|
|
8
|
+
This module provides a special column mapping that can resolve expressions
|
|
9
|
+
in the context of both the input DataFrame (for base columns) and the
|
|
10
|
+
aggregated DataFrame (for aggregate expressions and aliases).
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import Dict, List
|
|
14
|
+
|
|
15
|
+
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
16
|
+
|
|
17
|
+
from snowflake import snowpark
|
|
18
|
+
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
19
|
+
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
20
|
+
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class HybridColumnMap:
|
|
24
|
+
"""
|
|
25
|
+
A column mapping that can resolve expressions in both input and aggregated contexts.
|
|
26
|
+
|
|
27
|
+
This is specifically designed for HAVING clause resolution where expressions may reference:
|
|
28
|
+
1. Base columns from the input DataFrame (to build new aggregates)
|
|
29
|
+
2. Existing aggregate expressions and their aliases
|
|
30
|
+
3. Grouping columns
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
input_column_map: ColumnNameMap,
|
|
36
|
+
input_typer: ExpressionTyper,
|
|
37
|
+
aggregated_column_map: ColumnNameMap,
|
|
38
|
+
aggregated_typer: ExpressionTyper,
|
|
39
|
+
aggregate_expressions: List[expressions_proto.Expression],
|
|
40
|
+
grouping_expressions: List[expressions_proto.Expression],
|
|
41
|
+
aggregate_aliases: Dict[str, expressions_proto.Expression],
|
|
42
|
+
) -> None:
|
|
43
|
+
self.input_column_map = input_column_map
|
|
44
|
+
self.input_typer = input_typer
|
|
45
|
+
self.aggregated_column_map = aggregated_column_map
|
|
46
|
+
self.aggregated_typer = aggregated_typer
|
|
47
|
+
self.aggregate_expressions = aggregate_expressions
|
|
48
|
+
self.grouping_expressions = grouping_expressions
|
|
49
|
+
self.aggregate_aliases = aggregate_aliases
|
|
50
|
+
|
|
51
|
+
def is_aggregate_function(self, exp: expressions_proto.Expression) -> bool:
|
|
52
|
+
"""Check if an expression is an aggregate function."""
|
|
53
|
+
if exp.WhichOneof("expr_type") == "unresolved_function":
|
|
54
|
+
func_name = exp.unresolved_function.function_name.lower()
|
|
55
|
+
# Common aggregate functions - expand this list as needed
|
|
56
|
+
aggregate_functions = {
|
|
57
|
+
"avg",
|
|
58
|
+
"average",
|
|
59
|
+
"sum",
|
|
60
|
+
"count",
|
|
61
|
+
"min",
|
|
62
|
+
"max",
|
|
63
|
+
"stddev",
|
|
64
|
+
"stddev_pop",
|
|
65
|
+
"stddev_samp",
|
|
66
|
+
"variance",
|
|
67
|
+
"var_pop",
|
|
68
|
+
"var_samp",
|
|
69
|
+
"collect_list",
|
|
70
|
+
"collect_set",
|
|
71
|
+
"first",
|
|
72
|
+
"last",
|
|
73
|
+
"any_value",
|
|
74
|
+
"bool_and",
|
|
75
|
+
"bool_or",
|
|
76
|
+
"corr",
|
|
77
|
+
"covar_pop",
|
|
78
|
+
"covar_samp",
|
|
79
|
+
"kurtosis",
|
|
80
|
+
"skewness",
|
|
81
|
+
"percentile_cont",
|
|
82
|
+
"percentile_disc",
|
|
83
|
+
"approx_count_distinct",
|
|
84
|
+
}
|
|
85
|
+
return func_name in aggregate_functions
|
|
86
|
+
return False
|
|
87
|
+
|
|
88
|
+
def is_grouping_column(self, column_name: str) -> bool:
|
|
89
|
+
"""Check if a column name refers to a grouping column."""
|
|
90
|
+
for group_exp in self.grouping_expressions:
|
|
91
|
+
if (
|
|
92
|
+
group_exp.WhichOneof("expr_type") == "unresolved_attribute"
|
|
93
|
+
and group_exp.unresolved_attribute.unparsed_identifier == column_name
|
|
94
|
+
):
|
|
95
|
+
return True
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
def resolve_expression(
|
|
99
|
+
self, exp: expressions_proto.Expression
|
|
100
|
+
) -> tuple[list[str], TypedColumn]:
|
|
101
|
+
"""
|
|
102
|
+
Resolve an expression in the hybrid context.
|
|
103
|
+
|
|
104
|
+
Strategy:
|
|
105
|
+
1. If it's an aggregate function -> create new aggregate using input context
|
|
106
|
+
2. If it's an alias to existing aggregate -> use aggregated context
|
|
107
|
+
3. If it's a grouping column -> try aggregated context first, fall back to input context
|
|
108
|
+
(handles exclude_grouping_columns=True case)
|
|
109
|
+
4. Otherwise -> try input context first, then aggregated context
|
|
110
|
+
"""
|
|
111
|
+
from snowflake.snowpark_connect.expression.map_expression import map_expression
|
|
112
|
+
|
|
113
|
+
expr_type = exp.WhichOneof("expr_type")
|
|
114
|
+
|
|
115
|
+
# Handle aggregate functions - need to evaluate against input DataFrame
|
|
116
|
+
if self.is_aggregate_function(exp):
|
|
117
|
+
return map_expression(exp, self.input_column_map, self.input_typer)
|
|
118
|
+
|
|
119
|
+
# Handle column references
|
|
120
|
+
if expr_type == "unresolved_attribute":
|
|
121
|
+
column_name = exp.unresolved_attribute.unparsed_identifier
|
|
122
|
+
|
|
123
|
+
# Check if it's an alias to an existing aggregate expression
|
|
124
|
+
if column_name in self.aggregate_aliases:
|
|
125
|
+
# Use the aggregated context to get the alias
|
|
126
|
+
return map_expression(
|
|
127
|
+
exp, self.aggregated_column_map, self.aggregated_typer
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Check if it's a grouping column
|
|
131
|
+
if self.is_grouping_column(column_name):
|
|
132
|
+
# Try aggregated context first (for cases where grouping columns are included)
|
|
133
|
+
try:
|
|
134
|
+
return map_expression(
|
|
135
|
+
exp, self.aggregated_column_map, self.aggregated_typer
|
|
136
|
+
)
|
|
137
|
+
except Exception:
|
|
138
|
+
# Fall back to input context if grouping columns were excluded
|
|
139
|
+
# This handles the exclude_grouping_columns=True case
|
|
140
|
+
return map_expression(exp, self.input_column_map, self.input_typer)
|
|
141
|
+
|
|
142
|
+
# Try input context first (for base columns used in new aggregates)
|
|
143
|
+
try:
|
|
144
|
+
return map_expression(exp, self.input_column_map, self.input_typer)
|
|
145
|
+
except Exception:
|
|
146
|
+
# Fall back to aggregated context
|
|
147
|
+
return map_expression(
|
|
148
|
+
exp, self.aggregated_column_map, self.aggregated_typer
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# For other expression types, try aggregated context first (likely references to computed values)
|
|
152
|
+
try:
|
|
153
|
+
return map_expression(
|
|
154
|
+
exp, self.aggregated_column_map, self.aggregated_typer
|
|
155
|
+
)
|
|
156
|
+
except Exception:
|
|
157
|
+
# Fall back to input context
|
|
158
|
+
return map_expression(exp, self.input_column_map, self.input_typer)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def create_hybrid_column_map_for_having(
|
|
162
|
+
input_df: snowpark.DataFrame,
|
|
163
|
+
input_column_map: ColumnNameMap,
|
|
164
|
+
aggregated_df: snowpark.DataFrame,
|
|
165
|
+
aggregated_column_map: ColumnNameMap,
|
|
166
|
+
aggregate_expressions: List[expressions_proto.Expression],
|
|
167
|
+
grouping_expressions: List[expressions_proto.Expression],
|
|
168
|
+
spark_columns: List[str],
|
|
169
|
+
raw_aggregations: List[tuple[str, TypedColumn]],
|
|
170
|
+
) -> HybridColumnMap:
|
|
171
|
+
"""
|
|
172
|
+
Create a HybridColumnMap instance for HAVING clause resolution.
|
|
173
|
+
"""
|
|
174
|
+
# Create typers for both contexts
|
|
175
|
+
input_typer = ExpressionTyper(input_df)
|
|
176
|
+
aggregated_typer = ExpressionTyper(aggregated_df)
|
|
177
|
+
|
|
178
|
+
# Build alias mapping from spark column names to aggregate expressions
|
|
179
|
+
aggregate_aliases = {}
|
|
180
|
+
for i, (spark_name, _) in enumerate(raw_aggregations):
|
|
181
|
+
if i < len(aggregate_expressions):
|
|
182
|
+
aggregate_aliases[spark_name] = aggregate_expressions[i]
|
|
183
|
+
|
|
184
|
+
return HybridColumnMap(
|
|
185
|
+
input_column_map=input_column_map,
|
|
186
|
+
input_typer=input_typer,
|
|
187
|
+
aggregated_column_map=aggregated_column_map,
|
|
188
|
+
aggregated_typer=aggregated_typer,
|
|
189
|
+
aggregate_expressions=aggregate_expressions,
|
|
190
|
+
grouping_expressions=grouping_expressions,
|
|
191
|
+
aggregate_aliases=aggregate_aliases,
|
|
192
|
+
)
|
|
@@ -10,6 +10,7 @@ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
|
10
10
|
from tzlocal import get_localzone
|
|
11
11
|
|
|
12
12
|
from snowflake.snowpark_connect.config import global_config
|
|
13
|
+
from snowflake.snowpark_connect.utils.context import get_is_evaluating_sql
|
|
13
14
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
14
15
|
SnowparkConnectNotImplementedError,
|
|
15
16
|
)
|
|
@@ -47,17 +48,7 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
|
|
|
47
48
|
).date()
|
|
48
49
|
return date, f"DATE '{date}'"
|
|
49
50
|
case "timestamp" | "timestamp_ntz" as t:
|
|
50
|
-
|
|
51
|
-
# No need to apply timezone for lit datetime, because we set the TIMEZONE parameter in snowpark session,
|
|
52
|
-
# the snowflake backend would convert the lit datetime correctly. However, for returned column name, the
|
|
53
|
-
# timezone needs to be added. Pyspark has a weird behavior that datetime.datetime always gets converted
|
|
54
|
-
# to local timezone before printing according to spark_sql_session_timeZone setting. Haven't found
|
|
55
|
-
# official doc about it, but this behavior is based on my testings.
|
|
56
|
-
tz = (
|
|
57
|
-
ZoneInfo(global_config.spark_sql_session_timeZone)
|
|
58
|
-
if hasattr(global_config, "spark_sql_session_timeZone")
|
|
59
|
-
else get_localzone()
|
|
60
|
-
)
|
|
51
|
+
local_tz = get_localzone()
|
|
61
52
|
if t == "timestamp":
|
|
62
53
|
microseconds = literal.timestamp
|
|
63
54
|
else:
|
|
@@ -66,11 +57,17 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
|
|
|
66
57
|
microseconds // 1_000_000
|
|
67
58
|
) + datetime.timedelta(microseconds=microseconds % 1_000_000)
|
|
68
59
|
tz_dt = datetime.datetime.fromtimestamp(
|
|
69
|
-
microseconds // 1_000_000, tz=
|
|
60
|
+
microseconds // 1_000_000, tz=local_tz
|
|
70
61
|
) + datetime.timedelta(microseconds=microseconds % 1_000_000)
|
|
71
62
|
if t == "timestamp_ntz":
|
|
72
63
|
lit_dt = lit_dt.astimezone(datetime.timezone.utc)
|
|
73
64
|
tz_dt = tz_dt.astimezone(datetime.timezone.utc)
|
|
65
|
+
elif not get_is_evaluating_sql():
|
|
66
|
+
config_tz = global_config.spark_sql_session_timeZone
|
|
67
|
+
config_tz = ZoneInfo(config_tz) if config_tz else local_tz
|
|
68
|
+
tz_dt = tz_dt.astimezone(config_tz)
|
|
69
|
+
lit_dt = lit_dt.astimezone(local_tz)
|
|
70
|
+
|
|
74
71
|
return lit_dt, f"{t.upper()} '{tz_dt.strftime('%Y-%m-%d %H:%M:%S')}'"
|
|
75
72
|
case "day_time_interval":
|
|
76
73
|
# TODO(SNOW-1920942): Snowflake SQL is missing an "interval" type.
|
|
@@ -84,6 +81,11 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
|
|
|
84
81
|
case "decimal":
|
|
85
82
|
# literal.decimal.precision & scale are ignored, as decimal.Decimal doesn't accept them
|
|
86
83
|
return decimal.Decimal(literal.decimal.value), literal.decimal.value
|
|
84
|
+
case "array":
|
|
85
|
+
array_values, element_names = zip(
|
|
86
|
+
*(get_literal_field_and_name(e) for e in literal.array.elements)
|
|
87
|
+
)
|
|
88
|
+
return array_values, f"ARRAY({', '.join(element_names)})"
|
|
87
89
|
case "null" | None:
|
|
88
90
|
return None, "NULL"
|
|
89
91
|
case other:
|
|
@@ -127,10 +127,11 @@ def map_cast(
|
|
|
127
127
|
from_type = StringType()
|
|
128
128
|
if isinstance(to_type, StringType):
|
|
129
129
|
to_type = StringType()
|
|
130
|
+
|
|
131
|
+
# todo - verify if that's correct SNOW-2248680
|
|
130
132
|
if isinstance(from_type, TimestampType):
|
|
131
133
|
from_type = TimestampType()
|
|
132
|
-
|
|
133
|
-
to_type = TimestampType()
|
|
134
|
+
|
|
134
135
|
match (from_type, to_type):
|
|
135
136
|
case (_, _) if (from_type == to_type):
|
|
136
137
|
result_exp = col
|
|
@@ -185,6 +186,17 @@ def map_cast(
|
|
|
185
186
|
case (DateType(), TimestampType()):
|
|
186
187
|
result_exp = snowpark_fn.to_timestamp(col)
|
|
187
188
|
result_exp = result_exp.cast(TimestampType(TimestampTimeZone.NTZ))
|
|
189
|
+
case (TimestampType() as f, TimestampType() as t) if f.tzinfo == t.tzinfo:
|
|
190
|
+
result_exp = col
|
|
191
|
+
case (
|
|
192
|
+
TimestampType(),
|
|
193
|
+
TimestampType() as t,
|
|
194
|
+
) if t.tzinfo == TimestampTimeZone.NTZ:
|
|
195
|
+
zone = global_config.spark_sql_session_timeZone
|
|
196
|
+
result_exp = snowpark_fn.convert_timezone(snowpark_fn.lit(zone), col).cast(
|
|
197
|
+
TimestampType(TimestampTimeZone.NTZ)
|
|
198
|
+
)
|
|
199
|
+
# todo: verify if more support for LTZ and TZ is needed - SNOW-2248680
|
|
188
200
|
case (TimestampType(), TimestampType()):
|
|
189
201
|
result_exp = col
|
|
190
202
|
case (_, TimestampType()) if isinstance(from_type, _NumericType):
|
|
@@ -259,8 +271,12 @@ def map_cast(
|
|
|
259
271
|
case (_, _) if isinstance(from_type, _FractionalType) and isinstance(
|
|
260
272
|
to_type, _IntegralType
|
|
261
273
|
):
|
|
262
|
-
result_exp =
|
|
263
|
-
snowpark_fn.
|
|
274
|
+
result_exp = (
|
|
275
|
+
snowpark_fn.when(
|
|
276
|
+
col == snowpark_fn.lit(float("nan")), snowpark_fn.lit(0)
|
|
277
|
+
)
|
|
278
|
+
.when(col < 0, snowpark_fn.ceil(col))
|
|
279
|
+
.otherwise(snowpark_fn.floor(col))
|
|
264
280
|
)
|
|
265
281
|
result_exp = result_exp.cast(to_type)
|
|
266
282
|
case (StringType(), _) if (isinstance(to_type, _IntegralType)):
|
|
@@ -11,6 +11,7 @@ import snowflake.snowpark.functions as snowpark_fn
|
|
|
11
11
|
from snowflake import snowpark
|
|
12
12
|
from snowflake.snowpark import Session
|
|
13
13
|
from snowflake.snowpark._internal.analyzer.expression import UnresolvedAttribute
|
|
14
|
+
from snowflake.snowpark.types import TimestampTimeZone, TimestampType
|
|
14
15
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
15
16
|
from snowflake.snowpark_connect.expression import (
|
|
16
17
|
map_extension,
|
|
@@ -26,7 +27,10 @@ from snowflake.snowpark_connect.expression.literal import get_literal_field_and_
|
|
|
26
27
|
from snowflake.snowpark_connect.expression.map_cast import map_cast
|
|
27
28
|
from snowflake.snowpark_connect.expression.map_sql_expression import map_sql_expr
|
|
28
29
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
29
|
-
from snowflake.snowpark_connect.type_mapping import
|
|
30
|
+
from snowflake.snowpark_connect.type_mapping import (
|
|
31
|
+
map_simple_types,
|
|
32
|
+
proto_to_snowpark_type,
|
|
33
|
+
)
|
|
30
34
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
31
35
|
from snowflake.snowpark_connect.utils.context import (
|
|
32
36
|
gen_sql_plan_id,
|
|
@@ -165,6 +169,12 @@ def map_expression(
|
|
|
165
169
|
lambda: [map_simple_types(lit_type_str)],
|
|
166
170
|
)
|
|
167
171
|
|
|
172
|
+
if lit_type_str == "array":
|
|
173
|
+
result_exp = snowpark_fn.lit(lit_value)
|
|
174
|
+
element_types = proto_to_snowpark_type(exp.literal.array.element_type)
|
|
175
|
+
array_type = snowpark.types.ArrayType(element_types)
|
|
176
|
+
return [lit_name], TypedColumn(result_exp, lambda: [array_type])
|
|
177
|
+
|
|
168
178
|
# Decimal needs further processing to get the precision and scale properly.
|
|
169
179
|
if lit_type_str == "decimal":
|
|
170
180
|
# Precision and scale are optional in the proto.
|
|
@@ -190,9 +200,15 @@ def map_expression(
|
|
|
190
200
|
return [lit_name], TypedColumn(
|
|
191
201
|
snowpark_fn.lit(lit_value, return_type), lambda: [return_type]
|
|
192
202
|
)
|
|
203
|
+
result_exp = snowpark_fn.lit(lit_value)
|
|
204
|
+
|
|
205
|
+
if lit_type_str == "timestamp_ntz" and isinstance(
|
|
206
|
+
lit_value, datetime.datetime
|
|
207
|
+
):
|
|
208
|
+
result_exp = result_exp.cast(TimestampType(TimestampTimeZone.NTZ))
|
|
193
209
|
|
|
194
210
|
return [lit_name], TypedColumn(
|
|
195
|
-
|
|
211
|
+
result_exp, lambda: [map_simple_types(lit_type_str)]
|
|
196
212
|
)
|
|
197
213
|
case "sort_order":
|
|
198
214
|
child_name, child_column = map_single_column_expression(
|
|
@@ -10,7 +10,10 @@ from snowflake.snowpark.types import BooleanType
|
|
|
10
10
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
11
11
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
12
12
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
13
|
-
from snowflake.snowpark_connect.utils.context import
|
|
13
|
+
from snowflake.snowpark_connect.utils.context import (
|
|
14
|
+
push_evaluating_sql_scope,
|
|
15
|
+
push_outer_dataframe,
|
|
16
|
+
)
|
|
14
17
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
15
18
|
SnowparkConnectNotImplementedError,
|
|
16
19
|
)
|
|
@@ -52,12 +55,19 @@ def map_extension(
|
|
|
52
55
|
return [name], typed_col
|
|
53
56
|
|
|
54
57
|
case "subquery_expression":
|
|
58
|
+
from snowflake.snowpark_connect.dataframe_container import (
|
|
59
|
+
DataFrameContainer,
|
|
60
|
+
)
|
|
55
61
|
from snowflake.snowpark_connect.expression.map_expression import (
|
|
56
62
|
map_expression,
|
|
57
63
|
)
|
|
58
64
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
59
65
|
|
|
60
|
-
|
|
66
|
+
current_outer_df = DataFrameContainer(
|
|
67
|
+
dataframe=typer.df, column_map=column_mapping
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
with push_evaluating_sql_scope(), push_outer_dataframe(current_outer_df):
|
|
61
71
|
df_container = map_relation(extension.subquery_expression.input)
|
|
62
72
|
df = df_container.dataframe
|
|
63
73
|
|