snowpark-connect 0.20.2__py3-none-any.whl → 0.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +3 -2
- snowflake/snowpark_connect/column_name_handler.py +6 -65
- snowflake/snowpark_connect/config.py +47 -17
- snowflake/snowpark_connect/dataframe_container.py +242 -0
- snowflake/snowpark_connect/error/error_utils.py +25 -0
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +13 -23
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +9 -5
- snowflake/snowpark_connect/expression/map_extension.py +2 -1
- snowflake/snowpark_connect/expression/map_udf.py +4 -4
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +8 -7
- snowflake/snowpark_connect/expression/map_unresolved_function.py +481 -170
- snowflake/snowpark_connect/expression/map_unresolved_star.py +8 -8
- snowflake/snowpark_connect/expression/map_update_fields.py +1 -1
- snowflake/snowpark_connect/expression/typer.py +6 -6
- snowflake/snowpark_connect/proto/control_pb2.py +17 -16
- snowflake/snowpark_connect/proto/control_pb2.pyi +17 -17
- snowflake/snowpark_connect/proto/control_pb2_grpc.py +12 -63
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +15 -14
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +19 -14
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +27 -26
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +74 -68
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +5 -5
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +25 -17
- snowflake/snowpark_connect/relation/map_aggregate.py +170 -61
- snowflake/snowpark_connect/relation/map_catalog.py +2 -2
- snowflake/snowpark_connect/relation/map_column_ops.py +227 -145
- snowflake/snowpark_connect/relation/map_crosstab.py +25 -6
- snowflake/snowpark_connect/relation/map_extension.py +81 -56
- snowflake/snowpark_connect/relation/map_join.py +72 -63
- snowflake/snowpark_connect/relation/map_local_relation.py +35 -20
- snowflake/snowpark_connect/relation/map_map_partitions.py +24 -17
- snowflake/snowpark_connect/relation/map_relation.py +22 -16
- snowflake/snowpark_connect/relation/map_row_ops.py +232 -146
- snowflake/snowpark_connect/relation/map_sample_by.py +15 -8
- snowflake/snowpark_connect/relation/map_show_string.py +42 -5
- snowflake/snowpark_connect/relation/map_sql.py +141 -237
- snowflake/snowpark_connect/relation/map_stats.py +88 -39
- snowflake/snowpark_connect/relation/map_subquery_alias.py +13 -14
- snowflake/snowpark_connect/relation/map_udtf.py +10 -13
- snowflake/snowpark_connect/relation/read/map_read.py +8 -3
- snowflake/snowpark_connect/relation/read/map_read_csv.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_json.py +19 -8
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -7
- snowflake/snowpark_connect/relation/read/map_read_socket.py +7 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +25 -16
- snowflake/snowpark_connect/relation/read/map_read_text.py +7 -7
- snowflake/snowpark_connect/relation/read/reader_config.py +1 -0
- snowflake/snowpark_connect/relation/utils.py +11 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +15 -12
- snowflake/snowpark_connect/relation/write/map_write.py +259 -56
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +3 -2
- snowflake/snowpark_connect/server.py +43 -4
- snowflake/snowpark_connect/type_mapping.py +6 -23
- snowflake/snowpark_connect/utils/cache.py +27 -22
- snowflake/snowpark_connect/utils/context.py +33 -17
- snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
- snowflake/snowpark_connect/utils/{attribute_handling.py → identifiers.py} +47 -0
- snowflake/snowpark_connect/utils/session.py +41 -38
- snowflake/snowpark_connect/utils/telemetry.py +214 -63
- snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/__init__.py +0 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
- snowflake/snowpark_decoder/dp_session.py +111 -0
- snowflake/snowpark_decoder/spark_decoder.py +76 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/METADATA +6 -4
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/RECORD +83 -69
- snowpark_connect-0.22.1.dist-info/licenses/LICENSE-binary +568 -0
- snowpark_connect-0.22.1.dist-info/licenses/NOTICE-binary +1533 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/top_level.txt +1 -0
- spark/__init__.py +0 -0
- spark/connect/__init__.py +0 -0
- spark/connect/envelope_pb2.py +31 -0
- spark/connect/envelope_pb2.pyi +46 -0
- snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
- {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.20.2.data → snowpark_connect-0.22.1.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.20.2.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -8,53 +8,58 @@ from typing import Dict, Tuple
|
|
|
8
8
|
|
|
9
9
|
import pandas
|
|
10
10
|
|
|
11
|
-
from snowflake import
|
|
12
|
-
from snowflake.snowpark_connect.column_name_handler import set_schema_getter
|
|
11
|
+
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
13
12
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
14
13
|
|
|
15
|
-
# global cache mapping (sessionID, planID) -> cached snowpark
|
|
16
|
-
df_cache_map: Dict[Tuple[str, any],
|
|
14
|
+
# global cache mapping (sessionID, planID) -> cached snowpark dataframe container.
|
|
15
|
+
df_cache_map: Dict[Tuple[str, any], DataFrameContainer] = {}
|
|
17
16
|
|
|
18
17
|
# reentrant lock for thread safety
|
|
19
18
|
_cache_map_lock = threading.RLock()
|
|
20
19
|
|
|
21
20
|
|
|
22
|
-
def df_cache_map_get(key: Tuple[str, any]) ->
|
|
21
|
+
def df_cache_map_get(key: Tuple[str, any]) -> DataFrameContainer | None:
|
|
23
22
|
with _cache_map_lock:
|
|
24
23
|
return df_cache_map.get(key)
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
def df_cache_map_put_if_absent(
|
|
28
27
|
key: Tuple[str, any],
|
|
29
|
-
compute_fn: Callable[[],
|
|
28
|
+
compute_fn: Callable[[], DataFrameContainer | pandas.DataFrame],
|
|
30
29
|
materialize: bool,
|
|
31
|
-
) ->
|
|
30
|
+
) -> DataFrameContainer | pandas.DataFrame:
|
|
32
31
|
"""
|
|
33
|
-
Put a DataFrame into the cache map if the key is absent. Optionally, as side effect, materialize
|
|
32
|
+
Put a DataFrame container into the cache map if the key is absent. Optionally, as side effect, materialize
|
|
34
33
|
the DataFrame content in a temporary table.
|
|
35
34
|
|
|
36
35
|
Args:
|
|
37
36
|
key (Tuple[str, int]): The key to insert into the cache map (session_id, plan_id).
|
|
38
|
-
compute_fn (Callable[[], DataFrame]): A function to compute the DataFrame if the key is absent.
|
|
37
|
+
compute_fn (Callable[[], DataFrameContainer | pandas.DataFrame]): A function to compute the DataFrame container if the key is absent.
|
|
39
38
|
materialize (bool): Whether to materialize the DataFrame.
|
|
40
39
|
|
|
41
40
|
Returns:
|
|
42
|
-
|
|
41
|
+
DataFrameContainer | pandas.DataFrame: The cached or newly computed DataFrame container.
|
|
43
42
|
"""
|
|
44
43
|
|
|
45
|
-
def _object_to_cache(
|
|
44
|
+
def _object_to_cache(
|
|
45
|
+
container: DataFrameContainer,
|
|
46
|
+
) -> DataFrameContainer:
|
|
47
|
+
|
|
46
48
|
if materialize:
|
|
49
|
+
df = container.dataframe
|
|
47
50
|
cached_result = df.cache_result()
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
51
|
+
return DataFrameContainer(
|
|
52
|
+
dataframe=cached_result,
|
|
53
|
+
column_map=container.column_map,
|
|
54
|
+
table_name=container.table_name,
|
|
55
|
+
alias=container.alias,
|
|
56
|
+
cached_schema_getter=lambda: df.schema,
|
|
57
|
+
)
|
|
58
|
+
return container
|
|
54
59
|
|
|
55
60
|
with _cache_map_lock:
|
|
56
61
|
if key not in df_cache_map:
|
|
57
|
-
|
|
62
|
+
result = compute_fn()
|
|
58
63
|
|
|
59
64
|
# check cache again, since recursive call in compute_fn could've already cached the result.
|
|
60
65
|
# we want return it, instead of saving it again. This is important if materialize = True
|
|
@@ -62,19 +67,19 @@ def df_cache_map_put_if_absent(
|
|
|
62
67
|
if key in df_cache_map:
|
|
63
68
|
return df_cache_map[key]
|
|
64
69
|
|
|
65
|
-
# only cache
|
|
70
|
+
# only cache DataFrameContainer, but not pandas result.
|
|
66
71
|
# Pandas result is only returned when df.show() is called, where we convert
|
|
67
72
|
# a dataframe to a string representation.
|
|
68
73
|
# We don't expect map_relation would return pandas df here because that would
|
|
69
74
|
# be equivalent to calling df.show().cache(), which is not allowed.
|
|
70
|
-
if isinstance(
|
|
71
|
-
df_cache_map[key] = _object_to_cache(
|
|
75
|
+
if isinstance(result, DataFrameContainer):
|
|
76
|
+
df_cache_map[key] = _object_to_cache(result)
|
|
72
77
|
else:
|
|
73
78
|
# This is not expected, but we will just log a warning
|
|
74
79
|
logger.warning(
|
|
75
80
|
"Unexpected pandas dataframe returned for caching. Ignoring the cache call."
|
|
76
81
|
)
|
|
77
|
-
return
|
|
82
|
+
return result
|
|
78
83
|
|
|
79
84
|
return df_cache_map[key]
|
|
80
85
|
|
|
@@ -9,14 +9,14 @@ from typing import Mapping, Optional
|
|
|
9
9
|
|
|
10
10
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
11
11
|
|
|
12
|
-
from snowflake import
|
|
12
|
+
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
13
13
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
14
14
|
|
|
15
15
|
# TODO: remove session id from context when we host SAS in Snowflake server
|
|
16
16
|
|
|
17
17
|
_session_id = ContextVar[str]("_session_id")
|
|
18
|
-
_plan_id_map = ContextVar[Mapping[int,
|
|
19
|
-
_alias_map = ContextVar[Mapping[str,
|
|
18
|
+
_plan_id_map = ContextVar[Mapping[int, DataFrameContainer]]("_plan_id_map")
|
|
19
|
+
_alias_map = ContextVar[Mapping[str, DataFrameContainer | None]]("_alias_map")
|
|
20
20
|
_spark_version = ContextVar[str]("_spark_version")
|
|
21
21
|
_is_aggregate_function = ContextVar(
|
|
22
22
|
"_is_aggregate_function", default=("default", False)
|
|
@@ -40,10 +40,10 @@ _sql_named_args = ContextVar[dict[str, expressions_proto.Expression]]("_sql_name
|
|
|
40
40
|
_sql_pos_args = ContextVar[dict[int, expressions_proto.Expression]]("_sql_pos_args")
|
|
41
41
|
|
|
42
42
|
# Used to store the df before the last projection operation
|
|
43
|
-
_df_before_projection = ContextVar[
|
|
43
|
+
_df_before_projection = ContextVar[DataFrameContainer | None](
|
|
44
44
|
"_df_before_projection", default=None
|
|
45
45
|
)
|
|
46
|
-
_outer_dataframes = ContextVar[list[
|
|
46
|
+
_outer_dataframes = ContextVar[list[DataFrameContainer]]("_parent_dataframes")
|
|
47
47
|
|
|
48
48
|
_spark_client_type_regex = re.compile(r"spark/(?P<spark_version>\d+\.\d+\.\d+)")
|
|
49
49
|
_current_operation = ContextVar[str]("_current_operation", default="default")
|
|
@@ -66,6 +66,12 @@ _lca_alias_map: ContextVar[dict[str, TypedColumn]] = ContextVar(
|
|
|
66
66
|
default={},
|
|
67
67
|
)
|
|
68
68
|
|
|
69
|
+
# Context variable to track current grouping columns for grouping_id() function
|
|
70
|
+
_current_grouping_columns: ContextVar[list[str]] = ContextVar(
|
|
71
|
+
"_current_grouping_columns",
|
|
72
|
+
default=[],
|
|
73
|
+
)
|
|
74
|
+
|
|
69
75
|
|
|
70
76
|
def clear_lca_alias_map() -> None:
|
|
71
77
|
_lca_alias_map.set({})
|
|
@@ -87,6 +93,16 @@ def resolve_lca_alias(name: str) -> Optional[TypedColumn]:
|
|
|
87
93
|
return _lca_alias_map.get().get(_normalize(name))
|
|
88
94
|
|
|
89
95
|
|
|
96
|
+
def set_current_grouping_columns(columns: list[str]) -> None:
|
|
97
|
+
"""Set the current grouping columns for grouping_id() function."""
|
|
98
|
+
_current_grouping_columns.set(columns)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_current_grouping_columns() -> list[str]:
|
|
102
|
+
"""Get the current grouping columns for grouping_id() function."""
|
|
103
|
+
return _current_grouping_columns.get()
|
|
104
|
+
|
|
105
|
+
|
|
90
106
|
def set_session_id(value: str) -> None:
|
|
91
107
|
"""Set the session ID for the current context"""
|
|
92
108
|
_session_id.set(value)
|
|
@@ -97,13 +113,13 @@ def get_session_id() -> str:
|
|
|
97
113
|
return _session_id.get(None)
|
|
98
114
|
|
|
99
115
|
|
|
100
|
-
def set_plan_id_map(plan_id: int,
|
|
116
|
+
def set_plan_id_map(plan_id: int, container: DataFrameContainer) -> None:
|
|
101
117
|
"""Set the plan id map for the current context."""
|
|
102
|
-
_plan_id_map.get()[plan_id] =
|
|
118
|
+
_plan_id_map.get()[plan_id] = container
|
|
103
119
|
|
|
104
120
|
|
|
105
|
-
def get_plan_id_map(plan_id: int) ->
|
|
106
|
-
"""
|
|
121
|
+
def get_plan_id_map(plan_id: int) -> DataFrameContainer | None:
|
|
122
|
+
"""Get the plan id map for the current context."""
|
|
107
123
|
return _plan_id_map.get().get(plan_id)
|
|
108
124
|
|
|
109
125
|
|
|
@@ -295,30 +311,30 @@ def get_sql_pos_arg(pos: int) -> expressions_proto.Expression:
|
|
|
295
311
|
return _sql_pos_args.get()[pos]
|
|
296
312
|
|
|
297
313
|
|
|
298
|
-
def set_df_before_projection(df:
|
|
314
|
+
def set_df_before_projection(df: DataFrameContainer | None) -> None:
|
|
299
315
|
"""
|
|
300
|
-
Sets the current DataFrame in the context.
|
|
301
|
-
This is used to track the DataFrame in the current context.
|
|
316
|
+
Sets the current DataFrame container in the context.
|
|
317
|
+
This is used to track the DataFrame container in the current context.
|
|
302
318
|
"""
|
|
303
319
|
_df_before_projection.set(df)
|
|
304
320
|
|
|
305
321
|
|
|
306
|
-
def get_df_before_projection() ->
|
|
322
|
+
def get_df_before_projection() -> DataFrameContainer | None:
|
|
307
323
|
"""
|
|
308
|
-
Returns the current DataFrame if set, otherwise None.
|
|
309
|
-
This is used to track the DataFrame in the current context.
|
|
324
|
+
Returns the current DataFrame container if set, otherwise None.
|
|
325
|
+
This is used to track the DataFrame container in the current context.
|
|
310
326
|
"""
|
|
311
327
|
return _df_before_projection.get()
|
|
312
328
|
|
|
313
329
|
|
|
314
330
|
@contextmanager
|
|
315
|
-
def push_outer_dataframe(df:
|
|
331
|
+
def push_outer_dataframe(df: DataFrameContainer):
|
|
316
332
|
_outer_dataframes.get().append(df)
|
|
317
333
|
yield
|
|
318
334
|
_outer_dataframes.get().pop()
|
|
319
335
|
|
|
320
336
|
|
|
321
|
-
def get_outer_dataframes() -> list[
|
|
337
|
+
def get_outer_dataframes() -> list[DataFrameContainer]:
|
|
322
338
|
return _outer_dataframes.get()
|
|
323
339
|
|
|
324
340
|
|
|
@@ -131,21 +131,14 @@ def instrument_session_for_describe_cache(session: snowpark.Session):
|
|
|
131
131
|
logger.debug(f"DDL detected, clearing describe query cache: '{query}'")
|
|
132
132
|
cache.clear()
|
|
133
133
|
|
|
134
|
-
def report_query(qid: str, is_internal: bool) -> None:
|
|
135
|
-
if is_internal:
|
|
136
|
-
telemetry.report_internal_query()
|
|
137
|
-
elif qid:
|
|
138
|
-
telemetry.report_query_id(qid)
|
|
139
|
-
|
|
140
134
|
def wrap_execute(wrapped_fn):
|
|
141
135
|
def fn(query: str, **kwargs):
|
|
142
136
|
update_cache_for_query(query)
|
|
143
|
-
is_internal = kwargs.get("_is_internal", False)
|
|
144
137
|
try:
|
|
145
138
|
result = wrapped_fn(query, **kwargs)
|
|
146
|
-
report_query(result
|
|
139
|
+
telemetry.report_query(result, **kwargs)
|
|
147
140
|
except Exception as e:
|
|
148
|
-
report_query(e
|
|
141
|
+
telemetry.report_query(e, **kwargs)
|
|
149
142
|
raise e
|
|
150
143
|
return result
|
|
151
144
|
|
|
@@ -1,6 +1,53 @@
|
|
|
1
1
|
#
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
|
+
import re
|
|
5
|
+
|
|
6
|
+
from pyspark.errors import AnalysisException
|
|
7
|
+
|
|
8
|
+
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
9
|
+
quote_name_without_upper_casing,
|
|
10
|
+
)
|
|
11
|
+
from snowflake.snowpark_connect.config import (
|
|
12
|
+
auto_uppercase_column_identifiers,
|
|
13
|
+
auto_uppercase_non_column_identifiers,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
QUOTED_SPARK_IDENTIFIER = re.compile(r"^`[^`]*(?:``[^`]*)*`$")
|
|
17
|
+
UNQUOTED_SPARK_IDENTIFIER = re.compile(r"^\w+$")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def unquote_spark_identifier_if_quoted(spark_name: str) -> str:
|
|
21
|
+
if UNQUOTED_SPARK_IDENTIFIER.match(spark_name):
|
|
22
|
+
return spark_name
|
|
23
|
+
|
|
24
|
+
if QUOTED_SPARK_IDENTIFIER.match(spark_name):
|
|
25
|
+
return spark_name[1:-1].replace("``", "`")
|
|
26
|
+
|
|
27
|
+
raise AnalysisException(f"Invalid name: {spark_name}")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def spark_to_sf_single_id_with_unquoting(name: str) -> str:
|
|
31
|
+
"""
|
|
32
|
+
Transforms a spark name to a valid snowflake name by quoting and potentially uppercasing it.
|
|
33
|
+
Unquotes the spark name if necessary. Will raise an AnalysisException if given name is not valid.
|
|
34
|
+
"""
|
|
35
|
+
return spark_to_sf_single_id(unquote_spark_identifier_if_quoted(name))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def spark_to_sf_single_id(name: str, is_column: bool = False) -> str:
|
|
39
|
+
"""
|
|
40
|
+
Transforms a spark name to a valid snowflake name by quoting and potentially uppercasing it.
|
|
41
|
+
Assumes that the given spark name doesn't contain quotes,
|
|
42
|
+
meaning it's either already unquoted, or didn't need quoting.
|
|
43
|
+
"""
|
|
44
|
+
name = quote_name_without_upper_casing(name)
|
|
45
|
+
should_uppercase = (
|
|
46
|
+
auto_uppercase_column_identifiers()
|
|
47
|
+
if is_column
|
|
48
|
+
else auto_uppercase_non_column_identifiers()
|
|
49
|
+
)
|
|
50
|
+
return name.upper() if should_uppercase else name
|
|
4
51
|
|
|
5
52
|
|
|
6
53
|
def split_fully_qualified_spark_name(qualified_name: str | None) -> list[str]:
|
|
@@ -5,11 +5,10 @@
|
|
|
5
5
|
import logging
|
|
6
6
|
import os
|
|
7
7
|
from collections.abc import Sequence
|
|
8
|
-
from contextlib import contextmanager
|
|
9
8
|
from typing import Any
|
|
10
9
|
|
|
11
10
|
from snowflake import snowpark
|
|
12
|
-
from snowflake.snowpark.exceptions import SnowparkClientException
|
|
11
|
+
from snowflake.snowpark.exceptions import SnowparkClientException, SnowparkSQLException
|
|
13
12
|
from snowflake.snowpark.session import _get_active_session
|
|
14
13
|
from snowflake.snowpark_connect.constants import DEFAULT_CONNECTION_NAME
|
|
15
14
|
from snowflake.snowpark_connect.utils.describe_query_cache import (
|
|
@@ -20,6 +19,18 @@ from snowflake.snowpark_connect.utils.telemetry import telemetry
|
|
|
20
19
|
from snowflake.snowpark_connect.utils.udf_cache import init_builtin_udf_cache
|
|
21
20
|
|
|
22
21
|
|
|
22
|
+
# Suppress experimental warnings from snowflake.snowpark logger
|
|
23
|
+
def _filter_experimental_warnings(record):
|
|
24
|
+
"""Filter function to suppress experimental warnings."""
|
|
25
|
+
message = record.getMessage()
|
|
26
|
+
return not (
|
|
27
|
+
"is experimental since" in message and "Do not use it in production" in message
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
logging.getLogger("snowflake.snowpark").addFilter(_filter_experimental_warnings)
|
|
32
|
+
|
|
33
|
+
|
|
23
34
|
def _get_current_snowpark_session() -> snowpark.Session | None:
|
|
24
35
|
# TODO: this is a temporary solution to get the current session, it would be better to add a function in snowpark
|
|
25
36
|
try:
|
|
@@ -34,33 +45,6 @@ def _get_current_snowpark_session() -> snowpark.Session | None:
|
|
|
34
45
|
raise
|
|
35
46
|
|
|
36
47
|
|
|
37
|
-
@contextmanager
|
|
38
|
-
def suppress_experimental_warnings():
|
|
39
|
-
"""
|
|
40
|
-
Suppress experimental parameter warnings from snowpark logging.
|
|
41
|
-
|
|
42
|
-
This context manager filters out logging messages containing
|
|
43
|
-
"is experimental since" and "Do not use it in production"
|
|
44
|
-
from the snowpark logger, while preserving other important warnings.
|
|
45
|
-
"""
|
|
46
|
-
snowpark_logger = logging.getLogger("snowflake.snowpark")
|
|
47
|
-
|
|
48
|
-
def filter_experimental_warnings(record):
|
|
49
|
-
"""Filter function to suppress experimental parameter warnings."""
|
|
50
|
-
message = record.getMessage()
|
|
51
|
-
return not (
|
|
52
|
-
"is experimental since" in message
|
|
53
|
-
and "Do not use it in production" in message
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
snowpark_logger.addFilter(filter_experimental_warnings)
|
|
57
|
-
|
|
58
|
-
try:
|
|
59
|
-
yield
|
|
60
|
-
finally:
|
|
61
|
-
snowpark_logger.removeFilter(filter_experimental_warnings)
|
|
62
|
-
|
|
63
|
-
|
|
64
48
|
def configure_snowpark_session(session: snowpark.Session):
|
|
65
49
|
"""Configure a snowpark session with required parameters and settings."""
|
|
66
50
|
from snowflake.snowpark_connect.config import global_config
|
|
@@ -80,11 +64,10 @@ def configure_snowpark_session(session: snowpark.Session):
|
|
|
80
64
|
# built-in udf cache
|
|
81
65
|
init_builtin_udf_cache(session)
|
|
82
66
|
|
|
83
|
-
# Set experimental parameters
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
session.reduce_describe_query_enabled = True
|
|
67
|
+
# Set experimental parameters (warnings globally suppressed)
|
|
68
|
+
session.ast_enabled = False
|
|
69
|
+
session.eliminate_numeric_sql_value_cast_enabled = False
|
|
70
|
+
session.reduce_describe_query_enabled = True
|
|
88
71
|
|
|
89
72
|
session._join_alias_fix = True
|
|
90
73
|
session.connection.arrow_number_to_decimal_setter = True
|
|
@@ -101,6 +84,30 @@ def configure_snowpark_session(session: snowpark.Session):
|
|
|
101
84
|
session.sql(
|
|
102
85
|
f"ALTER SESSION SET {', '.join([f'{k} = {v}' for k, v in session_params.items()])}"
|
|
103
86
|
).collect()
|
|
87
|
+
|
|
88
|
+
# Rolling ahead in preparation of GS release 9.22 (ETA 8/5/2025). Once 9.22 is past rollback risk, merge this
|
|
89
|
+
# parameter with other in the session_params dictionary above
|
|
90
|
+
try:
|
|
91
|
+
session.sql(
|
|
92
|
+
"ALTER SESSION SET ENABLE_STRUCTURED_TYPES_IN_SNOWPARK_CONNECT_RESPONSE=true"
|
|
93
|
+
).collect()
|
|
94
|
+
except SnowparkSQLException:
|
|
95
|
+
logger.debug(
|
|
96
|
+
"ENABLE_STRUCTURED_TYPES_IN_SNOWPARK_CONNECT_RESPONSE is not defined"
|
|
97
|
+
)
|
|
98
|
+
try:
|
|
99
|
+
session.sql(
|
|
100
|
+
"ALTER SESSION SET ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT=true"
|
|
101
|
+
).collect()
|
|
102
|
+
except SnowparkSQLException:
|
|
103
|
+
logger.debug("ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT is not defined")
|
|
104
|
+
try:
|
|
105
|
+
session.sql(
|
|
106
|
+
"ALTER SESSION SET ENABLE_STRUCTURED_TYPES_IN_CLIENT_RESPONSE=true"
|
|
107
|
+
).collect()
|
|
108
|
+
except SnowparkSQLException:
|
|
109
|
+
logger.debug("ENABLE_STRUCTURED_TYPES_IN_CLIENT_RESPONSE is not defined")
|
|
110
|
+
|
|
104
111
|
# Instrument the snowpark session to use a cache for describe queries.
|
|
105
112
|
instrument_session_for_describe_cache(session)
|
|
106
113
|
|
|
@@ -174,7 +181,3 @@ def set_query_tags(spark_tags: Sequence[str]) -> None:
|
|
|
174
181
|
|
|
175
182
|
if spark_tags_str != snowpark_session.query_tag:
|
|
176
183
|
snowpark_session.query_tag = spark_tags_str
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
def get_python_udxf_import_files(session: snowpark.Session) -> str:
|
|
180
|
-
return ",".join([file for file in [*session._python_files, *session._import_files]])
|