snowpark-connect 1.6.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. snowflake/snowpark_connect/client/server.py +37 -0
  2. snowflake/snowpark_connect/config.py +72 -3
  3. snowflake/snowpark_connect/expression/error_utils.py +28 -0
  4. snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
  5. snowflake/snowpark_connect/expression/map_cast.py +108 -17
  6. snowflake/snowpark_connect/expression/map_udf.py +1 -0
  7. snowflake/snowpark_connect/expression/map_unresolved_function.py +229 -96
  8. snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
  9. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.2.0.jar +0 -0
  10. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
  11. snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
  12. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
  13. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.13-3.5.6.jar +0 -0
  14. snowflake/snowpark_connect/includes/jars/spark-sql_2.13-3.5.6.jar +0 -0
  15. snowflake/snowpark_connect/relation/map_aggregate.py +43 -1
  16. snowflake/snowpark_connect/relation/read/map_read_csv.py +73 -4
  17. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +4 -1
  18. snowflake/snowpark_connect/relation/read/map_read_json.py +4 -1
  19. snowflake/snowpark_connect/relation/read/map_read_parquet.py +4 -1
  20. snowflake/snowpark_connect/relation/read/map_read_socket.py +4 -0
  21. snowflake/snowpark_connect/relation/read/map_read_table.py +4 -1
  22. snowflake/snowpark_connect/relation/read/map_read_text.py +4 -1
  23. snowflake/snowpark_connect/relation/read/reader_config.py +6 -0
  24. snowflake/snowpark_connect/resources_initializer.py +90 -29
  25. snowflake/snowpark_connect/server.py +6 -41
  26. snowflake/snowpark_connect/server_common/__init__.py +4 -1
  27. snowflake/snowpark_connect/type_support.py +130 -0
  28. snowflake/snowpark_connect/utils/context.py +8 -0
  29. snowflake/snowpark_connect/utils/java_stored_procedure.py +53 -27
  30. snowflake/snowpark_connect/utils/java_udaf_utils.py +46 -28
  31. snowflake/snowpark_connect/utils/java_udtf_utils.py +1 -1
  32. snowflake/snowpark_connect/utils/jvm_udf_utils.py +48 -15
  33. snowflake/snowpark_connect/utils/scala_udf_utils.py +98 -22
  34. snowflake/snowpark_connect/utils/telemetry.py +33 -22
  35. snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
  36. snowflake/snowpark_connect/version.py +1 -1
  37. {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +12 -2
  38. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +4 -2
  39. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +46 -37
  40. {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
  41. {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
  42. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
  43. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
  44. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  45. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
  46. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
@@ -11,9 +11,10 @@ import snowflake.snowpark.functions as snowpark_fn
11
11
  from snowflake import snowpark
12
12
  from snowflake.snowpark import Column
13
13
  from snowflake.snowpark._internal.analyzer.unary_expression import Alias
14
- from snowflake.snowpark.types import DataType
14
+ from snowflake.snowpark.types import DataType, StructType
15
15
  from snowflake.snowpark_connect.column_name_handler import (
16
16
  make_column_names_snowpark_compatible,
17
+ make_unique_snowpark_name,
17
18
  )
18
19
  from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
19
20
  from snowflake.snowpark_connect.dataframe_container import (
@@ -57,6 +58,47 @@ def map_group_by_aggregate(
57
58
  *columns.aggregation_expressions()
58
59
  )
59
60
 
61
+ for rel_aggregate_expression, aggregate_original_column in zip(
62
+ rel.aggregate.aggregate_expressions, columns.aggregation_columns
63
+ ):
64
+ aggregate_original_data_type = aggregate_original_column.data_type
65
+
66
+ if not (
67
+ rel_aggregate_expression.HasField("unresolved_function")
68
+ and rel_aggregate_expression.unresolved_function.function_name == "reduce"
69
+ ) or not isinstance(aggregate_original_data_type, StructType):
70
+ continue
71
+
72
+ cols = []
73
+ new_snowpark_column_names = []
74
+ new_snowpark_column_types = [
75
+ field.datatype for field in aggregate_original_data_type.fields
76
+ ]
77
+
78
+ if not result.columns or len(result.columns) != 1:
79
+ raise ValueError(
80
+ "Expected result DataFrame to have exactly one column for reduce(StructType)"
81
+ )
82
+ aggregate_col = snowpark_fn.col(result.columns[0])
83
+
84
+ # Extract each field from the StructType result after aggregation to create separate columns
85
+ for spark_col_name in input_df_container.column_map.get_spark_columns():
86
+ unique_snowpark_name = make_unique_snowpark_name(spark_col_name)
87
+ cols.append(
88
+ snowpark_fn.get(aggregate_col, snowpark_fn.lit(spark_col_name)).alias(
89
+ unique_snowpark_name
90
+ )
91
+ )
92
+ new_snowpark_column_names.append(unique_snowpark_name)
93
+
94
+ result = result.select(*cols)
95
+ return DataFrameContainer.create_with_column_mapping(
96
+ dataframe=result,
97
+ spark_column_names=input_df_container.column_map.get_spark_columns(),
98
+ snowpark_column_names=new_snowpark_column_names,
99
+ snowpark_column_types=new_snowpark_column_types,
100
+ )
101
+
60
102
  # Store aggregate metadata for ORDER BY resolution
61
103
  aggregate_metadata = AggregateMetadata(
62
104
  input_column_map=input_df_container.column_map,
@@ -11,7 +11,18 @@ from pyspark.errors.exceptions.base import AnalysisException
11
11
  import snowflake.snowpark.functions as snowpark_fn
12
12
  from snowflake import snowpark
13
13
  from snowflake.snowpark.dataframe_reader import DataFrameReader
14
- from snowflake.snowpark.types import StringType, StructField, StructType
14
+ from snowflake.snowpark.types import (
15
+ DataType,
16
+ DecimalType,
17
+ DoubleType,
18
+ IntegerType,
19
+ LongType,
20
+ StringType,
21
+ StructField,
22
+ StructType,
23
+ _FractionalType,
24
+ _IntegralType,
25
+ )
15
26
  from snowflake.snowpark_connect.config import global_config, str_to_bool
16
27
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
17
28
  from snowflake.snowpark_connect.error.error_codes import ErrorCodes
@@ -26,6 +37,10 @@ from snowflake.snowpark_connect.relation.read.utils import (
26
37
  get_spark_column_names_from_snowpark_columns,
27
38
  rename_columns_as_snowflake_standard,
28
39
  )
40
+ from snowflake.snowpark_connect.type_support import (
41
+ _integral_types_conversion_enabled,
42
+ emulate_integral_types,
43
+ )
29
44
  from snowflake.snowpark_connect.utils.io_utils import cached_file_format
30
45
  from snowflake.snowpark_connect.utils.telemetry import (
31
46
  SnowparkConnectNotImplementedError,
@@ -108,7 +123,7 @@ def map_read_csv(
108
123
  df = df.union_all(reader.csv(p))
109
124
 
110
125
  if schema is None and not str_to_bool(
111
- str(raw_options.get("inferSchema", "false"))
126
+ str(raw_options.get("inferSchema", raw_options.get("inferschema", "false")))
112
127
  ):
113
128
  df = df.select(
114
129
  [snowpark_fn.col(c).cast("STRING").alias(c) for c in df.schema.names]
@@ -123,7 +138,9 @@ def map_read_csv(
123
138
  dataframe=renamed_df,
124
139
  spark_column_names=spark_column_names,
125
140
  snowpark_column_names=snowpark_column_names,
126
- snowpark_column_types=[f.datatype for f in df.schema.fields],
141
+ snowpark_column_types=[
142
+ _emulate_integral_types_for_csv(f.datatype) for f in df.schema.fields
143
+ ],
127
144
  )
128
145
 
129
146
 
@@ -320,7 +337,13 @@ def read_data(
320
337
  # Create schema with the column names and read CSV
321
338
  if len(headers) > 0:
322
339
  if (
323
- not str_to_bool(str(raw_options.get("inferSchema", "false")))
340
+ not str_to_bool(
341
+ str(
342
+ raw_options.get(
343
+ "inferSchema", raw_options.get("inferschema", "false")
344
+ )
345
+ )
346
+ )
324
347
  and schema is None
325
348
  ):
326
349
  inferred_schema = StructType(
@@ -350,3 +373,49 @@ def read_data(
350
373
 
351
374
  # Fallback: no headers, shouldn't reach here
352
375
  return reader.csv(path)
376
+
377
+
378
+ def _emulate_integral_types_for_csv(t: DataType) -> DataType:
379
+ """
380
+ CSV requires different type handling to match OSS Spark CSV schema inference.
381
+
382
+ After applying emulate_integral_types, converts to Spark CSV types:
383
+ - IntegerType, ShortType, ByteType -> IntegerType
384
+ - LongType -> LongType
385
+ - DecimalType with scale > 0 -> DoubleType
386
+ - DecimalType with precision > 18 -> DecimalType (too big for long)
387
+ - DecimalType with precision > 9 -> LongType
388
+ - DecimalType with precision <= 9 -> IntegerType
389
+ - FloatType, DoubleType -> DoubleType
390
+ """
391
+ if not _integral_types_conversion_enabled:
392
+ return t
393
+
394
+ # First apply standard integral type conversion
395
+ t = emulate_integral_types(t)
396
+
397
+ if isinstance(t, LongType):
398
+ return LongType()
399
+
400
+ elif isinstance(t, _IntegralType):
401
+ # ByteType, ShortType, IntegerType -> IntegerType
402
+ return IntegerType()
403
+
404
+ elif isinstance(t, DecimalType):
405
+ # DecimalType with scale > 0 means it has decimal places -> DoubleType
406
+ if t.scale > 0:
407
+ return DoubleType()
408
+ # DecimalType with scale = 0 is integral
409
+ if t.precision > 18:
410
+ # Too big for long, keep as DecimalType
411
+ return DecimalType(t.precision, 0)
412
+ elif t.precision > 9:
413
+ return LongType()
414
+ else:
415
+ return IntegerType()
416
+
417
+ elif isinstance(t, _FractionalType):
418
+ # FloatType, DoubleType -> DoubleType
419
+ return DoubleType()
420
+
421
+ return t
@@ -16,6 +16,7 @@ from snowflake.snowpark_connect.relation.read.utils import (
16
16
  Connection,
17
17
  rename_columns_as_snowflake_standard,
18
18
  )
19
+ from snowflake.snowpark_connect.type_support import emulate_integral_types
19
20
  from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
20
21
 
21
22
 
@@ -112,7 +113,9 @@ def map_read_jdbc(
112
113
  dataframe=renamed_df,
113
114
  spark_column_names=true_names,
114
115
  snowpark_column_names=snowpark_cols,
115
- snowpark_column_types=[f.datatype for f in df.schema.fields],
116
+ snowpark_column_types=[
117
+ emulate_integral_types(f.datatype) for f in df.schema.fields
118
+ ],
116
119
  )
117
120
  except Exception as e:
118
121
  exception = Exception(f"Error accessing JDBC datasource for read: {e}")
@@ -45,6 +45,7 @@ from snowflake.snowpark_connect.type_mapping import (
45
45
  map_simple_types,
46
46
  merge_different_types,
47
47
  )
48
+ from snowflake.snowpark_connect.type_support import emulate_integral_types
48
49
  from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
49
50
  from snowflake.snowpark_connect.utils.telemetry import (
50
51
  SnowparkConnectNotImplementedError,
@@ -148,7 +149,9 @@ def map_read_json(
148
149
  dataframe=renamed_df,
149
150
  spark_column_names=spark_column_names,
150
151
  snowpark_column_names=snowpark_column_names,
151
- snowpark_column_types=[f.datatype for f in df.schema.fields],
152
+ snowpark_column_types=[
153
+ emulate_integral_types(f.datatype) for f in df.schema.fields
154
+ ],
152
155
  )
153
156
 
154
157
 
@@ -44,6 +44,7 @@ from snowflake.snowpark_connect.relation.read.utils import (
44
44
  apply_metadata_exclusion_pattern,
45
45
  rename_columns_as_snowflake_standard,
46
46
  )
47
+ from snowflake.snowpark_connect.type_support import emulate_integral_types
47
48
  from snowflake.snowpark_connect.utils.io_utils import cached_file_format
48
49
  from snowflake.snowpark_connect.utils.telemetry import (
49
50
  SnowparkConnectNotImplementedError,
@@ -126,7 +127,9 @@ def map_read_parquet(
126
127
  dataframe=renamed_df,
127
128
  spark_column_names=[analyzer_utils.unquote_if_quoted(c) for c in df.columns],
128
129
  snowpark_column_names=snowpark_column_names,
129
- snowpark_column_types=[f.datatype for f in df.schema.fields],
130
+ snowpark_column_types=[
131
+ emulate_integral_types(f.datatype) for f in df.schema.fields
132
+ ],
130
133
  can_be_cached=can_be_cached,
131
134
  )
132
135
 
@@ -11,6 +11,7 @@ from snowflake import snowpark
11
11
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
12
12
  from snowflake.snowpark_connect.error.error_codes import ErrorCodes
13
13
  from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
14
+ from snowflake.snowpark_connect.type_support import emulate_integral_types
14
15
  from snowflake.snowpark_connect.utils.telemetry import (
15
16
  SnowparkConnectNotImplementedError,
16
17
  )
@@ -58,6 +59,9 @@ def map_read_socket(
58
59
  dataframe=df,
59
60
  spark_column_names=[spark_cname],
60
61
  snowpark_column_names=[snowpark_cname],
62
+ snowpark_column_types=[
63
+ emulate_integral_types(f.datatype) for f in df.schema.fields
64
+ ],
61
65
  )
62
66
  except OSError as e:
63
67
  exception = Exception(f"Error connecting to {host}:{port} - {e}")
@@ -24,6 +24,7 @@ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_cod
24
24
  from snowflake.snowpark_connect.relation.read.utils import (
25
25
  rename_columns_as_snowflake_standard,
26
26
  )
27
+ from snowflake.snowpark_connect.type_support import emulate_integral_types
27
28
  from snowflake.snowpark_connect.utils.context import get_processed_views
28
29
  from snowflake.snowpark_connect.utils.identifiers import (
29
30
  split_fully_qualified_spark_name,
@@ -58,7 +59,9 @@ def post_process_df(
58
59
  dataframe=renamed_df,
59
60
  spark_column_names=true_names,
60
61
  snowpark_column_names=snowpark_column_names,
61
- snowpark_column_types=[f.datatype for f in df.schema.fields],
62
+ snowpark_column_types=[
63
+ emulate_integral_types(f.datatype) for f in df.schema.fields
64
+ ],
62
65
  column_qualifiers=[{ColumnQualifier(tuple(name_parts))} for _ in true_names]
63
66
  if source_table_name
64
67
  else None,
@@ -14,6 +14,7 @@ from snowflake.snowpark_connect.relation.read.utils import (
14
14
  get_spark_column_names_from_snowpark_columns,
15
15
  rename_columns_as_snowflake_standard,
16
16
  )
17
+ from snowflake.snowpark_connect.type_support import emulate_integral_types
17
18
  from snowflake.snowpark_connect.utils.io_utils import file_format
18
19
  from snowflake.snowpark_connect.utils.telemetry import (
19
20
  SnowparkConnectNotImplementedError,
@@ -117,5 +118,7 @@ def map_read_text(
117
118
  dataframe=renamed_df,
118
119
  spark_column_names=spark_column_names,
119
120
  snowpark_column_names=snowpark_column_names,
120
- snowpark_column_types=[f.datatype for f in df.schema.fields],
121
+ snowpark_column_types=[
122
+ emulate_integral_types(f.datatype) for f in df.schema.fields
123
+ ],
121
124
  )
@@ -441,4 +441,10 @@ class ParquetReaderConfig(ReaderWriterConfig):
441
441
  "snowpark.connect.parquet.useVectorizedScanner"
442
442
  )
443
443
 
444
+ # Set USE_LOGICAL_TYPE from global config to properly handle Parquet logical types like TIMESTAMP.
445
+ # Without this, Parquet TIMESTAMP (INT64 physical) is incorrectly read as NUMBER(38,0).
446
+ snowpark_args["USE_LOGICAL_TYPE"] = global_config._get_config_setting(
447
+ "snowpark.connect.parquet.useLogicalType"
448
+ )
449
+
444
450
  return snowpark_args
@@ -3,7 +3,12 @@
3
3
  #
4
4
  import threading
5
5
  import time
6
+ from collections.abc import Callable
7
+ from pathlib import Path
6
8
 
9
+ from snowflake.snowpark_connect.client.error_utils import attach_custom_error_code
10
+ from snowflake.snowpark_connect.config import get_scala_version
11
+ from snowflake.snowpark_connect.error.error_codes import ErrorCodes
7
12
  from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
8
13
  from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
9
14
 
@@ -11,22 +16,31 @@ SPARK_VERSION = "3.5.6"
11
16
  RESOURCE_PATH = "/snowflake/snowpark_connect/resources"
12
17
 
13
18
  # On demand Scala UDF jar upload state - separate from general resource initialization
14
- _scala_jars_uploaded = threading.Event()
15
- _scala_jars_lock = threading.Lock()
19
+ _scala_2_12_jars_uploaded = threading.Event()
20
+ _scala_2_12_jars_lock = threading.Lock()
21
+ _scala_2_13_jars_uploaded = threading.Event()
22
+ _scala_2_13_jars_lock = threading.Lock()
16
23
 
17
24
  # Define Scala resource names
18
- SPARK_SQL_JAR = f"spark-sql_2.12-{SPARK_VERSION}.jar"
19
- SPARK_CONNECT_CLIENT_JAR = f"spark-connect-client-jvm_2.12-{SPARK_VERSION}.jar"
20
- SPARK_COMMON_UTILS_JAR = f"spark-common-utils_2.12-{SPARK_VERSION}.jar"
21
- SAS_SCALA_UDF_JAR = "sas-scala-udf_2.12-0.2.0.jar"
22
- JSON_4S_JAR = "json4s-ast_2.12-3.7.0-M11.jar"
23
- SCALA_REFLECT_JAR = "scala-reflect-2.12.18.jar"
24
-
25
-
26
- def _upload_scala_udf_jars_impl() -> None:
25
+ SPARK_SQL_JAR_212 = f"spark-sql_2.12-{SPARK_VERSION}.jar"
26
+ SPARK_CONNECT_CLIENT_JAR_212 = f"spark-connect-client-jvm_2.12-{SPARK_VERSION}.jar"
27
+ SPARK_COMMON_UTILS_JAR_212 = f"spark-common-utils_2.12-{SPARK_VERSION}.jar"
28
+ SAS_SCALA_UDF_JAR_212 = "sas-scala-udf_2.12-0.2.0.jar"
29
+ JSON_4S_JAR_212 = "json4s-ast_2.12-3.7.0-M11.jar"
30
+ SCALA_REFLECT_JAR_212 = "scala-reflect-2.12.18.jar"
31
+
32
+ # Static dependencies for Scala 2.13
33
+ SPARK_SQL_JAR_213 = f"spark-sql_2.13-{SPARK_VERSION}.jar"
34
+ SPARK_CONNECT_CLIENT_JAR_213 = f"spark-connect-client-jvm_2.13-{SPARK_VERSION}.jar"
35
+ SPARK_COMMON_UTILS_JAR_213 = f"spark-common-utils_2.13-{SPARK_VERSION}.jar"
36
+ SAS_SCALA_UDF_JAR_213 = "sas-scala-udf_2.13-0.2.0.jar"
37
+ JSON_4S_JAR_213 = "json4s-ast_2.13-3.7.0-M11.jar"
38
+ SCALA_REFLECT_JAR_213 = "scala-reflect-2.13.16.jar"
39
+
40
+
41
+ def _upload_scala_udf_jars(jar_files: list[str]) -> None:
27
42
  """Upload Spark jar files required for creating Scala UDFs.
28
43
  This is the internal implementation - use ensure_scala_udf_jars_uploaded() for thread-safe lazy loading."""
29
- from pathlib import Path
30
44
 
31
45
  session = get_or_create_snowpark_session()
32
46
  stage = session.get_session_stage()
@@ -34,15 +48,6 @@ def _upload_scala_udf_jars_impl() -> None:
34
48
  import snowpark_connect_deps_1
35
49
  import snowpark_connect_deps_2
36
50
 
37
- jar_files = [
38
- SPARK_SQL_JAR,
39
- SPARK_CONNECT_CLIENT_JAR,
40
- SPARK_COMMON_UTILS_JAR,
41
- SAS_SCALA_UDF_JAR,
42
- JSON_4S_JAR,
43
- SCALA_REFLECT_JAR, # Required for deserializing Scala lambdas
44
- ]
45
-
46
51
  # Path to includes/jars directory
47
52
  includes_jars_dir = Path(__file__).parent / "includes" / "jars"
48
53
 
@@ -78,31 +83,87 @@ def _upload_scala_udf_jars_impl() -> None:
78
83
  raise RuntimeError(f"Failed to upload JAR {jar_name}: {e}")
79
84
 
80
85
 
81
- def ensure_scala_udf_jars_uploaded() -> None:
82
- """Ensure Scala UDF jars are uploaded to Snowflake, uploading them lazily if not already done.
83
- This function is thread-safe and will only upload once even if called from multiple threads."""
86
+ def _upload_scala_2_12_jars() -> None:
87
+ scala_2_12_jars = [
88
+ SPARK_SQL_JAR_212,
89
+ SPARK_CONNECT_CLIENT_JAR_212,
90
+ SPARK_COMMON_UTILS_JAR_212,
91
+ SAS_SCALA_UDF_JAR_212,
92
+ JSON_4S_JAR_212,
93
+ SCALA_REFLECT_JAR_212, # Required for deserializing Scala lambdas
94
+ ]
95
+ _upload_scala_udf_jars(scala_2_12_jars)
84
96
 
97
+
98
+ def _upload_scala_2_13_jars() -> None:
99
+ scala_2_13_jars = [
100
+ SPARK_SQL_JAR_213,
101
+ SPARK_CONNECT_CLIENT_JAR_213,
102
+ SPARK_COMMON_UTILS_JAR_213,
103
+ SAS_SCALA_UDF_JAR_213,
104
+ JSON_4S_JAR_213,
105
+ SCALA_REFLECT_JAR_213,
106
+ ]
107
+ _upload_scala_udf_jars(scala_2_13_jars)
108
+
109
+
110
+ def _ensure_configured_scala_jars_uploaded(
111
+ jars_uploaded: threading.Event, lock: threading.Lock, upload_fn: Callable[[], None]
112
+ ) -> None:
113
+ """
114
+ Ensure Scala UDF jars are uploaded to Snowflake, uploading them lazily if not already done.
115
+ This function is thread-safe and will only upload once even if called from multiple threads.
116
+
117
+ Uses the given upload_fn to upload Scala jars if the jars_uploaded event is not set yet.
118
+ """
85
119
  # Fast path: if already uploaded, return immediately without acquiring lock
86
- if _scala_jars_uploaded.is_set():
120
+ if jars_uploaded.is_set():
87
121
  return
88
122
 
89
123
  # Slow path: need to upload, acquire lock to ensure only one thread does it
90
- with _scala_jars_lock:
124
+ with lock:
91
125
  # Double-check pattern: another thread might have uploaded while we waited for the lock
92
- if _scala_jars_uploaded.is_set():
126
+ if jars_uploaded.is_set():
93
127
  return
94
128
 
95
129
  try:
96
130
  start_time = time.time()
97
131
  logger.info("Uploading Scala UDF jars on-demand...")
98
- _upload_scala_udf_jars_impl()
99
- _scala_jars_uploaded.set()
132
+ upload_fn()
133
+ jars_uploaded.set()
100
134
  logger.info(f"Scala UDF jars uploaded in {time.time() - start_time:.2f}s")
101
135
  except Exception as e:
102
136
  logger.error(f"Failed to upload Scala UDF jars: {e}")
103
137
  raise
104
138
 
105
139
 
140
+ def ensure_scala_udf_jars_uploaded() -> None:
141
+ """
142
+ Public function to make sure Scala jars are uploaded and available for imports.
143
+ """
144
+ scala_version = get_scala_version()
145
+
146
+ match scala_version:
147
+ case "2.12":
148
+ _ensure_configured_scala_jars_uploaded(
149
+ _scala_2_12_jars_uploaded,
150
+ _scala_2_12_jars_lock,
151
+ _upload_scala_2_12_jars,
152
+ )
153
+ case "2.13":
154
+ _ensure_configured_scala_jars_uploaded(
155
+ _scala_2_13_jars_uploaded,
156
+ _scala_2_13_jars_lock,
157
+ _upload_scala_2_13_jars,
158
+ )
159
+ case _:
160
+ exception = ValueError(
161
+ f"Unsupported Scala version: {scala_version}. Snowpark Connect supports Scala 2.12 and 2.13"
162
+ )
163
+ attach_custom_error_code(exception, ErrorCodes.INVALID_CONFIG_VALUE)
164
+ raise exception
165
+
166
+
106
167
  def initialize_resources() -> None:
107
168
  """Initialize all expensive resources. We should initialize what we can here, so that actual rpc calls like
108
169
  ExecutePlan are as fast as possible."""
@@ -45,7 +45,10 @@ import snowflake.snowpark_connect.proto.control_pb2_grpc as control_grpc
45
45
  import snowflake.snowpark_connect.tcm as tcm
46
46
  from snowflake import snowpark
47
47
  from snowflake.snowpark_connect.analyze_plan.map_tree_string import map_tree_string
48
- from snowflake.snowpark_connect.config import route_config_proto
48
+ from snowflake.snowpark_connect.config import (
49
+ route_config_proto,
50
+ set_java_udf_creator_initialized_state,
51
+ )
49
52
  from snowflake.snowpark_connect.constants import SERVER_SIDE_SESSION_ID
50
53
  from snowflake.snowpark_connect.control_server import ControlServicer
51
54
  from snowflake.snowpark_connect.error.error_codes import ErrorCodes
@@ -112,9 +115,6 @@ from snowflake.snowpark_connect.utils.interrupt import (
112
115
  interrupt_queries_with_tag,
113
116
  interrupt_query,
114
117
  )
115
- from snowflake.snowpark_connect.utils.java_stored_procedure import (
116
- set_java_udf_creator_initialized_state,
117
- )
118
118
  from snowflake.snowpark_connect.utils.open_telemetry import (
119
119
  is_telemetry_enabled,
120
120
  otel_attach_context,
@@ -1126,51 +1126,16 @@ def start_jvm():
1126
1126
  attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
1127
1127
  raise exception
1128
1128
 
1129
- import pathlib
1130
- import zipfile
1131
-
1132
- import snowflake.snowpark_connect
1133
-
1134
1129
  # Import both JAR dependency packages
1135
1130
  import snowpark_connect_deps_1
1136
1131
  import snowpark_connect_deps_2
1137
1132
 
1138
- # First, add JARs from includes/jars directory
1139
- pyspark_jars = (
1140
- pathlib.Path(snowflake.snowpark_connect.__file__).parent / "includes" / "jars"
1141
- )
1142
-
1143
- if "dataframe_processor.zip" in str(pyspark_jars):
1144
- # importlib.resource doesn't work when local stage package is used in TCM
1145
- zip_path = pathlib.Path(
1146
- snowflake.snowpark_connect.__file__
1147
- ).parent.parent.parent
1148
- temp_dir = tempfile.gettempdir()
1149
- extract_folder = "snowflake/snowpark_connect/includes/jars/" # Folder to extract (must end with '/')
1150
-
1151
- with zipfile.ZipFile(zip_path, "r") as zip_ref:
1152
- for member in zip_ref.namelist():
1153
- if member.startswith(extract_folder):
1154
- zip_ref.extract(member, path=temp_dir)
1155
- pyspark_jars = pathlib.Path(temp_dir) / extract_folder
1156
-
1157
- included_jar_names = set()
1158
-
1159
- if pyspark_jars.exists():
1160
- for jar_path in pyspark_jars.glob(
1161
- "**/*.jar"
1162
- ): # Use **/*.jar to handle nested paths in TCM
1163
- jpype.addClassPath(str(jar_path))
1164
- included_jar_names.add(jar_path.name)
1165
-
1166
- # Load jar files from both packages, skipping those already loaded from includes/jars
1133
+ # Load all the jar files from both packages
1167
1134
  jar_path_list = (
1168
1135
  snowpark_connect_deps_1.list_jars() + snowpark_connect_deps_2.list_jars()
1169
1136
  )
1170
1137
  for jar_path in jar_path_list:
1171
- # Skip if this JAR was already loaded from includes/jars
1172
- if jar_path.name not in included_jar_names:
1173
- jpype.addClassPath(jar_path)
1138
+ jpype.addClassPath(jar_path)
1174
1139
 
1175
1140
  # TODO: Should remove convertStrings, but it breaks the JDBC code.
1176
1141
  jvm_settings: list[str] = list(
@@ -369,7 +369,10 @@ def _setup_spark_environment(setup_java_home: bool = True) -> None:
369
369
  lightweight client servers that don't need JVM.
370
370
  """
371
371
  if setup_java_home:
372
- if os.environ.get("JAVA_HOME") is None:
372
+ if (
373
+ os.environ.get("JAVA_HOME") is None
374
+ or str(os.environ.get("JAVA_HOME")).strip() == ""
375
+ ):
373
376
  try:
374
377
  # For Notebooks on SPCS
375
378
  from jdk4py import JAVA_HOME