snowpark-connect 1.6.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/client/server.py +37 -0
- snowflake/snowpark_connect/config.py +72 -3
- snowflake/snowpark_connect/expression/error_utils.py +28 -0
- snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
- snowflake/snowpark_connect/expression/map_cast.py +108 -17
- snowflake/snowpark_connect/expression/map_udf.py +1 -0
- snowflake/snowpark_connect/expression/map_unresolved_function.py +229 -96
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/relation/map_aggregate.py +43 -1
- snowflake/snowpark_connect/relation/read/map_read_csv.py +73 -4
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_json.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_socket.py +4 -0
- snowflake/snowpark_connect/relation/read/map_read_table.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_text.py +4 -1
- snowflake/snowpark_connect/relation/read/reader_config.py +6 -0
- snowflake/snowpark_connect/resources_initializer.py +90 -29
- snowflake/snowpark_connect/server.py +6 -41
- snowflake/snowpark_connect/server_common/__init__.py +4 -1
- snowflake/snowpark_connect/type_support.py +130 -0
- snowflake/snowpark_connect/utils/context.py +8 -0
- snowflake/snowpark_connect/utils/java_stored_procedure.py +53 -27
- snowflake/snowpark_connect/utils/java_udaf_utils.py +46 -28
- snowflake/snowpark_connect/utils/java_udtf_utils.py +1 -1
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +48 -15
- snowflake/snowpark_connect/utils/scala_udf_utils.py +98 -22
- snowflake/snowpark_connect/utils/telemetry.py +33 -22
- snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +12 -2
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +4 -2
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +46 -37
- {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -11,9 +11,10 @@ import snowflake.snowpark.functions as snowpark_fn
|
|
|
11
11
|
from snowflake import snowpark
|
|
12
12
|
from snowflake.snowpark import Column
|
|
13
13
|
from snowflake.snowpark._internal.analyzer.unary_expression import Alias
|
|
14
|
-
from snowflake.snowpark.types import DataType
|
|
14
|
+
from snowflake.snowpark.types import DataType, StructType
|
|
15
15
|
from snowflake.snowpark_connect.column_name_handler import (
|
|
16
16
|
make_column_names_snowpark_compatible,
|
|
17
|
+
make_unique_snowpark_name,
|
|
17
18
|
)
|
|
18
19
|
from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
|
|
19
20
|
from snowflake.snowpark_connect.dataframe_container import (
|
|
@@ -57,6 +58,47 @@ def map_group_by_aggregate(
|
|
|
57
58
|
*columns.aggregation_expressions()
|
|
58
59
|
)
|
|
59
60
|
|
|
61
|
+
for rel_aggregate_expression, aggregate_original_column in zip(
|
|
62
|
+
rel.aggregate.aggregate_expressions, columns.aggregation_columns
|
|
63
|
+
):
|
|
64
|
+
aggregate_original_data_type = aggregate_original_column.data_type
|
|
65
|
+
|
|
66
|
+
if not (
|
|
67
|
+
rel_aggregate_expression.HasField("unresolved_function")
|
|
68
|
+
and rel_aggregate_expression.unresolved_function.function_name == "reduce"
|
|
69
|
+
) or not isinstance(aggregate_original_data_type, StructType):
|
|
70
|
+
continue
|
|
71
|
+
|
|
72
|
+
cols = []
|
|
73
|
+
new_snowpark_column_names = []
|
|
74
|
+
new_snowpark_column_types = [
|
|
75
|
+
field.datatype for field in aggregate_original_data_type.fields
|
|
76
|
+
]
|
|
77
|
+
|
|
78
|
+
if not result.columns or len(result.columns) != 1:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"Expected result DataFrame to have exactly one column for reduce(StructType)"
|
|
81
|
+
)
|
|
82
|
+
aggregate_col = snowpark_fn.col(result.columns[0])
|
|
83
|
+
|
|
84
|
+
# Extract each field from the StructType result after aggregation to create separate columns
|
|
85
|
+
for spark_col_name in input_df_container.column_map.get_spark_columns():
|
|
86
|
+
unique_snowpark_name = make_unique_snowpark_name(spark_col_name)
|
|
87
|
+
cols.append(
|
|
88
|
+
snowpark_fn.get(aggregate_col, snowpark_fn.lit(spark_col_name)).alias(
|
|
89
|
+
unique_snowpark_name
|
|
90
|
+
)
|
|
91
|
+
)
|
|
92
|
+
new_snowpark_column_names.append(unique_snowpark_name)
|
|
93
|
+
|
|
94
|
+
result = result.select(*cols)
|
|
95
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
96
|
+
dataframe=result,
|
|
97
|
+
spark_column_names=input_df_container.column_map.get_spark_columns(),
|
|
98
|
+
snowpark_column_names=new_snowpark_column_names,
|
|
99
|
+
snowpark_column_types=new_snowpark_column_types,
|
|
100
|
+
)
|
|
101
|
+
|
|
60
102
|
# Store aggregate metadata for ORDER BY resolution
|
|
61
103
|
aggregate_metadata = AggregateMetadata(
|
|
62
104
|
input_column_map=input_df_container.column_map,
|
|
@@ -11,7 +11,18 @@ from pyspark.errors.exceptions.base import AnalysisException
|
|
|
11
11
|
import snowflake.snowpark.functions as snowpark_fn
|
|
12
12
|
from snowflake import snowpark
|
|
13
13
|
from snowflake.snowpark.dataframe_reader import DataFrameReader
|
|
14
|
-
from snowflake.snowpark.types import
|
|
14
|
+
from snowflake.snowpark.types import (
|
|
15
|
+
DataType,
|
|
16
|
+
DecimalType,
|
|
17
|
+
DoubleType,
|
|
18
|
+
IntegerType,
|
|
19
|
+
LongType,
|
|
20
|
+
StringType,
|
|
21
|
+
StructField,
|
|
22
|
+
StructType,
|
|
23
|
+
_FractionalType,
|
|
24
|
+
_IntegralType,
|
|
25
|
+
)
|
|
15
26
|
from snowflake.snowpark_connect.config import global_config, str_to_bool
|
|
16
27
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
17
28
|
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
@@ -26,6 +37,10 @@ from snowflake.snowpark_connect.relation.read.utils import (
|
|
|
26
37
|
get_spark_column_names_from_snowpark_columns,
|
|
27
38
|
rename_columns_as_snowflake_standard,
|
|
28
39
|
)
|
|
40
|
+
from snowflake.snowpark_connect.type_support import (
|
|
41
|
+
_integral_types_conversion_enabled,
|
|
42
|
+
emulate_integral_types,
|
|
43
|
+
)
|
|
29
44
|
from snowflake.snowpark_connect.utils.io_utils import cached_file_format
|
|
30
45
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
31
46
|
SnowparkConnectNotImplementedError,
|
|
@@ -108,7 +123,7 @@ def map_read_csv(
|
|
|
108
123
|
df = df.union_all(reader.csv(p))
|
|
109
124
|
|
|
110
125
|
if schema is None and not str_to_bool(
|
|
111
|
-
str(raw_options.get("inferSchema", "false"))
|
|
126
|
+
str(raw_options.get("inferSchema", raw_options.get("inferschema", "false")))
|
|
112
127
|
):
|
|
113
128
|
df = df.select(
|
|
114
129
|
[snowpark_fn.col(c).cast("STRING").alias(c) for c in df.schema.names]
|
|
@@ -123,7 +138,9 @@ def map_read_csv(
|
|
|
123
138
|
dataframe=renamed_df,
|
|
124
139
|
spark_column_names=spark_column_names,
|
|
125
140
|
snowpark_column_names=snowpark_column_names,
|
|
126
|
-
snowpark_column_types=[
|
|
141
|
+
snowpark_column_types=[
|
|
142
|
+
_emulate_integral_types_for_csv(f.datatype) for f in df.schema.fields
|
|
143
|
+
],
|
|
127
144
|
)
|
|
128
145
|
|
|
129
146
|
|
|
@@ -320,7 +337,13 @@ def read_data(
|
|
|
320
337
|
# Create schema with the column names and read CSV
|
|
321
338
|
if len(headers) > 0:
|
|
322
339
|
if (
|
|
323
|
-
not str_to_bool(
|
|
340
|
+
not str_to_bool(
|
|
341
|
+
str(
|
|
342
|
+
raw_options.get(
|
|
343
|
+
"inferSchema", raw_options.get("inferschema", "false")
|
|
344
|
+
)
|
|
345
|
+
)
|
|
346
|
+
)
|
|
324
347
|
and schema is None
|
|
325
348
|
):
|
|
326
349
|
inferred_schema = StructType(
|
|
@@ -350,3 +373,49 @@ def read_data(
|
|
|
350
373
|
|
|
351
374
|
# Fallback: no headers, shouldn't reach here
|
|
352
375
|
return reader.csv(path)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def _emulate_integral_types_for_csv(t: DataType) -> DataType:
|
|
379
|
+
"""
|
|
380
|
+
CSV requires different type handling to match OSS Spark CSV schema inference.
|
|
381
|
+
|
|
382
|
+
After applying emulate_integral_types, converts to Spark CSV types:
|
|
383
|
+
- IntegerType, ShortType, ByteType -> IntegerType
|
|
384
|
+
- LongType -> LongType
|
|
385
|
+
- DecimalType with scale > 0 -> DoubleType
|
|
386
|
+
- DecimalType with precision > 18 -> DecimalType (too big for long)
|
|
387
|
+
- DecimalType with precision > 9 -> LongType
|
|
388
|
+
- DecimalType with precision <= 9 -> IntegerType
|
|
389
|
+
- FloatType, DoubleType -> DoubleType
|
|
390
|
+
"""
|
|
391
|
+
if not _integral_types_conversion_enabled:
|
|
392
|
+
return t
|
|
393
|
+
|
|
394
|
+
# First apply standard integral type conversion
|
|
395
|
+
t = emulate_integral_types(t)
|
|
396
|
+
|
|
397
|
+
if isinstance(t, LongType):
|
|
398
|
+
return LongType()
|
|
399
|
+
|
|
400
|
+
elif isinstance(t, _IntegralType):
|
|
401
|
+
# ByteType, ShortType, IntegerType -> IntegerType
|
|
402
|
+
return IntegerType()
|
|
403
|
+
|
|
404
|
+
elif isinstance(t, DecimalType):
|
|
405
|
+
# DecimalType with scale > 0 means it has decimal places -> DoubleType
|
|
406
|
+
if t.scale > 0:
|
|
407
|
+
return DoubleType()
|
|
408
|
+
# DecimalType with scale = 0 is integral
|
|
409
|
+
if t.precision > 18:
|
|
410
|
+
# Too big for long, keep as DecimalType
|
|
411
|
+
return DecimalType(t.precision, 0)
|
|
412
|
+
elif t.precision > 9:
|
|
413
|
+
return LongType()
|
|
414
|
+
else:
|
|
415
|
+
return IntegerType()
|
|
416
|
+
|
|
417
|
+
elif isinstance(t, _FractionalType):
|
|
418
|
+
# FloatType, DoubleType -> DoubleType
|
|
419
|
+
return DoubleType()
|
|
420
|
+
|
|
421
|
+
return t
|
|
@@ -16,6 +16,7 @@ from snowflake.snowpark_connect.relation.read.utils import (
|
|
|
16
16
|
Connection,
|
|
17
17
|
rename_columns_as_snowflake_standard,
|
|
18
18
|
)
|
|
19
|
+
from snowflake.snowpark_connect.type_support import emulate_integral_types
|
|
19
20
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
20
21
|
|
|
21
22
|
|
|
@@ -112,7 +113,9 @@ def map_read_jdbc(
|
|
|
112
113
|
dataframe=renamed_df,
|
|
113
114
|
spark_column_names=true_names,
|
|
114
115
|
snowpark_column_names=snowpark_cols,
|
|
115
|
-
snowpark_column_types=[
|
|
116
|
+
snowpark_column_types=[
|
|
117
|
+
emulate_integral_types(f.datatype) for f in df.schema.fields
|
|
118
|
+
],
|
|
116
119
|
)
|
|
117
120
|
except Exception as e:
|
|
118
121
|
exception = Exception(f"Error accessing JDBC datasource for read: {e}")
|
|
@@ -45,6 +45,7 @@ from snowflake.snowpark_connect.type_mapping import (
|
|
|
45
45
|
map_simple_types,
|
|
46
46
|
merge_different_types,
|
|
47
47
|
)
|
|
48
|
+
from snowflake.snowpark_connect.type_support import emulate_integral_types
|
|
48
49
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
49
50
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
50
51
|
SnowparkConnectNotImplementedError,
|
|
@@ -148,7 +149,9 @@ def map_read_json(
|
|
|
148
149
|
dataframe=renamed_df,
|
|
149
150
|
spark_column_names=spark_column_names,
|
|
150
151
|
snowpark_column_names=snowpark_column_names,
|
|
151
|
-
snowpark_column_types=[
|
|
152
|
+
snowpark_column_types=[
|
|
153
|
+
emulate_integral_types(f.datatype) for f in df.schema.fields
|
|
154
|
+
],
|
|
152
155
|
)
|
|
153
156
|
|
|
154
157
|
|
|
@@ -44,6 +44,7 @@ from snowflake.snowpark_connect.relation.read.utils import (
|
|
|
44
44
|
apply_metadata_exclusion_pattern,
|
|
45
45
|
rename_columns_as_snowflake_standard,
|
|
46
46
|
)
|
|
47
|
+
from snowflake.snowpark_connect.type_support import emulate_integral_types
|
|
47
48
|
from snowflake.snowpark_connect.utils.io_utils import cached_file_format
|
|
48
49
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
49
50
|
SnowparkConnectNotImplementedError,
|
|
@@ -126,7 +127,9 @@ def map_read_parquet(
|
|
|
126
127
|
dataframe=renamed_df,
|
|
127
128
|
spark_column_names=[analyzer_utils.unquote_if_quoted(c) for c in df.columns],
|
|
128
129
|
snowpark_column_names=snowpark_column_names,
|
|
129
|
-
snowpark_column_types=[
|
|
130
|
+
snowpark_column_types=[
|
|
131
|
+
emulate_integral_types(f.datatype) for f in df.schema.fields
|
|
132
|
+
],
|
|
130
133
|
can_be_cached=can_be_cached,
|
|
131
134
|
)
|
|
132
135
|
|
|
@@ -11,6 +11,7 @@ from snowflake import snowpark
|
|
|
11
11
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
12
12
|
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
13
13
|
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
14
|
+
from snowflake.snowpark_connect.type_support import emulate_integral_types
|
|
14
15
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
15
16
|
SnowparkConnectNotImplementedError,
|
|
16
17
|
)
|
|
@@ -58,6 +59,9 @@ def map_read_socket(
|
|
|
58
59
|
dataframe=df,
|
|
59
60
|
spark_column_names=[spark_cname],
|
|
60
61
|
snowpark_column_names=[snowpark_cname],
|
|
62
|
+
snowpark_column_types=[
|
|
63
|
+
emulate_integral_types(f.datatype) for f in df.schema.fields
|
|
64
|
+
],
|
|
61
65
|
)
|
|
62
66
|
except OSError as e:
|
|
63
67
|
exception = Exception(f"Error connecting to {host}:{port} - {e}")
|
|
@@ -24,6 +24,7 @@ from snowflake.snowpark_connect.error.error_utils import attach_custom_error_cod
|
|
|
24
24
|
from snowflake.snowpark_connect.relation.read.utils import (
|
|
25
25
|
rename_columns_as_snowflake_standard,
|
|
26
26
|
)
|
|
27
|
+
from snowflake.snowpark_connect.type_support import emulate_integral_types
|
|
27
28
|
from snowflake.snowpark_connect.utils.context import get_processed_views
|
|
28
29
|
from snowflake.snowpark_connect.utils.identifiers import (
|
|
29
30
|
split_fully_qualified_spark_name,
|
|
@@ -58,7 +59,9 @@ def post_process_df(
|
|
|
58
59
|
dataframe=renamed_df,
|
|
59
60
|
spark_column_names=true_names,
|
|
60
61
|
snowpark_column_names=snowpark_column_names,
|
|
61
|
-
snowpark_column_types=[
|
|
62
|
+
snowpark_column_types=[
|
|
63
|
+
emulate_integral_types(f.datatype) for f in df.schema.fields
|
|
64
|
+
],
|
|
62
65
|
column_qualifiers=[{ColumnQualifier(tuple(name_parts))} for _ in true_names]
|
|
63
66
|
if source_table_name
|
|
64
67
|
else None,
|
|
@@ -14,6 +14,7 @@ from snowflake.snowpark_connect.relation.read.utils import (
|
|
|
14
14
|
get_spark_column_names_from_snowpark_columns,
|
|
15
15
|
rename_columns_as_snowflake_standard,
|
|
16
16
|
)
|
|
17
|
+
from snowflake.snowpark_connect.type_support import emulate_integral_types
|
|
17
18
|
from snowflake.snowpark_connect.utils.io_utils import file_format
|
|
18
19
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
19
20
|
SnowparkConnectNotImplementedError,
|
|
@@ -117,5 +118,7 @@ def map_read_text(
|
|
|
117
118
|
dataframe=renamed_df,
|
|
118
119
|
spark_column_names=spark_column_names,
|
|
119
120
|
snowpark_column_names=snowpark_column_names,
|
|
120
|
-
snowpark_column_types=[
|
|
121
|
+
snowpark_column_types=[
|
|
122
|
+
emulate_integral_types(f.datatype) for f in df.schema.fields
|
|
123
|
+
],
|
|
121
124
|
)
|
|
@@ -441,4 +441,10 @@ class ParquetReaderConfig(ReaderWriterConfig):
|
|
|
441
441
|
"snowpark.connect.parquet.useVectorizedScanner"
|
|
442
442
|
)
|
|
443
443
|
|
|
444
|
+
# Set USE_LOGICAL_TYPE from global config to properly handle Parquet logical types like TIMESTAMP.
|
|
445
|
+
# Without this, Parquet TIMESTAMP (INT64 physical) is incorrectly read as NUMBER(38,0).
|
|
446
|
+
snowpark_args["USE_LOGICAL_TYPE"] = global_config._get_config_setting(
|
|
447
|
+
"snowpark.connect.parquet.useLogicalType"
|
|
448
|
+
)
|
|
449
|
+
|
|
444
450
|
return snowpark_args
|
|
@@ -3,7 +3,12 @@
|
|
|
3
3
|
#
|
|
4
4
|
import threading
|
|
5
5
|
import time
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from pathlib import Path
|
|
6
8
|
|
|
9
|
+
from snowflake.snowpark_connect.client.error_utils import attach_custom_error_code
|
|
10
|
+
from snowflake.snowpark_connect.config import get_scala_version
|
|
11
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
7
12
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
8
13
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
9
14
|
|
|
@@ -11,22 +16,31 @@ SPARK_VERSION = "3.5.6"
|
|
|
11
16
|
RESOURCE_PATH = "/snowflake/snowpark_connect/resources"
|
|
12
17
|
|
|
13
18
|
# On demand Scala UDF jar upload state - separate from general resource initialization
|
|
14
|
-
|
|
15
|
-
|
|
19
|
+
_scala_2_12_jars_uploaded = threading.Event()
|
|
20
|
+
_scala_2_12_jars_lock = threading.Lock()
|
|
21
|
+
_scala_2_13_jars_uploaded = threading.Event()
|
|
22
|
+
_scala_2_13_jars_lock = threading.Lock()
|
|
16
23
|
|
|
17
24
|
# Define Scala resource names
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
25
|
+
SPARK_SQL_JAR_212 = f"spark-sql_2.12-{SPARK_VERSION}.jar"
|
|
26
|
+
SPARK_CONNECT_CLIENT_JAR_212 = f"spark-connect-client-jvm_2.12-{SPARK_VERSION}.jar"
|
|
27
|
+
SPARK_COMMON_UTILS_JAR_212 = f"spark-common-utils_2.12-{SPARK_VERSION}.jar"
|
|
28
|
+
SAS_SCALA_UDF_JAR_212 = "sas-scala-udf_2.12-0.2.0.jar"
|
|
29
|
+
JSON_4S_JAR_212 = "json4s-ast_2.12-3.7.0-M11.jar"
|
|
30
|
+
SCALA_REFLECT_JAR_212 = "scala-reflect-2.12.18.jar"
|
|
31
|
+
|
|
32
|
+
# Static dependencies for Scala 2.13
|
|
33
|
+
SPARK_SQL_JAR_213 = f"spark-sql_2.13-{SPARK_VERSION}.jar"
|
|
34
|
+
SPARK_CONNECT_CLIENT_JAR_213 = f"spark-connect-client-jvm_2.13-{SPARK_VERSION}.jar"
|
|
35
|
+
SPARK_COMMON_UTILS_JAR_213 = f"spark-common-utils_2.13-{SPARK_VERSION}.jar"
|
|
36
|
+
SAS_SCALA_UDF_JAR_213 = "sas-scala-udf_2.13-0.2.0.jar"
|
|
37
|
+
JSON_4S_JAR_213 = "json4s-ast_2.13-3.7.0-M11.jar"
|
|
38
|
+
SCALA_REFLECT_JAR_213 = "scala-reflect-2.13.16.jar"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _upload_scala_udf_jars(jar_files: list[str]) -> None:
|
|
27
42
|
"""Upload Spark jar files required for creating Scala UDFs.
|
|
28
43
|
This is the internal implementation - use ensure_scala_udf_jars_uploaded() for thread-safe lazy loading."""
|
|
29
|
-
from pathlib import Path
|
|
30
44
|
|
|
31
45
|
session = get_or_create_snowpark_session()
|
|
32
46
|
stage = session.get_session_stage()
|
|
@@ -34,15 +48,6 @@ def _upload_scala_udf_jars_impl() -> None:
|
|
|
34
48
|
import snowpark_connect_deps_1
|
|
35
49
|
import snowpark_connect_deps_2
|
|
36
50
|
|
|
37
|
-
jar_files = [
|
|
38
|
-
SPARK_SQL_JAR,
|
|
39
|
-
SPARK_CONNECT_CLIENT_JAR,
|
|
40
|
-
SPARK_COMMON_UTILS_JAR,
|
|
41
|
-
SAS_SCALA_UDF_JAR,
|
|
42
|
-
JSON_4S_JAR,
|
|
43
|
-
SCALA_REFLECT_JAR, # Required for deserializing Scala lambdas
|
|
44
|
-
]
|
|
45
|
-
|
|
46
51
|
# Path to includes/jars directory
|
|
47
52
|
includes_jars_dir = Path(__file__).parent / "includes" / "jars"
|
|
48
53
|
|
|
@@ -78,31 +83,87 @@ def _upload_scala_udf_jars_impl() -> None:
|
|
|
78
83
|
raise RuntimeError(f"Failed to upload JAR {jar_name}: {e}")
|
|
79
84
|
|
|
80
85
|
|
|
81
|
-
def
|
|
82
|
-
|
|
83
|
-
|
|
86
|
+
def _upload_scala_2_12_jars() -> None:
|
|
87
|
+
scala_2_12_jars = [
|
|
88
|
+
SPARK_SQL_JAR_212,
|
|
89
|
+
SPARK_CONNECT_CLIENT_JAR_212,
|
|
90
|
+
SPARK_COMMON_UTILS_JAR_212,
|
|
91
|
+
SAS_SCALA_UDF_JAR_212,
|
|
92
|
+
JSON_4S_JAR_212,
|
|
93
|
+
SCALA_REFLECT_JAR_212, # Required for deserializing Scala lambdas
|
|
94
|
+
]
|
|
95
|
+
_upload_scala_udf_jars(scala_2_12_jars)
|
|
84
96
|
|
|
97
|
+
|
|
98
|
+
def _upload_scala_2_13_jars() -> None:
|
|
99
|
+
scala_2_13_jars = [
|
|
100
|
+
SPARK_SQL_JAR_213,
|
|
101
|
+
SPARK_CONNECT_CLIENT_JAR_213,
|
|
102
|
+
SPARK_COMMON_UTILS_JAR_213,
|
|
103
|
+
SAS_SCALA_UDF_JAR_213,
|
|
104
|
+
JSON_4S_JAR_213,
|
|
105
|
+
SCALA_REFLECT_JAR_213,
|
|
106
|
+
]
|
|
107
|
+
_upload_scala_udf_jars(scala_2_13_jars)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _ensure_configured_scala_jars_uploaded(
|
|
111
|
+
jars_uploaded: threading.Event, lock: threading.Lock, upload_fn: Callable[[], None]
|
|
112
|
+
) -> None:
|
|
113
|
+
"""
|
|
114
|
+
Ensure Scala UDF jars are uploaded to Snowflake, uploading them lazily if not already done.
|
|
115
|
+
This function is thread-safe and will only upload once even if called from multiple threads.
|
|
116
|
+
|
|
117
|
+
Uses the given upload_fn to upload Scala jars if the jars_uploaded event is not set yet.
|
|
118
|
+
"""
|
|
85
119
|
# Fast path: if already uploaded, return immediately without acquiring lock
|
|
86
|
-
if
|
|
120
|
+
if jars_uploaded.is_set():
|
|
87
121
|
return
|
|
88
122
|
|
|
89
123
|
# Slow path: need to upload, acquire lock to ensure only one thread does it
|
|
90
|
-
with
|
|
124
|
+
with lock:
|
|
91
125
|
# Double-check pattern: another thread might have uploaded while we waited for the lock
|
|
92
|
-
if
|
|
126
|
+
if jars_uploaded.is_set():
|
|
93
127
|
return
|
|
94
128
|
|
|
95
129
|
try:
|
|
96
130
|
start_time = time.time()
|
|
97
131
|
logger.info("Uploading Scala UDF jars on-demand...")
|
|
98
|
-
|
|
99
|
-
|
|
132
|
+
upload_fn()
|
|
133
|
+
jars_uploaded.set()
|
|
100
134
|
logger.info(f"Scala UDF jars uploaded in {time.time() - start_time:.2f}s")
|
|
101
135
|
except Exception as e:
|
|
102
136
|
logger.error(f"Failed to upload Scala UDF jars: {e}")
|
|
103
137
|
raise
|
|
104
138
|
|
|
105
139
|
|
|
140
|
+
def ensure_scala_udf_jars_uploaded() -> None:
|
|
141
|
+
"""
|
|
142
|
+
Public function to make sure Scala jars are uploaded and available for imports.
|
|
143
|
+
"""
|
|
144
|
+
scala_version = get_scala_version()
|
|
145
|
+
|
|
146
|
+
match scala_version:
|
|
147
|
+
case "2.12":
|
|
148
|
+
_ensure_configured_scala_jars_uploaded(
|
|
149
|
+
_scala_2_12_jars_uploaded,
|
|
150
|
+
_scala_2_12_jars_lock,
|
|
151
|
+
_upload_scala_2_12_jars,
|
|
152
|
+
)
|
|
153
|
+
case "2.13":
|
|
154
|
+
_ensure_configured_scala_jars_uploaded(
|
|
155
|
+
_scala_2_13_jars_uploaded,
|
|
156
|
+
_scala_2_13_jars_lock,
|
|
157
|
+
_upload_scala_2_13_jars,
|
|
158
|
+
)
|
|
159
|
+
case _:
|
|
160
|
+
exception = ValueError(
|
|
161
|
+
f"Unsupported Scala version: {scala_version}. Snowpark Connect supports Scala 2.12 and 2.13"
|
|
162
|
+
)
|
|
163
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_CONFIG_VALUE)
|
|
164
|
+
raise exception
|
|
165
|
+
|
|
166
|
+
|
|
106
167
|
def initialize_resources() -> None:
|
|
107
168
|
"""Initialize all expensive resources. We should initialize what we can here, so that actual rpc calls like
|
|
108
169
|
ExecutePlan are as fast as possible."""
|
|
@@ -45,7 +45,10 @@ import snowflake.snowpark_connect.proto.control_pb2_grpc as control_grpc
|
|
|
45
45
|
import snowflake.snowpark_connect.tcm as tcm
|
|
46
46
|
from snowflake import snowpark
|
|
47
47
|
from snowflake.snowpark_connect.analyze_plan.map_tree_string import map_tree_string
|
|
48
|
-
from snowflake.snowpark_connect.config import
|
|
48
|
+
from snowflake.snowpark_connect.config import (
|
|
49
|
+
route_config_proto,
|
|
50
|
+
set_java_udf_creator_initialized_state,
|
|
51
|
+
)
|
|
49
52
|
from snowflake.snowpark_connect.constants import SERVER_SIDE_SESSION_ID
|
|
50
53
|
from snowflake.snowpark_connect.control_server import ControlServicer
|
|
51
54
|
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
@@ -112,9 +115,6 @@ from snowflake.snowpark_connect.utils.interrupt import (
|
|
|
112
115
|
interrupt_queries_with_tag,
|
|
113
116
|
interrupt_query,
|
|
114
117
|
)
|
|
115
|
-
from snowflake.snowpark_connect.utils.java_stored_procedure import (
|
|
116
|
-
set_java_udf_creator_initialized_state,
|
|
117
|
-
)
|
|
118
118
|
from snowflake.snowpark_connect.utils.open_telemetry import (
|
|
119
119
|
is_telemetry_enabled,
|
|
120
120
|
otel_attach_context,
|
|
@@ -1126,51 +1126,16 @@ def start_jvm():
|
|
|
1126
1126
|
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
1127
1127
|
raise exception
|
|
1128
1128
|
|
|
1129
|
-
import pathlib
|
|
1130
|
-
import zipfile
|
|
1131
|
-
|
|
1132
|
-
import snowflake.snowpark_connect
|
|
1133
|
-
|
|
1134
1129
|
# Import both JAR dependency packages
|
|
1135
1130
|
import snowpark_connect_deps_1
|
|
1136
1131
|
import snowpark_connect_deps_2
|
|
1137
1132
|
|
|
1138
|
-
#
|
|
1139
|
-
pyspark_jars = (
|
|
1140
|
-
pathlib.Path(snowflake.snowpark_connect.__file__).parent / "includes" / "jars"
|
|
1141
|
-
)
|
|
1142
|
-
|
|
1143
|
-
if "dataframe_processor.zip" in str(pyspark_jars):
|
|
1144
|
-
# importlib.resource doesn't work when local stage package is used in TCM
|
|
1145
|
-
zip_path = pathlib.Path(
|
|
1146
|
-
snowflake.snowpark_connect.__file__
|
|
1147
|
-
).parent.parent.parent
|
|
1148
|
-
temp_dir = tempfile.gettempdir()
|
|
1149
|
-
extract_folder = "snowflake/snowpark_connect/includes/jars/" # Folder to extract (must end with '/')
|
|
1150
|
-
|
|
1151
|
-
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
|
1152
|
-
for member in zip_ref.namelist():
|
|
1153
|
-
if member.startswith(extract_folder):
|
|
1154
|
-
zip_ref.extract(member, path=temp_dir)
|
|
1155
|
-
pyspark_jars = pathlib.Path(temp_dir) / extract_folder
|
|
1156
|
-
|
|
1157
|
-
included_jar_names = set()
|
|
1158
|
-
|
|
1159
|
-
if pyspark_jars.exists():
|
|
1160
|
-
for jar_path in pyspark_jars.glob(
|
|
1161
|
-
"**/*.jar"
|
|
1162
|
-
): # Use **/*.jar to handle nested paths in TCM
|
|
1163
|
-
jpype.addClassPath(str(jar_path))
|
|
1164
|
-
included_jar_names.add(jar_path.name)
|
|
1165
|
-
|
|
1166
|
-
# Load jar files from both packages, skipping those already loaded from includes/jars
|
|
1133
|
+
# Load all the jar files from both packages
|
|
1167
1134
|
jar_path_list = (
|
|
1168
1135
|
snowpark_connect_deps_1.list_jars() + snowpark_connect_deps_2.list_jars()
|
|
1169
1136
|
)
|
|
1170
1137
|
for jar_path in jar_path_list:
|
|
1171
|
-
|
|
1172
|
-
if jar_path.name not in included_jar_names:
|
|
1173
|
-
jpype.addClassPath(jar_path)
|
|
1138
|
+
jpype.addClassPath(jar_path)
|
|
1174
1139
|
|
|
1175
1140
|
# TODO: Should remove convertStrings, but it breaks the JDBC code.
|
|
1176
1141
|
jvm_settings: list[str] = list(
|
|
@@ -369,7 +369,10 @@ def _setup_spark_environment(setup_java_home: bool = True) -> None:
|
|
|
369
369
|
lightweight client servers that don't need JVM.
|
|
370
370
|
"""
|
|
371
371
|
if setup_java_home:
|
|
372
|
-
if
|
|
372
|
+
if (
|
|
373
|
+
os.environ.get("JAVA_HOME") is None
|
|
374
|
+
or str(os.environ.get("JAVA_HOME")).strip() == ""
|
|
375
|
+
):
|
|
373
376
|
try:
|
|
374
377
|
# For Notebooks on SPCS
|
|
375
378
|
from jdk4py import JAVA_HOME
|