snowpark-connect 0.22.1__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +0 -11
- snowflake/snowpark_connect/error/error_utils.py +7 -0
- snowflake/snowpark_connect/error/exceptions.py +4 -0
- snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
- snowflake/snowpark_connect/expression/literal.py +9 -12
- snowflake/snowpark_connect/expression/map_cast.py +20 -4
- snowflake/snowpark_connect/expression/map_expression.py +8 -1
- snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
- snowflake/snowpark_connect/expression/map_unresolved_function.py +66 -6
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
- snowflake/snowpark_connect/relation/map_aggregate.py +57 -5
- snowflake/snowpark_connect/relation/map_column_ops.py +38 -6
- snowflake/snowpark_connect/relation/map_extension.py +58 -24
- snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
- snowflake/snowpark_connect/relation/map_row_ops.py +30 -1
- snowflake/snowpark_connect/relation/map_sql.py +22 -5
- snowflake/snowpark_connect/relation/read/map_read.py +2 -1
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
- snowflake/snowpark_connect/relation/read/reader_config.py +9 -0
- snowflake/snowpark_connect/relation/read/utils.py +7 -6
- snowflake/snowpark_connect/relation/utils.py +170 -1
- snowflake/snowpark_connect/relation/write/map_write.py +243 -68
- snowflake/snowpark_connect/server.py +25 -5
- snowflake/snowpark_connect/type_mapping.py +2 -2
- snowflake/snowpark_connect/utils/env_utils.py +55 -0
- snowflake/snowpark_connect/utils/session.py +21 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/spark_decoder.py +1 -1
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/METADATA +2 -2
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/RECORD +40 -40
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
- {snowpark_connect-0.22.1.data → snowpark_connect-0.23.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.22.1.data → snowpark_connect-0.23.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.22.1.data → snowpark_connect-0.23.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.22.1.dist-info → snowpark_connect-0.23.0.dist-info}/top_level.txt +0 -0
|
@@ -4,10 +4,8 @@
|
|
|
4
4
|
|
|
5
5
|
# Proto source for reference:
|
|
6
6
|
# https://github.com/apache/spark/blob/branch-3.5/connector/connect/common/src/main/protobuf/spark/connect/base.proto#L420
|
|
7
|
-
import os
|
|
8
7
|
import re
|
|
9
8
|
import sys
|
|
10
|
-
import time
|
|
11
9
|
from collections import defaultdict
|
|
12
10
|
from copy import copy, deepcopy
|
|
13
11
|
from typing import Any
|
|
@@ -533,17 +531,8 @@ def set_snowflake_parameters(
|
|
|
533
531
|
snowpark_session.sql(
|
|
534
532
|
f"ALTER SESSION SET TIMEZONE = '{value}'"
|
|
535
533
|
).collect()
|
|
536
|
-
set_jvm_timezone(value)
|
|
537
|
-
if hasattr(time, "tzset"):
|
|
538
|
-
os.environ["TZ"] = value
|
|
539
|
-
time.tzset()
|
|
540
534
|
else:
|
|
541
535
|
snowpark_session.sql("ALTER SESSION UNSET TIMEZONE").collect()
|
|
542
|
-
reset_jvm_timezone_to_system_default()
|
|
543
|
-
if hasattr(time, "tzset") and False:
|
|
544
|
-
if "TZ" in os.environ:
|
|
545
|
-
del os.environ["TZ"]
|
|
546
|
-
time.tzset()
|
|
547
536
|
case "spark.sql.globalTempDatabase":
|
|
548
537
|
if not value:
|
|
549
538
|
value = global_config.default_static_global_config.get(key)
|
|
@@ -75,6 +75,10 @@ terminate_multi_args_exception_pattern = (
|
|
|
75
75
|
snowpark_connect_exception_pattern = re.compile(
|
|
76
76
|
r"\[snowpark-connect-exception(?::(\w+))?\]\s*(.+?)'\s*is not recognized"
|
|
77
77
|
)
|
|
78
|
+
invalid_bit_pattern = re.compile(
|
|
79
|
+
r"Invalid bit position: \d+ exceeds the bit (?:upper|lower) limit",
|
|
80
|
+
re.IGNORECASE,
|
|
81
|
+
)
|
|
78
82
|
|
|
79
83
|
|
|
80
84
|
def contains_udtf_select(sql_string):
|
|
@@ -107,6 +111,9 @@ def _get_converted_known_sql_or_custom_exception(
|
|
|
107
111
|
return SparkRuntimeException(
|
|
108
112
|
message="Unexpected value for start in function slice: SQL array indices start at 1."
|
|
109
113
|
)
|
|
114
|
+
invalid_bit = invalid_bit_pattern.search(msg)
|
|
115
|
+
if invalid_bit:
|
|
116
|
+
return IllegalArgumentException(message=invalid_bit.group(0))
|
|
110
117
|
match = snowpark_connect_exception_pattern.search(
|
|
111
118
|
ex.message if hasattr(ex, "message") else str(ex)
|
|
112
119
|
)
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Hybrid column mapping for HAVING clause resolution.
|
|
7
|
+
|
|
8
|
+
This module provides a special column mapping that can resolve expressions
|
|
9
|
+
in the context of both the input DataFrame (for base columns) and the
|
|
10
|
+
aggregated DataFrame (for aggregate expressions and aliases).
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import Dict, List
|
|
14
|
+
|
|
15
|
+
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
16
|
+
|
|
17
|
+
from snowflake import snowpark
|
|
18
|
+
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
19
|
+
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
20
|
+
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class HybridColumnMap:
|
|
24
|
+
"""
|
|
25
|
+
A column mapping that can resolve expressions in both input and aggregated contexts.
|
|
26
|
+
|
|
27
|
+
This is specifically designed for HAVING clause resolution where expressions may reference:
|
|
28
|
+
1. Base columns from the input DataFrame (to build new aggregates)
|
|
29
|
+
2. Existing aggregate expressions and their aliases
|
|
30
|
+
3. Grouping columns
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
input_column_map: ColumnNameMap,
|
|
36
|
+
input_typer: ExpressionTyper,
|
|
37
|
+
aggregated_column_map: ColumnNameMap,
|
|
38
|
+
aggregated_typer: ExpressionTyper,
|
|
39
|
+
aggregate_expressions: List[expressions_proto.Expression],
|
|
40
|
+
grouping_expressions: List[expressions_proto.Expression],
|
|
41
|
+
aggregate_aliases: Dict[str, expressions_proto.Expression],
|
|
42
|
+
) -> None:
|
|
43
|
+
self.input_column_map = input_column_map
|
|
44
|
+
self.input_typer = input_typer
|
|
45
|
+
self.aggregated_column_map = aggregated_column_map
|
|
46
|
+
self.aggregated_typer = aggregated_typer
|
|
47
|
+
self.aggregate_expressions = aggregate_expressions
|
|
48
|
+
self.grouping_expressions = grouping_expressions
|
|
49
|
+
self.aggregate_aliases = aggregate_aliases
|
|
50
|
+
|
|
51
|
+
def is_aggregate_function(self, exp: expressions_proto.Expression) -> bool:
|
|
52
|
+
"""Check if an expression is an aggregate function."""
|
|
53
|
+
if exp.WhichOneof("expr_type") == "unresolved_function":
|
|
54
|
+
func_name = exp.unresolved_function.function_name.lower()
|
|
55
|
+
# Common aggregate functions - expand this list as needed
|
|
56
|
+
aggregate_functions = {
|
|
57
|
+
"avg",
|
|
58
|
+
"average",
|
|
59
|
+
"sum",
|
|
60
|
+
"count",
|
|
61
|
+
"min",
|
|
62
|
+
"max",
|
|
63
|
+
"stddev",
|
|
64
|
+
"stddev_pop",
|
|
65
|
+
"stddev_samp",
|
|
66
|
+
"variance",
|
|
67
|
+
"var_pop",
|
|
68
|
+
"var_samp",
|
|
69
|
+
"collect_list",
|
|
70
|
+
"collect_set",
|
|
71
|
+
"first",
|
|
72
|
+
"last",
|
|
73
|
+
"any_value",
|
|
74
|
+
"bool_and",
|
|
75
|
+
"bool_or",
|
|
76
|
+
"corr",
|
|
77
|
+
"covar_pop",
|
|
78
|
+
"covar_samp",
|
|
79
|
+
"kurtosis",
|
|
80
|
+
"skewness",
|
|
81
|
+
"percentile_cont",
|
|
82
|
+
"percentile_disc",
|
|
83
|
+
"approx_count_distinct",
|
|
84
|
+
}
|
|
85
|
+
return func_name in aggregate_functions
|
|
86
|
+
return False
|
|
87
|
+
|
|
88
|
+
def is_grouping_column(self, column_name: str) -> bool:
|
|
89
|
+
"""Check if a column name refers to a grouping column."""
|
|
90
|
+
for group_exp in self.grouping_expressions:
|
|
91
|
+
if (
|
|
92
|
+
group_exp.WhichOneof("expr_type") == "unresolved_attribute"
|
|
93
|
+
and group_exp.unresolved_attribute.unparsed_identifier == column_name
|
|
94
|
+
):
|
|
95
|
+
return True
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
def resolve_expression(
|
|
99
|
+
self, exp: expressions_proto.Expression
|
|
100
|
+
) -> tuple[list[str], TypedColumn]:
|
|
101
|
+
"""
|
|
102
|
+
Resolve an expression in the hybrid context.
|
|
103
|
+
|
|
104
|
+
Strategy:
|
|
105
|
+
1. If it's an aggregate function -> create new aggregate using input context
|
|
106
|
+
2. If it's an alias to existing aggregate -> use aggregated context
|
|
107
|
+
3. If it's a grouping column -> try aggregated context first, fall back to input context
|
|
108
|
+
(handles exclude_grouping_columns=True case)
|
|
109
|
+
4. Otherwise -> try input context first, then aggregated context
|
|
110
|
+
"""
|
|
111
|
+
from snowflake.snowpark_connect.expression.map_expression import map_expression
|
|
112
|
+
|
|
113
|
+
expr_type = exp.WhichOneof("expr_type")
|
|
114
|
+
|
|
115
|
+
# Handle aggregate functions - need to evaluate against input DataFrame
|
|
116
|
+
if self.is_aggregate_function(exp):
|
|
117
|
+
return map_expression(exp, self.input_column_map, self.input_typer)
|
|
118
|
+
|
|
119
|
+
# Handle column references
|
|
120
|
+
if expr_type == "unresolved_attribute":
|
|
121
|
+
column_name = exp.unresolved_attribute.unparsed_identifier
|
|
122
|
+
|
|
123
|
+
# Check if it's an alias to an existing aggregate expression
|
|
124
|
+
if column_name in self.aggregate_aliases:
|
|
125
|
+
# Use the aggregated context to get the alias
|
|
126
|
+
return map_expression(
|
|
127
|
+
exp, self.aggregated_column_map, self.aggregated_typer
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Check if it's a grouping column
|
|
131
|
+
if self.is_grouping_column(column_name):
|
|
132
|
+
# Try aggregated context first (for cases where grouping columns are included)
|
|
133
|
+
try:
|
|
134
|
+
return map_expression(
|
|
135
|
+
exp, self.aggregated_column_map, self.aggregated_typer
|
|
136
|
+
)
|
|
137
|
+
except Exception:
|
|
138
|
+
# Fall back to input context if grouping columns were excluded
|
|
139
|
+
# This handles the exclude_grouping_columns=True case
|
|
140
|
+
return map_expression(exp, self.input_column_map, self.input_typer)
|
|
141
|
+
|
|
142
|
+
# Try input context first (for base columns used in new aggregates)
|
|
143
|
+
try:
|
|
144
|
+
return map_expression(exp, self.input_column_map, self.input_typer)
|
|
145
|
+
except Exception:
|
|
146
|
+
# Fall back to aggregated context
|
|
147
|
+
return map_expression(
|
|
148
|
+
exp, self.aggregated_column_map, self.aggregated_typer
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# For other expression types, try aggregated context first (likely references to computed values)
|
|
152
|
+
try:
|
|
153
|
+
return map_expression(
|
|
154
|
+
exp, self.aggregated_column_map, self.aggregated_typer
|
|
155
|
+
)
|
|
156
|
+
except Exception:
|
|
157
|
+
# Fall back to input context
|
|
158
|
+
return map_expression(exp, self.input_column_map, self.input_typer)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def create_hybrid_column_map_for_having(
|
|
162
|
+
input_df: snowpark.DataFrame,
|
|
163
|
+
input_column_map: ColumnNameMap,
|
|
164
|
+
aggregated_df: snowpark.DataFrame,
|
|
165
|
+
aggregated_column_map: ColumnNameMap,
|
|
166
|
+
aggregate_expressions: List[expressions_proto.Expression],
|
|
167
|
+
grouping_expressions: List[expressions_proto.Expression],
|
|
168
|
+
spark_columns: List[str],
|
|
169
|
+
raw_aggregations: List[tuple[str, TypedColumn]],
|
|
170
|
+
) -> HybridColumnMap:
|
|
171
|
+
"""
|
|
172
|
+
Create a HybridColumnMap instance for HAVING clause resolution.
|
|
173
|
+
"""
|
|
174
|
+
# Create typers for both contexts
|
|
175
|
+
input_typer = ExpressionTyper(input_df)
|
|
176
|
+
aggregated_typer = ExpressionTyper(aggregated_df)
|
|
177
|
+
|
|
178
|
+
# Build alias mapping from spark column names to aggregate expressions
|
|
179
|
+
aggregate_aliases = {}
|
|
180
|
+
for i, (spark_name, _) in enumerate(raw_aggregations):
|
|
181
|
+
if i < len(aggregate_expressions):
|
|
182
|
+
aggregate_aliases[spark_name] = aggregate_expressions[i]
|
|
183
|
+
|
|
184
|
+
return HybridColumnMap(
|
|
185
|
+
input_column_map=input_column_map,
|
|
186
|
+
input_typer=input_typer,
|
|
187
|
+
aggregated_column_map=aggregated_column_map,
|
|
188
|
+
aggregated_typer=aggregated_typer,
|
|
189
|
+
aggregate_expressions=aggregate_expressions,
|
|
190
|
+
grouping_expressions=grouping_expressions,
|
|
191
|
+
aggregate_aliases=aggregate_aliases,
|
|
192
|
+
)
|
|
@@ -10,6 +10,7 @@ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
|
10
10
|
from tzlocal import get_localzone
|
|
11
11
|
|
|
12
12
|
from snowflake.snowpark_connect.config import global_config
|
|
13
|
+
from snowflake.snowpark_connect.utils.context import get_is_evaluating_sql
|
|
13
14
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
14
15
|
SnowparkConnectNotImplementedError,
|
|
15
16
|
)
|
|
@@ -47,17 +48,7 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
|
|
|
47
48
|
).date()
|
|
48
49
|
return date, f"DATE '{date}'"
|
|
49
50
|
case "timestamp" | "timestamp_ntz" as t:
|
|
50
|
-
|
|
51
|
-
# No need to apply timezone for lit datetime, because we set the TIMEZONE parameter in snowpark session,
|
|
52
|
-
# the snowflake backend would convert the lit datetime correctly. However, for returned column name, the
|
|
53
|
-
# timezone needs to be added. Pyspark has a weird behavior that datetime.datetime always gets converted
|
|
54
|
-
# to local timezone before printing according to spark_sql_session_timeZone setting. Haven't found
|
|
55
|
-
# official doc about it, but this behavior is based on my testings.
|
|
56
|
-
tz = (
|
|
57
|
-
ZoneInfo(global_config.spark_sql_session_timeZone)
|
|
58
|
-
if hasattr(global_config, "spark_sql_session_timeZone")
|
|
59
|
-
else get_localzone()
|
|
60
|
-
)
|
|
51
|
+
local_tz = get_localzone()
|
|
61
52
|
if t == "timestamp":
|
|
62
53
|
microseconds = literal.timestamp
|
|
63
54
|
else:
|
|
@@ -66,11 +57,17 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
|
|
|
66
57
|
microseconds // 1_000_000
|
|
67
58
|
) + datetime.timedelta(microseconds=microseconds % 1_000_000)
|
|
68
59
|
tz_dt = datetime.datetime.fromtimestamp(
|
|
69
|
-
microseconds // 1_000_000, tz=
|
|
60
|
+
microseconds // 1_000_000, tz=local_tz
|
|
70
61
|
) + datetime.timedelta(microseconds=microseconds % 1_000_000)
|
|
71
62
|
if t == "timestamp_ntz":
|
|
72
63
|
lit_dt = lit_dt.astimezone(datetime.timezone.utc)
|
|
73
64
|
tz_dt = tz_dt.astimezone(datetime.timezone.utc)
|
|
65
|
+
elif not get_is_evaluating_sql():
|
|
66
|
+
config_tz = global_config.spark_sql_session_timeZone
|
|
67
|
+
config_tz = ZoneInfo(config_tz) if config_tz else local_tz
|
|
68
|
+
tz_dt = tz_dt.astimezone(config_tz)
|
|
69
|
+
lit_dt = lit_dt.astimezone(local_tz)
|
|
70
|
+
|
|
74
71
|
return lit_dt, f"{t.upper()} '{tz_dt.strftime('%Y-%m-%d %H:%M:%S')}'"
|
|
75
72
|
case "day_time_interval":
|
|
76
73
|
# TODO(SNOW-1920942): Snowflake SQL is missing an "interval" type.
|
|
@@ -127,10 +127,11 @@ def map_cast(
|
|
|
127
127
|
from_type = StringType()
|
|
128
128
|
if isinstance(to_type, StringType):
|
|
129
129
|
to_type = StringType()
|
|
130
|
+
|
|
131
|
+
# todo - verify if that's correct SNOW-2248680
|
|
130
132
|
if isinstance(from_type, TimestampType):
|
|
131
133
|
from_type = TimestampType()
|
|
132
|
-
|
|
133
|
-
to_type = TimestampType()
|
|
134
|
+
|
|
134
135
|
match (from_type, to_type):
|
|
135
136
|
case (_, _) if (from_type == to_type):
|
|
136
137
|
result_exp = col
|
|
@@ -185,6 +186,17 @@ def map_cast(
|
|
|
185
186
|
case (DateType(), TimestampType()):
|
|
186
187
|
result_exp = snowpark_fn.to_timestamp(col)
|
|
187
188
|
result_exp = result_exp.cast(TimestampType(TimestampTimeZone.NTZ))
|
|
189
|
+
case (TimestampType() as f, TimestampType() as t) if f.tzinfo == t.tzinfo:
|
|
190
|
+
result_exp = col
|
|
191
|
+
case (
|
|
192
|
+
TimestampType(),
|
|
193
|
+
TimestampType() as t,
|
|
194
|
+
) if t.tzinfo == TimestampTimeZone.NTZ:
|
|
195
|
+
zone = global_config.spark_sql_session_timeZone
|
|
196
|
+
result_exp = snowpark_fn.convert_timezone(snowpark_fn.lit(zone), col).cast(
|
|
197
|
+
TimestampType(TimestampTimeZone.NTZ)
|
|
198
|
+
)
|
|
199
|
+
# todo: verify if more support for LTZ and TZ is needed - SNOW-2248680
|
|
188
200
|
case (TimestampType(), TimestampType()):
|
|
189
201
|
result_exp = col
|
|
190
202
|
case (_, TimestampType()) if isinstance(from_type, _NumericType):
|
|
@@ -259,8 +271,12 @@ def map_cast(
|
|
|
259
271
|
case (_, _) if isinstance(from_type, _FractionalType) and isinstance(
|
|
260
272
|
to_type, _IntegralType
|
|
261
273
|
):
|
|
262
|
-
result_exp =
|
|
263
|
-
snowpark_fn.
|
|
274
|
+
result_exp = (
|
|
275
|
+
snowpark_fn.when(
|
|
276
|
+
col == snowpark_fn.lit(float("nan")), snowpark_fn.lit(0)
|
|
277
|
+
)
|
|
278
|
+
.when(col < 0, snowpark_fn.ceil(col))
|
|
279
|
+
.otherwise(snowpark_fn.floor(col))
|
|
264
280
|
)
|
|
265
281
|
result_exp = result_exp.cast(to_type)
|
|
266
282
|
case (StringType(), _) if (isinstance(to_type, _IntegralType)):
|
|
@@ -11,6 +11,7 @@ import snowflake.snowpark.functions as snowpark_fn
|
|
|
11
11
|
from snowflake import snowpark
|
|
12
12
|
from snowflake.snowpark import Session
|
|
13
13
|
from snowflake.snowpark._internal.analyzer.expression import UnresolvedAttribute
|
|
14
|
+
from snowflake.snowpark.types import TimestampTimeZone, TimestampType
|
|
14
15
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
15
16
|
from snowflake.snowpark_connect.expression import (
|
|
16
17
|
map_extension,
|
|
@@ -190,9 +191,15 @@ def map_expression(
|
|
|
190
191
|
return [lit_name], TypedColumn(
|
|
191
192
|
snowpark_fn.lit(lit_value, return_type), lambda: [return_type]
|
|
192
193
|
)
|
|
194
|
+
result_exp = snowpark_fn.lit(lit_value)
|
|
195
|
+
|
|
196
|
+
if lit_type_str == "timestamp_ntz" and isinstance(
|
|
197
|
+
lit_value, datetime.datetime
|
|
198
|
+
):
|
|
199
|
+
result_exp = result_exp.cast(TimestampType(TimestampTimeZone.NTZ))
|
|
193
200
|
|
|
194
201
|
return [lit_name], TypedColumn(
|
|
195
|
-
|
|
202
|
+
result_exp, lambda: [map_simple_types(lit_type_str)]
|
|
196
203
|
)
|
|
197
204
|
case "sort_order":
|
|
198
205
|
child_name, child_column = map_single_column_expression(
|
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
6
6
|
|
|
7
7
|
import snowflake.snowpark.functions as snowpark_fn
|
|
8
|
+
from snowflake.snowpark._internal.analyzer.expression import Literal
|
|
8
9
|
from snowflake.snowpark.types import ArrayType, MapType, StructType, _IntegralType
|
|
9
10
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
10
11
|
from snowflake.snowpark_connect.config import global_config
|
|
@@ -57,7 +58,8 @@ def map_unresolved_extract_value(
|
|
|
57
58
|
extract_fn = snowpark_fn.get_ignore_case
|
|
58
59
|
# Set index to a dummy value before we use it later in the ansi mode check.
|
|
59
60
|
index = snowpark_fn.lit(1)
|
|
60
|
-
|
|
61
|
+
is_array = _check_if_array_type(extract_typed_column, child_typed_column)
|
|
62
|
+
if is_array:
|
|
61
63
|
# Set all non-valid array indices to NULL.
|
|
62
64
|
# This is done because both conditions of a CASE WHEN statement are executed regardless of if the condition is true or not.
|
|
63
65
|
# Getting a negative index in Snowflake throws an error; thus, we convert all non-valid array indices to NULL before getting the index.
|
|
@@ -74,12 +76,37 @@ def map_unresolved_extract_value(
|
|
|
74
76
|
|
|
75
77
|
spark_sql_ansi_enabled = global_config.spark_sql_ansi_enabled
|
|
76
78
|
|
|
77
|
-
if spark_sql_ansi_enabled and
|
|
78
|
-
extract_typed_column, child_typed_column
|
|
79
|
-
):
|
|
79
|
+
if spark_sql_ansi_enabled and is_array:
|
|
80
80
|
result_exp = snowpark_fn.when(
|
|
81
81
|
index.isNull(),
|
|
82
82
|
child_typed_column.col.getItem("[snowpark_connect::INVALID_ARRAY_INDEX]"),
|
|
83
83
|
).otherwise(result_exp)
|
|
84
84
|
|
|
85
|
-
|
|
85
|
+
def _get_extracted_value_type():
|
|
86
|
+
if is_array:
|
|
87
|
+
return [child_typed_column.typ.element_type]
|
|
88
|
+
elif isinstance(child_typed_column.typ, MapType):
|
|
89
|
+
return [child_typed_column.typ.value_type]
|
|
90
|
+
elif (
|
|
91
|
+
isinstance(child_typed_column.typ, StructType)
|
|
92
|
+
and isinstance(extract_typed_column.col._expr1, Literal)
|
|
93
|
+
and isinstance(extract_typed_column.col._expr1.value, str)
|
|
94
|
+
):
|
|
95
|
+
struct = dict(
|
|
96
|
+
{
|
|
97
|
+
(
|
|
98
|
+
f.name
|
|
99
|
+
if global_config.spark_sql_caseSensitive
|
|
100
|
+
else f.name.lower(),
|
|
101
|
+
f.datatype,
|
|
102
|
+
)
|
|
103
|
+
for f in child_typed_column.typ.fields
|
|
104
|
+
}
|
|
105
|
+
)
|
|
106
|
+
key = extract_typed_column.col._expr1.value
|
|
107
|
+
key = key if global_config.spark_sql_caseSensitive else key.lower()
|
|
108
|
+
|
|
109
|
+
return [struct[key]] if key in struct else typer.type(result_exp)
|
|
110
|
+
return typer.type(result_exp)
|
|
111
|
+
|
|
112
|
+
return spark_function_name, TypedColumn(result_exp, _get_extracted_value_type)
|
|
@@ -1107,7 +1107,7 @@ def map_unresolved_function(
|
|
|
1107
1107
|
result_exp = TypedColumn(
|
|
1108
1108
|
result_exp, lambda: [ArrayType(snowpark_typed_args[0].typ)]
|
|
1109
1109
|
)
|
|
1110
|
-
case "array_size"
|
|
1110
|
+
case "array_size":
|
|
1111
1111
|
array_type = snowpark_typed_args[0].typ
|
|
1112
1112
|
if not isinstance(array_type, ArrayType):
|
|
1113
1113
|
raise AnalysisException(
|
|
@@ -1116,6 +1116,16 @@ def map_unresolved_function(
|
|
|
1116
1116
|
result_exp = TypedColumn(
|
|
1117
1117
|
snowpark_fn.array_size(*snowpark_args), lambda: [LongType()]
|
|
1118
1118
|
)
|
|
1119
|
+
case "cardinality":
|
|
1120
|
+
arg_type = snowpark_typed_args[0].typ
|
|
1121
|
+
if isinstance(arg_type, (ArrayType, MapType)):
|
|
1122
|
+
result_exp = TypedColumn(
|
|
1123
|
+
snowpark_fn.size(*snowpark_args), lambda: [LongType()]
|
|
1124
|
+
)
|
|
1125
|
+
else:
|
|
1126
|
+
raise AnalysisException(
|
|
1127
|
+
f"Expected argument '{snowpark_arg_names[0]}' to have an ArrayType or MapType, but got {arg_type.simpleString()}."
|
|
1128
|
+
)
|
|
1119
1129
|
case "array_sort":
|
|
1120
1130
|
result_exp = TypedColumn(
|
|
1121
1131
|
snowpark_fn.array_sort(*snowpark_args),
|
|
@@ -1295,10 +1305,35 @@ def map_unresolved_function(
|
|
|
1295
1305
|
)
|
|
1296
1306
|
result_exp = TypedColumn(result_exp, lambda: [LongType()])
|
|
1297
1307
|
case "bit_get" | "getbit":
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
bit_get_function(*snowpark_args), lambda: [LongType()]
|
|
1308
|
+
snowflake_compat = get_boolean_session_config_param(
|
|
1309
|
+
"enable_snowflake_extension_behavior"
|
|
1301
1310
|
)
|
|
1311
|
+
col, pos = snowpark_args
|
|
1312
|
+
if snowflake_compat:
|
|
1313
|
+
bit_get_function = snowpark_fn.function("GETBIT")(col, pos)
|
|
1314
|
+
else:
|
|
1315
|
+
raise_error = _raise_error_helper(LongType())
|
|
1316
|
+
bit_get_function = snowpark_fn.when(
|
|
1317
|
+
(snowpark_fn.lit(0) <= pos) & (pos <= snowpark_fn.lit(63))
|
|
1318
|
+
| snowpark_fn.is_null(pos),
|
|
1319
|
+
snowpark_fn.function("GETBIT")(col, pos),
|
|
1320
|
+
).otherwise(
|
|
1321
|
+
raise_error(
|
|
1322
|
+
snowpark_fn.concat(
|
|
1323
|
+
snowpark_fn.lit(
|
|
1324
|
+
"Invalid bit position: ",
|
|
1325
|
+
),
|
|
1326
|
+
snowpark_fn.cast(
|
|
1327
|
+
pos,
|
|
1328
|
+
StringType(),
|
|
1329
|
+
),
|
|
1330
|
+
snowpark_fn.lit(
|
|
1331
|
+
" exceeds the bit upper limit",
|
|
1332
|
+
),
|
|
1333
|
+
)
|
|
1334
|
+
)
|
|
1335
|
+
)
|
|
1336
|
+
result_exp = TypedColumn(bit_get_function, lambda: [LongType()])
|
|
1302
1337
|
case "bit_length":
|
|
1303
1338
|
bit_length_function = snowpark_fn.function("bit_length")
|
|
1304
1339
|
result_exp = TypedColumn(
|
|
@@ -3787,7 +3822,13 @@ def map_unresolved_function(
|
|
|
3787
3822
|
case "locate":
|
|
3788
3823
|
substr = unwrap_literal(exp.unresolved_function.arguments[0])
|
|
3789
3824
|
value = snowpark_args[1]
|
|
3790
|
-
|
|
3825
|
+
if len(exp.unresolved_function.arguments) == 3:
|
|
3826
|
+
start_pos = unwrap_literal(exp.unresolved_function.arguments[2])
|
|
3827
|
+
else:
|
|
3828
|
+
# start_pos is an optional argument and if not provided we should default to 1.
|
|
3829
|
+
# This path will only be reached by spark connect scala clients.
|
|
3830
|
+
start_pos = 1
|
|
3831
|
+
spark_function_name = f"locate({', '.join(snowpark_arg_names)}, 1)"
|
|
3791
3832
|
|
|
3792
3833
|
if start_pos > 0:
|
|
3793
3834
|
result_exp = snowpark_fn.locate(substr, value, start_pos)
|
|
@@ -5496,9 +5537,27 @@ def map_unresolved_function(
|
|
|
5496
5537
|
):
|
|
5497
5538
|
result_exp = snowpark_fn.lit(None)
|
|
5498
5539
|
else:
|
|
5540
|
+
right_expr = snowpark_fn.right(*snowpark_args)
|
|
5541
|
+
if isinstance(snowpark_typed_args[0].typ, TimestampType):
|
|
5542
|
+
# Spark format is always displayed as YYY-MM-DD HH:mm:ss.FF6
|
|
5543
|
+
# When microseconds are equal to 0 .FF6 part is removed
|
|
5544
|
+
# When microseconds are equal to 0 at the end, they are removed i.e. .123000 -> .123 when displayed
|
|
5545
|
+
|
|
5546
|
+
formated_timestamp = snowpark_fn.to_varchar(
|
|
5547
|
+
snowpark_args[0], "YYYY-MM-DD HH:MI:SS.FF6"
|
|
5548
|
+
)
|
|
5549
|
+
right_expr = snowpark_fn.right(
|
|
5550
|
+
snowpark_fn.regexp_replace(
|
|
5551
|
+
snowpark_fn.regexp_replace(formated_timestamp, "0+$", ""),
|
|
5552
|
+
"\\.$",
|
|
5553
|
+
"",
|
|
5554
|
+
),
|
|
5555
|
+
snowpark_args[1],
|
|
5556
|
+
)
|
|
5557
|
+
|
|
5499
5558
|
result_exp = snowpark_fn.when(
|
|
5500
5559
|
snowpark_args[1] <= 0, snowpark_fn.lit("")
|
|
5501
|
-
).otherwise(
|
|
5560
|
+
).otherwise(right_expr)
|
|
5502
5561
|
result_type = StringType()
|
|
5503
5562
|
case "rint":
|
|
5504
5563
|
result_exp = snowpark_fn.cast(
|
|
@@ -6729,6 +6788,7 @@ def map_unresolved_function(
|
|
|
6729
6788
|
if value == "" or any(
|
|
6730
6789
|
c in value for c in [",", "\n", "\r", '"', "'"]
|
|
6731
6790
|
):
|
|
6791
|
+
value = value.replace("\\", "\\\\").replace('"', '\\"')
|
|
6732
6792
|
result.append(f'"{value}"')
|
|
6733
6793
|
else:
|
|
6734
6794
|
result.append(value)
|
|
@@ -16,7 +16,7 @@ from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_rel
|
|
|
16
16
|
from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1csnowflake_relation_ext.proto\x12\rsnowflake.ext\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\"\xe3\x02\n\tExtension\x12(\n\x07rdd_map\x18\x01 \x01(\x0b\x32\x15.snowflake.ext.RddMapH\x00\x12.\n\nrdd_reduce\x18\x02 \x01(\x0b\x32\x18.snowflake.ext.RddReduceH\x00\x12G\n\x17subquery_column_aliases\x18\x03 \x01(\x0b\x32$.snowflake.ext.SubqueryColumnAliasesH\x00\x12\x32\n\x0clateral_join\x18\x04 \x01(\x0b\x32\x1a.snowflake.ext.LateralJoinH\x00\x12J\n\x19udtf_with_table_arguments\x18\x05 \x01(\x0b\x32%.snowflake.ext.UDTFWithTableArgumentsH\x00\x12-\n\taggregate\x18\x06 \x01(\x0b\x32\x18.snowflake.ext.AggregateH\x00\x42\x04\n\x02op\">\n\x06RddMap\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"A\n\tRddReduce\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"P\n\x15SubqueryColumnAliases\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0f\n\x07\x61liases\x18\x02 \x03(\t\"\\\n\x0bLateralJoin\x12%\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12&\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.Relation\"\x98\x01\n\x16UDTFWithTableArguments\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12,\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x39\n\x0ftable_arguments\x18\x03 \x03(\x0b\x32 .snowflake.ext.TableArgumentInfo\"`\n\x11TableArgumentInfo\x12/\n\x0etable_argument\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x1a\n\x12table_argument_idx\x18\x02 \x01(\x05\"\
|
|
19
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1csnowflake_relation_ext.proto\x12\rsnowflake.ext\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\"\xe3\x02\n\tExtension\x12(\n\x07rdd_map\x18\x01 \x01(\x0b\x32\x15.snowflake.ext.RddMapH\x00\x12.\n\nrdd_reduce\x18\x02 \x01(\x0b\x32\x18.snowflake.ext.RddReduceH\x00\x12G\n\x17subquery_column_aliases\x18\x03 \x01(\x0b\x32$.snowflake.ext.SubqueryColumnAliasesH\x00\x12\x32\n\x0clateral_join\x18\x04 \x01(\x0b\x32\x1a.snowflake.ext.LateralJoinH\x00\x12J\n\x19udtf_with_table_arguments\x18\x05 \x01(\x0b\x32%.snowflake.ext.UDTFWithTableArgumentsH\x00\x12-\n\taggregate\x18\x06 \x01(\x0b\x32\x18.snowflake.ext.AggregateH\x00\x42\x04\n\x02op\">\n\x06RddMap\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"A\n\tRddReduce\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"P\n\x15SubqueryColumnAliases\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x0f\n\x07\x61liases\x18\x02 \x03(\t\"\\\n\x0bLateralJoin\x12%\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12&\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.Relation\"\x98\x01\n\x16UDTFWithTableArguments\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12,\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x39\n\x0ftable_arguments\x18\x03 \x03(\x0b\x32 .snowflake.ext.TableArgumentInfo\"`\n\x11TableArgumentInfo\x12/\n\x0etable_argument\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x1a\n\x12table_argument_idx\x18\x02 \x01(\x05\"\xc7\x05\n\tAggregate\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x36\n\ngroup_type\x18\x02 \x01(\x0e\x32\".snowflake.ext.Aggregate.GroupType\x12\x37\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x38\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.Expression\x12-\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.snowflake.ext.Aggregate.Pivot\x12<\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.snowflake.ext.Aggregate.GroupingSets\x12\x33\n\x10having_condition\x18\x07 \x01(\x0b\x32\x19.spark.connect.Expression\x1a\x62\n\x05Pivot\x12&\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.Expression\x12\x31\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.Literal\x1a?\n\x0cGroupingSets\x12/\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.Expression\"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05\x62\x06proto3')
|
|
20
20
|
|
|
21
21
|
_globals = globals()
|
|
22
22
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -38,11 +38,11 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
38
38
|
_globals['_TABLEARGUMENTINFO']._serialized_start=931
|
|
39
39
|
_globals['_TABLEARGUMENTINFO']._serialized_end=1027
|
|
40
40
|
_globals['_AGGREGATE']._serialized_start=1030
|
|
41
|
-
_globals['_AGGREGATE']._serialized_end=
|
|
42
|
-
_globals['_AGGREGATE_PIVOT']._serialized_start=
|
|
43
|
-
_globals['_AGGREGATE_PIVOT']._serialized_end=
|
|
44
|
-
_globals['_AGGREGATE_GROUPINGSETS']._serialized_start=
|
|
45
|
-
_globals['_AGGREGATE_GROUPINGSETS']._serialized_end=
|
|
46
|
-
_globals['_AGGREGATE_GROUPTYPE']._serialized_start=
|
|
47
|
-
_globals['_AGGREGATE_GROUPTYPE']._serialized_end=
|
|
41
|
+
_globals['_AGGREGATE']._serialized_end=1741
|
|
42
|
+
_globals['_AGGREGATE_PIVOT']._serialized_start=1416
|
|
43
|
+
_globals['_AGGREGATE_PIVOT']._serialized_end=1514
|
|
44
|
+
_globals['_AGGREGATE_GROUPINGSETS']._serialized_start=1516
|
|
45
|
+
_globals['_AGGREGATE_GROUPINGSETS']._serialized_end=1579
|
|
46
|
+
_globals['_AGGREGATE_GROUPTYPE']._serialized_start=1582
|
|
47
|
+
_globals['_AGGREGATE_GROUPTYPE']._serialized_end=1741
|
|
48
48
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -75,7 +75,7 @@ class TableArgumentInfo(_message.Message):
|
|
|
75
75
|
def __init__(self, table_argument: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., table_argument_idx: _Optional[int] = ...) -> None: ...
|
|
76
76
|
|
|
77
77
|
class Aggregate(_message.Message):
|
|
78
|
-
__slots__ = ("input", "group_type", "grouping_expressions", "aggregate_expressions", "pivot", "grouping_sets")
|
|
78
|
+
__slots__ = ("input", "group_type", "grouping_expressions", "aggregate_expressions", "pivot", "grouping_sets", "having_condition")
|
|
79
79
|
class GroupType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
|
80
80
|
__slots__ = ()
|
|
81
81
|
GROUP_TYPE_UNSPECIFIED: _ClassVar[Aggregate.GroupType]
|
|
@@ -108,10 +108,12 @@ class Aggregate(_message.Message):
|
|
|
108
108
|
AGGREGATE_EXPRESSIONS_FIELD_NUMBER: _ClassVar[int]
|
|
109
109
|
PIVOT_FIELD_NUMBER: _ClassVar[int]
|
|
110
110
|
GROUPING_SETS_FIELD_NUMBER: _ClassVar[int]
|
|
111
|
+
HAVING_CONDITION_FIELD_NUMBER: _ClassVar[int]
|
|
111
112
|
input: _relations_pb2.Relation
|
|
112
113
|
group_type: Aggregate.GroupType
|
|
113
114
|
grouping_expressions: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
|
|
114
115
|
aggregate_expressions: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
|
|
115
116
|
pivot: Aggregate.Pivot
|
|
116
117
|
grouping_sets: _containers.RepeatedCompositeFieldContainer[Aggregate.GroupingSets]
|
|
117
|
-
|
|
118
|
+
having_condition: _expressions_pb2.Expression
|
|
119
|
+
def __init__(self, input: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., group_type: _Optional[_Union[Aggregate.GroupType, str]] = ..., grouping_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., aggregate_expressions: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ..., pivot: _Optional[_Union[Aggregate.Pivot, _Mapping]] = ..., grouping_sets: _Optional[_Iterable[_Union[Aggregate.GroupingSets, _Mapping]]] = ..., having_condition: _Optional[_Union[_expressions_pb2.Expression, _Mapping]] = ...) -> None: ...
|