snowpark-connect 0.28.1__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (47) hide show
  1. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  2. snowflake/snowpark_connect/client.py +65 -0
  3. snowflake/snowpark_connect/column_name_handler.py +6 -0
  4. snowflake/snowpark_connect/config.py +33 -5
  5. snowflake/snowpark_connect/execute_plan/map_execution_root.py +21 -19
  6. snowflake/snowpark_connect/expression/map_extension.py +277 -1
  7. snowflake/snowpark_connect/expression/map_sql_expression.py +107 -2
  8. snowflake/snowpark_connect/expression/map_unresolved_function.py +425 -269
  9. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  10. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  11. snowflake/snowpark_connect/relation/io_utils.py +21 -1
  12. snowflake/snowpark_connect/relation/map_column_ops.py +9 -4
  13. snowflake/snowpark_connect/relation/map_extension.py +21 -4
  14. snowflake/snowpark_connect/relation/map_join.py +8 -0
  15. snowflake/snowpark_connect/relation/map_map_partitions.py +7 -8
  16. snowflake/snowpark_connect/relation/map_relation.py +1 -3
  17. snowflake/snowpark_connect/relation/map_row_ops.py +116 -15
  18. snowflake/snowpark_connect/relation/map_show_string.py +14 -6
  19. snowflake/snowpark_connect/relation/map_sql.py +39 -5
  20. snowflake/snowpark_connect/relation/map_stats.py +1 -1
  21. snowflake/snowpark_connect/relation/read/map_read.py +22 -3
  22. snowflake/snowpark_connect/relation/read/map_read_csv.py +119 -29
  23. snowflake/snowpark_connect/relation/read/map_read_json.py +57 -36
  24. snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -1
  25. snowflake/snowpark_connect/relation/read/map_read_text.py +6 -1
  26. snowflake/snowpark_connect/relation/read/metadata_utils.py +159 -0
  27. snowflake/snowpark_connect/relation/stage_locator.py +85 -53
  28. snowflake/snowpark_connect/relation/write/map_write.py +67 -4
  29. snowflake/snowpark_connect/server.py +29 -16
  30. snowflake/snowpark_connect/type_mapping.py +75 -3
  31. snowflake/snowpark_connect/utils/context.py +0 -14
  32. snowflake/snowpark_connect/utils/describe_query_cache.py +6 -3
  33. snowflake/snowpark_connect/utils/io_utils.py +36 -0
  34. snowflake/snowpark_connect/utils/session.py +4 -0
  35. snowflake/snowpark_connect/utils/telemetry.py +30 -5
  36. snowflake/snowpark_connect/utils/udf_cache.py +37 -7
  37. snowflake/snowpark_connect/version.py +1 -1
  38. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/METADATA +3 -2
  39. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/RECORD +47 -45
  40. {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-connect +0 -0
  41. {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-session +0 -0
  42. {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-submit +0 -0
  43. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/WHEEL +0 -0
  44. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/LICENSE-binary +0 -0
  45. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/LICENSE.txt +0 -0
  46. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/NOTICE-binary +0 -0
  47. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@
5
5
  import os
6
6
 
7
7
  from fsspec.core import url_to_fs
8
+ from pyspark.errors.exceptions.base import AnalysisException
8
9
  from s3fs.core import S3FileSystem
9
10
 
10
11
  from snowflake import snowpark
@@ -33,37 +34,42 @@ def get_paths_from_stage(
33
34
 
34
35
  # TODO : What if GCP?
35
36
  # TODO: What if already stage path?
36
- if get_cloud_from_url(paths[0]) == "azure":
37
- rewrite_paths = []
38
- for p in paths:
39
- _, bucket_name, path = parse_azure_url(p)
40
- rewrite_paths.append(f"{stage_name}/{path}")
41
- paths = rewrite_paths
42
- else:
43
- filesystem, parsed_path = url_to_fs(paths[0])
44
- if isinstance(filesystem, S3FileSystem): # aws
45
- # Remove bucket name from the path since the stage name will replace
46
- # the bucket name in the path.
47
- paths = [
48
- f"{stage_name}/{'/'.join(url_to_fs(p)[1].split('/')[1:])}"
49
- for p in paths
50
- ]
51
- else: # local
52
- # For local files, we need to preserve directory structure for partitioned data
53
- # Instead of just using basename, we'll use the last few path components
54
- new_paths = []
37
+ match get_cloud_from_url(paths[0]):
38
+ case "azure":
39
+ rewrite_paths = []
55
40
  for p in paths:
56
- # Split the path and take the last 2-3 components to preserve structure
57
- # but avoid very long paths
58
- path_parts = p.split(os.sep)
59
- if len(path_parts) >= 2:
60
- # Take last 2 components (e.g., "base_case/x=abc")
61
- relative_path = "/".join(path_parts[-2:])
62
- else:
63
- # Single component, use basename
64
- relative_path = os.path.basename(p)
65
- new_paths.append(f"{stage_name}/{relative_path}")
66
- paths = new_paths
41
+ _, bucket_name, path = parse_azure_url(p)
42
+ rewrite_paths.append(f"{stage_name}/{path}")
43
+ paths = rewrite_paths
44
+ case "gcp":
45
+ raise AnalysisException(
46
+ "You must configure an integration for Google Cloud Storage to perform I/O operations rather than accessing the URL directly. Reference: https://docs.snowflake.com/en/user-guide/data-load-gcs-config"
47
+ )
48
+ case _:
49
+ filesystem, parsed_path = url_to_fs(paths[0])
50
+ if isinstance(filesystem, S3FileSystem): # aws
51
+ # Remove bucket name from the path since the stage name will replace
52
+ # the bucket name in the path.
53
+ paths = [
54
+ f"{stage_name}/{'/'.join(url_to_fs(p)[1].split('/')[1:])}"
55
+ for p in paths
56
+ ]
57
+ else: # local
58
+ # For local files, we need to preserve directory structure for partitioned data
59
+ # Instead of just using basename, we'll use the last few path components
60
+ new_paths = []
61
+ for p in paths:
62
+ # Split the path and take the last 2-3 components to preserve structure
63
+ # but avoid very long paths
64
+ path_parts = p.split(os.sep)
65
+ if len(path_parts) >= 2:
66
+ # Take last 2 components (e.g., "base_case/x=abc")
67
+ relative_path = "/".join(path_parts[-2:])
68
+ else:
69
+ # Single component, use basename
70
+ relative_path = os.path.basename(p)
71
+ new_paths.append(f"{stage_name}/{relative_path}")
72
+ paths = new_paths
67
73
 
68
74
  return paths
69
75
 
@@ -102,15 +108,21 @@ class StageLocator:
102
108
  sql_query = f"CREATE OR REPLACE TEMP STAGE {stage_name[1:]} URL='azure://{account}.blob.core.windows.net/{bucket_name}'"
103
109
 
104
110
  credential_session_key = (
105
- f"fs.azure.sas.{bucket_name}.{account}.blob.core.windows.net"
111
+ f"fs.azure.sas.fixed.token.{account}.dfs.core.windows.net",
112
+ f"fs.azure.sas.{bucket_name}.{account}.blob.core.windows.net",
106
113
  )
107
114
  credential = sessions_config.get(spark_session_id, None)
108
- if (
109
- credential is not None
110
- and credential.get(credential_session_key) is not None
111
- and credential.get(credential_session_key).strip() != ""
112
- ):
113
- sql_query += f" CREDENTIALS = (AZURE_SAS_TOKEN = '{credential.get(credential_session_key)}')"
115
+ sas_token = None
116
+ for session_key in credential_session_key:
117
+ if (
118
+ credential is not None
119
+ and credential.get(session_key) is not None
120
+ and credential.get(session_key).strip() != ""
121
+ ):
122
+ sas_token = credential.get(session_key)
123
+ break
124
+ if sas_token is not None:
125
+ sql_query += f" CREDENTIALS = (AZURE_SAS_TOKEN = '{sas_token}')"
114
126
 
115
127
  logger.info(self.session.sql(sql_query).collect())
116
128
  self.stages_for_azure[bucket_name] = stage_name
@@ -128,24 +140,44 @@ class StageLocator:
128
140
  # but the rest of the time it's used, it does. We just drop it here.
129
141
  sql_query = f"CREATE OR REPLACE TEMP STAGE {stage_name[1:]} URL='s3://{parsed_path.split('/')[0]}'"
130
142
  credential = sessions_config.get(spark_session_id, None)
131
- if (
132
- credential is not None
133
- and credential.get("spark.hadoop.fs.s3a.access.key") is not None
134
- and credential.get("spark.hadoop.fs.s3a.secret.key") is not None
135
- and credential.get("spark.hadoop.fs.s3a.access.key").strip()
136
- != ""
137
- and credential.get("spark.hadoop.fs.s3a.secret.key").strip()
138
- != ""
139
- ):
140
- aws_keys = f" AWS_KEY_ID = '{credential.get('spark.hadoop.fs.s3a.access.key')}'"
141
- aws_keys += f" AWS_SECRET_KEY = '{credential.get('spark.hadoop.fs.s3a.secret.key')}'"
142
- if (
143
- credential.get("spark.hadoop.fs.s3a.session.token")
143
+ if credential is not None:
144
+ if ( # USE AWS KEYS to connect
145
+ credential.get("spark.hadoop.fs.s3a.access.key") is not None
146
+ and credential.get("spark.hadoop.fs.s3a.secret.key")
147
+ is not None
148
+ and credential.get("spark.hadoop.fs.s3a.access.key").strip()
149
+ != ""
150
+ and credential.get("spark.hadoop.fs.s3a.secret.key").strip()
151
+ != ""
152
+ ):
153
+ aws_keys = f" AWS_KEY_ID = '{credential.get('spark.hadoop.fs.s3a.access.key')}'"
154
+ aws_keys += f" AWS_SECRET_KEY = '{credential.get('spark.hadoop.fs.s3a.secret.key')}'"
155
+ if (
156
+ credential.get("spark.hadoop.fs.s3a.session.token")
157
+ is not None
158
+ ):
159
+ aws_keys += f" AWS_TOKEN = '{credential.get('spark.hadoop.fs.s3a.session.token')}'"
160
+ sql_query += f" CREDENTIALS = ({aws_keys})"
161
+ sql_query += " ENCRYPTION = ( TYPE = 'AWS_SSE_S3' )"
162
+ elif ( # USE AWS ROLE and KMS KEY to connect
163
+ credential.get(
164
+ "spark.hadoop.fs.s3a.server-side-encryption.key"
165
+ )
166
+ is not None
167
+ and credential.get(
168
+ "spark.hadoop.fs.s3a.server-side-encryption.key"
169
+ ).strip()
170
+ != ""
171
+ and credential.get("spark.hadoop.fs.s3a.assumed.role.arn")
144
172
  is not None
173
+ and credential.get(
174
+ "spark.hadoop.fs.s3a.assumed.role.arn"
175
+ ).strip()
176
+ != ""
145
177
  ):
146
- aws_keys += f" AWS_TOKEN = '{credential.get('spark.hadoop.fs.s3a.session.token')}'"
147
- sql_query += f" CREDENTIALS = ({aws_keys})"
148
- sql_query += " ENCRYPTION = ( TYPE = 'AWS_SSE_S3' )"
178
+ aws_role = f" AWS_ROLE = '{credential.get('spark.hadoop.fs.s3a.assumed.role.arn')}'"
179
+ sql_query += f" CREDENTIALS = ({aws_role})"
180
+ sql_query += f" ENCRYPTION = ( TYPE='AWS_SSE_KMS' KMS_KEY_ID = '{credential.get('spark.hadoop.fs.s3a.server-side-encryption.key')}' )"
149
181
 
150
182
  logger.info(self.session.sql(sql_query).collect())
151
183
  self.stages_for_aws[bucket_name] = stage_name
@@ -36,8 +36,13 @@ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
36
36
  from snowflake.snowpark_connect.relation.io_utils import (
37
37
  convert_file_prefix_path,
38
38
  is_cloud_path,
39
+ is_supported_compression,
40
+ supported_compressions_for_format,
39
41
  )
40
42
  from snowflake.snowpark_connect.relation.map_relation import map_relation
43
+ from snowflake.snowpark_connect.relation.read.metadata_utils import (
44
+ filter_metadata_columns,
45
+ )
41
46
  from snowflake.snowpark_connect.relation.read.reader_config import CsvWriterConfig
42
47
  from snowflake.snowpark_connect.relation.stage_locator import get_paths_from_stage
43
48
  from snowflake.snowpark_connect.relation.utils import (
@@ -127,6 +132,19 @@ def map_write(request: proto_base.ExecutePlanRequest):
127
132
 
128
133
  result = map_relation(write_op.input)
129
134
  input_df: snowpark.DataFrame = handle_column_names(result, write_op.source)
135
+
136
+ # Create updated container with transformed dataframe, then filter METADATA$FILENAME columns
137
+ # Update the container to use the transformed dataframe from handle_column_names
138
+ updated_result = DataFrameContainer(
139
+ dataframe=input_df,
140
+ column_map=result.column_map,
141
+ table_name=result.table_name,
142
+ alias=result.alias,
143
+ partition_hint=result.partition_hint,
144
+ )
145
+ updated_result = filter_metadata_columns(updated_result)
146
+ input_df = updated_result.dataframe
147
+
130
148
  session: snowpark.Session = get_or_create_snowpark_session()
131
149
 
132
150
  # Snowflake saveAsTable doesn't support format
@@ -179,7 +197,7 @@ def map_write(request: proto_base.ExecutePlanRequest):
179
197
  f"Skipping REMOVE for root path {write_path} - too broad scope"
180
198
  )
181
199
  else:
182
- remove_command = f"REMOVE {write_path}/"
200
+ remove_command = f"REMOVE '{write_path}/'"
183
201
  session.sql(remove_command).collect()
184
202
  logger.info(f"Successfully cleared directory: {write_path}")
185
203
  except Exception as e:
@@ -208,6 +226,20 @@ def map_write(request: proto_base.ExecutePlanRequest):
208
226
  compression = write_op.options.get(
209
227
  "compression", default_compression
210
228
  ).upper()
229
+
230
+ if not is_supported_compression(write_op.source, compression):
231
+ supported_compressions = supported_compressions_for_format(
232
+ write_op.source
233
+ )
234
+ raise AnalysisException(
235
+ f"Compression {compression} is not supported for {write_op.source} format. "
236
+ + (
237
+ f"Supported compressions: {sorted(supported_compressions)}"
238
+ if supported_compressions
239
+ else "No compression supported for this format."
240
+ )
241
+ )
242
+
211
243
  parameters = {
212
244
  "location": temp_file_prefix_on_stage,
213
245
  "file_format_type": write_op.source
@@ -417,9 +449,27 @@ def map_write(request: proto_base.ExecutePlanRequest):
417
449
  )
418
450
  case _:
419
451
  snowpark_table_name = _spark_to_snowflake(write_op.table.table_name)
452
+ save_method = write_op.table.save_method
420
453
 
421
454
  if (
422
- write_op.table.save_method
455
+ write_op.source == "snowflake"
456
+ and write_op.table.save_method
457
+ == commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_UNSPECIFIED
458
+ ):
459
+ save_method = (
460
+ commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE
461
+ )
462
+ if len(write_op.table.table_name) == 0:
463
+ dbtable_name = write_op.options.get("dbtable", "")
464
+ if len(dbtable_name) == 0:
465
+ raise SnowparkConnectNotImplementedError(
466
+ "Save command is not supported without a table name"
467
+ )
468
+ else:
469
+ snowpark_table_name = _spark_to_snowflake(dbtable_name)
470
+
471
+ if (
472
+ save_method
423
473
  == commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE
424
474
  ):
425
475
  match write_mode:
@@ -481,7 +531,7 @@ def map_write(request: proto_base.ExecutePlanRequest):
481
531
  column_order=_column_order_for_write,
482
532
  )
483
533
  elif (
484
- write_op.table.save_method
534
+ save_method
485
535
  == commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO
486
536
  ):
487
537
  _validate_schema_and_get_writer(
@@ -493,7 +543,7 @@ def map_write(request: proto_base.ExecutePlanRequest):
493
543
  )
494
544
  else:
495
545
  raise SnowparkConnectNotImplementedError(
496
- f"Save command not supported: {write_op.table.save_method}"
546
+ f"Save command not supported: {save_method}"
497
547
  )
498
548
 
499
549
 
@@ -503,6 +553,19 @@ def map_write_v2(request: proto_base.ExecutePlanRequest):
503
553
  snowpark_table_name = _spark_to_snowflake(write_op.table_name)
504
554
  result = map_relation(write_op.input)
505
555
  input_df: snowpark.DataFrame = handle_column_names(result, "table")
556
+
557
+ # Create updated container with transformed dataframe, then filter METADATA$FILENAME columns
558
+ # Update the container to use the transformed dataframe from handle_column_names
559
+ updated_result = DataFrameContainer(
560
+ dataframe=input_df,
561
+ column_map=result.column_map,
562
+ table_name=result.table_name,
563
+ alias=result.alias,
564
+ partition_hint=result.partition_hint,
565
+ )
566
+ updated_result = filter_metadata_columns(updated_result)
567
+ input_df = updated_result.dataframe
568
+
506
569
  session: snowpark.Session = get_or_create_snowpark_session()
507
570
 
508
571
  if write_op.table_name is None or write_op.table_name == "":
@@ -232,12 +232,20 @@ class SnowflakeConnectServicer(proto_base_grpc.SparkConnectServiceServicer):
232
232
  match request.WhichOneof("analyze"):
233
233
  case "schema":
234
234
  result = map_relation(request.schema.plan.root)
235
- snowpark_df = result.dataframe
236
- snowpark_schema: snowpark.types.StructType = snowpark_df.schema
235
+
236
+ from snowflake.snowpark_connect.relation.read.metadata_utils import (
237
+ filter_metadata_columns,
238
+ )
239
+
240
+ filtered_result = filter_metadata_columns(result)
241
+ filtered_df = filtered_result.dataframe
242
+
237
243
  schema = proto_base.AnalyzePlanResponse.Schema(
238
244
  schema=types_proto.DataType(
239
245
  **snowpark_to_proto_type(
240
- snowpark_schema, result.column_map, snowpark_df
246
+ filtered_df.schema,
247
+ filtered_result.column_map,
248
+ filtered_df,
241
249
  )
242
250
  )
243
251
  )
@@ -1161,23 +1169,28 @@ def get_session(url: Optional[str] = None, conf: SparkConf = None) -> SparkSessi
1161
1169
 
1162
1170
 
1163
1171
  def init_spark_session(conf: SparkConf = None) -> SparkSession:
1164
- try:
1165
- # For Notebooks on SPCS
1166
- from jdk4py import JAVA_HOME
1167
-
1168
- os.environ["JAVA_HOME"] = str(JAVA_HOME)
1169
- except ModuleNotFoundError:
1170
- # For notebooks on Warehouse
1171
- os.environ["JAVA_HOME"] = os.environ["CONDA_PREFIX"]
1172
- os.environ["JAVA_LD_LIBRARY_PATH"] = os.path.join(
1173
- os.environ["CONDA_PREFIX"], "lib", "server"
1174
- )
1175
- logger.info("JAVA_HOME=%s", os.environ["JAVA_HOME"])
1172
+ if os.environ.get("JAVA_HOME") is None:
1173
+ try:
1174
+ # For Notebooks on SPCS
1175
+ from jdk4py import JAVA_HOME
1176
+
1177
+ os.environ["JAVA_HOME"] = str(JAVA_HOME)
1178
+ except ModuleNotFoundError:
1179
+ # For notebooks on Warehouse
1180
+ conda_prefix = os.environ.get("CONDA_PREFIX")
1181
+ if conda_prefix is not None:
1182
+ os.environ["JAVA_HOME"] = conda_prefix
1183
+ os.environ["JAVA_LD_LIBRARY_PATH"] = os.path.join(
1184
+ conda_prefix, "lib", "server"
1185
+ )
1186
+ logger.info("JAVA_HOME=%s", os.environ.get("JAVA_HOME", "Not defined"))
1176
1187
 
1177
1188
  os.environ["SPARK_LOCAL_HOSTNAME"] = "127.0.0.1"
1178
1189
  os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
1179
1190
 
1180
- snowpark_session = snowpark.context.get_active_session()
1191
+ from snowflake.snowpark_connect.utils.session import _get_current_snowpark_session
1192
+
1193
+ snowpark_session = _get_current_snowpark_session()
1181
1194
  start_session(snowpark_session=snowpark_session)
1182
1195
  return get_session(conf=conf)
1183
1196
 
@@ -30,6 +30,10 @@ from snowflake.snowpark_connect.date_time_format_mapping import (
30
30
  convert_spark_format_to_snowflake,
31
31
  )
32
32
  from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
33
+ from snowflake.snowpark_connect.expression.map_sql_expression import (
34
+ _INTERVAL_DAYTIME_PATTERN_RE,
35
+ _INTERVAL_YEARMONTH_PATTERN_RE,
36
+ )
33
37
  from snowflake.snowpark_connect.utils.context import get_is_evaluating_sql
34
38
  from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
35
39
  from snowflake.snowpark_connect.utils.telemetry import (
@@ -274,6 +278,18 @@ def snowpark_to_proto_type(
274
278
  case snowpark.types.VariantType:
275
279
  # For now we are returning a string type for variant types.
276
280
  return {"string": types_proto.DataType.String()}
281
+ case snowpark.types.YearMonthIntervalType:
282
+ return {
283
+ "year_month_interval": types_proto.DataType.YearMonthInterval(
284
+ start_field=data_type.start_field, end_field=data_type.end_field
285
+ )
286
+ }
287
+ case snowpark.types.DayTimeIntervalType:
288
+ return {
289
+ "day_time_interval": types_proto.DataType.DayTimeInterval(
290
+ start_field=data_type.start_field, end_field=data_type.end_field
291
+ )
292
+ }
277
293
  case _:
278
294
  raise SnowparkConnectNotImplementedError(
279
295
  f"Unsupported snowpark data type: {data_type}"
@@ -328,6 +344,24 @@ def cast_to_match_snowpark_type(
328
344
  return str(content)
329
345
  case snowpark.types.TimestampType:
330
346
  return str(content)
347
+ case snowpark.types.YearMonthIntervalType:
348
+ if isinstance(content, (int, float)):
349
+ total_months = int(content)
350
+ years = total_months // 12
351
+ months = total_months % 12
352
+ return f"INTERVAL '{years}-{months}' YEAR TO MONTH"
353
+ elif isinstance(content, str) and content.startswith(("+", "-")):
354
+ # Handle Snowflake's native interval format (e.g., "+11-08" or "-2-3")
355
+ # Convert to Spark's format: "INTERVAL 'Y-M' YEAR TO MONTH"
356
+ sign = content[0]
357
+ interval_part = content[1:] # Remove sign
358
+ if sign == "-":
359
+ return f"INTERVAL '-{interval_part}' YEAR TO MONTH"
360
+ else:
361
+ return f"INTERVAL '{interval_part}' YEAR TO MONTH"
362
+ return str(content)
363
+ case snowpark.types.DayTimeIntervalType:
364
+ return str(content)
331
365
  case _:
332
366
  raise SnowparkConnectNotImplementedError(
333
367
  f"Unsupported snowpark data type in casting: {data_type}"
@@ -411,6 +445,18 @@ def proto_to_snowpark_type(
411
445
  # For UDT types, return the underlying SQL type
412
446
  logger.debug("Returning underlying sql type for udt")
413
447
  return proto_to_snowpark_type(data_type.udt.sql_type)
448
+ case "year_month_interval":
449
+ # Preserve start_field and end_field from protobuf
450
+ return snowpark.types.YearMonthIntervalType(
451
+ start_field=data_type.year_month_interval.start_field,
452
+ end_field=data_type.year_month_interval.end_field,
453
+ )
454
+ case "day_time_interval":
455
+ # Preserve start_field and end_field from protobuf
456
+ return snowpark.types.DayTimeIntervalType(
457
+ start_field=data_type.day_time_interval.start_field,
458
+ end_field=data_type.day_time_interval.end_field,
459
+ )
414
460
  case _:
415
461
  return map_simple_types(data_type.WhichOneof("kind"))
416
462
 
@@ -523,6 +569,12 @@ def map_snowpark_types_to_pyarrow_types(
523
569
  return pa.timestamp(unit, tz=tz)
524
570
  case snowpark.types.VariantType:
525
571
  return pa.string()
572
+ case snowpark.types.YearMonthIntervalType:
573
+ # Return string type so formatted intervals are preserved in display
574
+ return pa.string()
575
+ case snowpark.types.DayTimeIntervalType:
576
+ # Return string type so formatted intervals are preserved in display
577
+ return pa.string()
526
578
  case _:
527
579
  raise SnowparkConnectNotImplementedError(
528
580
  f"Unsupported snowpark data type: {snowpark_type}"
@@ -676,6 +728,14 @@ def map_pyspark_types_to_snowpark_types(
676
728
  return snowpark.types.TimestampType()
677
729
  if isinstance(type_to_map, pyspark.sql.types.TimestampNTZType):
678
730
  return snowpark.types.TimestampType(timezone=TimestampTimeZone.NTZ)
731
+ if isinstance(type_to_map, pyspark.sql.types.YearMonthIntervalType):
732
+ return snowpark.types.YearMonthIntervalType(
733
+ type_to_map.startField, type_to_map.endField
734
+ )
735
+ if isinstance(type_to_map, pyspark.sql.types.DayTimeIntervalType):
736
+ return snowpark.types.DayTimeIntervalType(
737
+ type_to_map.startField, type_to_map.endField
738
+ )
679
739
  raise SnowparkConnectNotImplementedError(
680
740
  f"Unsupported spark data type: {type_to_map}"
681
741
  )
@@ -743,6 +803,14 @@ def map_snowpark_to_pyspark_types(
743
803
  if type_to_map.tz == snowpark.types.TimestampTimeZone.NTZ:
744
804
  return pyspark.sql.types.TimestampNTZType()
745
805
  return pyspark.sql.types.TimestampType()
806
+ if isinstance(type_to_map, snowpark.types.YearMonthIntervalType):
807
+ return pyspark.sql.types.YearMonthIntervalType(
808
+ type_to_map.start_field, type_to_map.end_field
809
+ )
810
+ if isinstance(type_to_map, snowpark.types.DayTimeIntervalType):
811
+ return pyspark.sql.types.DayTimeIntervalType(
812
+ type_to_map.start_field, type_to_map.end_field
813
+ )
746
814
  raise SnowparkConnectNotImplementedError(f"Unsupported data type: {type_to_map}")
747
815
 
748
816
 
@@ -785,10 +853,14 @@ def map_simple_types(simple_type: str) -> snowpark.types.DataType:
785
853
  return snowpark.types.TimestampType(snowpark.types.TimestampTimeZone.NTZ)
786
854
  case "timestamp_ltz":
787
855
  return snowpark.types.TimestampType(snowpark.types.TimestampTimeZone.LTZ)
856
+ case "year_month_interval":
857
+ return snowpark.types.YearMonthIntervalType()
788
858
  case "day_time_interval":
789
- # this is not a column type in snowflake so there won't be a dataframe column
790
- # with this, for now this type won't make any sense
791
- return snowpark.types.StringType()
859
+ return snowpark.types.DayTimeIntervalType()
860
+ case type_name if _INTERVAL_YEARMONTH_PATTERN_RE.match(type_name):
861
+ return snowpark.types.YearMonthIntervalType()
862
+ case type_name if _INTERVAL_DAYTIME_PATTERN_RE.match(type_name):
863
+ return snowpark.types.DayTimeIntervalType()
792
864
  case _:
793
865
  if simple_type.startswith("decimal"):
794
866
  precision = int(simple_type.split("(")[1].split(",")[0])
@@ -30,9 +30,6 @@ _sql_aggregate_function_count = ContextVar[int](
30
30
  "_contains_aggregate_function", default=0
31
31
  )
32
32
 
33
- # Context for parsing map_partitions
34
- _map_partitions_stack = ContextVar[int]("_map_partitions_stack", default=0)
35
-
36
33
  # We have to generate our own plan IDs that are different from Spark's.
37
34
  # Spark plan IDs start at 0, so pick a "big enough" number to avoid overlaps.
38
35
  _STARTING_SQL_PLAN_ID = 0x80000000
@@ -230,16 +227,6 @@ def push_evaluating_join_condition(join_type, left_keys, right_keys):
230
227
  _is_evaluating_join_condition.set(prev)
231
228
 
232
229
 
233
- @contextmanager
234
- def push_map_partitions():
235
- _map_partitions_stack.set(_map_partitions_stack.get() + 1)
236
- yield
237
-
238
-
239
- def map_partitions_depth() -> int:
240
- return _map_partitions_stack.get()
241
-
242
-
243
230
  @contextmanager
244
231
  def push_sql_scope():
245
232
  """
@@ -410,7 +397,6 @@ def clear_context_data() -> None:
410
397
  _view_process_context.set([])
411
398
  _next_sql_plan_id.set(_STARTING_SQL_PLAN_ID)
412
399
  _sql_plan_name_map.set({})
413
- _map_partitions_stack.set(0)
414
400
  _sql_aggregate_function_count.set(0)
415
401
  _sql_named_args.set({})
416
402
  _sql_pos_args.set({})
@@ -16,7 +16,6 @@ from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
16
16
  from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
17
17
  from snowflake.snowpark_connect.utils.telemetry import telemetry
18
18
 
19
- DESCRIBE_CACHE_TTL_SECONDS = 15
20
19
  USE_DESCRIBE_QUERY_CACHE = True
21
20
 
22
21
  DDL_DETECTION_PATTERN = re.compile(r"\s*(CREATE|ALTER|DROP)\b", re.IGNORECASE)
@@ -51,6 +50,8 @@ class DescribeQueryCache:
51
50
  return sql_query
52
51
 
53
52
  def get(self, sql_query: str) -> list[ResultMetadataV2] | None:
53
+ from snowflake.snowpark_connect.config import get_describe_cache_ttl_seconds
54
+
54
55
  telemetry.report_describe_query_cache_lookup()
55
56
 
56
57
  cache_key = self._get_cache_key(sql_query)
@@ -59,7 +60,9 @@ class DescribeQueryCache:
59
60
 
60
61
  if key in self._cache:
61
62
  result, timestamp = self._cache[key]
62
- if current_time < timestamp + DESCRIBE_CACHE_TTL_SECONDS:
63
+
64
+ expired_by = current_time - (timestamp + get_describe_cache_ttl_seconds())
65
+ if expired_by < 0:
63
66
  logger.debug(
64
67
  f"Returning query result from cache for query: {sql_query[:20]}"
65
68
  )
@@ -92,7 +95,7 @@ class DescribeQueryCache:
92
95
  telemetry.report_describe_query_cache_hit()
93
96
  return result
94
97
  else:
95
- telemetry.report_describe_query_cache_expired()
98
+ telemetry.report_describe_query_cache_expired(expired_by)
96
99
  del self._cache[key]
97
100
  return None
98
101
 
@@ -3,10 +3,46 @@
3
3
  #
4
4
  import contextlib
5
5
  import functools
6
+ import re
6
7
 
7
8
  from snowflake.snowpark import Session
9
+ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
10
+ create_file_format_statement,
11
+ )
8
12
  from snowflake.snowpark_connect.utils.identifiers import FQN
9
13
 
14
+ _MINUS_AT_THE_BEGINNING_REGEX = re.compile(r"^-")
15
+
16
+
17
+ def cached_file_format(
18
+ session: Session, file_format: str, format_type_options: dict[str, str]
19
+ ) -> str:
20
+ """
21
+ Cache and return a file format name based on the given options.
22
+ """
23
+
24
+ function_name = _MINUS_AT_THE_BEGINNING_REGEX.sub(
25
+ "1", str(hash(frozenset(format_type_options.items())))
26
+ )
27
+ file_format_name = f"__SNOWPARK_CONNECT_FILE_FORMAT__{file_format}_{function_name}"
28
+ if file_format_name in session._file_formats:
29
+ return file_format_name
30
+
31
+ session.sql(
32
+ create_file_format_statement(
33
+ file_format_name,
34
+ file_format,
35
+ format_type_options,
36
+ temp=True,
37
+ if_not_exist=True,
38
+ use_scoped_temp_objects=False,
39
+ is_generated=True,
40
+ )
41
+ ).collect()
42
+
43
+ session._file_formats.add(file_format_name)
44
+ return file_format_name
45
+
10
46
 
11
47
  @functools.cache
12
48
  def file_format(
@@ -71,6 +71,9 @@ def configure_snowpark_session(session: snowpark.Session):
71
71
  init_builtin_udf_cache(session)
72
72
  init_external_udxf_cache(session)
73
73
 
74
+ # file format cache
75
+ session._file_formats = set()
76
+
74
77
  # Set experimental parameters (warnings globally suppressed)
75
78
  session.ast_enabled = False
76
79
  session.eliminate_numeric_sql_value_cast_enabled = False
@@ -117,6 +120,7 @@ def configure_snowpark_session(session: snowpark.Session):
117
120
  "PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS": "false", # this is required for creating udfs from sproc
118
121
  "ENABLE_STRUCTURED_TYPES_IN_SNOWPARK_CONNECT_RESPONSE": "true",
119
122
  "QUERY_TAG": f"'{query_tag}'",
123
+ "FEATURE_INTERVAL_TYPES": "enabled",
120
124
  }
121
125
 
122
126
  session.sql(