snowpark-connect 1.6.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. snowflake/snowpark_connect/client/server.py +37 -0
  2. snowflake/snowpark_connect/config.py +72 -3
  3. snowflake/snowpark_connect/expression/error_utils.py +28 -0
  4. snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
  5. snowflake/snowpark_connect/expression/map_cast.py +108 -17
  6. snowflake/snowpark_connect/expression/map_udf.py +1 -0
  7. snowflake/snowpark_connect/expression/map_unresolved_function.py +229 -96
  8. snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
  9. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.2.0.jar +0 -0
  10. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
  11. snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
  12. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
  13. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.13-3.5.6.jar +0 -0
  14. snowflake/snowpark_connect/includes/jars/spark-sql_2.13-3.5.6.jar +0 -0
  15. snowflake/snowpark_connect/relation/map_aggregate.py +43 -1
  16. snowflake/snowpark_connect/relation/read/map_read_csv.py +73 -4
  17. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +4 -1
  18. snowflake/snowpark_connect/relation/read/map_read_json.py +4 -1
  19. snowflake/snowpark_connect/relation/read/map_read_parquet.py +4 -1
  20. snowflake/snowpark_connect/relation/read/map_read_socket.py +4 -0
  21. snowflake/snowpark_connect/relation/read/map_read_table.py +4 -1
  22. snowflake/snowpark_connect/relation/read/map_read_text.py +4 -1
  23. snowflake/snowpark_connect/relation/read/reader_config.py +6 -0
  24. snowflake/snowpark_connect/resources_initializer.py +90 -29
  25. snowflake/snowpark_connect/server.py +6 -41
  26. snowflake/snowpark_connect/server_common/__init__.py +4 -1
  27. snowflake/snowpark_connect/type_support.py +130 -0
  28. snowflake/snowpark_connect/utils/context.py +8 -0
  29. snowflake/snowpark_connect/utils/java_stored_procedure.py +53 -27
  30. snowflake/snowpark_connect/utils/java_udaf_utils.py +46 -28
  31. snowflake/snowpark_connect/utils/java_udtf_utils.py +1 -1
  32. snowflake/snowpark_connect/utils/jvm_udf_utils.py +48 -15
  33. snowflake/snowpark_connect/utils/scala_udf_utils.py +98 -22
  34. snowflake/snowpark_connect/utils/telemetry.py +33 -22
  35. snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
  36. snowflake/snowpark_connect/version.py +1 -1
  37. {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +12 -2
  38. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +4 -2
  39. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +46 -37
  40. {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
  41. {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
  42. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
  43. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
  44. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  45. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
  46. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,130 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ import threading
6
+
7
+ from snowflake import snowpark
8
+ from snowflake.snowpark.types import (
9
+ ArrayType,
10
+ ByteType,
11
+ DataType,
12
+ DecimalType,
13
+ IntegerType,
14
+ LongType,
15
+ MapType,
16
+ ShortType,
17
+ StructField,
18
+ StructType,
19
+ _IntegralType,
20
+ )
21
+
22
+ _integral_types_conversion_enabled: bool = False
23
+ _client_mode_lock = threading.Lock()
24
+
25
+
26
+ def set_integral_types_conversion(enabled: bool) -> None:
27
+ global _integral_types_conversion_enabled
28
+
29
+ with _client_mode_lock:
30
+ if _integral_types_conversion_enabled == enabled:
31
+ return
32
+
33
+ _integral_types_conversion_enabled = enabled
34
+
35
+ if enabled:
36
+ snowpark.context._integral_type_default_precision = {
37
+ LongType: 19,
38
+ IntegerType: 10,
39
+ ShortType: 5,
40
+ ByteType: 3,
41
+ }
42
+ else:
43
+ snowpark.context._integral_type_default_precision = {}
44
+
45
+
46
+ def set_integral_types_for_client_default(is_python_client: bool) -> None:
47
+ """
48
+ Set integral types based on client type when config is 'client_default'.
49
+ """
50
+ from snowflake.snowpark_connect.config import global_config
51
+
52
+ config_key = "snowpark.connect.integralTypesEmulation"
53
+ if global_config.get(config_key) != "client_default":
54
+ return
55
+
56
+ # if client mode matches, no action needed (no lock overhead)
57
+ if _integral_types_conversion_enabled == (not is_python_client):
58
+ return
59
+
60
+ set_integral_types_conversion(not is_python_client)
61
+
62
+
63
+ def emulate_integral_types(t: DataType) -> DataType:
64
+ """
65
+ Map LongType based on precision attribute to appropriate integral types.
66
+
67
+ Mappings:
68
+ - _IntegralType with precision=19 -> LongType
69
+ - _IntegralType with precision=10 -> IntegerType
70
+ - _IntegralType with precision=5 -> ShortType
71
+ - _IntegralType with precision=3 -> ByteType
72
+ - _IntegralType with other precision -> DecimalType(precision, 0)
73
+
74
+ This conversion is controlled by the 'snowpark.connect.integralTypesEmulation' config.
75
+ When disabled, the function returns the input type unchanged.
76
+
77
+ Args:
78
+ t: The DataType to transform
79
+
80
+ Returns:
81
+ The transformed DataType with integral type conversions applied based on precision.
82
+ """
83
+ global _integral_types_conversion_enabled
84
+
85
+ with _client_mode_lock:
86
+ enabled = _integral_types_conversion_enabled
87
+ if not enabled:
88
+ return t
89
+ if isinstance(t, _IntegralType):
90
+ precision = getattr(t, "_precision", None)
91
+
92
+ if precision is None:
93
+ return t
94
+ elif precision == 19:
95
+ return LongType()
96
+ elif precision == 10:
97
+ return IntegerType()
98
+ elif precision == 5:
99
+ return ShortType()
100
+ elif precision == 3:
101
+ return ByteType()
102
+ else:
103
+ return DecimalType(precision, 0)
104
+
105
+ elif isinstance(t, StructType):
106
+ new_fields = [
107
+ StructField(
108
+ field.name,
109
+ emulate_integral_types(field.datatype),
110
+ field.nullable,
111
+ _is_column=field._is_column,
112
+ )
113
+ for field in t.fields
114
+ ]
115
+ return StructType(new_fields)
116
+
117
+ elif isinstance(t, ArrayType):
118
+ return ArrayType(
119
+ emulate_integral_types(t.element_type),
120
+ t.contains_null,
121
+ )
122
+
123
+ elif isinstance(t, MapType):
124
+ return MapType(
125
+ emulate_integral_types(t.key_type),
126
+ emulate_integral_types(t.value_type),
127
+ t.value_contains_null,
128
+ )
129
+
130
+ return t
@@ -12,6 +12,9 @@ from typing import Iterator, Mapping, Optional
12
12
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
13
13
 
14
14
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
15
+ from snowflake.snowpark_connect.type_support import (
16
+ set_integral_types_for_client_default,
17
+ )
15
18
  from snowflake.snowpark_connect.typed_column import TypedColumn
16
19
 
17
20
  # TODO: remove session id from context when we host SAS in Snowflake server
@@ -267,6 +270,11 @@ def set_spark_version(client_type: str) -> None:
267
270
  version = match.group("spark_version") if match else ""
268
271
  _spark_version.set(version)
269
272
 
273
+ # enable integral types (only if config is "client_default")
274
+
275
+ is_python_client = "_SPARK_CONNECT_PYTHON" in client_type
276
+ set_integral_types_for_client_default(is_python_client)
277
+
270
278
 
271
279
  def get_is_aggregate_function() -> tuple[str, bool]:
272
280
  """
@@ -7,11 +7,22 @@ from pyspark.errors import AnalysisException
7
7
  import snowflake.snowpark.types as snowpark_type
8
8
  from snowflake.snowpark import Session
9
9
  from snowflake.snowpark._internal.type_utils import type_string_to_type_object
10
+ from snowflake.snowpark_connect.client.error_utils import attach_custom_error_code
11
+ from snowflake.snowpark_connect.config import (
12
+ get_scala_version,
13
+ is_java_udf_creator_initialized,
14
+ set_java_udf_creator_initialized_state,
15
+ )
16
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
10
17
  from snowflake.snowpark_connect.resources_initializer import (
11
18
  RESOURCE_PATH,
12
- SPARK_COMMON_UTILS_JAR,
13
- SPARK_CONNECT_CLIENT_JAR,
14
- SPARK_SQL_JAR,
19
+ SPARK_COMMON_UTILS_JAR_212,
20
+ SPARK_COMMON_UTILS_JAR_213,
21
+ SPARK_CONNECT_CLIENT_JAR_212,
22
+ SPARK_CONNECT_CLIENT_JAR_213,
23
+ SPARK_SQL_JAR_212,
24
+ SPARK_SQL_JAR_213,
25
+ ensure_scala_udf_jars_uploaded,
15
26
  )
16
27
  from snowflake.snowpark_connect.utils.upload_java_jar import upload_java_udf_jar
17
28
 
@@ -22,7 +33,7 @@ CREATE OR REPLACE TEMPORARY PROCEDURE __SC_JAVA_SP_CREATE_JAVA_UDF(udf_name VARC
22
33
  RETURNS VARCHAR
23
34
  LANGUAGE JAVA
24
35
  RUNTIME_VERSION = 17
25
- PACKAGES = ('com.snowflake:snowpark:latest')
36
+ PACKAGES = ('com.snowflake:snowpark___scala_version__:latest')
26
37
  __snowflake_udf_imports__
27
38
  HANDLER = 'com.snowflake.snowpark_connect.procedures.JavaUDFCreator.process'
28
39
  EXECUTE AS CALLER
@@ -30,19 +41,6 @@ EXECUTE AS CALLER
30
41
  """
31
42
 
32
43
 
33
- _is_initialized = False
34
-
35
-
36
- def is_initialized() -> bool:
37
- global _is_initialized
38
- return _is_initialized
39
-
40
-
41
- def set_java_udf_creator_initialized_state(value: bool) -> None:
42
- global _is_initialized
43
- _is_initialized = value
44
-
45
-
46
44
  class JavaUdf:
47
45
  """
48
46
  Reference class for Java UDFs, providing similar properties like Python UserDefinedFunction.
@@ -70,12 +68,33 @@ class JavaUdf:
70
68
  self._return_type = return_type
71
69
 
72
70
 
71
+ def _scala_static_imports_for_sproc(stage_resource_path: str) -> set[str]:
72
+ scala_version = get_scala_version()
73
+ if scala_version == "2.12":
74
+ return {
75
+ f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_212}",
76
+ f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_212}",
77
+ f"{stage_resource_path}/{SPARK_SQL_JAR_212}",
78
+ }
79
+
80
+ if scala_version == "2.13":
81
+ return {
82
+ f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_213}",
83
+ f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_213}",
84
+ f"{stage_resource_path}/{SPARK_SQL_JAR_213}",
85
+ }
86
+
87
+ # invalid Scala version
88
+ exception = ValueError(
89
+ f"Unsupported Scala version: {scala_version}. Snowpark Connect supports Scala 2.12 and 2.13"
90
+ )
91
+ attach_custom_error_code(exception, ErrorCodes.INVALID_CONFIG_VALUE)
92
+ raise exception
93
+
94
+
73
95
  def get_quoted_imports(session: Session) -> str:
74
96
  stage_resource_path = session.get_session_stage() + RESOURCE_PATH
75
- spark_imports = {
76
- f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR}",
77
- f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR}",
78
- f"{stage_resource_path}/{SPARK_SQL_JAR}",
97
+ spark_imports = _scala_static_imports_for_sproc(stage_resource_path) | {
79
98
  f"{stage_resource_path}/java_udfs-1.0-SNAPSHOT.jar",
80
99
  }
81
100
 
@@ -83,14 +102,21 @@ def get_quoted_imports(session: Session) -> str:
83
102
  """Helper function to wrap strings in single quotes for SQL."""
84
103
  return "'" + s + "'"
85
104
 
86
- return ", ".join(quote_single(x) for x in session._artifact_jars | spark_imports)
105
+ from snowflake.snowpark_connect.config import global_config
87
106
 
107
+ config_imports = global_config.get("snowpark.connect.udf.java.imports", "")
108
+ config_imports = (
109
+ {x.strip() for x in config_imports.strip("[] ").split(",") if x.strip()}
110
+ if config_imports
111
+ else set()
112
+ )
88
113
 
89
- def create_snowflake_imports(session: Session) -> str:
90
- from snowflake.snowpark_connect.resources_initializer import (
91
- ensure_scala_udf_jars_uploaded,
114
+ return ", ".join(
115
+ quote_single(x) for x in session._artifact_jars | spark_imports | config_imports
92
116
  )
93
117
 
118
+
119
+ def create_snowflake_imports(session: Session) -> str:
94
120
  # Make sure that the resource initializer thread is completed before creating Java UDFs since we depend on the jars
95
121
  # uploaded by it.
96
122
  ensure_scala_udf_jars_uploaded()
@@ -99,12 +125,12 @@ def create_snowflake_imports(session: Session) -> str:
99
125
 
100
126
 
101
127
  def create_java_udf(session: Session, function_name: str, java_class: str):
102
- if not is_initialized():
128
+ if not is_java_udf_creator_initialized():
103
129
  upload_java_udf_jar(session)
104
130
  session.sql(
105
131
  SP_TEMPLATE.replace(
106
132
  "__snowflake_udf_imports__", create_snowflake_imports(session)
107
- )
133
+ ).replace("__scala_version__", get_scala_version())
108
134
  ).collect()
109
135
  set_java_udf_creator_initialized_state(True)
110
136
  name = CREATE_JAVA_UDF_PREFIX + function_name
@@ -12,7 +12,6 @@ from snowflake.snowpark_connect.utils.jvm_udf_utils import (
12
12
  ReturnType,
13
13
  Signature,
14
14
  build_jvm_udxf_imports,
15
- cast_java_map_args_from_given_type,
16
15
  map_type_to_java_type,
17
16
  )
18
17
  from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
@@ -41,19 +40,20 @@ import com.snowflake.snowpark_java.types.*;
41
40
 
42
41
  public class JavaUDAF {
43
42
  private final static String OPERATION_FILE = "__operation_file__";
44
- private static scala.Function2<__accumulator_type__, __value_type__, __value_type__> operation = null;
43
+ private static scala.Function2<__reduce_type__, __reduce_type__, __reduce_type__> operation = null;
44
+ private static UdfPacket udfPacket = null;
45
45
 
46
46
  private static void loadOperation() throws IOException, ClassNotFoundException {
47
47
  if (operation != null) {
48
48
  return; // Already loaded
49
49
  }
50
50
 
51
- final UdfPacket udfPacket = com.snowflake.sas.scala.Utils$.MODULE$.deserializeUdfPacket(OPERATION_FILE);
52
- operation = (scala.Function2<__accumulator_type__, __value_type__, __value_type__>) udfPacket.function();
51
+ udfPacket = com.snowflake.sas.scala.Utils$.MODULE$.deserializeUdfPacket(OPERATION_FILE);
52
+ operation = (scala.Function2<__reduce_type__, __reduce_type__, __reduce_type__>) udfPacket.function();
53
53
  }
54
54
 
55
55
  public static class State implements Serializable {
56
- public __accumulator_type__ value = null;
56
+ public __reduce_type__ value = null;
57
57
  public boolean initialized = false;
58
58
  }
59
59
 
@@ -69,10 +69,10 @@ public class JavaUDAF {
69
69
  }
70
70
 
71
71
  if (!state.initialized) {
72
- state.value = input;
72
+ state.value = __mapped_value__;
73
73
  state.initialized = true;
74
74
  } else {
75
- state.value = operation.apply(state.value, input);
75
+ state.value = operation.apply(state.value, __mapped_value__);
76
76
  }
77
77
  return state;
78
78
  }
@@ -115,7 +115,6 @@ class JavaUDAFDef:
115
115
  name: str
116
116
  signature: Signature
117
117
  java_signature: Signature
118
- java_invocation_args: list[str]
119
118
  imports: list[str]
120
119
  null_handling: NullHandling = NullHandling.RETURNS_NULL_ON_NULL_INPUT
121
120
 
@@ -131,17 +130,31 @@ class JavaUDAFDef:
131
130
  Returns:
132
131
  String containing the complete Java code for the UDAF body
133
132
  """
134
- returns_variant = self.signature.returns.data_type == "VARIANT"
133
+ returns_variant = self.signature.returns.data_type.lower() == "variant"
135
134
  return_type = (
136
135
  "Variant" if returns_variant else self.java_signature.params[0].data_type
137
136
  )
138
137
  response_wrapper = (
139
- "new Variant(state.value)" if returns_variant else "state.value"
138
+ "com.snowflake.sas.scala.Utils$.MODULE$.toVariant(state.value, udfPacket)"
139
+ if returns_variant
140
+ else "state.value"
141
+ )
142
+
143
+ is_variant_input = self.java_signature.params[0].data_type.lower() == "variant"
144
+ reduce_type = (
145
+ "Object" if is_variant_input else self.java_signature.params[0].data_type
140
146
  )
141
147
  return (
142
148
  UDAF_TEMPLATE.replace("__operation_file__", self.imports[0].split("/")[-1])
143
149
  .replace("__accumulator_type__", self.java_signature.params[0].data_type)
144
150
  .replace("__value_type__", self.java_signature.params[1].data_type)
151
+ .replace(
152
+ "__mapped_value__",
153
+ "com.snowflake.sas.scala.UdfPacketUtils$.MODULE$.fromVariant(udfPacket, input, 0)"
154
+ if is_variant_input
155
+ else "input",
156
+ )
157
+ .replace("__reduce_type__", reduce_type)
145
158
  .replace("__return_type__", return_type)
146
159
  .replace("__response_wrapper__", response_wrapper)
147
160
  )
@@ -231,12 +244,11 @@ def create_java_udaf_for_reduce_scala_function(
231
244
  A JavaUdaf object representing the Java UDAF.
232
245
  """
233
246
  from snowflake.snowpark_connect.resources_initializer import (
234
- wait_for_resource_initialization,
247
+ ensure_scala_udf_jars_uploaded,
235
248
  )
236
249
 
237
- # Make sure that the resource initializer thread is completed before creating Java UDFs since we depend on the jars
238
- # uploaded by it.
239
- wait_for_resource_initialization()
250
+ # Make sure Scala UDF jars are uploaded before creating Java UDAFs since we depend on them.
251
+ ensure_scala_udf_jars_uploaded()
240
252
 
241
253
  from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
242
254
 
@@ -252,23 +264,26 @@ def create_java_udaf_for_reduce_scala_function(
252
264
 
253
265
  java_input_params: list[Param] = []
254
266
  sql_input_params: list[Param] = []
255
- java_invocation_args: list[str] = [] # arguments passed into the udf function
256
267
  if input_types: # input_types can be None when no arguments are provided
257
268
  for i, input_type in enumerate(input_types):
258
269
  param_name = "arg" + str(i)
270
+ if isinstance(
271
+ input_type,
272
+ (
273
+ snowpark_type.ArrayType,
274
+ snowpark_type.MapType,
275
+ snowpark_type.VariantType,
276
+ ),
277
+ ):
278
+ java_type = "Variant"
279
+ snowflake_type = "Variant"
280
+ else:
281
+ java_type = map_type_to_java_type(input_type)
282
+ snowflake_type = map_type_to_snowflake_type(input_type)
259
283
  # Create the Java arguments and input types string: "arg0: Type0, arg1: Type1, ...".
260
- java_input_params.append(
261
- Param(param_name, map_type_to_java_type(input_type))
262
- )
284
+ java_input_params.append(Param(param_name, java_type))
263
285
  # Create the Snowflake SQL arguments and input types string: "arg0 TYPE0, arg1 TYPE1, ...".
264
- sql_input_params.append(
265
- Param(param_name, map_type_to_snowflake_type(input_type))
266
- )
267
- # In the case of Map input types, we need to cast the argument to the correct type in Java.
268
- # Snowflake SQL Java can only handle MAP[VARCHAR, VARCHAR] as input types.
269
- java_invocation_args.append(
270
- cast_java_map_args_from_given_type(param_name, input_type)
271
- )
286
+ sql_input_params.append(Param(param_name, snowflake_type))
272
287
 
273
288
  java_return_type = map_type_to_java_type(pciudf._original_return_type)
274
289
  # If the SQL return type is a MAP or STRUCT, change this to VARIANT because of issues with Java UDAFs.
@@ -282,7 +297,11 @@ def create_java_udaf_for_reduce_scala_function(
282
297
  )
283
298
  sql_return_type = (
284
299
  "VARIANT"
285
- if (sql_return_type.startswith("MAP") or sql_return_type.startswith("OBJECT"))
300
+ if (
301
+ sql_return_type.startswith("MAP")
302
+ or sql_return_type.startswith("OBJECT")
303
+ or sql_return_type.startswith("ARRAY")
304
+ )
286
305
  else sql_return_type
287
306
  )
288
307
 
@@ -295,7 +314,6 @@ def create_java_udaf_for_reduce_scala_function(
295
314
  java_signature=Signature(
296
315
  params=java_input_params, returns=ReturnType(java_return_type)
297
316
  ),
298
- java_invocation_args=java_invocation_args,
299
317
  )
300
318
  create_udf_sql = udf_def.to_create_function_sql()
301
319
  logger.info(f"Creating Java UDAF: {create_udf_sql}")
@@ -95,7 +95,7 @@ public class JavaUdtfHandler {
95
95
  java.util.Iterator<Variant> javaResult = new java.util.Iterator<Variant>() {
96
96
  public boolean hasNext() { return scalaResult.hasNext(); }
97
97
  public Variant next() {
98
- return com.snowflake.sas.scala.Utils$.MODULE$.toVariant(scalaResult.next());
98
+ return com.snowflake.sas.scala.Utils$.MODULE$.toVariant(scalaResult.next(), udfPacket);
99
99
  }
100
100
  };
101
101
 
@@ -9,16 +9,23 @@ from typing import List, Union
9
9
  import snowflake.snowpark.types as snowpark_type
10
10
  import snowflake.snowpark_connect.includes.python.pyspark.sql.connect.proto.types_pb2 as types_proto
11
11
  from snowflake import snowpark
12
+ from snowflake.snowpark_connect.config import get_scala_version
12
13
  from snowflake.snowpark_connect.error.error_codes import ErrorCodes
13
14
  from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
14
15
  from snowflake.snowpark_connect.resources_initializer import (
15
- JSON_4S_JAR,
16
+ JSON_4S_JAR_212,
17
+ JSON_4S_JAR_213,
16
18
  RESOURCE_PATH,
17
- SAS_SCALA_UDF_JAR,
18
- SCALA_REFLECT_JAR,
19
- SPARK_COMMON_UTILS_JAR,
20
- SPARK_CONNECT_CLIENT_JAR,
21
- SPARK_SQL_JAR,
19
+ SAS_SCALA_UDF_JAR_212,
20
+ SAS_SCALA_UDF_JAR_213,
21
+ SCALA_REFLECT_JAR_212,
22
+ SCALA_REFLECT_JAR_213,
23
+ SPARK_COMMON_UTILS_JAR_212,
24
+ SPARK_COMMON_UTILS_JAR_213,
25
+ SPARK_CONNECT_CLIENT_JAR_212,
26
+ SPARK_CONNECT_CLIENT_JAR_213,
27
+ SPARK_SQL_JAR_212,
28
+ SPARK_SQL_JAR_213,
22
29
  )
23
30
 
24
31
 
@@ -108,15 +115,41 @@ def build_jvm_udxf_imports(
108
115
  )
109
116
 
110
117
  # Format the user jars to be used in the IMPORTS clause of the stored procedure.
111
- return [
112
- closure_binary_file,
113
- f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR}",
114
- f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR}",
115
- f"{stage_resource_path}/{SPARK_SQL_JAR}",
116
- f"{stage_resource_path}/{JSON_4S_JAR}",
117
- f"{stage_resource_path}/{SAS_SCALA_UDF_JAR}",
118
- f"{stage_resource_path}/{SCALA_REFLECT_JAR}", # Required for deserializing Scala lambdas
119
- ] + list(session._artifact_jars)
118
+ return (
119
+ [closure_binary_file]
120
+ + _scala_static_imports_for_udf(stage_resource_path)
121
+ + list(session._artifact_jars)
122
+ )
123
+
124
+
125
+ def _scala_static_imports_for_udf(stage_resource_path: str) -> list[str]:
126
+ scala_version = get_scala_version()
127
+ if scala_version == "2.12":
128
+ return [
129
+ f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_212}",
130
+ f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_212}",
131
+ f"{stage_resource_path}/{SPARK_SQL_JAR_212}",
132
+ f"{stage_resource_path}/{JSON_4S_JAR_212}",
133
+ f"{stage_resource_path}/{SAS_SCALA_UDF_JAR_212}",
134
+ f"{stage_resource_path}/{SCALA_REFLECT_JAR_212}", # Required for deserializing Scala lambdas
135
+ ]
136
+
137
+ if scala_version == "2.13":
138
+ return [
139
+ f"{stage_resource_path}/{SPARK_CONNECT_CLIENT_JAR_213}",
140
+ f"{stage_resource_path}/{SPARK_COMMON_UTILS_JAR_213}",
141
+ f"{stage_resource_path}/{SPARK_SQL_JAR_213}",
142
+ f"{stage_resource_path}/{JSON_4S_JAR_213}",
143
+ f"{stage_resource_path}/{SAS_SCALA_UDF_JAR_213}",
144
+ f"{stage_resource_path}/{SCALA_REFLECT_JAR_213}", # Required for deserializing Scala lambdas
145
+ ]
146
+
147
+ # invalid Scala version
148
+ exception = ValueError(
149
+ f"Unsupported Scala version: {scala_version}. Snowpark Connect supports Scala 2.12 and 2.13"
150
+ )
151
+ attach_custom_error_code(exception, ErrorCodes.INVALID_CONFIG_VALUE)
152
+ raise exception
120
153
 
121
154
 
122
155
  def map_type_to_java_type(