snowpark-connect 0.28.1__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client.py +65 -0
- snowflake/snowpark_connect/column_name_handler.py +6 -0
- snowflake/snowpark_connect/config.py +33 -5
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +21 -19
- snowflake/snowpark_connect/expression/map_extension.py +277 -1
- snowflake/snowpark_connect/expression/map_sql_expression.py +107 -2
- snowflake/snowpark_connect/expression/map_unresolved_function.py +425 -269
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/relation/io_utils.py +21 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +9 -4
- snowflake/snowpark_connect/relation/map_extension.py +21 -4
- snowflake/snowpark_connect/relation/map_join.py +8 -0
- snowflake/snowpark_connect/relation/map_map_partitions.py +7 -8
- snowflake/snowpark_connect/relation/map_relation.py +1 -3
- snowflake/snowpark_connect/relation/map_row_ops.py +116 -15
- snowflake/snowpark_connect/relation/map_show_string.py +14 -6
- snowflake/snowpark_connect/relation/map_sql.py +39 -5
- snowflake/snowpark_connect/relation/map_stats.py +1 -1
- snowflake/snowpark_connect/relation/read/map_read.py +22 -3
- snowflake/snowpark_connect/relation/read/map_read_csv.py +119 -29
- snowflake/snowpark_connect/relation/read/map_read_json.py +57 -36
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -1
- snowflake/snowpark_connect/relation/read/map_read_text.py +6 -1
- snowflake/snowpark_connect/relation/read/metadata_utils.py +159 -0
- snowflake/snowpark_connect/relation/stage_locator.py +85 -53
- snowflake/snowpark_connect/relation/write/map_write.py +67 -4
- snowflake/snowpark_connect/server.py +29 -16
- snowflake/snowpark_connect/type_mapping.py +75 -3
- snowflake/snowpark_connect/utils/context.py +0 -14
- snowflake/snowpark_connect/utils/describe_query_cache.py +6 -3
- snowflake/snowpark_connect/utils/io_utils.py +36 -0
- snowflake/snowpark_connect/utils/session.py +4 -0
- snowflake/snowpark_connect/utils/telemetry.py +30 -5
- snowflake/snowpark_connect/utils/udf_cache.py +37 -7
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/METADATA +3 -2
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/RECORD +47 -45
- {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/top_level.txt +0 -0
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
import os
|
|
6
6
|
|
|
7
7
|
from fsspec.core import url_to_fs
|
|
8
|
+
from pyspark.errors.exceptions.base import AnalysisException
|
|
8
9
|
from s3fs.core import S3FileSystem
|
|
9
10
|
|
|
10
11
|
from snowflake import snowpark
|
|
@@ -33,37 +34,42 @@ def get_paths_from_stage(
|
|
|
33
34
|
|
|
34
35
|
# TODO : What if GCP?
|
|
35
36
|
# TODO: What if already stage path?
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
_, bucket_name, path = parse_azure_url(p)
|
|
40
|
-
rewrite_paths.append(f"{stage_name}/{path}")
|
|
41
|
-
paths = rewrite_paths
|
|
42
|
-
else:
|
|
43
|
-
filesystem, parsed_path = url_to_fs(paths[0])
|
|
44
|
-
if isinstance(filesystem, S3FileSystem): # aws
|
|
45
|
-
# Remove bucket name from the path since the stage name will replace
|
|
46
|
-
# the bucket name in the path.
|
|
47
|
-
paths = [
|
|
48
|
-
f"{stage_name}/{'/'.join(url_to_fs(p)[1].split('/')[1:])}"
|
|
49
|
-
for p in paths
|
|
50
|
-
]
|
|
51
|
-
else: # local
|
|
52
|
-
# For local files, we need to preserve directory structure for partitioned data
|
|
53
|
-
# Instead of just using basename, we'll use the last few path components
|
|
54
|
-
new_paths = []
|
|
37
|
+
match get_cloud_from_url(paths[0]):
|
|
38
|
+
case "azure":
|
|
39
|
+
rewrite_paths = []
|
|
55
40
|
for p in paths:
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
41
|
+
_, bucket_name, path = parse_azure_url(p)
|
|
42
|
+
rewrite_paths.append(f"{stage_name}/{path}")
|
|
43
|
+
paths = rewrite_paths
|
|
44
|
+
case "gcp":
|
|
45
|
+
raise AnalysisException(
|
|
46
|
+
"You must configure an integration for Google Cloud Storage to perform I/O operations rather than accessing the URL directly. Reference: https://docs.snowflake.com/en/user-guide/data-load-gcs-config"
|
|
47
|
+
)
|
|
48
|
+
case _:
|
|
49
|
+
filesystem, parsed_path = url_to_fs(paths[0])
|
|
50
|
+
if isinstance(filesystem, S3FileSystem): # aws
|
|
51
|
+
# Remove bucket name from the path since the stage name will replace
|
|
52
|
+
# the bucket name in the path.
|
|
53
|
+
paths = [
|
|
54
|
+
f"{stage_name}/{'/'.join(url_to_fs(p)[1].split('/')[1:])}"
|
|
55
|
+
for p in paths
|
|
56
|
+
]
|
|
57
|
+
else: # local
|
|
58
|
+
# For local files, we need to preserve directory structure for partitioned data
|
|
59
|
+
# Instead of just using basename, we'll use the last few path components
|
|
60
|
+
new_paths = []
|
|
61
|
+
for p in paths:
|
|
62
|
+
# Split the path and take the last 2-3 components to preserve structure
|
|
63
|
+
# but avoid very long paths
|
|
64
|
+
path_parts = p.split(os.sep)
|
|
65
|
+
if len(path_parts) >= 2:
|
|
66
|
+
# Take last 2 components (e.g., "base_case/x=abc")
|
|
67
|
+
relative_path = "/".join(path_parts[-2:])
|
|
68
|
+
else:
|
|
69
|
+
# Single component, use basename
|
|
70
|
+
relative_path = os.path.basename(p)
|
|
71
|
+
new_paths.append(f"{stage_name}/{relative_path}")
|
|
72
|
+
paths = new_paths
|
|
67
73
|
|
|
68
74
|
return paths
|
|
69
75
|
|
|
@@ -102,15 +108,21 @@ class StageLocator:
|
|
|
102
108
|
sql_query = f"CREATE OR REPLACE TEMP STAGE {stage_name[1:]} URL='azure://{account}.blob.core.windows.net/{bucket_name}'"
|
|
103
109
|
|
|
104
110
|
credential_session_key = (
|
|
105
|
-
f"fs.azure.sas.
|
|
111
|
+
f"fs.azure.sas.fixed.token.{account}.dfs.core.windows.net",
|
|
112
|
+
f"fs.azure.sas.{bucket_name}.{account}.blob.core.windows.net",
|
|
106
113
|
)
|
|
107
114
|
credential = sessions_config.get(spark_session_id, None)
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
115
|
+
sas_token = None
|
|
116
|
+
for session_key in credential_session_key:
|
|
117
|
+
if (
|
|
118
|
+
credential is not None
|
|
119
|
+
and credential.get(session_key) is not None
|
|
120
|
+
and credential.get(session_key).strip() != ""
|
|
121
|
+
):
|
|
122
|
+
sas_token = credential.get(session_key)
|
|
123
|
+
break
|
|
124
|
+
if sas_token is not None:
|
|
125
|
+
sql_query += f" CREDENTIALS = (AZURE_SAS_TOKEN = '{sas_token}')"
|
|
114
126
|
|
|
115
127
|
logger.info(self.session.sql(sql_query).collect())
|
|
116
128
|
self.stages_for_azure[bucket_name] = stage_name
|
|
@@ -128,24 +140,44 @@ class StageLocator:
|
|
|
128
140
|
# but the rest of the time it's used, it does. We just drop it here.
|
|
129
141
|
sql_query = f"CREATE OR REPLACE TEMP STAGE {stage_name[1:]} URL='s3://{parsed_path.split('/')[0]}'"
|
|
130
142
|
credential = sessions_config.get(spark_session_id, None)
|
|
131
|
-
if
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
143
|
+
if credential is not None:
|
|
144
|
+
if ( # USE AWS KEYS to connect
|
|
145
|
+
credential.get("spark.hadoop.fs.s3a.access.key") is not None
|
|
146
|
+
and credential.get("spark.hadoop.fs.s3a.secret.key")
|
|
147
|
+
is not None
|
|
148
|
+
and credential.get("spark.hadoop.fs.s3a.access.key").strip()
|
|
149
|
+
!= ""
|
|
150
|
+
and credential.get("spark.hadoop.fs.s3a.secret.key").strip()
|
|
151
|
+
!= ""
|
|
152
|
+
):
|
|
153
|
+
aws_keys = f" AWS_KEY_ID = '{credential.get('spark.hadoop.fs.s3a.access.key')}'"
|
|
154
|
+
aws_keys += f" AWS_SECRET_KEY = '{credential.get('spark.hadoop.fs.s3a.secret.key')}'"
|
|
155
|
+
if (
|
|
156
|
+
credential.get("spark.hadoop.fs.s3a.session.token")
|
|
157
|
+
is not None
|
|
158
|
+
):
|
|
159
|
+
aws_keys += f" AWS_TOKEN = '{credential.get('spark.hadoop.fs.s3a.session.token')}'"
|
|
160
|
+
sql_query += f" CREDENTIALS = ({aws_keys})"
|
|
161
|
+
sql_query += " ENCRYPTION = ( TYPE = 'AWS_SSE_S3' )"
|
|
162
|
+
elif ( # USE AWS ROLE and KMS KEY to connect
|
|
163
|
+
credential.get(
|
|
164
|
+
"spark.hadoop.fs.s3a.server-side-encryption.key"
|
|
165
|
+
)
|
|
166
|
+
is not None
|
|
167
|
+
and credential.get(
|
|
168
|
+
"spark.hadoop.fs.s3a.server-side-encryption.key"
|
|
169
|
+
).strip()
|
|
170
|
+
!= ""
|
|
171
|
+
and credential.get("spark.hadoop.fs.s3a.assumed.role.arn")
|
|
144
172
|
is not None
|
|
173
|
+
and credential.get(
|
|
174
|
+
"spark.hadoop.fs.s3a.assumed.role.arn"
|
|
175
|
+
).strip()
|
|
176
|
+
!= ""
|
|
145
177
|
):
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
178
|
+
aws_role = f" AWS_ROLE = '{credential.get('spark.hadoop.fs.s3a.assumed.role.arn')}'"
|
|
179
|
+
sql_query += f" CREDENTIALS = ({aws_role})"
|
|
180
|
+
sql_query += f" ENCRYPTION = ( TYPE='AWS_SSE_KMS' KMS_KEY_ID = '{credential.get('spark.hadoop.fs.s3a.server-side-encryption.key')}' )"
|
|
149
181
|
|
|
150
182
|
logger.info(self.session.sql(sql_query).collect())
|
|
151
183
|
self.stages_for_aws[bucket_name] = stage_name
|
|
@@ -36,8 +36,13 @@ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
|
36
36
|
from snowflake.snowpark_connect.relation.io_utils import (
|
|
37
37
|
convert_file_prefix_path,
|
|
38
38
|
is_cloud_path,
|
|
39
|
+
is_supported_compression,
|
|
40
|
+
supported_compressions_for_format,
|
|
39
41
|
)
|
|
40
42
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
43
|
+
from snowflake.snowpark_connect.relation.read.metadata_utils import (
|
|
44
|
+
filter_metadata_columns,
|
|
45
|
+
)
|
|
41
46
|
from snowflake.snowpark_connect.relation.read.reader_config import CsvWriterConfig
|
|
42
47
|
from snowflake.snowpark_connect.relation.stage_locator import get_paths_from_stage
|
|
43
48
|
from snowflake.snowpark_connect.relation.utils import (
|
|
@@ -127,6 +132,19 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
127
132
|
|
|
128
133
|
result = map_relation(write_op.input)
|
|
129
134
|
input_df: snowpark.DataFrame = handle_column_names(result, write_op.source)
|
|
135
|
+
|
|
136
|
+
# Create updated container with transformed dataframe, then filter METADATA$FILENAME columns
|
|
137
|
+
# Update the container to use the transformed dataframe from handle_column_names
|
|
138
|
+
updated_result = DataFrameContainer(
|
|
139
|
+
dataframe=input_df,
|
|
140
|
+
column_map=result.column_map,
|
|
141
|
+
table_name=result.table_name,
|
|
142
|
+
alias=result.alias,
|
|
143
|
+
partition_hint=result.partition_hint,
|
|
144
|
+
)
|
|
145
|
+
updated_result = filter_metadata_columns(updated_result)
|
|
146
|
+
input_df = updated_result.dataframe
|
|
147
|
+
|
|
130
148
|
session: snowpark.Session = get_or_create_snowpark_session()
|
|
131
149
|
|
|
132
150
|
# Snowflake saveAsTable doesn't support format
|
|
@@ -179,7 +197,7 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
179
197
|
f"Skipping REMOVE for root path {write_path} - too broad scope"
|
|
180
198
|
)
|
|
181
199
|
else:
|
|
182
|
-
remove_command = f"REMOVE {write_path}/"
|
|
200
|
+
remove_command = f"REMOVE '{write_path}/'"
|
|
183
201
|
session.sql(remove_command).collect()
|
|
184
202
|
logger.info(f"Successfully cleared directory: {write_path}")
|
|
185
203
|
except Exception as e:
|
|
@@ -208,6 +226,20 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
208
226
|
compression = write_op.options.get(
|
|
209
227
|
"compression", default_compression
|
|
210
228
|
).upper()
|
|
229
|
+
|
|
230
|
+
if not is_supported_compression(write_op.source, compression):
|
|
231
|
+
supported_compressions = supported_compressions_for_format(
|
|
232
|
+
write_op.source
|
|
233
|
+
)
|
|
234
|
+
raise AnalysisException(
|
|
235
|
+
f"Compression {compression} is not supported for {write_op.source} format. "
|
|
236
|
+
+ (
|
|
237
|
+
f"Supported compressions: {sorted(supported_compressions)}"
|
|
238
|
+
if supported_compressions
|
|
239
|
+
else "No compression supported for this format."
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
|
|
211
243
|
parameters = {
|
|
212
244
|
"location": temp_file_prefix_on_stage,
|
|
213
245
|
"file_format_type": write_op.source
|
|
@@ -417,9 +449,27 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
417
449
|
)
|
|
418
450
|
case _:
|
|
419
451
|
snowpark_table_name = _spark_to_snowflake(write_op.table.table_name)
|
|
452
|
+
save_method = write_op.table.save_method
|
|
420
453
|
|
|
421
454
|
if (
|
|
422
|
-
write_op.
|
|
455
|
+
write_op.source == "snowflake"
|
|
456
|
+
and write_op.table.save_method
|
|
457
|
+
== commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_UNSPECIFIED
|
|
458
|
+
):
|
|
459
|
+
save_method = (
|
|
460
|
+
commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE
|
|
461
|
+
)
|
|
462
|
+
if len(write_op.table.table_name) == 0:
|
|
463
|
+
dbtable_name = write_op.options.get("dbtable", "")
|
|
464
|
+
if len(dbtable_name) == 0:
|
|
465
|
+
raise SnowparkConnectNotImplementedError(
|
|
466
|
+
"Save command is not supported without a table name"
|
|
467
|
+
)
|
|
468
|
+
else:
|
|
469
|
+
snowpark_table_name = _spark_to_snowflake(dbtable_name)
|
|
470
|
+
|
|
471
|
+
if (
|
|
472
|
+
save_method
|
|
423
473
|
== commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE
|
|
424
474
|
):
|
|
425
475
|
match write_mode:
|
|
@@ -481,7 +531,7 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
481
531
|
column_order=_column_order_for_write,
|
|
482
532
|
)
|
|
483
533
|
elif (
|
|
484
|
-
|
|
534
|
+
save_method
|
|
485
535
|
== commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO
|
|
486
536
|
):
|
|
487
537
|
_validate_schema_and_get_writer(
|
|
@@ -493,7 +543,7 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
493
543
|
)
|
|
494
544
|
else:
|
|
495
545
|
raise SnowparkConnectNotImplementedError(
|
|
496
|
-
f"Save command not supported: {
|
|
546
|
+
f"Save command not supported: {save_method}"
|
|
497
547
|
)
|
|
498
548
|
|
|
499
549
|
|
|
@@ -503,6 +553,19 @@ def map_write_v2(request: proto_base.ExecutePlanRequest):
|
|
|
503
553
|
snowpark_table_name = _spark_to_snowflake(write_op.table_name)
|
|
504
554
|
result = map_relation(write_op.input)
|
|
505
555
|
input_df: snowpark.DataFrame = handle_column_names(result, "table")
|
|
556
|
+
|
|
557
|
+
# Create updated container with transformed dataframe, then filter METADATA$FILENAME columns
|
|
558
|
+
# Update the container to use the transformed dataframe from handle_column_names
|
|
559
|
+
updated_result = DataFrameContainer(
|
|
560
|
+
dataframe=input_df,
|
|
561
|
+
column_map=result.column_map,
|
|
562
|
+
table_name=result.table_name,
|
|
563
|
+
alias=result.alias,
|
|
564
|
+
partition_hint=result.partition_hint,
|
|
565
|
+
)
|
|
566
|
+
updated_result = filter_metadata_columns(updated_result)
|
|
567
|
+
input_df = updated_result.dataframe
|
|
568
|
+
|
|
506
569
|
session: snowpark.Session = get_or_create_snowpark_session()
|
|
507
570
|
|
|
508
571
|
if write_op.table_name is None or write_op.table_name == "":
|
|
@@ -232,12 +232,20 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
|
|
|
232
232
|
match request.WhichOneof("analyze"):
|
|
233
233
|
case "schema":
|
|
234
234
|
result = map_relation(request.schema.plan.root)
|
|
235
|
-
|
|
236
|
-
|
|
235
|
+
|
|
236
|
+
from snowflake.snowpark_connect.relation.read.metadata_utils import (
|
|
237
|
+
filter_metadata_columns,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
filtered_result = filter_metadata_columns(result)
|
|
241
|
+
filtered_df = filtered_result.dataframe
|
|
242
|
+
|
|
237
243
|
schema = proto_base.AnalyzePlanResponse.Schema(
|
|
238
244
|
schema=types_proto.DataType(
|
|
239
245
|
**snowpark_to_proto_type(
|
|
240
|
-
|
|
246
|
+
filtered_df.schema,
|
|
247
|
+
filtered_result.column_map,
|
|
248
|
+
filtered_df,
|
|
241
249
|
)
|
|
242
250
|
)
|
|
243
251
|
)
|
|
@@ -1161,23 +1169,28 @@ def get_session(url: Optional[str] = None, conf: SparkConf = None) -> SparkSessi
|
|
|
1161
1169
|
|
|
1162
1170
|
|
|
1163
1171
|
def init_spark_session(conf: SparkConf = None) -> SparkSession:
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1172
|
+
if os.environ.get("JAVA_HOME") is None:
|
|
1173
|
+
try:
|
|
1174
|
+
# For Notebooks on SPCS
|
|
1175
|
+
from jdk4py import JAVA_HOME
|
|
1176
|
+
|
|
1177
|
+
os.environ["JAVA_HOME"] = str(JAVA_HOME)
|
|
1178
|
+
except ModuleNotFoundError:
|
|
1179
|
+
# For notebooks on Warehouse
|
|
1180
|
+
conda_prefix = os.environ.get("CONDA_PREFIX")
|
|
1181
|
+
if conda_prefix is not None:
|
|
1182
|
+
os.environ["JAVA_HOME"] = conda_prefix
|
|
1183
|
+
os.environ["JAVA_LD_LIBRARY_PATH"] = os.path.join(
|
|
1184
|
+
conda_prefix, "lib", "server"
|
|
1185
|
+
)
|
|
1186
|
+
logger.info("JAVA_HOME=%s", os.environ.get("JAVA_HOME", "Not defined"))
|
|
1176
1187
|
|
|
1177
1188
|
os.environ["SPARK_LOCAL_HOSTNAME"] = "127.0.0.1"
|
|
1178
1189
|
os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
|
|
1179
1190
|
|
|
1180
|
-
|
|
1191
|
+
from snowflake.snowpark_connect.utils.session import _get_current_snowpark_session
|
|
1192
|
+
|
|
1193
|
+
snowpark_session = _get_current_snowpark_session()
|
|
1181
1194
|
start_session(snowpark_session=snowpark_session)
|
|
1182
1195
|
return get_session(conf=conf)
|
|
1183
1196
|
|
|
@@ -30,6 +30,10 @@ from snowflake.snowpark_connect.date_time_format_mapping import (
|
|
|
30
30
|
convert_spark_format_to_snowflake,
|
|
31
31
|
)
|
|
32
32
|
from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
|
|
33
|
+
from snowflake.snowpark_connect.expression.map_sql_expression import (
|
|
34
|
+
_INTERVAL_DAYTIME_PATTERN_RE,
|
|
35
|
+
_INTERVAL_YEARMONTH_PATTERN_RE,
|
|
36
|
+
)
|
|
33
37
|
from snowflake.snowpark_connect.utils.context import get_is_evaluating_sql
|
|
34
38
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
35
39
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
@@ -274,6 +278,18 @@ def snowpark_to_proto_type(
|
|
|
274
278
|
case snowpark.types.VariantType:
|
|
275
279
|
# For now we are returning a string type for variant types.
|
|
276
280
|
return {"string": types_proto.DataType.String()}
|
|
281
|
+
case snowpark.types.YearMonthIntervalType:
|
|
282
|
+
return {
|
|
283
|
+
"year_month_interval": types_proto.DataType.YearMonthInterval(
|
|
284
|
+
start_field=data_type.start_field, end_field=data_type.end_field
|
|
285
|
+
)
|
|
286
|
+
}
|
|
287
|
+
case snowpark.types.DayTimeIntervalType:
|
|
288
|
+
return {
|
|
289
|
+
"day_time_interval": types_proto.DataType.DayTimeInterval(
|
|
290
|
+
start_field=data_type.start_field, end_field=data_type.end_field
|
|
291
|
+
)
|
|
292
|
+
}
|
|
277
293
|
case _:
|
|
278
294
|
raise SnowparkConnectNotImplementedError(
|
|
279
295
|
f"Unsupported snowpark data type: {data_type}"
|
|
@@ -328,6 +344,24 @@ def cast_to_match_snowpark_type(
|
|
|
328
344
|
return str(content)
|
|
329
345
|
case snowpark.types.TimestampType:
|
|
330
346
|
return str(content)
|
|
347
|
+
case snowpark.types.YearMonthIntervalType:
|
|
348
|
+
if isinstance(content, (int, float)):
|
|
349
|
+
total_months = int(content)
|
|
350
|
+
years = total_months // 12
|
|
351
|
+
months = total_months % 12
|
|
352
|
+
return f"INTERVAL '{years}-{months}' YEAR TO MONTH"
|
|
353
|
+
elif isinstance(content, str) and content.startswith(("+", "-")):
|
|
354
|
+
# Handle Snowflake's native interval format (e.g., "+11-08" or "-2-3")
|
|
355
|
+
# Convert to Spark's format: "INTERVAL 'Y-M' YEAR TO MONTH"
|
|
356
|
+
sign = content[0]
|
|
357
|
+
interval_part = content[1:] # Remove sign
|
|
358
|
+
if sign == "-":
|
|
359
|
+
return f"INTERVAL '-{interval_part}' YEAR TO MONTH"
|
|
360
|
+
else:
|
|
361
|
+
return f"INTERVAL '{interval_part}' YEAR TO MONTH"
|
|
362
|
+
return str(content)
|
|
363
|
+
case snowpark.types.DayTimeIntervalType:
|
|
364
|
+
return str(content)
|
|
331
365
|
case _:
|
|
332
366
|
raise SnowparkConnectNotImplementedError(
|
|
333
367
|
f"Unsupported snowpark data type in casting: {data_type}"
|
|
@@ -411,6 +445,18 @@ def proto_to_snowpark_type(
|
|
|
411
445
|
# For UDT types, return the underlying SQL type
|
|
412
446
|
logger.debug("Returning underlying sql type for udt")
|
|
413
447
|
return proto_to_snowpark_type(data_type.udt.sql_type)
|
|
448
|
+
case "year_month_interval":
|
|
449
|
+
# Preserve start_field and end_field from protobuf
|
|
450
|
+
return snowpark.types.YearMonthIntervalType(
|
|
451
|
+
start_field=data_type.year_month_interval.start_field,
|
|
452
|
+
end_field=data_type.year_month_interval.end_field,
|
|
453
|
+
)
|
|
454
|
+
case "day_time_interval":
|
|
455
|
+
# Preserve start_field and end_field from protobuf
|
|
456
|
+
return snowpark.types.DayTimeIntervalType(
|
|
457
|
+
start_field=data_type.day_time_interval.start_field,
|
|
458
|
+
end_field=data_type.day_time_interval.end_field,
|
|
459
|
+
)
|
|
414
460
|
case _:
|
|
415
461
|
return map_simple_types(data_type.WhichOneof("kind"))
|
|
416
462
|
|
|
@@ -523,6 +569,12 @@ def map_snowpark_types_to_pyarrow_types(
|
|
|
523
569
|
return pa.timestamp(unit, tz=tz)
|
|
524
570
|
case snowpark.types.VariantType:
|
|
525
571
|
return pa.string()
|
|
572
|
+
case snowpark.types.YearMonthIntervalType:
|
|
573
|
+
# Return string type so formatted intervals are preserved in display
|
|
574
|
+
return pa.string()
|
|
575
|
+
case snowpark.types.DayTimeIntervalType:
|
|
576
|
+
# Return string type so formatted intervals are preserved in display
|
|
577
|
+
return pa.string()
|
|
526
578
|
case _:
|
|
527
579
|
raise SnowparkConnectNotImplementedError(
|
|
528
580
|
f"Unsupported snowpark data type: {snowpark_type}"
|
|
@@ -676,6 +728,14 @@ def map_pyspark_types_to_snowpark_types(
|
|
|
676
728
|
return snowpark.types.TimestampType()
|
|
677
729
|
if isinstance(type_to_map, pyspark.sql.types.TimestampNTZType):
|
|
678
730
|
return snowpark.types.TimestampType(timezone=TimestampTimeZone.NTZ)
|
|
731
|
+
if isinstance(type_to_map, pyspark.sql.types.YearMonthIntervalType):
|
|
732
|
+
return snowpark.types.YearMonthIntervalType(
|
|
733
|
+
type_to_map.startField, type_to_map.endField
|
|
734
|
+
)
|
|
735
|
+
if isinstance(type_to_map, pyspark.sql.types.DayTimeIntervalType):
|
|
736
|
+
return snowpark.types.DayTimeIntervalType(
|
|
737
|
+
type_to_map.startField, type_to_map.endField
|
|
738
|
+
)
|
|
679
739
|
raise SnowparkConnectNotImplementedError(
|
|
680
740
|
f"Unsupported spark data type: {type_to_map}"
|
|
681
741
|
)
|
|
@@ -743,6 +803,14 @@ def map_snowpark_to_pyspark_types(
|
|
|
743
803
|
if type_to_map.tz == snowpark.types.TimestampTimeZone.NTZ:
|
|
744
804
|
return pyspark.sql.types.TimestampNTZType()
|
|
745
805
|
return pyspark.sql.types.TimestampType()
|
|
806
|
+
if isinstance(type_to_map, snowpark.types.YearMonthIntervalType):
|
|
807
|
+
return pyspark.sql.types.YearMonthIntervalType(
|
|
808
|
+
type_to_map.start_field, type_to_map.end_field
|
|
809
|
+
)
|
|
810
|
+
if isinstance(type_to_map, snowpark.types.DayTimeIntervalType):
|
|
811
|
+
return pyspark.sql.types.DayTimeIntervalType(
|
|
812
|
+
type_to_map.start_field, type_to_map.end_field
|
|
813
|
+
)
|
|
746
814
|
raise SnowparkConnectNotImplementedError(f"Unsupported data type: {type_to_map}")
|
|
747
815
|
|
|
748
816
|
|
|
@@ -785,10 +853,14 @@ def map_simple_types(simple_type: str) -> snowpark.types.DataType:
|
|
|
785
853
|
return snowpark.types.TimestampType(snowpark.types.TimestampTimeZone.NTZ)
|
|
786
854
|
case "timestamp_ltz":
|
|
787
855
|
return snowpark.types.TimestampType(snowpark.types.TimestampTimeZone.LTZ)
|
|
856
|
+
case "year_month_interval":
|
|
857
|
+
return snowpark.types.YearMonthIntervalType()
|
|
788
858
|
case "day_time_interval":
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
return snowpark.types.
|
|
859
|
+
return snowpark.types.DayTimeIntervalType()
|
|
860
|
+
case type_name if _INTERVAL_YEARMONTH_PATTERN_RE.match(type_name):
|
|
861
|
+
return snowpark.types.YearMonthIntervalType()
|
|
862
|
+
case type_name if _INTERVAL_DAYTIME_PATTERN_RE.match(type_name):
|
|
863
|
+
return snowpark.types.DayTimeIntervalType()
|
|
792
864
|
case _:
|
|
793
865
|
if simple_type.startswith("decimal"):
|
|
794
866
|
precision = int(simple_type.split("(")[1].split(",")[0])
|
|
@@ -30,9 +30,6 @@ _sql_aggregate_function_count = ContextVar[int](
|
|
|
30
30
|
"_contains_aggregate_function", default=0
|
|
31
31
|
)
|
|
32
32
|
|
|
33
|
-
# Context for parsing map_partitions
|
|
34
|
-
_map_partitions_stack = ContextVar[int]("_map_partitions_stack", default=0)
|
|
35
|
-
|
|
36
33
|
# We have to generate our own plan IDs that are different from Spark's.
|
|
37
34
|
# Spark plan IDs start at 0, so pick a "big enough" number to avoid overlaps.
|
|
38
35
|
_STARTING_SQL_PLAN_ID = 0x80000000
|
|
@@ -230,16 +227,6 @@ def push_evaluating_join_condition(join_type, left_keys, right_keys):
|
|
|
230
227
|
_is_evaluating_join_condition.set(prev)
|
|
231
228
|
|
|
232
229
|
|
|
233
|
-
@contextmanager
|
|
234
|
-
def push_map_partitions():
|
|
235
|
-
_map_partitions_stack.set(_map_partitions_stack.get() + 1)
|
|
236
|
-
yield
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
def map_partitions_depth() -> int:
|
|
240
|
-
return _map_partitions_stack.get()
|
|
241
|
-
|
|
242
|
-
|
|
243
230
|
@contextmanager
|
|
244
231
|
def push_sql_scope():
|
|
245
232
|
"""
|
|
@@ -410,7 +397,6 @@ def clear_context_data() -> None:
|
|
|
410
397
|
_view_process_context.set([])
|
|
411
398
|
_next_sql_plan_id.set(_STARTING_SQL_PLAN_ID)
|
|
412
399
|
_sql_plan_name_map.set({})
|
|
413
|
-
_map_partitions_stack.set(0)
|
|
414
400
|
_sql_aggregate_function_count.set(0)
|
|
415
401
|
_sql_named_args.set({})
|
|
416
402
|
_sql_pos_args.set({})
|
|
@@ -16,7 +16,6 @@ from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
|
|
|
16
16
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
17
17
|
from snowflake.snowpark_connect.utils.telemetry import telemetry
|
|
18
18
|
|
|
19
|
-
DESCRIBE_CACHE_TTL_SECONDS = 15
|
|
20
19
|
USE_DESCRIBE_QUERY_CACHE = True
|
|
21
20
|
|
|
22
21
|
DDL_DETECTION_PATTERN = re.compile(r"\s*(CREATE|ALTER|DROP)\b", re.IGNORECASE)
|
|
@@ -51,6 +50,8 @@ class DescribeQueryCache:
|
|
|
51
50
|
return sql_query
|
|
52
51
|
|
|
53
52
|
def get(self, sql_query: str) -> list[ResultMetadataV2] | None:
|
|
53
|
+
from snowflake.snowpark_connect.config import get_describe_cache_ttl_seconds
|
|
54
|
+
|
|
54
55
|
telemetry.report_describe_query_cache_lookup()
|
|
55
56
|
|
|
56
57
|
cache_key = self._get_cache_key(sql_query)
|
|
@@ -59,7 +60,9 @@ class DescribeQueryCache:
|
|
|
59
60
|
|
|
60
61
|
if key in self._cache:
|
|
61
62
|
result, timestamp = self._cache[key]
|
|
62
|
-
|
|
63
|
+
|
|
64
|
+
expired_by = current_time - (timestamp + get_describe_cache_ttl_seconds())
|
|
65
|
+
if expired_by < 0:
|
|
63
66
|
logger.debug(
|
|
64
67
|
f"Returning query result from cache for query: {sql_query[:20]}"
|
|
65
68
|
)
|
|
@@ -92,7 +95,7 @@ class DescribeQueryCache:
|
|
|
92
95
|
telemetry.report_describe_query_cache_hit()
|
|
93
96
|
return result
|
|
94
97
|
else:
|
|
95
|
-
telemetry.report_describe_query_cache_expired()
|
|
98
|
+
telemetry.report_describe_query_cache_expired(expired_by)
|
|
96
99
|
del self._cache[key]
|
|
97
100
|
return None
|
|
98
101
|
|
|
@@ -3,10 +3,46 @@
|
|
|
3
3
|
#
|
|
4
4
|
import contextlib
|
|
5
5
|
import functools
|
|
6
|
+
import re
|
|
6
7
|
|
|
7
8
|
from snowflake.snowpark import Session
|
|
9
|
+
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
10
|
+
create_file_format_statement,
|
|
11
|
+
)
|
|
8
12
|
from snowflake.snowpark_connect.utils.identifiers import FQN
|
|
9
13
|
|
|
14
|
+
_MINUS_AT_THE_BEGINNING_REGEX = re.compile(r"^-")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def cached_file_format(
|
|
18
|
+
session: Session, file_format: str, format_type_options: dict[str, str]
|
|
19
|
+
) -> str:
|
|
20
|
+
"""
|
|
21
|
+
Cache and return a file format name based on the given options.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
function_name = _MINUS_AT_THE_BEGINNING_REGEX.sub(
|
|
25
|
+
"1", str(hash(frozenset(format_type_options.items())))
|
|
26
|
+
)
|
|
27
|
+
file_format_name = f"__SNOWPARK_CONNECT_FILE_FORMAT__{file_format}_{function_name}"
|
|
28
|
+
if file_format_name in session._file_formats:
|
|
29
|
+
return file_format_name
|
|
30
|
+
|
|
31
|
+
session.sql(
|
|
32
|
+
create_file_format_statement(
|
|
33
|
+
file_format_name,
|
|
34
|
+
file_format,
|
|
35
|
+
format_type_options,
|
|
36
|
+
temp=True,
|
|
37
|
+
if_not_exist=True,
|
|
38
|
+
use_scoped_temp_objects=False,
|
|
39
|
+
is_generated=True,
|
|
40
|
+
)
|
|
41
|
+
).collect()
|
|
42
|
+
|
|
43
|
+
session._file_formats.add(file_format_name)
|
|
44
|
+
return file_format_name
|
|
45
|
+
|
|
10
46
|
|
|
11
47
|
@functools.cache
|
|
12
48
|
def file_format(
|
|
@@ -71,6 +71,9 @@ def configure_snowpark_session(session: snowpark.Session):
|
|
|
71
71
|
init_builtin_udf_cache(session)
|
|
72
72
|
init_external_udxf_cache(session)
|
|
73
73
|
|
|
74
|
+
# file format cache
|
|
75
|
+
session._file_formats = set()
|
|
76
|
+
|
|
74
77
|
# Set experimental parameters (warnings globally suppressed)
|
|
75
78
|
session.ast_enabled = False
|
|
76
79
|
session.eliminate_numeric_sql_value_cast_enabled = False
|
|
@@ -117,6 +120,7 @@ def configure_snowpark_session(session: snowpark.Session):
|
|
|
117
120
|
"PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS": "false", # this is required for creating udfs from sproc
|
|
118
121
|
"ENABLE_STRUCTURED_TYPES_IN_SNOWPARK_CONNECT_RESPONSE": "true",
|
|
119
122
|
"QUERY_TAG": f"'{query_tag}'",
|
|
123
|
+
"FEATURE_INTERVAL_TYPES": "enabled",
|
|
120
124
|
}
|
|
121
125
|
|
|
122
126
|
session.sql(
|