snowpark-connect 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +19 -14
- snowflake/snowpark_connect/error/error_utils.py +32 -0
- snowflake/snowpark_connect/error/exceptions.py +4 -0
- snowflake/snowpark_connect/expression/hybrid_column_map.py +192 -0
- snowflake/snowpark_connect/expression/literal.py +9 -12
- snowflake/snowpark_connect/expression/map_cast.py +20 -4
- snowflake/snowpark_connect/expression/map_expression.py +8 -1
- snowflake/snowpark_connect/expression/map_udf.py +4 -4
- snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +32 -5
- snowflake/snowpark_connect/expression/map_unresolved_function.py +269 -134
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +8 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +4 -2
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +127 -21
- snowflake/snowpark_connect/relation/map_aggregate.py +154 -18
- snowflake/snowpark_connect/relation/map_column_ops.py +59 -8
- snowflake/snowpark_connect/relation/map_extension.py +58 -24
- snowflake/snowpark_connect/relation/map_local_relation.py +8 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +3 -1
- snowflake/snowpark_connect/relation/map_row_ops.py +30 -1
- snowflake/snowpark_connect/relation/map_sql.py +40 -196
- snowflake/snowpark_connect/relation/map_udtf.py +4 -4
- snowflake/snowpark_connect/relation/read/map_read.py +2 -1
- snowflake/snowpark_connect/relation/read/map_read_json.py +12 -1
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +8 -1
- snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
- snowflake/snowpark_connect/relation/read/utils.py +7 -6
- snowflake/snowpark_connect/relation/utils.py +170 -1
- snowflake/snowpark_connect/relation/write/map_write.py +306 -87
- snowflake/snowpark_connect/server.py +34 -5
- snowflake/snowpark_connect/type_mapping.py +6 -2
- snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
- snowflake/snowpark_connect/utils/env_utils.py +55 -0
- snowflake/snowpark_connect/utils/session.py +21 -4
- snowflake/snowpark_connect/utils/telemetry.py +213 -61
- snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/__init__.py +0 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
- snowflake/snowpark_decoder/dp_session.py +111 -0
- snowflake/snowpark_decoder/spark_decoder.py +76 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/METADATA +2 -2
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/RECORD +55 -44
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/top_level.txt +1 -0
- spark/__init__.py +0 -0
- spark/connect/__init__.py +0 -0
- spark/connect/envelope_pb2.py +31 -0
- spark/connect/envelope_pb2.pyi +46 -0
- snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.23.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.23.0.dist-info}/licenses/NOTICE-binary +0 -0
|
@@ -4,10 +4,8 @@
|
|
|
4
4
|
|
|
5
5
|
# Proto source for reference:
|
|
6
6
|
# https://github.com/apache/spark/blob/branch-3.5/connector/connect/common/src/main/protobuf/spark/connect/base.proto#L420
|
|
7
|
-
import os
|
|
8
7
|
import re
|
|
9
8
|
import sys
|
|
10
|
-
import time
|
|
11
9
|
from collections import defaultdict
|
|
12
10
|
from copy import copy, deepcopy
|
|
13
11
|
from typing import Any
|
|
@@ -168,6 +166,9 @@ class GlobalConfig:
|
|
|
168
166
|
"snowpark.connect.udf.packages": lambda session, packages: session.add_packages(
|
|
169
167
|
*packages.strip("[] ").split(",")
|
|
170
168
|
),
|
|
169
|
+
"snowpark.connect.udf.imports": lambda session, imports: parse_imports(
|
|
170
|
+
session, imports
|
|
171
|
+
),
|
|
171
172
|
}
|
|
172
173
|
|
|
173
174
|
float_config_list = []
|
|
@@ -332,7 +333,7 @@ def route_config_proto(
|
|
|
332
333
|
match op_type:
|
|
333
334
|
case "set":
|
|
334
335
|
logger.info("SET")
|
|
335
|
-
|
|
336
|
+
telemetry.report_config_set(config.operation.set.pairs)
|
|
336
337
|
for pair in config.operation.set.pairs:
|
|
337
338
|
# Check if the value field is present, not present when invalid fields are set in conf.
|
|
338
339
|
if not pair.HasField("value"):
|
|
@@ -342,7 +343,6 @@ def route_config_proto(
|
|
|
342
343
|
f"Cannot set config '{pair.key}' to None"
|
|
343
344
|
)
|
|
344
345
|
|
|
345
|
-
telemetry.report_config_set(pair.key, pair.value)
|
|
346
346
|
set_config_param(
|
|
347
347
|
config.session_id, pair.key, pair.value, snowpark_session
|
|
348
348
|
)
|
|
@@ -350,14 +350,15 @@ def route_config_proto(
|
|
|
350
350
|
return proto_base.ConfigResponse(session_id=config.session_id)
|
|
351
351
|
case "unset":
|
|
352
352
|
logger.info("UNSET")
|
|
353
|
+
telemetry.report_config_unset(config.operation.unset.keys)
|
|
353
354
|
for key in config.operation.unset.keys:
|
|
354
|
-
telemetry.report_config_unset(key)
|
|
355
355
|
unset_config_param(config.session_id, key, snowpark_session)
|
|
356
356
|
|
|
357
357
|
return proto_base.ConfigResponse(session_id=config.session_id)
|
|
358
358
|
case "get":
|
|
359
359
|
logger.info("GET")
|
|
360
360
|
res = proto_base.ConfigResponse(session_id=config.session_id)
|
|
361
|
+
telemetry.report_config_get(config.operation.get.keys)
|
|
361
362
|
for key in config.operation.get.keys:
|
|
362
363
|
pair = res.pairs.add()
|
|
363
364
|
pair.key = key
|
|
@@ -367,6 +368,9 @@ def route_config_proto(
|
|
|
367
368
|
return res
|
|
368
369
|
case "get_with_default":
|
|
369
370
|
logger.info("GET_WITH_DEFAULT")
|
|
371
|
+
telemetry.report_config_get(
|
|
372
|
+
[pair.key for pair in config.operation.get_with_default.pairs]
|
|
373
|
+
)
|
|
370
374
|
result_pairs = [
|
|
371
375
|
proto_base.KeyValue(
|
|
372
376
|
key=pair.key,
|
|
@@ -383,6 +387,7 @@ def route_config_proto(
|
|
|
383
387
|
case "get_option":
|
|
384
388
|
logger.info("GET_OPTION")
|
|
385
389
|
res = proto_base.ConfigResponse(session_id=config.session_id)
|
|
390
|
+
telemetry.report_config_get(config.operation.get_option.keys)
|
|
386
391
|
for key in config.operation.get_option.keys:
|
|
387
392
|
pair = res.pairs.add()
|
|
388
393
|
pair.key = key
|
|
@@ -411,6 +416,7 @@ def route_config_proto(
|
|
|
411
416
|
case "is_modifiable":
|
|
412
417
|
logger.info("IS_MODIFIABLE")
|
|
413
418
|
res = proto_base.ConfigResponse(session_id=config.session_id)
|
|
419
|
+
telemetry.report_config_get(config.operation.is_modifiable.keys)
|
|
414
420
|
for key in config.operation.is_modifiable.keys:
|
|
415
421
|
pair = res.pairs.add()
|
|
416
422
|
pair.key = key
|
|
@@ -525,17 +531,8 @@ def set_snowflake_parameters(
|
|
|
525
531
|
snowpark_session.sql(
|
|
526
532
|
f"ALTER SESSION SET TIMEZONE = '{value}'"
|
|
527
533
|
).collect()
|
|
528
|
-
set_jvm_timezone(value)
|
|
529
|
-
if hasattr(time, "tzset"):
|
|
530
|
-
os.environ["TZ"] = value
|
|
531
|
-
time.tzset()
|
|
532
534
|
else:
|
|
533
535
|
snowpark_session.sql("ALTER SESSION UNSET TIMEZONE").collect()
|
|
534
|
-
reset_jvm_timezone_to_system_default()
|
|
535
|
-
if hasattr(time, "tzset") and False:
|
|
536
|
-
if "TZ" in os.environ:
|
|
537
|
-
del os.environ["TZ"]
|
|
538
|
-
time.tzset()
|
|
539
536
|
case "spark.sql.globalTempDatabase":
|
|
540
537
|
if not value:
|
|
541
538
|
value = global_config.default_static_global_config.get(key)
|
|
@@ -588,3 +585,11 @@ def auto_uppercase_non_column_identifiers() -> bool:
|
|
|
588
585
|
return session_config[
|
|
589
586
|
"snowpark.connect.sql.identifiers.auto-uppercase"
|
|
590
587
|
].lower() in ("all", "all_except_columns")
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def parse_imports(session: snowpark.Session, imports: str | None) -> None:
|
|
591
|
+
if not imports:
|
|
592
|
+
return
|
|
593
|
+
|
|
594
|
+
for udf_import in imports.strip("[] ").split(","):
|
|
595
|
+
session.add_import(udf_import)
|
|
@@ -28,7 +28,9 @@ from pyspark.errors.exceptions.base import (
|
|
|
28
28
|
PySparkException,
|
|
29
29
|
PythonException,
|
|
30
30
|
SparkRuntimeException,
|
|
31
|
+
UnsupportedOperationException,
|
|
31
32
|
)
|
|
33
|
+
from pyspark.errors.exceptions.connect import SparkConnectGrpcException
|
|
32
34
|
from snowflake.core.exceptions import NotFoundError
|
|
33
35
|
|
|
34
36
|
from snowflake.connector.errors import ProgrammingError
|
|
@@ -49,7 +51,9 @@ SPARK_PYTHON_TO_JAVA_EXCEPTION = {
|
|
|
49
51
|
ArrayIndexOutOfBoundsException: "java.lang.ArrayIndexOutOfBoundsException",
|
|
50
52
|
NumberFormatException: "java.lang.NumberFormatException",
|
|
51
53
|
SparkRuntimeException: "org.apache.spark.SparkRuntimeException",
|
|
54
|
+
SparkConnectGrpcException: "pyspark.errors.exceptions.connect.SparkConnectGrpcException",
|
|
52
55
|
PythonException: "org.apache.spark.api.python.PythonException",
|
|
56
|
+
UnsupportedOperationException: "java.lang.UnsupportedOperationException",
|
|
53
57
|
}
|
|
54
58
|
|
|
55
59
|
WINDOW_FUNCTION_ANALYSIS_EXCEPTION_SQL_ERROR_CODE = {1005, 2303}
|
|
@@ -68,6 +72,13 @@ init_multi_args_exception_pattern = (
|
|
|
68
72
|
terminate_multi_args_exception_pattern = (
|
|
69
73
|
r"terminate\(\) missing \d+ required positional argument"
|
|
70
74
|
)
|
|
75
|
+
snowpark_connect_exception_pattern = re.compile(
|
|
76
|
+
r"\[snowpark-connect-exception(?::(\w+))?\]\s*(.+?)'\s*is not recognized"
|
|
77
|
+
)
|
|
78
|
+
invalid_bit_pattern = re.compile(
|
|
79
|
+
r"Invalid bit position: \d+ exceeds the bit (?:upper|lower) limit",
|
|
80
|
+
re.IGNORECASE,
|
|
81
|
+
)
|
|
71
82
|
|
|
72
83
|
|
|
73
84
|
def contains_udtf_select(sql_string):
|
|
@@ -100,6 +111,22 @@ def _get_converted_known_sql_or_custom_exception(
|
|
|
100
111
|
return SparkRuntimeException(
|
|
101
112
|
message="Unexpected value for start in function slice: SQL array indices start at 1."
|
|
102
113
|
)
|
|
114
|
+
invalid_bit = invalid_bit_pattern.search(msg)
|
|
115
|
+
if invalid_bit:
|
|
116
|
+
return IllegalArgumentException(message=invalid_bit.group(0))
|
|
117
|
+
match = snowpark_connect_exception_pattern.search(
|
|
118
|
+
ex.message if hasattr(ex, "message") else str(ex)
|
|
119
|
+
)
|
|
120
|
+
if match:
|
|
121
|
+
class_name = match.group(1)
|
|
122
|
+
message = match.group(2)
|
|
123
|
+
exception_class = (
|
|
124
|
+
globals().get(class_name, SparkConnectGrpcException)
|
|
125
|
+
if class_name
|
|
126
|
+
else SparkConnectGrpcException
|
|
127
|
+
)
|
|
128
|
+
return exception_class(message=message)
|
|
129
|
+
|
|
103
130
|
if "select with no columns" in msg and contains_udtf_select(query):
|
|
104
131
|
# We try our best to detect if the SQL string contains a UDTF call and the output schema is empty.
|
|
105
132
|
return PythonException(message=f"[UDTF_RETURN_SCHEMA_MISMATCH] {ex.message}")
|
|
@@ -131,6 +158,11 @@ def _get_converted_known_sql_or_custom_exception(
|
|
|
131
158
|
message=f"[UDTF_EXEC_ERROR] User defined table function encountered an error in the terminate method: {ex.message}"
|
|
132
159
|
)
|
|
133
160
|
|
|
161
|
+
if "failed to split string, provided pattern:" in msg:
|
|
162
|
+
return IllegalArgumentException(
|
|
163
|
+
message=f"Failed to split string using provided pattern. {ex.message}"
|
|
164
|
+
)
|
|
165
|
+
|
|
134
166
|
if "100357" in msg and "wrong tuple size for returned value" in msg:
|
|
135
167
|
return PythonException(
|
|
136
168
|
message=f"[UDTF_RETURN_SCHEMA_MISMATCH] The number of columns in the result does not match the specified schema. {ex.message}"
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Hybrid column mapping for HAVING clause resolution.
|
|
7
|
+
|
|
8
|
+
This module provides a special column mapping that can resolve expressions
|
|
9
|
+
in the context of both the input DataFrame (for base columns) and the
|
|
10
|
+
aggregated DataFrame (for aggregate expressions and aliases).
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import Dict, List
|
|
14
|
+
|
|
15
|
+
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
16
|
+
|
|
17
|
+
from snowflake import snowpark
|
|
18
|
+
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
19
|
+
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
20
|
+
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class HybridColumnMap:
|
|
24
|
+
"""
|
|
25
|
+
A column mapping that can resolve expressions in both input and aggregated contexts.
|
|
26
|
+
|
|
27
|
+
This is specifically designed for HAVING clause resolution where expressions may reference:
|
|
28
|
+
1. Base columns from the input DataFrame (to build new aggregates)
|
|
29
|
+
2. Existing aggregate expressions and their aliases
|
|
30
|
+
3. Grouping columns
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
input_column_map: ColumnNameMap,
|
|
36
|
+
input_typer: ExpressionTyper,
|
|
37
|
+
aggregated_column_map: ColumnNameMap,
|
|
38
|
+
aggregated_typer: ExpressionTyper,
|
|
39
|
+
aggregate_expressions: List[expressions_proto.Expression],
|
|
40
|
+
grouping_expressions: List[expressions_proto.Expression],
|
|
41
|
+
aggregate_aliases: Dict[str, expressions_proto.Expression],
|
|
42
|
+
) -> None:
|
|
43
|
+
self.input_column_map = input_column_map
|
|
44
|
+
self.input_typer = input_typer
|
|
45
|
+
self.aggregated_column_map = aggregated_column_map
|
|
46
|
+
self.aggregated_typer = aggregated_typer
|
|
47
|
+
self.aggregate_expressions = aggregate_expressions
|
|
48
|
+
self.grouping_expressions = grouping_expressions
|
|
49
|
+
self.aggregate_aliases = aggregate_aliases
|
|
50
|
+
|
|
51
|
+
def is_aggregate_function(self, exp: expressions_proto.Expression) -> bool:
|
|
52
|
+
"""Check if an expression is an aggregate function."""
|
|
53
|
+
if exp.WhichOneof("expr_type") == "unresolved_function":
|
|
54
|
+
func_name = exp.unresolved_function.function_name.lower()
|
|
55
|
+
# Common aggregate functions - expand this list as needed
|
|
56
|
+
aggregate_functions = {
|
|
57
|
+
"avg",
|
|
58
|
+
"average",
|
|
59
|
+
"sum",
|
|
60
|
+
"count",
|
|
61
|
+
"min",
|
|
62
|
+
"max",
|
|
63
|
+
"stddev",
|
|
64
|
+
"stddev_pop",
|
|
65
|
+
"stddev_samp",
|
|
66
|
+
"variance",
|
|
67
|
+
"var_pop",
|
|
68
|
+
"var_samp",
|
|
69
|
+
"collect_list",
|
|
70
|
+
"collect_set",
|
|
71
|
+
"first",
|
|
72
|
+
"last",
|
|
73
|
+
"any_value",
|
|
74
|
+
"bool_and",
|
|
75
|
+
"bool_or",
|
|
76
|
+
"corr",
|
|
77
|
+
"covar_pop",
|
|
78
|
+
"covar_samp",
|
|
79
|
+
"kurtosis",
|
|
80
|
+
"skewness",
|
|
81
|
+
"percentile_cont",
|
|
82
|
+
"percentile_disc",
|
|
83
|
+
"approx_count_distinct",
|
|
84
|
+
}
|
|
85
|
+
return func_name in aggregate_functions
|
|
86
|
+
return False
|
|
87
|
+
|
|
88
|
+
def is_grouping_column(self, column_name: str) -> bool:
|
|
89
|
+
"""Check if a column name refers to a grouping column."""
|
|
90
|
+
for group_exp in self.grouping_expressions:
|
|
91
|
+
if (
|
|
92
|
+
group_exp.WhichOneof("expr_type") == "unresolved_attribute"
|
|
93
|
+
and group_exp.unresolved_attribute.unparsed_identifier == column_name
|
|
94
|
+
):
|
|
95
|
+
return True
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
def resolve_expression(
|
|
99
|
+
self, exp: expressions_proto.Expression
|
|
100
|
+
) -> tuple[list[str], TypedColumn]:
|
|
101
|
+
"""
|
|
102
|
+
Resolve an expression in the hybrid context.
|
|
103
|
+
|
|
104
|
+
Strategy:
|
|
105
|
+
1. If it's an aggregate function -> create new aggregate using input context
|
|
106
|
+
2. If it's an alias to existing aggregate -> use aggregated context
|
|
107
|
+
3. If it's a grouping column -> try aggregated context first, fall back to input context
|
|
108
|
+
(handles exclude_grouping_columns=True case)
|
|
109
|
+
4. Otherwise -> try input context first, then aggregated context
|
|
110
|
+
"""
|
|
111
|
+
from snowflake.snowpark_connect.expression.map_expression import map_expression
|
|
112
|
+
|
|
113
|
+
expr_type = exp.WhichOneof("expr_type")
|
|
114
|
+
|
|
115
|
+
# Handle aggregate functions - need to evaluate against input DataFrame
|
|
116
|
+
if self.is_aggregate_function(exp):
|
|
117
|
+
return map_expression(exp, self.input_column_map, self.input_typer)
|
|
118
|
+
|
|
119
|
+
# Handle column references
|
|
120
|
+
if expr_type == "unresolved_attribute":
|
|
121
|
+
column_name = exp.unresolved_attribute.unparsed_identifier
|
|
122
|
+
|
|
123
|
+
# Check if it's an alias to an existing aggregate expression
|
|
124
|
+
if column_name in self.aggregate_aliases:
|
|
125
|
+
# Use the aggregated context to get the alias
|
|
126
|
+
return map_expression(
|
|
127
|
+
exp, self.aggregated_column_map, self.aggregated_typer
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Check if it's a grouping column
|
|
131
|
+
if self.is_grouping_column(column_name):
|
|
132
|
+
# Try aggregated context first (for cases where grouping columns are included)
|
|
133
|
+
try:
|
|
134
|
+
return map_expression(
|
|
135
|
+
exp, self.aggregated_column_map, self.aggregated_typer
|
|
136
|
+
)
|
|
137
|
+
except Exception:
|
|
138
|
+
# Fall back to input context if grouping columns were excluded
|
|
139
|
+
# This handles the exclude_grouping_columns=True case
|
|
140
|
+
return map_expression(exp, self.input_column_map, self.input_typer)
|
|
141
|
+
|
|
142
|
+
# Try input context first (for base columns used in new aggregates)
|
|
143
|
+
try:
|
|
144
|
+
return map_expression(exp, self.input_column_map, self.input_typer)
|
|
145
|
+
except Exception:
|
|
146
|
+
# Fall back to aggregated context
|
|
147
|
+
return map_expression(
|
|
148
|
+
exp, self.aggregated_column_map, self.aggregated_typer
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# For other expression types, try aggregated context first (likely references to computed values)
|
|
152
|
+
try:
|
|
153
|
+
return map_expression(
|
|
154
|
+
exp, self.aggregated_column_map, self.aggregated_typer
|
|
155
|
+
)
|
|
156
|
+
except Exception:
|
|
157
|
+
# Fall back to input context
|
|
158
|
+
return map_expression(exp, self.input_column_map, self.input_typer)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def create_hybrid_column_map_for_having(
|
|
162
|
+
input_df: snowpark.DataFrame,
|
|
163
|
+
input_column_map: ColumnNameMap,
|
|
164
|
+
aggregated_df: snowpark.DataFrame,
|
|
165
|
+
aggregated_column_map: ColumnNameMap,
|
|
166
|
+
aggregate_expressions: List[expressions_proto.Expression],
|
|
167
|
+
grouping_expressions: List[expressions_proto.Expression],
|
|
168
|
+
spark_columns: List[str],
|
|
169
|
+
raw_aggregations: List[tuple[str, TypedColumn]],
|
|
170
|
+
) -> HybridColumnMap:
|
|
171
|
+
"""
|
|
172
|
+
Create a HybridColumnMap instance for HAVING clause resolution.
|
|
173
|
+
"""
|
|
174
|
+
# Create typers for both contexts
|
|
175
|
+
input_typer = ExpressionTyper(input_df)
|
|
176
|
+
aggregated_typer = ExpressionTyper(aggregated_df)
|
|
177
|
+
|
|
178
|
+
# Build alias mapping from spark column names to aggregate expressions
|
|
179
|
+
aggregate_aliases = {}
|
|
180
|
+
for i, (spark_name, _) in enumerate(raw_aggregations):
|
|
181
|
+
if i < len(aggregate_expressions):
|
|
182
|
+
aggregate_aliases[spark_name] = aggregate_expressions[i]
|
|
183
|
+
|
|
184
|
+
return HybridColumnMap(
|
|
185
|
+
input_column_map=input_column_map,
|
|
186
|
+
input_typer=input_typer,
|
|
187
|
+
aggregated_column_map=aggregated_column_map,
|
|
188
|
+
aggregated_typer=aggregated_typer,
|
|
189
|
+
aggregate_expressions=aggregate_expressions,
|
|
190
|
+
grouping_expressions=grouping_expressions,
|
|
191
|
+
aggregate_aliases=aggregate_aliases,
|
|
192
|
+
)
|
|
@@ -10,6 +10,7 @@ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
|
10
10
|
from tzlocal import get_localzone
|
|
11
11
|
|
|
12
12
|
from snowflake.snowpark_connect.config import global_config
|
|
13
|
+
from snowflake.snowpark_connect.utils.context import get_is_evaluating_sql
|
|
13
14
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
14
15
|
SnowparkConnectNotImplementedError,
|
|
15
16
|
)
|
|
@@ -47,17 +48,7 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
|
|
|
47
48
|
).date()
|
|
48
49
|
return date, f"DATE '{date}'"
|
|
49
50
|
case "timestamp" | "timestamp_ntz" as t:
|
|
50
|
-
|
|
51
|
-
# No need to apply timezone for lit datetime, because we set the TIMEZONE parameter in snowpark session,
|
|
52
|
-
# the snowflake backend would convert the lit datetime correctly. However, for returned column name, the
|
|
53
|
-
# timezone needs to be added. Pyspark has a weird behavior that datetime.datetime always gets converted
|
|
54
|
-
# to local timezone before printing according to spark_sql_session_timeZone setting. Haven't found
|
|
55
|
-
# official doc about it, but this behavior is based on my testings.
|
|
56
|
-
tz = (
|
|
57
|
-
ZoneInfo(global_config.spark_sql_session_timeZone)
|
|
58
|
-
if hasattr(global_config, "spark_sql_session_timeZone")
|
|
59
|
-
else get_localzone()
|
|
60
|
-
)
|
|
51
|
+
local_tz = get_localzone()
|
|
61
52
|
if t == "timestamp":
|
|
62
53
|
microseconds = literal.timestamp
|
|
63
54
|
else:
|
|
@@ -66,11 +57,17 @@ def get_literal_field_and_name(literal: expressions_proto.Expression.Literal):
|
|
|
66
57
|
microseconds // 1_000_000
|
|
67
58
|
) + datetime.timedelta(microseconds=microseconds % 1_000_000)
|
|
68
59
|
tz_dt = datetime.datetime.fromtimestamp(
|
|
69
|
-
microseconds // 1_000_000, tz=
|
|
60
|
+
microseconds // 1_000_000, tz=local_tz
|
|
70
61
|
) + datetime.timedelta(microseconds=microseconds % 1_000_000)
|
|
71
62
|
if t == "timestamp_ntz":
|
|
72
63
|
lit_dt = lit_dt.astimezone(datetime.timezone.utc)
|
|
73
64
|
tz_dt = tz_dt.astimezone(datetime.timezone.utc)
|
|
65
|
+
elif not get_is_evaluating_sql():
|
|
66
|
+
config_tz = global_config.spark_sql_session_timeZone
|
|
67
|
+
config_tz = ZoneInfo(config_tz) if config_tz else local_tz
|
|
68
|
+
tz_dt = tz_dt.astimezone(config_tz)
|
|
69
|
+
lit_dt = lit_dt.astimezone(local_tz)
|
|
70
|
+
|
|
74
71
|
return lit_dt, f"{t.upper()} '{tz_dt.strftime('%Y-%m-%d %H:%M:%S')}'"
|
|
75
72
|
case "day_time_interval":
|
|
76
73
|
# TODO(SNOW-1920942): Snowflake SQL is missing an "interval" type.
|
|
@@ -127,10 +127,11 @@ def map_cast(
|
|
|
127
127
|
from_type = StringType()
|
|
128
128
|
if isinstance(to_type, StringType):
|
|
129
129
|
to_type = StringType()
|
|
130
|
+
|
|
131
|
+
# todo - verify if that's correct SNOW-2248680
|
|
130
132
|
if isinstance(from_type, TimestampType):
|
|
131
133
|
from_type = TimestampType()
|
|
132
|
-
|
|
133
|
-
to_type = TimestampType()
|
|
134
|
+
|
|
134
135
|
match (from_type, to_type):
|
|
135
136
|
case (_, _) if (from_type == to_type):
|
|
136
137
|
result_exp = col
|
|
@@ -185,6 +186,17 @@ def map_cast(
|
|
|
185
186
|
case (DateType(), TimestampType()):
|
|
186
187
|
result_exp = snowpark_fn.to_timestamp(col)
|
|
187
188
|
result_exp = result_exp.cast(TimestampType(TimestampTimeZone.NTZ))
|
|
189
|
+
case (TimestampType() as f, TimestampType() as t) if f.tzinfo == t.tzinfo:
|
|
190
|
+
result_exp = col
|
|
191
|
+
case (
|
|
192
|
+
TimestampType(),
|
|
193
|
+
TimestampType() as t,
|
|
194
|
+
) if t.tzinfo == TimestampTimeZone.NTZ:
|
|
195
|
+
zone = global_config.spark_sql_session_timeZone
|
|
196
|
+
result_exp = snowpark_fn.convert_timezone(snowpark_fn.lit(zone), col).cast(
|
|
197
|
+
TimestampType(TimestampTimeZone.NTZ)
|
|
198
|
+
)
|
|
199
|
+
# todo: verify if more support for LTZ and TZ is needed - SNOW-2248680
|
|
188
200
|
case (TimestampType(), TimestampType()):
|
|
189
201
|
result_exp = col
|
|
190
202
|
case (_, TimestampType()) if isinstance(from_type, _NumericType):
|
|
@@ -259,8 +271,12 @@ def map_cast(
|
|
|
259
271
|
case (_, _) if isinstance(from_type, _FractionalType) and isinstance(
|
|
260
272
|
to_type, _IntegralType
|
|
261
273
|
):
|
|
262
|
-
result_exp =
|
|
263
|
-
snowpark_fn.
|
|
274
|
+
result_exp = (
|
|
275
|
+
snowpark_fn.when(
|
|
276
|
+
col == snowpark_fn.lit(float("nan")), snowpark_fn.lit(0)
|
|
277
|
+
)
|
|
278
|
+
.when(col < 0, snowpark_fn.ceil(col))
|
|
279
|
+
.otherwise(snowpark_fn.floor(col))
|
|
264
280
|
)
|
|
265
281
|
result_exp = result_exp.cast(to_type)
|
|
266
282
|
case (StringType(), _) if (isinstance(to_type, _IntegralType)):
|
|
@@ -11,6 +11,7 @@ import snowflake.snowpark.functions as snowpark_fn
|
|
|
11
11
|
from snowflake import snowpark
|
|
12
12
|
from snowflake.snowpark import Session
|
|
13
13
|
from snowflake.snowpark._internal.analyzer.expression import UnresolvedAttribute
|
|
14
|
+
from snowflake.snowpark.types import TimestampTimeZone, TimestampType
|
|
14
15
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
15
16
|
from snowflake.snowpark_connect.expression import (
|
|
16
17
|
map_extension,
|
|
@@ -190,9 +191,15 @@ def map_expression(
|
|
|
190
191
|
return [lit_name], TypedColumn(
|
|
191
192
|
snowpark_fn.lit(lit_value, return_type), lambda: [return_type]
|
|
192
193
|
)
|
|
194
|
+
result_exp = snowpark_fn.lit(lit_value)
|
|
195
|
+
|
|
196
|
+
if lit_type_str == "timestamp_ntz" and isinstance(
|
|
197
|
+
lit_value, datetime.datetime
|
|
198
|
+
):
|
|
199
|
+
result_exp = result_exp.cast(TimestampType(TimestampTimeZone.NTZ))
|
|
193
200
|
|
|
194
201
|
return [lit_name], TypedColumn(
|
|
195
|
-
|
|
202
|
+
result_exp, lambda: [map_simple_types(lit_type_str)]
|
|
196
203
|
)
|
|
197
204
|
case "sort_order":
|
|
198
205
|
child_name, child_column = map_single_column_expression(
|
|
@@ -13,10 +13,7 @@ from snowflake.snowpark_connect.config import global_config
|
|
|
13
13
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
14
14
|
from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
|
|
15
15
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
16
|
-
from snowflake.snowpark_connect.utils.session import
|
|
17
|
-
get_or_create_snowpark_session,
|
|
18
|
-
get_python_udxf_import_files,
|
|
19
|
-
)
|
|
16
|
+
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
20
17
|
from snowflake.snowpark_connect.utils.udf_helper import (
|
|
21
18
|
SnowparkUDF,
|
|
22
19
|
gen_input_types,
|
|
@@ -28,6 +25,9 @@ from snowflake.snowpark_connect.utils.udf_helper import (
|
|
|
28
25
|
from snowflake.snowpark_connect.utils.udf_utils import (
|
|
29
26
|
ProcessCommonInlineUserDefinedFunction,
|
|
30
27
|
)
|
|
28
|
+
from snowflake.snowpark_connect.utils.udxf_import_utils import (
|
|
29
|
+
get_python_udxf_import_files,
|
|
30
|
+
)
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
def process_udf_return_type(
|
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
6
6
|
|
|
7
7
|
import snowflake.snowpark.functions as snowpark_fn
|
|
8
|
+
from snowflake.snowpark._internal.analyzer.expression import Literal
|
|
8
9
|
from snowflake.snowpark.types import ArrayType, MapType, StructType, _IntegralType
|
|
9
10
|
from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
10
11
|
from snowflake.snowpark_connect.config import global_config
|
|
@@ -57,7 +58,8 @@ def map_unresolved_extract_value(
|
|
|
57
58
|
extract_fn = snowpark_fn.get_ignore_case
|
|
58
59
|
# Set index to a dummy value before we use it later in the ansi mode check.
|
|
59
60
|
index = snowpark_fn.lit(1)
|
|
60
|
-
|
|
61
|
+
is_array = _check_if_array_type(extract_typed_column, child_typed_column)
|
|
62
|
+
if is_array:
|
|
61
63
|
# Set all non-valid array indices to NULL.
|
|
62
64
|
# This is done because both conditions of a CASE WHEN statement are executed regardless of if the condition is true or not.
|
|
63
65
|
# Getting a negative index in Snowflake throws an error; thus, we convert all non-valid array indices to NULL before getting the index.
|
|
@@ -74,12 +76,37 @@ def map_unresolved_extract_value(
|
|
|
74
76
|
|
|
75
77
|
spark_sql_ansi_enabled = global_config.spark_sql_ansi_enabled
|
|
76
78
|
|
|
77
|
-
if spark_sql_ansi_enabled and
|
|
78
|
-
extract_typed_column, child_typed_column
|
|
79
|
-
):
|
|
79
|
+
if spark_sql_ansi_enabled and is_array:
|
|
80
80
|
result_exp = snowpark_fn.when(
|
|
81
81
|
index.isNull(),
|
|
82
82
|
child_typed_column.col.getItem("[snowpark_connect::INVALID_ARRAY_INDEX]"),
|
|
83
83
|
).otherwise(result_exp)
|
|
84
84
|
|
|
85
|
-
|
|
85
|
+
def _get_extracted_value_type():
|
|
86
|
+
if is_array:
|
|
87
|
+
return [child_typed_column.typ.element_type]
|
|
88
|
+
elif isinstance(child_typed_column.typ, MapType):
|
|
89
|
+
return [child_typed_column.typ.value_type]
|
|
90
|
+
elif (
|
|
91
|
+
isinstance(child_typed_column.typ, StructType)
|
|
92
|
+
and isinstance(extract_typed_column.col._expr1, Literal)
|
|
93
|
+
and isinstance(extract_typed_column.col._expr1.value, str)
|
|
94
|
+
):
|
|
95
|
+
struct = dict(
|
|
96
|
+
{
|
|
97
|
+
(
|
|
98
|
+
f.name
|
|
99
|
+
if global_config.spark_sql_caseSensitive
|
|
100
|
+
else f.name.lower(),
|
|
101
|
+
f.datatype,
|
|
102
|
+
)
|
|
103
|
+
for f in child_typed_column.typ.fields
|
|
104
|
+
}
|
|
105
|
+
)
|
|
106
|
+
key = extract_typed_column.col._expr1.value
|
|
107
|
+
key = key if global_config.spark_sql_caseSensitive else key.lower()
|
|
108
|
+
|
|
109
|
+
return [struct[key]] if key in struct else typer.type(result_exp)
|
|
110
|
+
return typer.type(result_exp)
|
|
111
|
+
|
|
112
|
+
return spark_function_name, TypedColumn(result_exp, _get_extracted_value_type)
|