snowpark-connect 0.32.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/column_name_handler.py +91 -40
- snowflake/snowpark_connect/column_qualifier.py +0 -4
- snowflake/snowpark_connect/config.py +9 -0
- snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
- snowflake/snowpark_connect/expression/literal.py +12 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +18 -4
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +150 -29
- snowflake/snowpark_connect/expression/map_unresolved_function.py +93 -55
- snowflake/snowpark_connect/relation/map_aggregate.py +156 -257
- snowflake/snowpark_connect/relation/map_column_ops.py +19 -0
- snowflake/snowpark_connect/relation/map_join.py +454 -252
- snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
- snowflake/snowpark_connect/relation/map_sql.py +335 -90
- snowflake/snowpark_connect/relation/read/map_read.py +9 -1
- snowflake/snowpark_connect/relation/read/map_read_csv.py +19 -2
- snowflake/snowpark_connect/relation/read/map_read_json.py +90 -2
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +3 -0
- snowflake/snowpark_connect/relation/read/map_read_text.py +4 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
- snowflake/snowpark_connect/relation/read/utils.py +41 -0
- snowflake/snowpark_connect/relation/utils.py +50 -2
- snowflake/snowpark_connect/relation/write/map_write.py +251 -292
- snowflake/snowpark_connect/resources_initializer.py +25 -13
- snowflake/snowpark_connect/server.py +9 -24
- snowflake/snowpark_connect/type_mapping.py +2 -0
- snowflake/snowpark_connect/typed_column.py +2 -2
- snowflake/snowpark_connect/utils/context.py +0 -14
- snowflake/snowpark_connect/utils/expression_transformer.py +163 -0
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +4 -1
- snowflake/snowpark_connect/utils/udf_helper.py +1 -0
- snowflake/snowpark_connect/utils/udtf_helper.py +3 -0
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/METADATA +4 -2
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/RECORD +43 -104
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
#
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
import pathlib
|
|
5
4
|
import threading
|
|
6
5
|
import time
|
|
7
6
|
|
|
@@ -51,11 +50,9 @@ def initialize_resources() -> None:
|
|
|
51
50
|
"""Upload Spark jar files required for creating Scala UDFs."""
|
|
52
51
|
stage = session.get_session_stage()
|
|
53
52
|
resource_path = stage + RESOURCE_PATH
|
|
54
|
-
import
|
|
53
|
+
import snowpark_connect_deps_1
|
|
54
|
+
import snowpark_connect_deps_2
|
|
55
55
|
|
|
56
|
-
pyspark_jars = (
|
|
57
|
-
pathlib.Path(snowflake.snowpark_connect.__file__).parent / "includes/jars"
|
|
58
|
-
)
|
|
59
56
|
jar_files = [
|
|
60
57
|
f"spark-sql_2.12-{SPARK_VERSION}.jar",
|
|
61
58
|
f"spark-connect-client-jvm_2.12-{SPARK_VERSION}.jar",
|
|
@@ -64,14 +61,29 @@ def initialize_resources() -> None:
|
|
|
64
61
|
"json4s-ast_2.12-3.7.0-M11.jar",
|
|
65
62
|
]
|
|
66
63
|
|
|
67
|
-
for
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
64
|
+
for jar_name in jar_files:
|
|
65
|
+
# Try to find the JAR in package 1 first, then package 2
|
|
66
|
+
jar_path = None
|
|
67
|
+
try:
|
|
68
|
+
jar_path = snowpark_connect_deps_1.get_jar_path(jar_name)
|
|
69
|
+
except FileNotFoundError:
|
|
70
|
+
try:
|
|
71
|
+
jar_path = snowpark_connect_deps_2.get_jar_path(jar_name)
|
|
72
|
+
except FileNotFoundError:
|
|
73
|
+
raise FileNotFoundError(
|
|
74
|
+
f"JAR {jar_name} not found in either package"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
session.file.put(
|
|
79
|
+
str(jar_path),
|
|
80
|
+
resource_path,
|
|
81
|
+
auto_compress=False,
|
|
82
|
+
overwrite=False,
|
|
83
|
+
source_compression="NONE",
|
|
84
|
+
)
|
|
85
|
+
except Exception as e:
|
|
86
|
+
raise RuntimeError(f"Failed to upload JAR {jar_name}: {e}")
|
|
75
87
|
|
|
76
88
|
start_time = time.time()
|
|
77
89
|
|
|
@@ -24,12 +24,10 @@
|
|
|
24
24
|
import atexit
|
|
25
25
|
import logging
|
|
26
26
|
import os
|
|
27
|
-
import pathlib
|
|
28
27
|
import socket
|
|
29
28
|
import tempfile
|
|
30
29
|
import threading
|
|
31
30
|
import urllib.parse
|
|
32
|
-
import zipfile
|
|
33
31
|
from concurrent import futures
|
|
34
32
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
35
33
|
|
|
@@ -48,7 +46,6 @@ from pyspark.errors import PySparkValueError
|
|
|
48
46
|
from pyspark.sql.connect.client.core import ChannelBuilder
|
|
49
47
|
from pyspark.sql.connect.session import SparkSession
|
|
50
48
|
|
|
51
|
-
import snowflake.snowpark_connect
|
|
52
49
|
import snowflake.snowpark_connect.proto.control_pb2_grpc as control_grpc
|
|
53
50
|
import snowflake.snowpark_connect.tcm as tcm
|
|
54
51
|
from snowflake import snowpark
|
|
@@ -1032,28 +1029,16 @@ def start_jvm():
|
|
|
1032
1029
|
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
1033
1030
|
raise exception
|
|
1034
1031
|
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
if "dataframe_processor.zip" in str(pyspark_jars):
|
|
1040
|
-
# importlib.resource doesn't work when local stage package is used in TCM
|
|
1041
|
-
zip_path = pathlib.Path(
|
|
1042
|
-
snowflake.snowpark_connect.__file__
|
|
1043
|
-
).parent.parent.parent
|
|
1044
|
-
temp_dir = tempfile.gettempdir()
|
|
1045
|
-
|
|
1046
|
-
extract_folder = "snowflake/snowpark_connect/includes/jars/" # Folder to extract (must end with '/')
|
|
1032
|
+
# Import both JAR dependency packages
|
|
1033
|
+
import snowpark_connect_deps_1
|
|
1034
|
+
import snowpark_connect_deps_2
|
|
1047
1035
|
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
for path in pyspark_jars.glob("**/*.jar"):
|
|
1056
|
-
jpype.addClassPath(path)
|
|
1036
|
+
# Load all the jar files from both packages
|
|
1037
|
+
jar_path_list = (
|
|
1038
|
+
snowpark_connect_deps_1.list_jars() + snowpark_connect_deps_2.list_jars()
|
|
1039
|
+
)
|
|
1040
|
+
for jar_path in jar_path_list:
|
|
1041
|
+
jpype.addClassPath(jar_path)
|
|
1057
1042
|
|
|
1058
1043
|
# TODO: Should remove convertStrings, but it breaks the JDBC code.
|
|
1059
1044
|
jvm_settings: list[str] = list(
|
|
@@ -373,6 +373,8 @@ def cast_to_match_snowpark_type(
|
|
|
373
373
|
return str(content)
|
|
374
374
|
case snowpark.types.DayTimeIntervalType:
|
|
375
375
|
return str(content)
|
|
376
|
+
case snowpark.types.MapType:
|
|
377
|
+
return content
|
|
376
378
|
case _:
|
|
377
379
|
exception = SnowparkConnectNotImplementedError(
|
|
378
380
|
f"Unsupported snowpark data type in casting: {data_type}"
|
|
@@ -49,7 +49,7 @@ class TypedColumn:
|
|
|
49
49
|
self.qualifiers = qualifiers
|
|
50
50
|
|
|
51
51
|
def get_qualifiers(self) -> set[ColumnQualifier]:
|
|
52
|
-
return getattr(self, "qualifiers",
|
|
52
|
+
return getattr(self, "qualifiers", set())
|
|
53
53
|
|
|
54
54
|
def set_catalog_database_info(self, catalog_database_info: dict[str, str]) -> None:
|
|
55
55
|
self._catalog_database_info = catalog_database_info
|
|
@@ -70,7 +70,7 @@ class TypedColumn:
|
|
|
70
70
|
def get_multi_col_qualifiers(self, num_columns) -> list[set[ColumnQualifier]]:
|
|
71
71
|
if not hasattr(self, "multi_col_qualifiers"):
|
|
72
72
|
|
|
73
|
-
return [
|
|
73
|
+
return [set() for i in range(num_columns)]
|
|
74
74
|
assert (
|
|
75
75
|
len(self.multi_col_qualifiers) == num_columns
|
|
76
76
|
), f"Expected {num_columns} multi-column qualifiers, got {len(self.multi_col_qualifiers)}"
|
|
@@ -55,7 +55,6 @@ _resolving_lambda_fun = ContextVar[bool]("_resolving_lambdas", default=False)
|
|
|
55
55
|
_current_lambda_params = ContextVar[list[str]]("_current_lambda_params", default=[])
|
|
56
56
|
|
|
57
57
|
_is_window_enabled = ContextVar[bool]("_is_window_enabled", default=False)
|
|
58
|
-
_is_in_pivot = ContextVar[bool]("_is_in_pivot", default=False)
|
|
59
58
|
_is_in_udtf_context = ContextVar[bool]("_is_in_udtf_context", default=False)
|
|
60
59
|
_accessing_temp_object = ContextVar[bool]("_accessing_temp_object", default=False)
|
|
61
60
|
|
|
@@ -467,19 +466,6 @@ def is_window_enabled():
|
|
|
467
466
|
return _is_window_enabled.get()
|
|
468
467
|
|
|
469
468
|
|
|
470
|
-
@contextmanager
|
|
471
|
-
def temporary_pivot_expression(value: bool):
|
|
472
|
-
token = _is_in_pivot.set(value)
|
|
473
|
-
try:
|
|
474
|
-
yield
|
|
475
|
-
finally:
|
|
476
|
-
_is_in_pivot.reset(token)
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
def is_in_pivot() -> bool:
|
|
480
|
-
return _is_in_pivot.get()
|
|
481
|
-
|
|
482
|
-
|
|
483
469
|
def get_is_in_udtf_context() -> bool:
|
|
484
470
|
"""
|
|
485
471
|
Gets the value of _is_in_udtf_context for the current context, defaults to False.
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
from snowflake.snowpark import Column, functions as snowpark_fn
|
|
6
|
+
from snowflake.snowpark._internal.analyzer.expression import (
|
|
7
|
+
CaseWhen,
|
|
8
|
+
Expression,
|
|
9
|
+
FunctionExpression,
|
|
10
|
+
SnowflakeUDF,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
_SF_AGGREGATE_FUNCTIONS = [
|
|
14
|
+
"any_value",
|
|
15
|
+
"avg",
|
|
16
|
+
"corr",
|
|
17
|
+
"count",
|
|
18
|
+
"count_if",
|
|
19
|
+
"covar_pop",
|
|
20
|
+
"covar_samp",
|
|
21
|
+
"listagg",
|
|
22
|
+
"max",
|
|
23
|
+
"max_by",
|
|
24
|
+
"median",
|
|
25
|
+
"min",
|
|
26
|
+
"min_by",
|
|
27
|
+
"mode",
|
|
28
|
+
"percentile_cont",
|
|
29
|
+
"percentile_disc",
|
|
30
|
+
"stddev",
|
|
31
|
+
"stddev_samp",
|
|
32
|
+
"stddev_pop",
|
|
33
|
+
"sum",
|
|
34
|
+
"var_pop",
|
|
35
|
+
"var_samp",
|
|
36
|
+
"variance_pop",
|
|
37
|
+
"variance",
|
|
38
|
+
"variance_samp",
|
|
39
|
+
"bitand_agg",
|
|
40
|
+
"bitor_agg",
|
|
41
|
+
"bitxor_agg",
|
|
42
|
+
"booland_agg",
|
|
43
|
+
"boolor_agg",
|
|
44
|
+
"boolxor_agg",
|
|
45
|
+
"hash_agg",
|
|
46
|
+
"array_agg",
|
|
47
|
+
"object_agg",
|
|
48
|
+
"regr_avgx",
|
|
49
|
+
"regr_avgy",
|
|
50
|
+
"regr_count",
|
|
51
|
+
"regr_intercept",
|
|
52
|
+
"regr_r2",
|
|
53
|
+
"regr_slope",
|
|
54
|
+
"regr_sxx",
|
|
55
|
+
"regr_sxy",
|
|
56
|
+
"regr_syy",
|
|
57
|
+
"kurtosis",
|
|
58
|
+
"skew",
|
|
59
|
+
"array_union_agg",
|
|
60
|
+
"array_unique_agg",
|
|
61
|
+
"bitmap_bit_position",
|
|
62
|
+
"bitmap_bucket_number",
|
|
63
|
+
"bitmap_count",
|
|
64
|
+
"bitmap_construct_agg",
|
|
65
|
+
"bitmap_or_agg",
|
|
66
|
+
"approx_count_distinct",
|
|
67
|
+
"datasketches_hll",
|
|
68
|
+
"datasketches_hll_accumulate",
|
|
69
|
+
"datasketches_hll_combine",
|
|
70
|
+
"datasketches_hll_estimate",
|
|
71
|
+
"hll",
|
|
72
|
+
"hll_accumulate",
|
|
73
|
+
"hll_combine",
|
|
74
|
+
"hll_estimate",
|
|
75
|
+
"hll_export",
|
|
76
|
+
"hll_import",
|
|
77
|
+
"approximate_jaccard_index",
|
|
78
|
+
"approximate_similarity",
|
|
79
|
+
"minhash",
|
|
80
|
+
"minhash_combine",
|
|
81
|
+
"approx_top_k",
|
|
82
|
+
"approx_top_k_accumulate",
|
|
83
|
+
"approx_top_k_combine",
|
|
84
|
+
"approx_top_k_estimate",
|
|
85
|
+
"approx_percentile",
|
|
86
|
+
"approx_percentile_accumulate",
|
|
87
|
+
"approx_percentile_combine",
|
|
88
|
+
"approx_percentile_estimate",
|
|
89
|
+
"grouping",
|
|
90
|
+
"grouping_id",
|
|
91
|
+
"ai_agg",
|
|
92
|
+
"ai_summarize_agg",
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _is_agg_function_expression(expression: Expression) -> bool:
|
|
97
|
+
if (
|
|
98
|
+
isinstance(expression, FunctionExpression)
|
|
99
|
+
and expression.pretty_name.lower() in _SF_AGGREGATE_FUNCTIONS
|
|
100
|
+
):
|
|
101
|
+
return True
|
|
102
|
+
|
|
103
|
+
# For PySpark aggregate functions that were mapped using a UDAF, e.g. try_sum
|
|
104
|
+
if isinstance(expression, SnowflakeUDF) and expression.is_aggregate_function:
|
|
105
|
+
return True
|
|
106
|
+
|
|
107
|
+
return False
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _get_child_expressions(expression: Expression) -> list[Expression]:
|
|
111
|
+
if isinstance(expression, CaseWhen):
|
|
112
|
+
return expression._child_expressions
|
|
113
|
+
|
|
114
|
+
return expression.children or []
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def inject_condition_to_all_agg_functions(
|
|
118
|
+
expression: Expression, condition: Column
|
|
119
|
+
) -> None:
|
|
120
|
+
"""
|
|
121
|
+
Recursively traverses an expression tree and wraps all aggregate function arguments with a CASE WHEN condition.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
expression: The Snowpark expression tree to traverse and modify.
|
|
125
|
+
condition: The Column condition to inject into aggregate function arguments.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
any_agg_function_found = _inject_condition_to_all_agg_functions(
|
|
129
|
+
expression, condition
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
if not any_agg_function_found:
|
|
133
|
+
raise ValueError(f"No aggregate functions found in: {expression.sql}")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _inject_condition_to_all_agg_functions(
|
|
137
|
+
expression: Expression, condition: Column
|
|
138
|
+
) -> bool:
|
|
139
|
+
any_agg_function_found = False
|
|
140
|
+
|
|
141
|
+
if _is_agg_function_expression(expression):
|
|
142
|
+
new_children = []
|
|
143
|
+
for child in _get_child_expressions(expression):
|
|
144
|
+
case_when = snowpark_fn.when(condition, Column(child))
|
|
145
|
+
|
|
146
|
+
new_children.append(case_when._expr1)
|
|
147
|
+
|
|
148
|
+
# Swap children
|
|
149
|
+
expression.children = new_children
|
|
150
|
+
if len(new_children) > 0:
|
|
151
|
+
expression.child = new_children[0]
|
|
152
|
+
|
|
153
|
+
return True
|
|
154
|
+
|
|
155
|
+
for child in _get_child_expressions(expression):
|
|
156
|
+
is_agg_function_in_child = _inject_condition_to_all_agg_functions(
|
|
157
|
+
child, condition
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if is_agg_function_in_child:
|
|
161
|
+
any_agg_function_found = True
|
|
162
|
+
|
|
163
|
+
return any_agg_function_found
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
import threading
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
|
|
8
|
+
from snowflake.snowpark_connect.utils.context import get_session_id
|
|
9
|
+
|
|
10
|
+
# per session number sequences to generate unique snowpark columns
|
|
11
|
+
_session_sequences = defaultdict(int)
|
|
12
|
+
|
|
13
|
+
_lock = threading.Lock()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def next_unique_num():
|
|
17
|
+
session_id = get_session_id()
|
|
18
|
+
with _lock:
|
|
19
|
+
next_num = _session_sequences[session_id]
|
|
20
|
+
_session_sequences[session_id] = next_num + 1
|
|
21
|
+
return next_num
|
|
@@ -94,6 +94,10 @@ def configure_snowpark_session(session: snowpark.Session):
|
|
|
94
94
|
session.connection.arrow_number_to_decimal_setter = True
|
|
95
95
|
session.custom_package_usage_config["enabled"] = True
|
|
96
96
|
|
|
97
|
+
# Scoped temp objects may not be accessible in stored procedure and cause "object does not exist" error. So disable
|
|
98
|
+
# _use_scoped_temp_objects here and use temp table instead.
|
|
99
|
+
session._use_scoped_temp_objects = False
|
|
100
|
+
|
|
97
101
|
# Configure CTE optimization based on session configuration
|
|
98
102
|
cte_optimization_enabled = get_cte_optimization_enabled()
|
|
99
103
|
session.cte_optimization_enabled = cte_optimization_enabled
|
|
@@ -128,7 +132,6 @@ def configure_snowpark_session(session: snowpark.Session):
|
|
|
128
132
|
"TIMEZONE": f"'{global_config.spark_sql_session_timeZone}'",
|
|
129
133
|
"QUOTED_IDENTIFIERS_IGNORE_CASE": "false",
|
|
130
134
|
"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION": "true",
|
|
131
|
-
"PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS": "false", # this is required for creating udfs from sproc
|
|
132
135
|
"ENABLE_STRUCTURED_TYPES_IN_SNOWPARK_CONNECT_RESPONSE": "true",
|
|
133
136
|
"QUERY_TAG": f"'{query_tag}'",
|
|
134
137
|
}
|
|
@@ -186,6 +186,7 @@ def parse_return_type(return_type_json_str) -> Optional[DataType]:
|
|
|
186
186
|
|
|
187
187
|
|
|
188
188
|
def create(session, called_from, return_type_json_str, input_types_json_str, input_column_names_json_str, udf_name, replace, udf_packages, udf_imports, b64_str, original_return_type):
|
|
189
|
+
session._use_scoped_temp_objects = False
|
|
189
190
|
import snowflake.snowpark.context as context
|
|
190
191
|
context._use_structured_type_semantics = True
|
|
191
192
|
context._is_snowpark_connect_compatible_mode = True
|
|
@@ -153,6 +153,7 @@ def parse_types(types_json_str) -> Optional[list[DataType]]:
|
|
|
153
153
|
return json.loads(types_json_str)
|
|
154
154
|
|
|
155
155
|
def create(session, b64_str, expected_types_json_str, output_schema_json_str, packages, imports, is_arrow_enabled, is_spark_compatible_udtf_mode_enabled, called_from):
|
|
156
|
+
session._use_scoped_temp_objects = False
|
|
156
157
|
import snowflake.snowpark.context as context
|
|
157
158
|
context._use_structured_type_semantics = True
|
|
158
159
|
context._is_snowpark_connect_compatible_mode = True
|
|
@@ -257,6 +258,7 @@ from snowflake.snowpark.types import _parse_datatype_json_value
|
|
|
257
258
|
{inline_udtf_utils_py_code}
|
|
258
259
|
|
|
259
260
|
def create(session, b64_str, spark_column_names_json_str, input_schema_json_str, return_schema_json_str):
|
|
261
|
+
session._use_scoped_temp_objects = False
|
|
260
262
|
import snowflake.snowpark.context as context
|
|
261
263
|
context._use_structured_type_semantics = True
|
|
262
264
|
context._is_snowpark_connect_compatible_mode = True
|
|
@@ -330,6 +332,7 @@ from snowflake.snowpark.types import _parse_datatype_json_value
|
|
|
330
332
|
from pyspark.serializers import CloudPickleSerializer
|
|
331
333
|
|
|
332
334
|
def create(session, func_info_json):
|
|
335
|
+
session._use_scoped_temp_objects = False
|
|
333
336
|
import snowflake.snowpark.context as context
|
|
334
337
|
context._use_structured_type_semantics = True
|
|
335
338
|
context._is_snowpark_connect_compatible_mode = True
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: snowpark-connect
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.0
|
|
4
4
|
Summary: Snowpark Connect for Spark
|
|
5
5
|
Author: Snowflake, Inc
|
|
6
6
|
License: Apache License, Version 2.0
|
|
@@ -9,6 +9,8 @@ Description-Content-Type: text/markdown
|
|
|
9
9
|
License-File: LICENSE.txt
|
|
10
10
|
License-File: LICENSE-binary
|
|
11
11
|
License-File: NOTICE-binary
|
|
12
|
+
Requires-Dist: snowpark-connect-deps-1==3.56.2
|
|
13
|
+
Requires-Dist: snowpark-connect-deps-2==3.56.2
|
|
12
14
|
Requires-Dist: certifi>=2025.1.31
|
|
13
15
|
Requires-Dist: cloudpickle
|
|
14
16
|
Requires-Dist: fsspec[http]
|
|
@@ -16,7 +18,7 @@ Requires-Dist: jpype1
|
|
|
16
18
|
Requires-Dist: protobuf<6.32.0,>=4.25.3
|
|
17
19
|
Requires-Dist: s3fs>=2025.3.0
|
|
18
20
|
Requires-Dist: snowflake.core<2,>=1.0.5
|
|
19
|
-
Requires-Dist: snowflake-snowpark-python[pandas]<1.
|
|
21
|
+
Requires-Dist: snowflake-snowpark-python[pandas]<1.43.0,==1.42.0
|
|
20
22
|
Requires-Dist: snowflake-connector-python<4.0.0,>=3.18.0
|
|
21
23
|
Requires-Dist: sqlglot>=26.3.8
|
|
22
24
|
Requires-Dist: jaydebeapi
|