snowpark-connect 0.28.0__py3-none-any.whl → 0.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +12 -3
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +31 -68
- snowflake/snowpark_connect/expression/map_unresolved_function.py +172 -210
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +207 -20
- snowflake/snowpark_connect/relation/io_utils.py +21 -1
- snowflake/snowpark_connect/relation/map_extension.py +21 -4
- snowflake/snowpark_connect/relation/map_map_partitions.py +7 -8
- snowflake/snowpark_connect/relation/map_relation.py +1 -3
- snowflake/snowpark_connect/relation/map_sql.py +112 -53
- snowflake/snowpark_connect/relation/read/map_read.py +22 -3
- snowflake/snowpark_connect/relation/read/map_read_csv.py +105 -26
- snowflake/snowpark_connect/relation/read/map_read_json.py +45 -34
- snowflake/snowpark_connect/relation/read/map_read_table.py +58 -0
- snowflake/snowpark_connect/relation/read/map_read_text.py +6 -1
- snowflake/snowpark_connect/relation/stage_locator.py +85 -53
- snowflake/snowpark_connect/relation/write/map_write.py +95 -14
- snowflake/snowpark_connect/server.py +18 -13
- snowflake/snowpark_connect/utils/context.py +21 -14
- snowflake/snowpark_connect/utils/identifiers.py +8 -2
- snowflake/snowpark_connect/utils/io_utils.py +36 -0
- snowflake/snowpark_connect/utils/session.py +3 -0
- snowflake/snowpark_connect/utils/temporary_view_cache.py +61 -0
- snowflake/snowpark_connect/utils/udf_cache.py +37 -7
- snowflake/snowpark_connect/utils/udf_utils.py +9 -8
- snowflake/snowpark_connect/utils/udtf_utils.py +3 -2
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/METADATA +3 -2
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/RECORD +36 -35
- {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/top_level.txt +0 -0
|
@@ -43,7 +43,12 @@ def read_text(
|
|
|
43
43
|
) -> snowpark.DataFrame:
|
|
44
44
|
# TODO: handle stage name with double quotes
|
|
45
45
|
files_paths = get_file_paths_from_stage(path, session)
|
|
46
|
-
|
|
46
|
+
# Remove matching quotes from both ends of the path to get the stage name, if present.
|
|
47
|
+
if path and len(path) > 1 and path[0] == path[-1] and path[0] in ('"', "'"):
|
|
48
|
+
unquoted_path = path[1:-1]
|
|
49
|
+
else:
|
|
50
|
+
unquoted_path = path
|
|
51
|
+
stage_name = unquoted_path.split("/")[0]
|
|
47
52
|
line_sep = options.get("lineSep") or "\n"
|
|
48
53
|
column_name = (
|
|
49
54
|
schema[0].name if schema is not None and len(schema.fields) > 0 else '"value"'
|
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
import os
|
|
6
6
|
|
|
7
7
|
from fsspec.core import url_to_fs
|
|
8
|
+
from pyspark.errors.exceptions.base import AnalysisException
|
|
8
9
|
from s3fs.core import S3FileSystem
|
|
9
10
|
|
|
10
11
|
from snowflake import snowpark
|
|
@@ -33,37 +34,42 @@ def get_paths_from_stage(
|
|
|
33
34
|
|
|
34
35
|
# TODO : What if GCP?
|
|
35
36
|
# TODO: What if already stage path?
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
_, bucket_name, path = parse_azure_url(p)
|
|
40
|
-
rewrite_paths.append(f"{stage_name}/{path}")
|
|
41
|
-
paths = rewrite_paths
|
|
42
|
-
else:
|
|
43
|
-
filesystem, parsed_path = url_to_fs(paths[0])
|
|
44
|
-
if isinstance(filesystem, S3FileSystem): # aws
|
|
45
|
-
# Remove bucket name from the path since the stage name will replace
|
|
46
|
-
# the bucket name in the path.
|
|
47
|
-
paths = [
|
|
48
|
-
f"{stage_name}/{'/'.join(url_to_fs(p)[1].split('/')[1:])}"
|
|
49
|
-
for p in paths
|
|
50
|
-
]
|
|
51
|
-
else: # local
|
|
52
|
-
# For local files, we need to preserve directory structure for partitioned data
|
|
53
|
-
# Instead of just using basename, we'll use the last few path components
|
|
54
|
-
new_paths = []
|
|
37
|
+
match get_cloud_from_url(paths[0]):
|
|
38
|
+
case "azure":
|
|
39
|
+
rewrite_paths = []
|
|
55
40
|
for p in paths:
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
41
|
+
_, bucket_name, path = parse_azure_url(p)
|
|
42
|
+
rewrite_paths.append(f"{stage_name}/{path}")
|
|
43
|
+
paths = rewrite_paths
|
|
44
|
+
case "gcp":
|
|
45
|
+
raise AnalysisException(
|
|
46
|
+
"You must configure an integration for Google Cloud Storage to perform I/O operations rather than accessing the URL directly. Reference: https://docs.snowflake.com/en/user-guide/data-load-gcs-config"
|
|
47
|
+
)
|
|
48
|
+
case _:
|
|
49
|
+
filesystem, parsed_path = url_to_fs(paths[0])
|
|
50
|
+
if isinstance(filesystem, S3FileSystem): # aws
|
|
51
|
+
# Remove bucket name from the path since the stage name will replace
|
|
52
|
+
# the bucket name in the path.
|
|
53
|
+
paths = [
|
|
54
|
+
f"{stage_name}/{'/'.join(url_to_fs(p)[1].split('/')[1:])}"
|
|
55
|
+
for p in paths
|
|
56
|
+
]
|
|
57
|
+
else: # local
|
|
58
|
+
# For local files, we need to preserve directory structure for partitioned data
|
|
59
|
+
# Instead of just using basename, we'll use the last few path components
|
|
60
|
+
new_paths = []
|
|
61
|
+
for p in paths:
|
|
62
|
+
# Split the path and take the last 2-3 components to preserve structure
|
|
63
|
+
# but avoid very long paths
|
|
64
|
+
path_parts = p.split(os.sep)
|
|
65
|
+
if len(path_parts) >= 2:
|
|
66
|
+
# Take last 2 components (e.g., "base_case/x=abc")
|
|
67
|
+
relative_path = "/".join(path_parts[-2:])
|
|
68
|
+
else:
|
|
69
|
+
# Single component, use basename
|
|
70
|
+
relative_path = os.path.basename(p)
|
|
71
|
+
new_paths.append(f"{stage_name}/{relative_path}")
|
|
72
|
+
paths = new_paths
|
|
67
73
|
|
|
68
74
|
return paths
|
|
69
75
|
|
|
@@ -102,15 +108,21 @@ class StageLocator:
|
|
|
102
108
|
sql_query = f"CREATE OR REPLACE TEMP STAGE {stage_name[1:]} URL='azure://{account}.blob.core.windows.net/{bucket_name}'"
|
|
103
109
|
|
|
104
110
|
credential_session_key = (
|
|
105
|
-
f"fs.azure.sas.
|
|
111
|
+
f"fs.azure.sas.fixed.token.{account}.dfs.core.windows.net",
|
|
112
|
+
f"fs.azure.sas.{bucket_name}.{account}.blob.core.windows.net",
|
|
106
113
|
)
|
|
107
114
|
credential = sessions_config.get(spark_session_id, None)
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
115
|
+
sas_token = None
|
|
116
|
+
for session_key in credential_session_key:
|
|
117
|
+
if (
|
|
118
|
+
credential is not None
|
|
119
|
+
and credential.get(session_key) is not None
|
|
120
|
+
and credential.get(session_key).strip() != ""
|
|
121
|
+
):
|
|
122
|
+
sas_token = credential.get(session_key)
|
|
123
|
+
break
|
|
124
|
+
if sas_token is not None:
|
|
125
|
+
sql_query += f" CREDENTIALS = (AZURE_SAS_TOKEN = '{sas_token}')"
|
|
114
126
|
|
|
115
127
|
logger.info(self.session.sql(sql_query).collect())
|
|
116
128
|
self.stages_for_azure[bucket_name] = stage_name
|
|
@@ -128,24 +140,44 @@ class StageLocator:
|
|
|
128
140
|
# but the rest of the time it's used, it does. We just drop it here.
|
|
129
141
|
sql_query = f"CREATE OR REPLACE TEMP STAGE {stage_name[1:]} URL='s3://{parsed_path.split('/')[0]}'"
|
|
130
142
|
credential = sessions_config.get(spark_session_id, None)
|
|
131
|
-
if
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
143
|
+
if credential is not None:
|
|
144
|
+
if ( # USE AWS KEYS to connect
|
|
145
|
+
credential.get("spark.hadoop.fs.s3a.access.key") is not None
|
|
146
|
+
and credential.get("spark.hadoop.fs.s3a.secret.key")
|
|
147
|
+
is not None
|
|
148
|
+
and credential.get("spark.hadoop.fs.s3a.access.key").strip()
|
|
149
|
+
!= ""
|
|
150
|
+
and credential.get("spark.hadoop.fs.s3a.secret.key").strip()
|
|
151
|
+
!= ""
|
|
152
|
+
):
|
|
153
|
+
aws_keys = f" AWS_KEY_ID = '{credential.get('spark.hadoop.fs.s3a.access.key')}'"
|
|
154
|
+
aws_keys += f" AWS_SECRET_KEY = '{credential.get('spark.hadoop.fs.s3a.secret.key')}'"
|
|
155
|
+
if (
|
|
156
|
+
credential.get("spark.hadoop.fs.s3a.session.token")
|
|
157
|
+
is not None
|
|
158
|
+
):
|
|
159
|
+
aws_keys += f" AWS_TOKEN = '{credential.get('spark.hadoop.fs.s3a.session.token')}'"
|
|
160
|
+
sql_query += f" CREDENTIALS = ({aws_keys})"
|
|
161
|
+
sql_query += " ENCRYPTION = ( TYPE = 'AWS_SSE_S3' )"
|
|
162
|
+
elif ( # USE AWS ROLE and KMS KEY to connect
|
|
163
|
+
credential.get(
|
|
164
|
+
"spark.hadoop.fs.s3a.server-side-encryption.key"
|
|
165
|
+
)
|
|
166
|
+
is not None
|
|
167
|
+
and credential.get(
|
|
168
|
+
"spark.hadoop.fs.s3a.server-side-encryption.key"
|
|
169
|
+
).strip()
|
|
170
|
+
!= ""
|
|
171
|
+
and credential.get("spark.hadoop.fs.s3a.assumed.role.arn")
|
|
144
172
|
is not None
|
|
173
|
+
and credential.get(
|
|
174
|
+
"spark.hadoop.fs.s3a.assumed.role.arn"
|
|
175
|
+
).strip()
|
|
176
|
+
!= ""
|
|
145
177
|
):
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
178
|
+
aws_role = f" AWS_ROLE = '{credential.get('spark.hadoop.fs.s3a.assumed.role.arn')}'"
|
|
179
|
+
sql_query += f" CREDENTIALS = ({aws_role})"
|
|
180
|
+
sql_query += f" ENCRYPTION = ( TYPE='AWS_SSE_KMS' KMS_KEY_ID = '{credential.get('spark.hadoop.fs.s3a.server-side-encryption.key')}' )"
|
|
149
181
|
|
|
150
182
|
logger.info(self.session.sql(sql_query).collect())
|
|
151
183
|
self.stages_for_aws[bucket_name] = stage_name
|
|
@@ -36,6 +36,8 @@ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
|
36
36
|
from snowflake.snowpark_connect.relation.io_utils import (
|
|
37
37
|
convert_file_prefix_path,
|
|
38
38
|
is_cloud_path,
|
|
39
|
+
is_supported_compression,
|
|
40
|
+
supported_compressions_for_format,
|
|
39
41
|
)
|
|
40
42
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
41
43
|
from snowflake.snowpark_connect.relation.read.reader_config import CsvWriterConfig
|
|
@@ -179,7 +181,7 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
179
181
|
f"Skipping REMOVE for root path {write_path} - too broad scope"
|
|
180
182
|
)
|
|
181
183
|
else:
|
|
182
|
-
remove_command = f"REMOVE {write_path}/"
|
|
184
|
+
remove_command = f"REMOVE '{write_path}/'"
|
|
183
185
|
session.sql(remove_command).collect()
|
|
184
186
|
logger.info(f"Successfully cleared directory: {write_path}")
|
|
185
187
|
except Exception as e:
|
|
@@ -208,6 +210,20 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
208
210
|
compression = write_op.options.get(
|
|
209
211
|
"compression", default_compression
|
|
210
212
|
).upper()
|
|
213
|
+
|
|
214
|
+
if not is_supported_compression(write_op.source, compression):
|
|
215
|
+
supported_compressions = supported_compressions_for_format(
|
|
216
|
+
write_op.source
|
|
217
|
+
)
|
|
218
|
+
raise AnalysisException(
|
|
219
|
+
f"Compression {compression} is not supported for {write_op.source} format. "
|
|
220
|
+
+ (
|
|
221
|
+
f"Supported compressions: {sorted(supported_compressions)}"
|
|
222
|
+
if supported_compressions
|
|
223
|
+
else "No compression supported for this format."
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
|
|
211
227
|
parameters = {
|
|
212
228
|
"location": temp_file_prefix_on_stage,
|
|
213
229
|
"file_format_type": write_op.source
|
|
@@ -218,8 +234,9 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
218
234
|
},
|
|
219
235
|
"overwrite": overwrite,
|
|
220
236
|
}
|
|
221
|
-
#
|
|
222
|
-
|
|
237
|
+
# Download from the base write path to ensure we fetch whatever Snowflake produced.
|
|
238
|
+
# Using the base avoids coupling to exact filenames/prefixes.
|
|
239
|
+
download_stage_path = write_path
|
|
223
240
|
|
|
224
241
|
# Check for partition hint early to determine precedence over single option
|
|
225
242
|
partition_hint = result.partition_hint
|
|
@@ -238,13 +255,19 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
238
255
|
raise SnowparkConnectNotImplementedError(
|
|
239
256
|
"Partitioning is only supported for parquet format"
|
|
240
257
|
)
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
258
|
+
# Build Spark-style directory structure: col1=value1/col2=value2/...
|
|
259
|
+
# Example produced expression (Snowflake SQL):
|
|
260
|
+
# 'department=' || TO_VARCHAR("department") || '/' || 'region=' || TO_VARCHAR("region")
|
|
261
|
+
partitioning_column_names = list(write_op.partitioning_columns)
|
|
262
|
+
partition_expr_parts: list[str] = []
|
|
263
|
+
for col_name in partitioning_column_names:
|
|
264
|
+
quoted = f'"{col_name}"'
|
|
265
|
+
segment = f"'{col_name}=' || COALESCE(TO_VARCHAR({quoted}), '__HIVE_DEFAULT_PARTITION__')"
|
|
266
|
+
partition_expr_parts.append(segment)
|
|
267
|
+
parameters["partition_by"] = " || '/' || ".join(partition_expr_parts)
|
|
268
|
+
# When using PARTITION BY, Snowflake writes into subdirectories under the base path.
|
|
269
|
+
# Download from the base write path to preserve partition directories locally.
|
|
270
|
+
download_stage_path = write_path
|
|
248
271
|
|
|
249
272
|
# If a partition hint is present (from DataFrame.repartition(n)), optionally split the
|
|
250
273
|
# write into n COPY INTO calls by assigning a synthetic partition id. Controlled by config.
|
|
@@ -410,9 +433,27 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
410
433
|
)
|
|
411
434
|
case _:
|
|
412
435
|
snowpark_table_name = _spark_to_snowflake(write_op.table.table_name)
|
|
436
|
+
save_method = write_op.table.save_method
|
|
437
|
+
|
|
438
|
+
if (
|
|
439
|
+
write_op.source == "snowflake"
|
|
440
|
+
and write_op.table.save_method
|
|
441
|
+
== commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_UNSPECIFIED
|
|
442
|
+
):
|
|
443
|
+
save_method = (
|
|
444
|
+
commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE
|
|
445
|
+
)
|
|
446
|
+
if len(write_op.table.table_name) == 0:
|
|
447
|
+
dbtable_name = write_op.options.get("dbtable", "")
|
|
448
|
+
if len(dbtable_name) == 0:
|
|
449
|
+
raise SnowparkConnectNotImplementedError(
|
|
450
|
+
"Save command is not supported without a table name"
|
|
451
|
+
)
|
|
452
|
+
else:
|
|
453
|
+
snowpark_table_name = _spark_to_snowflake(dbtable_name)
|
|
413
454
|
|
|
414
455
|
if (
|
|
415
|
-
|
|
456
|
+
save_method
|
|
416
457
|
== commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE
|
|
417
458
|
):
|
|
418
459
|
match write_mode:
|
|
@@ -474,7 +515,7 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
474
515
|
column_order=_column_order_for_write,
|
|
475
516
|
)
|
|
476
517
|
elif (
|
|
477
|
-
|
|
518
|
+
save_method
|
|
478
519
|
== commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO
|
|
479
520
|
):
|
|
480
521
|
_validate_schema_and_get_writer(
|
|
@@ -486,7 +527,7 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
486
527
|
)
|
|
487
528
|
else:
|
|
488
529
|
raise SnowparkConnectNotImplementedError(
|
|
489
|
-
f"Save command not supported: {
|
|
530
|
+
f"Save command not supported: {save_method}"
|
|
490
531
|
)
|
|
491
532
|
|
|
492
533
|
|
|
@@ -978,7 +1019,47 @@ def store_files_locally(
|
|
|
978
1019
|
)
|
|
979
1020
|
if overwrite and os.path.isdir(target_path):
|
|
980
1021
|
_truncate_directory(real_path)
|
|
981
|
-
|
|
1022
|
+
# Per Snowflake docs: "The command does not preserve stage directory structure when transferring files to your client machine"
|
|
1023
|
+
# https://docs.snowflake.com/en/sql-reference/sql/get
|
|
1024
|
+
# Preserve directory structure under stage_path by listing files and
|
|
1025
|
+
# downloading each into its corresponding local subdirectory when partition subdirs exist.
|
|
1026
|
+
# Otherwise, fall back to a direct GET which flattens.
|
|
1027
|
+
|
|
1028
|
+
# TODO(SNOW-2326973): This can be parallelized further. Its not done here because it only affects
|
|
1029
|
+
# write to local storage.
|
|
1030
|
+
|
|
1031
|
+
ls_dataframe = session.sql(f"LS {stage_path}")
|
|
1032
|
+
ls_iterator = ls_dataframe.toLocalIterator()
|
|
1033
|
+
|
|
1034
|
+
# Build a normalized base prefix from stage_path to compute relatives
|
|
1035
|
+
# Example: stage_path='@MY_STAGE/prefix' -> base_prefix='my_stage/prefix/'
|
|
1036
|
+
base_prefix = stage_path.lstrip("@").rstrip("/") + "/"
|
|
1037
|
+
base_prefix_lower = base_prefix.lower()
|
|
1038
|
+
|
|
1039
|
+
# Group by parent directory under the base prefix, then issue a GET per directory.
|
|
1040
|
+
# This gives a small parallelism advantage if we have many files per partition directory.
|
|
1041
|
+
parent_dirs: set[str] = set()
|
|
1042
|
+
for row in ls_iterator:
|
|
1043
|
+
name: str = row[0]
|
|
1044
|
+
name_lower = name.lower()
|
|
1045
|
+
rel_start = name_lower.find(base_prefix_lower)
|
|
1046
|
+
relative = name[rel_start + len(base_prefix) :] if rel_start != -1 else name
|
|
1047
|
+
parent_dir = os.path.dirname(relative)
|
|
1048
|
+
if parent_dir and parent_dir != ".":
|
|
1049
|
+
parent_dirs.add(parent_dir)
|
|
1050
|
+
|
|
1051
|
+
# If no parent directories were discovered (non-partitioned unload prefix), use direct GET.
|
|
1052
|
+
if not parent_dirs:
|
|
1053
|
+
snowpark.file_operation.FileOperation(session).get(stage_path, str(real_path))
|
|
1054
|
+
return
|
|
1055
|
+
|
|
1056
|
+
file_op = snowpark.file_operation.FileOperation(session)
|
|
1057
|
+
for parent_dir in sorted(parent_dirs):
|
|
1058
|
+
local_dir = real_path / parent_dir
|
|
1059
|
+
os.makedirs(local_dir, exist_ok=True)
|
|
1060
|
+
|
|
1061
|
+
src_dir = f"@{base_prefix}{parent_dir}"
|
|
1062
|
+
file_op.get(src_dir, str(local_dir))
|
|
982
1063
|
|
|
983
1064
|
|
|
984
1065
|
def _truncate_directory(directory_path: Path) -> None:
|
|
@@ -1161,23 +1161,28 @@ def get_session(url: Optional[str] = None, conf: SparkConf = None) -> SparkSessi
|
|
|
1161
1161
|
|
|
1162
1162
|
|
|
1163
1163
|
def init_spark_session(conf: SparkConf = None) -> SparkSession:
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1164
|
+
if os.environ.get("JAVA_HOME") is None:
|
|
1165
|
+
try:
|
|
1166
|
+
# For Notebooks on SPCS
|
|
1167
|
+
from jdk4py import JAVA_HOME
|
|
1168
|
+
|
|
1169
|
+
os.environ["JAVA_HOME"] = str(JAVA_HOME)
|
|
1170
|
+
except ModuleNotFoundError:
|
|
1171
|
+
# For notebooks on Warehouse
|
|
1172
|
+
conda_prefix = os.environ.get("CONDA_PREFIX")
|
|
1173
|
+
if conda_prefix is not None:
|
|
1174
|
+
os.environ["JAVA_HOME"] = conda_prefix
|
|
1175
|
+
os.environ["JAVA_LD_LIBRARY_PATH"] = os.path.join(
|
|
1176
|
+
conda_prefix, "lib", "server"
|
|
1177
|
+
)
|
|
1178
|
+
logger.info("JAVA_HOME=%s", os.environ.get("JAVA_HOME", "Not defined"))
|
|
1176
1179
|
|
|
1177
1180
|
os.environ["SPARK_LOCAL_HOSTNAME"] = "127.0.0.1"
|
|
1178
1181
|
os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
|
|
1179
1182
|
|
|
1180
|
-
|
|
1183
|
+
from snowflake.snowpark_connect.utils.session import _get_current_snowpark_session
|
|
1184
|
+
|
|
1185
|
+
snowpark_session = _get_current_snowpark_session()
|
|
1181
1186
|
start_session(snowpark_session=snowpark_session)
|
|
1182
1187
|
return get_session(conf=conf)
|
|
1183
1188
|
|
|
@@ -30,9 +30,6 @@ _sql_aggregate_function_count = ContextVar[int](
|
|
|
30
30
|
"_contains_aggregate_function", default=0
|
|
31
31
|
)
|
|
32
32
|
|
|
33
|
-
# Context for parsing map_partitions
|
|
34
|
-
_map_partitions_stack = ContextVar[int]("_map_partitions_stack", default=0)
|
|
35
|
-
|
|
36
33
|
# We have to generate our own plan IDs that are different from Spark's.
|
|
37
34
|
# Spark plan IDs start at 0, so pick a "big enough" number to avoid overlaps.
|
|
38
35
|
_STARTING_SQL_PLAN_ID = 0x80000000
|
|
@@ -70,6 +67,26 @@ _lca_alias_map: ContextVar[dict[str, TypedColumn]] = ContextVar(
|
|
|
70
67
|
default={},
|
|
71
68
|
)
|
|
72
69
|
|
|
70
|
+
_view_process_context = ContextVar("_view_process_context", default=[])
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@contextmanager
|
|
74
|
+
def push_processed_view(name: str):
|
|
75
|
+
_view_process_context.set(_view_process_context.get() + [name])
|
|
76
|
+
yield
|
|
77
|
+
_view_process_context.set(_view_process_context.get()[:-1])
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_processed_views() -> list[str]:
|
|
81
|
+
return _view_process_context.get()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def register_processed_view(name: str) -> None:
|
|
85
|
+
context = _view_process_context.get()
|
|
86
|
+
context.append(name)
|
|
87
|
+
_view_process_context.set(context)
|
|
88
|
+
|
|
89
|
+
|
|
73
90
|
# Context variable to track current grouping columns for grouping_id() function
|
|
74
91
|
_current_grouping_columns: ContextVar[list[str]] = ContextVar(
|
|
75
92
|
"_current_grouping_columns",
|
|
@@ -210,16 +227,6 @@ def push_evaluating_join_condition(join_type, left_keys, right_keys):
|
|
|
210
227
|
_is_evaluating_join_condition.set(prev)
|
|
211
228
|
|
|
212
229
|
|
|
213
|
-
@contextmanager
|
|
214
|
-
def push_map_partitions():
|
|
215
|
-
_map_partitions_stack.set(_map_partitions_stack.get() + 1)
|
|
216
|
-
yield
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
def map_partitions_depth() -> int:
|
|
220
|
-
return _map_partitions_stack.get()
|
|
221
|
-
|
|
222
|
-
|
|
223
230
|
@contextmanager
|
|
224
231
|
def push_sql_scope():
|
|
225
232
|
"""
|
|
@@ -387,9 +394,9 @@ def clear_context_data() -> None:
|
|
|
387
394
|
_plan_id_map.set({})
|
|
388
395
|
_alias_map.set({})
|
|
389
396
|
|
|
397
|
+
_view_process_context.set([])
|
|
390
398
|
_next_sql_plan_id.set(_STARTING_SQL_PLAN_ID)
|
|
391
399
|
_sql_plan_name_map.set({})
|
|
392
|
-
_map_partitions_stack.set(0)
|
|
393
400
|
_sql_aggregate_function_count.set(0)
|
|
394
401
|
_sql_named_args.set({})
|
|
395
402
|
_sql_pos_args.set({})
|
|
@@ -28,12 +28,18 @@ def unquote_spark_identifier_if_quoted(spark_name: str) -> str:
|
|
|
28
28
|
raise AnalysisException(f"Invalid name: {spark_name}")
|
|
29
29
|
|
|
30
30
|
|
|
31
|
-
def spark_to_sf_single_id_with_unquoting(
|
|
31
|
+
def spark_to_sf_single_id_with_unquoting(
|
|
32
|
+
name: str, use_auto_upper_case: bool = False
|
|
33
|
+
) -> str:
|
|
32
34
|
"""
|
|
33
35
|
Transforms a spark name to a valid snowflake name by quoting and potentially uppercasing it.
|
|
34
36
|
Unquotes the spark name if necessary. Will raise an AnalysisException if given name is not valid.
|
|
35
37
|
"""
|
|
36
|
-
return
|
|
38
|
+
return (
|
|
39
|
+
spark_to_sf_single_id(unquote_spark_identifier_if_quoted(name))
|
|
40
|
+
if use_auto_upper_case
|
|
41
|
+
else quote_name_without_upper_casing(unquote_spark_identifier_if_quoted(name))
|
|
42
|
+
)
|
|
37
43
|
|
|
38
44
|
|
|
39
45
|
def spark_to_sf_single_id(name: str, is_column: bool = False) -> str:
|
|
@@ -3,10 +3,46 @@
|
|
|
3
3
|
#
|
|
4
4
|
import contextlib
|
|
5
5
|
import functools
|
|
6
|
+
import re
|
|
6
7
|
|
|
7
8
|
from snowflake.snowpark import Session
|
|
9
|
+
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
10
|
+
create_file_format_statement,
|
|
11
|
+
)
|
|
8
12
|
from snowflake.snowpark_connect.utils.identifiers import FQN
|
|
9
13
|
|
|
14
|
+
_MINUS_AT_THE_BEGINNING_REGEX = re.compile(r"^-")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def cached_file_format(
|
|
18
|
+
session: Session, file_format: str, format_type_options: dict[str, str]
|
|
19
|
+
) -> str:
|
|
20
|
+
"""
|
|
21
|
+
Cache and return a file format name based on the given options.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
function_name = _MINUS_AT_THE_BEGINNING_REGEX.sub(
|
|
25
|
+
"1", str(hash(frozenset(format_type_options.items())))
|
|
26
|
+
)
|
|
27
|
+
file_format_name = f"__SNOWPARK_CONNECT_FILE_FORMAT__{file_format}_{function_name}"
|
|
28
|
+
if file_format_name in session._file_formats:
|
|
29
|
+
return file_format_name
|
|
30
|
+
|
|
31
|
+
session.sql(
|
|
32
|
+
create_file_format_statement(
|
|
33
|
+
file_format_name,
|
|
34
|
+
file_format,
|
|
35
|
+
format_type_options,
|
|
36
|
+
temp=True,
|
|
37
|
+
if_not_exist=True,
|
|
38
|
+
use_scoped_temp_objects=False,
|
|
39
|
+
is_generated=True,
|
|
40
|
+
)
|
|
41
|
+
).collect()
|
|
42
|
+
|
|
43
|
+
session._file_formats.add(file_format_name)
|
|
44
|
+
return file_format_name
|
|
45
|
+
|
|
10
46
|
|
|
11
47
|
@functools.cache
|
|
12
48
|
def file_format(
|
|
@@ -71,6 +71,9 @@ def configure_snowpark_session(session: snowpark.Session):
|
|
|
71
71
|
init_builtin_udf_cache(session)
|
|
72
72
|
init_external_udxf_cache(session)
|
|
73
73
|
|
|
74
|
+
# file format cache
|
|
75
|
+
session._file_formats = set()
|
|
76
|
+
|
|
74
77
|
# Set experimental parameters (warnings globally suppressed)
|
|
75
78
|
session.ast_enabled = False
|
|
76
79
|
session.eliminate_numeric_sql_value_cast_enabled = False
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
|
|
5
|
+
from typing import Optional, Tuple
|
|
6
|
+
|
|
7
|
+
from pyspark.errors import AnalysisException
|
|
8
|
+
|
|
9
|
+
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
10
|
+
from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
|
|
11
|
+
from snowflake.snowpark_connect.utils.context import get_session_id
|
|
12
|
+
|
|
13
|
+
_temp_views = SynchronizedDict[Tuple[str, str], DataFrameContainer]()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def register_temp_view(name: str, df: DataFrameContainer, replace: bool) -> None:
|
|
17
|
+
normalized_name = _normalize(name)
|
|
18
|
+
current_session_id = get_session_id()
|
|
19
|
+
for key in list(_temp_views.keys()):
|
|
20
|
+
if _normalize(key[0]) == normalized_name and key[1] == current_session_id:
|
|
21
|
+
if replace:
|
|
22
|
+
_temp_views.remove(key)
|
|
23
|
+
break
|
|
24
|
+
else:
|
|
25
|
+
raise AnalysisException(
|
|
26
|
+
f"[TEMP_TABLE_OR_VIEW_ALREADY_EXISTS] Cannot create the temporary view `{name}` because it already exists."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
_temp_views[(name, current_session_id)] = df
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def unregister_temp_view(name: str) -> bool:
|
|
33
|
+
normalized_name = _normalize(name)
|
|
34
|
+
|
|
35
|
+
for key in _temp_views.keys():
|
|
36
|
+
normalized_key = _normalize(key[0])
|
|
37
|
+
if normalized_name == normalized_key and key[1] == get_session_id():
|
|
38
|
+
pop_result = _temp_views.remove(key)
|
|
39
|
+
return pop_result is not None
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_temp_view(name: str) -> Optional[DataFrameContainer]:
|
|
44
|
+
normalized_name = _normalize(name)
|
|
45
|
+
for key in _temp_views.keys():
|
|
46
|
+
normalized_key = _normalize(key[0])
|
|
47
|
+
if normalized_name == normalized_key and key[1] == get_session_id():
|
|
48
|
+
return _temp_views.get(key)
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_temp_view_normalized_names() -> list[str]:
|
|
53
|
+
return [
|
|
54
|
+
_normalize(key[0]) for key in _temp_views.keys() if key[1] == get_session_id()
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _normalize(name: str) -> str:
|
|
59
|
+
from snowflake.snowpark_connect.config import global_config
|
|
60
|
+
|
|
61
|
+
return name if global_config.spark_sql_caseSensitive else name.lower()
|