snowpark-connect 1.6.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/client/server.py +37 -0
- snowflake/snowpark_connect/config.py +72 -3
- snowflake/snowpark_connect/expression/error_utils.py +28 -0
- snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
- snowflake/snowpark_connect/expression/map_cast.py +108 -17
- snowflake/snowpark_connect/expression/map_udf.py +1 -0
- snowflake/snowpark_connect/expression/map_unresolved_function.py +229 -96
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/relation/map_aggregate.py +43 -1
- snowflake/snowpark_connect/relation/read/map_read_csv.py +73 -4
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_json.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_socket.py +4 -0
- snowflake/snowpark_connect/relation/read/map_read_table.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_text.py +4 -1
- snowflake/snowpark_connect/relation/read/reader_config.py +6 -0
- snowflake/snowpark_connect/resources_initializer.py +90 -29
- snowflake/snowpark_connect/server.py +6 -41
- snowflake/snowpark_connect/server_common/__init__.py +4 -1
- snowflake/snowpark_connect/type_support.py +130 -0
- snowflake/snowpark_connect/utils/context.py +8 -0
- snowflake/snowpark_connect/utils/java_stored_procedure.py +53 -27
- snowflake/snowpark_connect/utils/java_udaf_utils.py +46 -28
- snowflake/snowpark_connect/utils/java_udtf_utils.py +1 -1
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +48 -15
- snowflake/snowpark_connect/utils/scala_udf_utils.py +98 -22
- snowflake/snowpark_connect/utils/telemetry.py +33 -22
- snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +12 -2
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +4 -2
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +46 -37
- {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
import threading
|
|
6
|
+
|
|
7
|
+
from snowflake import snowpark
|
|
8
|
+
from snowflake.snowpark.types import (
|
|
9
|
+
ArrayType,
|
|
10
|
+
ByteType,
|
|
11
|
+
DataType,
|
|
12
|
+
DecimalType,
|
|
13
|
+
IntegerType,
|
|
14
|
+
LongType,
|
|
15
|
+
MapType,
|
|
16
|
+
ShortType,
|
|
17
|
+
StructField,
|
|
18
|
+
StructType,
|
|
19
|
+
_IntegralType,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
_integral_types_conversion_enabled: bool = False
|
|
23
|
+
_client_mode_lock = threading.Lock()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def set_integral_types_conversion(enabled: bool) -> None:
|
|
27
|
+
global _integral_types_conversion_enabled
|
|
28
|
+
|
|
29
|
+
with _client_mode_lock:
|
|
30
|
+
if _integral_types_conversion_enabled == enabled:
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
_integral_types_conversion_enabled = enabled
|
|
34
|
+
|
|
35
|
+
if enabled:
|
|
36
|
+
snowpark.context._integral_type_default_precision = {
|
|
37
|
+
LongType: 19,
|
|
38
|
+
IntegerType: 10,
|
|
39
|
+
ShortType: 5,
|
|
40
|
+
ByteType: 3,
|
|
41
|
+
}
|
|
42
|
+
else:
|
|
43
|
+
snowpark.context._integral_type_default_precision = {}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def set_integral_types_for_client_default(is_python_client: bool) -> None:
|
|
47
|
+
"""
|
|
48
|
+
Set integral types based on client type when config is 'client_default'.
|
|
49
|
+
"""
|
|
50
|
+
from snowflake.snowpark_connect.config import global_config
|
|
51
|
+
|
|
52
|
+
config_key = "snowpark.connect.integralTypesEmulation"
|
|
53
|
+
if global_config.get(config_key) != "client_default":
|
|
54
|
+
return
|
|
55
|
+
|
|
56
|
+
# if client mode matches, no action needed (no lock overhead)
|
|
57
|
+
if _integral_types_conversion_enabled == (not is_python_client):
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
set_integral_types_conversion(not is_python_client)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def emulate_integral_types(t: DataType) -> DataType:
|
|
64
|
+
"""
|
|
65
|
+
Map LongType based on precision attribute to appropriate integral types.
|
|
66
|
+
|
|
67
|
+
Mappings:
|
|
68
|
+
- _IntegralType with precision=19 -> LongType
|
|
69
|
+
- _IntegralType with precision=10 -> IntegerType
|
|
70
|
+
- _IntegralType with precision=5 -> ShortType
|
|
71
|
+
- _IntegralType with precision=3 -> ByteType
|
|
72
|
+
- _IntegralType with other precision -> DecimalType(precision, 0)
|
|
73
|
+
|
|
74
|
+
This conversion is controlled by the 'snowpark.connect.integralTypesEmulation' config.
|
|
75
|
+
When disabled, the function returns the input type unchanged.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
t: The DataType to transform
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
The transformed DataType with integral type conversions applied based on precision.
|
|
82
|
+
"""
|
|
83
|
+
global _integral_types_conversion_enabled
|
|
84
|
+
|
|
85
|
+
with _client_mode_lock:
|
|
86
|
+
enabled = _integral_types_conversion_enabled
|
|
87
|
+
if not enabled:
|
|
88
|
+
return t
|
|
89
|
+
if isinstance(t, _IntegralType):
|
|
90
|
+
precision = getattr(t, "_precision", None)
|
|
91
|
+
|
|
92
|
+
if precision is None:
|
|
93
|
+
return t
|
|
94
|
+
elif precision == 19:
|
|
95
|
+
return LongType()
|
|
96
|
+
elif precision == 10:
|
|
97
|
+
return IntegerType()
|
|
98
|
+
elif precision == 5:
|
|
99
|
+
return ShortType()
|
|
100
|
+
elif precision == 3:
|
|
101
|
+
return ByteType()
|
|
102
|
+
else:
|
|
103
|
+
return DecimalType(precision, 0)
|
|
104
|
+
|
|
105
|
+
elif isinstance(t, StructType):
|
|
106
|
+
new_fields = [
|
|
107
|
+
StructField(
|
|
108
|
+
field.name,
|
|
109
|
+
emulate_integral_types(field.datatype),
|
|
110
|
+
field.nullable,
|
|
111
|
+
_is_column=field._is_column,
|
|
112
|
+
)
|
|
113
|
+
for field in t.fields
|
|
114
|
+
]
|
|
115
|
+
return StructType(new_fields)
|
|
116
|
+
|
|
117
|
+
elif isinstance(t, ArrayType):
|
|
118
|
+
return ArrayType(
|
|
119
|
+
emulate_integral_types(t.element_type),
|
|
120
|
+
t.contains_null,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
elif isinstance(t, MapType):
|
|
124
|
+
return MapType(
|
|
125
|
+
emulate_integral_types(t.key_type),
|
|
126
|
+
emulate_integral_types(t.value_type),
|
|
127
|
+
t.value_contains_null,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
return t
|
|
@@ -12,6 +12,9 @@ from typing import Iterator, Mapping, Optional
|
|
|
12
12
|
import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
13
13
|
|
|
14
14
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
15
|
+
from snowflake.snowpark_connect.type_support import (
|
|
16
|
+
set_integral_types_for_client_default,
|
|
17
|
+
)
|
|
15
18
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
16
19
|
|
|
17
20
|
# TODO: remove session id from context when we host SAS in Snowflake server
|
|
@@ -267,6 +270,11 @@ def set_spark_version(client_type: str) -> None:
|
|
|
267
270
|
version = match.group("spark_version") if match else ""
|
|
268
271
|
_spark_version.set(version)
|
|
269
272
|
|
|
273
|
+
# enable integral types (only if config is "client_default")
|
|
274
|
+
|
|
275
|
+
is_python_client = "_SPARK_CONNECT_PYTHON" in client_type
|
|
276
|
+
set_integral_types_for_client_default(is_python_client)
|
|
277
|
+
|
|
270
278
|
|
|
271
279
|
def get_is_aggregate_function() -> tuple[str, bool]:
|
|
272
280
|
"""
|
|
@@ -7,11 +7,22 @@ from pyspark.errors import AnalysisException
|
|
|
7
7
|
import snowflake.snowpark.types as snowpark_type
|
|
8
8
|
from snowflake.snowpark import Session
|
|
9
9
|
from snowflake.snowpark._internal.type_utils import type_string_to_type_object
|
|
10
|
+
from snowflake.snowpark_connect.client.error_utils import attach_custom_error_code
|
|
11
|
+
from snowflake.snowpark_connect.config import (
|
|
12
|
+
get_scala_version,
|
|
13
|
+
is_java_udf_creator_initialized,
|
|
14
|
+
set_java_udf_creator_initialized_state,
|
|
15
|
+
)
|
|
16
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
10
17
|
from snowflake.snowpark_connect.resources_initializer import (
|
|
11
18
|
RESOURCE_PATH,
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
19
|
+
SPARK_COMMON_UTILS_JAR_212,
|
|
20
|
+
SPARK_COMMON_UTILS_JAR_213,
|
|
21
|
+
SPARK_CONNECT_CLIENT_JAR_212,
|
|
22
|
+
SPARK_CONNECT_CLIENT_JAR_213,
|
|
23
|
+
SPARK_SQL_JAR_212,
|
|
24
|
+
SPARK_SQL_JAR_213,
|
|
25
|
+
ensure_scala_udf_jars_uploaded,
|
|
15
26
|
)
|
|
16
27
|
from snowflake.snowpark_connect.utils.upload_java_jar import upload_java_udf_jar
|
|
17
28
|
|
|
@@ -22,7 +33,7 @@ CREATE OR REPLACE TEMPORARY PROCEDURE __SC_JAVA_SP_CREATE_JAVA_UDF(udf_name VARC
|
|
|
22
33
|
RETURNS VARCHAR
|
|
23
34
|
LANGUAGE JAVA
|
|
24
35
|
RUNTIME_VERSION = 17
|
|
25
|
-
PACKAGES = ('com.snowflake:
|
|
36
|
+
PACKAGES = ('com.snowflake:snowpark___scala_version__:latest')
|
|
26
37
|
__snowflake_udf_imports__
|
|
27
38
|
HANDLER = 'com.snowflake.snowpark_connect.procedures.JavaUDFCreator.process'
|
|
28
39
|
EXECUTE AS CALLER
|
|
@@ -30,19 +41,6 @@ EXECUTE AS CALLER
|
|
|
30
41
|
"""
|
|
31
42
|
|
|
32
43
|
|
|
33
|
-
_is_initialized = False
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def is_initialized() -> bool:
|
|
37
|
-
global _is_initialized
|
|
38
|
-
return _is_initialized
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def set_java_udf_creator_initialized_state(value: bool) -> None:
|
|
42
|
-
global _is_initialized
|
|
43
|
-
_is_initialized = value
|
|
44
|
-
|
|
45
|
-
|
|
46
44
|
class JavaUdf:
|
|
47
45
|
"""
|
|
48
46
|
Reference class for Java UDFs, providing similar properties like Python UserDefinedFunction.
|
|
@@ -70,12 +68,33 @@ class JavaUdf:
|
|
|
70
68
|
self._return_type = return_type
|
|
71
69
|
|
|
72
70
|
|
|
71
|
+
def _scala_static_imports_for_sproc(stage_resource_path: str) -> set[str]:
|
|
72
|
+
scala_version = get_scala_version()
|
|
73
|
+
if scala_version == "2.12":
|
|
74
|
+
return {
|
|
75
|
+
f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_212}",
|
|
76
|
+
f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_212}",
|
|
77
|
+
f"{stage_resource_path}/{SPARK_SQL_JAR_212}",
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
if scala_version == "2.13":
|
|
81
|
+
return {
|
|
82
|
+
f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_213}",
|
|
83
|
+
f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_213}",
|
|
84
|
+
f"{stage_resource_path}/{SPARK_SQL_JAR_213}",
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
# invalid Scala version
|
|
88
|
+
exception = ValueError(
|
|
89
|
+
f"Unsupported Scala version: {scala_version}. Snowpark Connect supports Scala 2.12 and 2.13"
|
|
90
|
+
)
|
|
91
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_CONFIG_VALUE)
|
|
92
|
+
raise exception
|
|
93
|
+
|
|
94
|
+
|
|
73
95
|
def get_quoted_imports(session: Session) -> str:
|
|
74
96
|
stage_resource_path = session.get_session_stage() + RESOURCE_PATH
|
|
75
|
-
spark_imports = {
|
|
76
|
-
f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR}",
|
|
77
|
-
f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR}",
|
|
78
|
-
f"{stage_resource_path}/{SPARK_SQL_JAR}",
|
|
97
|
+
spark_imports = _scala_static_imports_for_sproc(stage_resource_path) | {
|
|
79
98
|
f"{stage_resource_path}/java_udfs-1.0-SNAPSHOT.jar",
|
|
80
99
|
}
|
|
81
100
|
|
|
@@ -83,14 +102,21 @@ def get_quoted_imports(session: Session) -> str:
|
|
|
83
102
|
"""Helper function to wrap strings in single quotes for SQL."""
|
|
84
103
|
return "'" + s + "'"
|
|
85
104
|
|
|
86
|
-
|
|
105
|
+
from snowflake.snowpark_connect.config import global_config
|
|
87
106
|
|
|
107
|
+
config_imports = global_config.get("snowpark.connect.udf.java.imports", "")
|
|
108
|
+
config_imports = (
|
|
109
|
+
{x.strip() for x in config_imports.strip("[] ").split(",") if x.strip()}
|
|
110
|
+
if config_imports
|
|
111
|
+
else set()
|
|
112
|
+
)
|
|
88
113
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
ensure_scala_udf_jars_uploaded,
|
|
114
|
+
return ", ".join(
|
|
115
|
+
quote_single(x) for x in session._artifact_jars | spark_imports | config_imports
|
|
92
116
|
)
|
|
93
117
|
|
|
118
|
+
|
|
119
|
+
def create_snowflake_imports(session: Session) -> str:
|
|
94
120
|
# Make sure that the resource initializer thread is completed before creating Java UDFs since we depend on the jars
|
|
95
121
|
# uploaded by it.
|
|
96
122
|
ensure_scala_udf_jars_uploaded()
|
|
@@ -99,12 +125,12 @@ def create_snowflake_imports(session: Session) -> str:
|
|
|
99
125
|
|
|
100
126
|
|
|
101
127
|
def create_java_udf(session: Session, function_name: str, java_class: str):
|
|
102
|
-
if not
|
|
128
|
+
if not is_java_udf_creator_initialized():
|
|
103
129
|
upload_java_udf_jar(session)
|
|
104
130
|
session.sql(
|
|
105
131
|
SP_TEMPLATE.replace(
|
|
106
132
|
"__snowflake_udf_imports__", create_snowflake_imports(session)
|
|
107
|
-
)
|
|
133
|
+
).replace("__scala_version__", get_scala_version())
|
|
108
134
|
).collect()
|
|
109
135
|
set_java_udf_creator_initialized_state(True)
|
|
110
136
|
name = CREATE_JAVA_UDF_PREFIX + function_name
|
|
@@ -12,7 +12,6 @@ from snowflake.snowpark_connect.utils.jvm_udf_utils import (
|
|
|
12
12
|
ReturnType,
|
|
13
13
|
Signature,
|
|
14
14
|
build_jvm_udxf_imports,
|
|
15
|
-
cast_java_map_args_from_given_type,
|
|
16
15
|
map_type_to_java_type,
|
|
17
16
|
)
|
|
18
17
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
@@ -41,19 +40,20 @@ import com.snowflake.snowpark_java.types.*;
|
|
|
41
40
|
|
|
42
41
|
public class JavaUDAF {
|
|
43
42
|
private final static String OPERATION_FILE = "__operation_file__";
|
|
44
|
-
private static scala.Function2<
|
|
43
|
+
private static scala.Function2<__reduce_type__, __reduce_type__, __reduce_type__> operation = null;
|
|
44
|
+
private static UdfPacket udfPacket = null;
|
|
45
45
|
|
|
46
46
|
private static void loadOperation() throws IOException, ClassNotFoundException {
|
|
47
47
|
if (operation != null) {
|
|
48
48
|
return; // Already loaded
|
|
49
49
|
}
|
|
50
50
|
|
|
51
|
-
|
|
52
|
-
operation = (scala.Function2<
|
|
51
|
+
udfPacket = com.snowflake.sas.scala.Utils$.MODULE$.deserializeUdfPacket(OPERATION_FILE);
|
|
52
|
+
operation = (scala.Function2<__reduce_type__, __reduce_type__, __reduce_type__>) udfPacket.function();
|
|
53
53
|
}
|
|
54
54
|
|
|
55
55
|
public static class State implements Serializable {
|
|
56
|
-
public
|
|
56
|
+
public __reduce_type__ value = null;
|
|
57
57
|
public boolean initialized = false;
|
|
58
58
|
}
|
|
59
59
|
|
|
@@ -69,10 +69,10 @@ public class JavaUDAF {
|
|
|
69
69
|
}
|
|
70
70
|
|
|
71
71
|
if (!state.initialized) {
|
|
72
|
-
state.value =
|
|
72
|
+
state.value = __mapped_value__;
|
|
73
73
|
state.initialized = true;
|
|
74
74
|
} else {
|
|
75
|
-
state.value = operation.apply(state.value,
|
|
75
|
+
state.value = operation.apply(state.value, __mapped_value__);
|
|
76
76
|
}
|
|
77
77
|
return state;
|
|
78
78
|
}
|
|
@@ -115,7 +115,6 @@ class JavaUDAFDef:
|
|
|
115
115
|
name: str
|
|
116
116
|
signature: Signature
|
|
117
117
|
java_signature: Signature
|
|
118
|
-
java_invocation_args: list[str]
|
|
119
118
|
imports: list[str]
|
|
120
119
|
null_handling: NullHandling = NullHandling.RETURNS_NULL_ON_NULL_INPUT
|
|
121
120
|
|
|
@@ -131,17 +130,31 @@ class JavaUDAFDef:
|
|
|
131
130
|
Returns:
|
|
132
131
|
String containing the complete Java code for the UDAF body
|
|
133
132
|
"""
|
|
134
|
-
returns_variant = self.signature.returns.data_type == "
|
|
133
|
+
returns_variant = self.signature.returns.data_type.lower() == "variant"
|
|
135
134
|
return_type = (
|
|
136
135
|
"Variant" if returns_variant else self.java_signature.params[0].data_type
|
|
137
136
|
)
|
|
138
137
|
response_wrapper = (
|
|
139
|
-
"
|
|
138
|
+
"com.snowflake.sas.scala.Utils$.MODULE$.toVariant(state.value, udfPacket)"
|
|
139
|
+
if returns_variant
|
|
140
|
+
else "state.value"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
is_variant_input = self.java_signature.params[0].data_type.lower() == "variant"
|
|
144
|
+
reduce_type = (
|
|
145
|
+
"Object" if is_variant_input else self.java_signature.params[0].data_type
|
|
140
146
|
)
|
|
141
147
|
return (
|
|
142
148
|
UDAF_TEMPLATE.replace("__operation_file__", self.imports[0].split("/")[-1])
|
|
143
149
|
.replace("__accumulator_type__", self.java_signature.params[0].data_type)
|
|
144
150
|
.replace("__value_type__", self.java_signature.params[1].data_type)
|
|
151
|
+
.replace(
|
|
152
|
+
"__mapped_value__",
|
|
153
|
+
"com.snowflake.sas.scala.UdfPacketUtils$.MODULE$.fromVariant(udfPacket, input, 0)"
|
|
154
|
+
if is_variant_input
|
|
155
|
+
else "input",
|
|
156
|
+
)
|
|
157
|
+
.replace("__reduce_type__", reduce_type)
|
|
145
158
|
.replace("__return_type__", return_type)
|
|
146
159
|
.replace("__response_wrapper__", response_wrapper)
|
|
147
160
|
)
|
|
@@ -231,12 +244,11 @@ def create_java_udaf_for_reduce_scala_function(
|
|
|
231
244
|
A JavaUdaf object representing the Java UDAF.
|
|
232
245
|
"""
|
|
233
246
|
from snowflake.snowpark_connect.resources_initializer import (
|
|
234
|
-
|
|
247
|
+
ensure_scala_udf_jars_uploaded,
|
|
235
248
|
)
|
|
236
249
|
|
|
237
|
-
# Make sure
|
|
238
|
-
|
|
239
|
-
wait_for_resource_initialization()
|
|
250
|
+
# Make sure Scala UDF jars are uploaded before creating Java UDAFs since we depend on them.
|
|
251
|
+
ensure_scala_udf_jars_uploaded()
|
|
240
252
|
|
|
241
253
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
242
254
|
|
|
@@ -252,23 +264,26 @@ def create_java_udaf_for_reduce_scala_function(
|
|
|
252
264
|
|
|
253
265
|
java_input_params: list[Param] = []
|
|
254
266
|
sql_input_params: list[Param] = []
|
|
255
|
-
java_invocation_args: list[str] = [] # arguments passed into the udf function
|
|
256
267
|
if input_types: # input_types can be None when no arguments are provided
|
|
257
268
|
for i, input_type in enumerate(input_types):
|
|
258
269
|
param_name = "arg" + str(i)
|
|
270
|
+
if isinstance(
|
|
271
|
+
input_type,
|
|
272
|
+
(
|
|
273
|
+
snowpark_type.ArrayType,
|
|
274
|
+
snowpark_type.MapType,
|
|
275
|
+
snowpark_type.VariantType,
|
|
276
|
+
),
|
|
277
|
+
):
|
|
278
|
+
java_type = "Variant"
|
|
279
|
+
snowflake_type = "Variant"
|
|
280
|
+
else:
|
|
281
|
+
java_type = map_type_to_java_type(input_type)
|
|
282
|
+
snowflake_type = map_type_to_snowflake_type(input_type)
|
|
259
283
|
# Create the Java arguments and input types string: "arg0: Type0, arg1: Type1, ...".
|
|
260
|
-
java_input_params.append(
|
|
261
|
-
Param(param_name, map_type_to_java_type(input_type))
|
|
262
|
-
)
|
|
284
|
+
java_input_params.append(Param(param_name, java_type))
|
|
263
285
|
# Create the Snowflake SQL arguments and input types string: "arg0 TYPE0, arg1 TYPE1, ...".
|
|
264
|
-
sql_input_params.append(
|
|
265
|
-
Param(param_name, map_type_to_snowflake_type(input_type))
|
|
266
|
-
)
|
|
267
|
-
# In the case of Map input types, we need to cast the argument to the correct type in Java.
|
|
268
|
-
# Snowflake SQL Java can only handle MAP[VARCHAR, VARCHAR] as input types.
|
|
269
|
-
java_invocation_args.append(
|
|
270
|
-
cast_java_map_args_from_given_type(param_name, input_type)
|
|
271
|
-
)
|
|
286
|
+
sql_input_params.append(Param(param_name, snowflake_type))
|
|
272
287
|
|
|
273
288
|
java_return_type = map_type_to_java_type(pciudf._original_return_type)
|
|
274
289
|
# If the SQL return type is a MAP or STRUCT, change this to VARIANT because of issues with Java UDAFs.
|
|
@@ -282,7 +297,11 @@ def create_java_udaf_for_reduce_scala_function(
|
|
|
282
297
|
)
|
|
283
298
|
sql_return_type = (
|
|
284
299
|
"VARIANT"
|
|
285
|
-
if (
|
|
300
|
+
if (
|
|
301
|
+
sql_return_type.startswith("MAP")
|
|
302
|
+
or sql_return_type.startswith("OBJECT")
|
|
303
|
+
or sql_return_type.startswith("ARRAY")
|
|
304
|
+
)
|
|
286
305
|
else sql_return_type
|
|
287
306
|
)
|
|
288
307
|
|
|
@@ -295,7 +314,6 @@ def create_java_udaf_for_reduce_scala_function(
|
|
|
295
314
|
java_signature=Signature(
|
|
296
315
|
params=java_input_params, returns=ReturnType(java_return_type)
|
|
297
316
|
),
|
|
298
|
-
java_invocation_args=java_invocation_args,
|
|
299
317
|
)
|
|
300
318
|
create_udf_sql = udf_def.to_create_function_sql()
|
|
301
319
|
logger.info(f"Creating Java UDAF: {create_udf_sql}")
|
|
@@ -95,7 +95,7 @@ public class JavaUdtfHandler {
|
|
|
95
95
|
java.util.Iterator<Variant> javaResult = new java.util.Iterator<Variant>() {
|
|
96
96
|
public boolean hasNext() { return scalaResult.hasNext(); }
|
|
97
97
|
public Variant next() {
|
|
98
|
-
return com.snowflake.sas.scala.Utils$.MODULE$.toVariant(scalaResult.next());
|
|
98
|
+
return com.snowflake.sas.scala.Utils$.MODULE$.toVariant(scalaResult.next(), udfPacket);
|
|
99
99
|
}
|
|
100
100
|
};
|
|
101
101
|
|
|
@@ -9,16 +9,23 @@ from typing import List, Union
|
|
|
9
9
|
import snowflake.snowpark.types as snowpark_type
|
|
10
10
|
import snowflake.snowpark_connect.includes.python.pyspark.sql.connect.proto.types_pb2 as types_proto
|
|
11
11
|
from snowflake import snowpark
|
|
12
|
+
from snowflake.snowpark_connect.config import get_scala_version
|
|
12
13
|
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
13
14
|
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
14
15
|
from snowflake.snowpark_connect.resources_initializer import (
|
|
15
|
-
|
|
16
|
+
JSON_4S_JAR_212,
|
|
17
|
+
JSON_4S_JAR_213,
|
|
16
18
|
RESOURCE_PATH,
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
19
|
+
SAS_SCALA_UDF_JAR_212,
|
|
20
|
+
SAS_SCALA_UDF_JAR_213,
|
|
21
|
+
SCALA_REFLECT_JAR_212,
|
|
22
|
+
SCALA_REFLECT_JAR_213,
|
|
23
|
+
SPARK_COMMON_UTILS_JAR_212,
|
|
24
|
+
SPARK_COMMON_UTILS_JAR_213,
|
|
25
|
+
SPARK_CONNECT_CLIENT_JAR_212,
|
|
26
|
+
SPARK_CONNECT_CLIENT_JAR_213,
|
|
27
|
+
SPARK_SQL_JAR_212,
|
|
28
|
+
SPARK_SQL_JAR_213,
|
|
22
29
|
)
|
|
23
30
|
|
|
24
31
|
|
|
@@ -108,15 +115,41 @@ def build_jvm_udxf_imports(
|
|
|
108
115
|
)
|
|
109
116
|
|
|
110
117
|
# Format the user jars to be used in the IMPORTS clause of the stored procedure.
|
|
111
|
-
return
|
|
112
|
-
closure_binary_file
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
118
|
+
return (
|
|
119
|
+
[closure_binary_file]
|
|
120
|
+
+ _scala_static_imports_for_udf(stage_resource_path)
|
|
121
|
+
+ list(session._artifact_jars)
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _scala_static_imports_for_udf(stage_resource_path: str) -> list[str]:
|
|
126
|
+
scala_version = get_scala_version()
|
|
127
|
+
if scala_version == "2.12":
|
|
128
|
+
return [
|
|
129
|
+
f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_212}",
|
|
130
|
+
f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_212}",
|
|
131
|
+
f"{stage_resource_path}/{SPARK_SQL_JAR_212}",
|
|
132
|
+
f"{stage_resource_path}/{JSON_4S_JAR_212}",
|
|
133
|
+
f"{stage_resource_path}/{SAS_SCALA_UDF_JAR_212}",
|
|
134
|
+
f"{stage_resource_path}/{SCALA_REFLECT_JAR_212}", # Required for deserializing Scala lambdas
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
if scala_version == "2.13":
|
|
138
|
+
return [
|
|
139
|
+
f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_213}",
|
|
140
|
+
f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_213}",
|
|
141
|
+
f"{stage_resource_path}/{SPARK_SQL_JAR_213}",
|
|
142
|
+
f"{stage_resource_path}/{JSON_4S_JAR_213}",
|
|
143
|
+
f"{stage_resource_path}/{SAS_SCALA_UDF_JAR_213}",
|
|
144
|
+
f"{stage_resource_path}/{SCALA_REFLECT_JAR_213}", # Required for deserializing Scala lambdas
|
|
145
|
+
]
|
|
146
|
+
|
|
147
|
+
# invalid Scala version
|
|
148
|
+
exception = ValueError(
|
|
149
|
+
f"Unsupported Scala version: {scala_version}. Snowpark Connect supports Scala 2.12 and 2.13"
|
|
150
|
+
)
|
|
151
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_CONFIG_VALUE)
|
|
152
|
+
raise exception
|
|
120
153
|
|
|
121
154
|
|
|
122
155
|
def map_type_to_java_type(
|