snowpark-connect 0.28.0__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (36) hide show
  1. snowflake/snowpark_connect/config.py +12 -3
  2. snowflake/snowpark_connect/execute_plan/map_execution_command.py +31 -68
  3. snowflake/snowpark_connect/expression/map_unresolved_function.py +172 -210
  4. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +207 -20
  5. snowflake/snowpark_connect/relation/io_utils.py +21 -1
  6. snowflake/snowpark_connect/relation/map_extension.py +21 -4
  7. snowflake/snowpark_connect/relation/map_map_partitions.py +7 -8
  8. snowflake/snowpark_connect/relation/map_relation.py +1 -3
  9. snowflake/snowpark_connect/relation/map_sql.py +112 -53
  10. snowflake/snowpark_connect/relation/read/map_read.py +22 -3
  11. snowflake/snowpark_connect/relation/read/map_read_csv.py +105 -26
  12. snowflake/snowpark_connect/relation/read/map_read_json.py +45 -34
  13. snowflake/snowpark_connect/relation/read/map_read_table.py +58 -0
  14. snowflake/snowpark_connect/relation/read/map_read_text.py +6 -1
  15. snowflake/snowpark_connect/relation/stage_locator.py +85 -53
  16. snowflake/snowpark_connect/relation/write/map_write.py +95 -14
  17. snowflake/snowpark_connect/server.py +18 -13
  18. snowflake/snowpark_connect/utils/context.py +21 -14
  19. snowflake/snowpark_connect/utils/identifiers.py +8 -2
  20. snowflake/snowpark_connect/utils/io_utils.py +36 -0
  21. snowflake/snowpark_connect/utils/session.py +3 -0
  22. snowflake/snowpark_connect/utils/temporary_view_cache.py +61 -0
  23. snowflake/snowpark_connect/utils/udf_cache.py +37 -7
  24. snowflake/snowpark_connect/utils/udf_utils.py +9 -8
  25. snowflake/snowpark_connect/utils/udtf_utils.py +3 -2
  26. snowflake/snowpark_connect/version.py +1 -1
  27. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/METADATA +3 -2
  28. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/RECORD +36 -35
  29. {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-connect +0 -0
  30. {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-session +0 -0
  31. {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-submit +0 -0
  32. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/WHEEL +0 -0
  33. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE-binary +0 -0
  34. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE.txt +0 -0
  35. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/NOTICE-binary +0 -0
  36. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/top_level.txt +0 -0
@@ -43,7 +43,12 @@ def read_text(
43
43
  ) -> snowpark.DataFrame:
44
44
  # TODO: handle stage name with double quotes
45
45
  files_paths = get_file_paths_from_stage(path, session)
46
- stage_name = path.split("/")[0]
46
+ # Remove matching quotes from both ends of the path to get the stage name, if present.
47
+ if path and len(path) > 1 and path[0] == path[-1] and path[0] in ('"', "'"):
48
+ unquoted_path = path[1:-1]
49
+ else:
50
+ unquoted_path = path
51
+ stage_name = unquoted_path.split("/")[0]
47
52
  line_sep = options.get("lineSep") or "\n"
48
53
  column_name = (
49
54
  schema[0].name if schema is not None and len(schema.fields) > 0 else '"value"'
@@ -5,6 +5,7 @@
5
5
  import os
6
6
 
7
7
  from fsspec.core import url_to_fs
8
+ from pyspark.errors.exceptions.base import AnalysisException
8
9
  from s3fs.core import S3FileSystem
9
10
 
10
11
  from snowflake import snowpark
@@ -33,37 +34,42 @@ def get_paths_from_stage(
33
34
 
34
35
  # TODO : What if GCP?
35
36
  # TODO: What if already stage path?
36
- if get_cloud_from_url(paths[0]) == "azure":
37
- rewrite_paths = []
38
- for p in paths:
39
- _, bucket_name, path = parse_azure_url(p)
40
- rewrite_paths.append(f"{stage_name}/{path}")
41
- paths = rewrite_paths
42
- else:
43
- filesystem, parsed_path = url_to_fs(paths[0])
44
- if isinstance(filesystem, S3FileSystem): # aws
45
- # Remove bucket name from the path since the stage name will replace
46
- # the bucket name in the path.
47
- paths = [
48
- f"{stage_name}/{'/'.join(url_to_fs(p)[1].split('/')[1:])}"
49
- for p in paths
50
- ]
51
- else: # local
52
- # For local files, we need to preserve directory structure for partitioned data
53
- # Instead of just using basename, we'll use the last few path components
54
- new_paths = []
37
+ match get_cloud_from_url(paths[0]):
38
+ case "azure":
39
+ rewrite_paths = []
55
40
  for p in paths:
56
- # Split the path and take the last 2-3 components to preserve structure
57
- # but avoid very long paths
58
- path_parts = p.split(os.sep)
59
- if len(path_parts) >= 2:
60
- # Take last 2 components (e.g., "base_case/x=abc")
61
- relative_path = "/".join(path_parts[-2:])
62
- else:
63
- # Single component, use basename
64
- relative_path = os.path.basename(p)
65
- new_paths.append(f"{stage_name}/{relative_path}")
66
- paths = new_paths
41
+ _, bucket_name, path = parse_azure_url(p)
42
+ rewrite_paths.append(f"{stage_name}/{path}")
43
+ paths = rewrite_paths
44
+ case "gcp":
45
+ raise AnalysisException(
46
+ "You must configure an integration for Google Cloud Storage to perform I/O operations rather than accessing the URL directly. Reference: https://docs.snowflake.com/en/user-guide/data-load-gcs-config"
47
+ )
48
+ case _:
49
+ filesystem, parsed_path = url_to_fs(paths[0])
50
+ if isinstance(filesystem, S3FileSystem): # aws
51
+ # Remove bucket name from the path since the stage name will replace
52
+ # the bucket name in the path.
53
+ paths = [
54
+ f"{stage_name}/{'/'.join(url_to_fs(p)[1].split('/')[1:])}"
55
+ for p in paths
56
+ ]
57
+ else: # local
58
+ # For local files, we need to preserve directory structure for partitioned data
59
+ # Instead of just using basename, we'll use the last few path components
60
+ new_paths = []
61
+ for p in paths:
62
+ # Split the path and take the last 2-3 components to preserve structure
63
+ # but avoid very long paths
64
+ path_parts = p.split(os.sep)
65
+ if len(path_parts) >= 2:
66
+ # Take last 2 components (e.g., "base_case/x=abc")
67
+ relative_path = "/".join(path_parts[-2:])
68
+ else:
69
+ # Single component, use basename
70
+ relative_path = os.path.basename(p)
71
+ new_paths.append(f"{stage_name}/{relative_path}")
72
+ paths = new_paths
67
73
 
68
74
  return paths
69
75
 
@@ -102,15 +108,21 @@ class StageLocator:
102
108
  sql_query = f"CREATE OR REPLACE TEMP STAGE {stage_name[1:]} URL='azure://{account}.blob.core.windows.net/{bucket_name}'"
103
109
 
104
110
  credential_session_key = (
105
- f"fs.azure.sas.{bucket_name}.{account}.blob.core.windows.net"
111
+ f"fs.azure.sas.fixed.token.{account}.dfs.core.windows.net",
112
+ f"fs.azure.sas.{bucket_name}.{account}.blob.core.windows.net",
106
113
  )
107
114
  credential = sessions_config.get(spark_session_id, None)
108
- if (
109
- credential is not None
110
- and credential.get(credential_session_key) is not None
111
- and credential.get(credential_session_key).strip() != ""
112
- ):
113
- sql_query += f" CREDENTIALS = (AZURE_SAS_TOKEN = '{credential.get(credential_session_key)}')"
115
+ sas_token = None
116
+ for session_key in credential_session_key:
117
+ if (
118
+ credential is not None
119
+ and credential.get(session_key) is not None
120
+ and credential.get(session_key).strip() != ""
121
+ ):
122
+ sas_token = credential.get(session_key)
123
+ break
124
+ if sas_token is not None:
125
+ sql_query += f" CREDENTIALS = (AZURE_SAS_TOKEN = '{sas_token}')"
114
126
 
115
127
  logger.info(self.session.sql(sql_query).collect())
116
128
  self.stages_for_azure[bucket_name] = stage_name
@@ -128,24 +140,44 @@ class StageLocator:
128
140
  # but the rest of the time it's used, it does. We just drop it here.
129
141
  sql_query = f"CREATE OR REPLACE TEMP STAGE {stage_name[1:]} URL='s3://{parsed_path.split('/')[0]}'"
130
142
  credential = sessions_config.get(spark_session_id, None)
131
- if (
132
- credential is not None
133
- and credential.get("spark.hadoop.fs.s3a.access.key") is not None
134
- and credential.get("spark.hadoop.fs.s3a.secret.key") is not None
135
- and credential.get("spark.hadoop.fs.s3a.access.key").strip()
136
- != ""
137
- and credential.get("spark.hadoop.fs.s3a.secret.key").strip()
138
- != ""
139
- ):
140
- aws_keys = f" AWS_KEY_ID = '{credential.get('spark.hadoop.fs.s3a.access.key')}'"
141
- aws_keys += f" AWS_SECRET_KEY = '{credential.get('spark.hadoop.fs.s3a.secret.key')}'"
142
- if (
143
- credential.get("spark.hadoop.fs.s3a.session.token")
143
+ if credential is not None:
144
+ if ( # USE AWS KEYS to connect
145
+ credential.get("spark.hadoop.fs.s3a.access.key") is not None
146
+ and credential.get("spark.hadoop.fs.s3a.secret.key")
147
+ is not None
148
+ and credential.get("spark.hadoop.fs.s3a.access.key").strip()
149
+ != ""
150
+ and credential.get("spark.hadoop.fs.s3a.secret.key").strip()
151
+ != ""
152
+ ):
153
+ aws_keys = f" AWS_KEY_ID = '{credential.get('spark.hadoop.fs.s3a.access.key')}'"
154
+ aws_keys += f" AWS_SECRET_KEY = '{credential.get('spark.hadoop.fs.s3a.secret.key')}'"
155
+ if (
156
+ credential.get("spark.hadoop.fs.s3a.session.token")
157
+ is not None
158
+ ):
159
+ aws_keys += f" AWS_TOKEN = '{credential.get('spark.hadoop.fs.s3a.session.token')}'"
160
+ sql_query += f" CREDENTIALS = ({aws_keys})"
161
+ sql_query += " ENCRYPTION = ( TYPE = 'AWS_SSE_S3' )"
162
+ elif ( # USE AWS ROLE and KMS KEY to connect
163
+ credential.get(
164
+ "spark.hadoop.fs.s3a.server-side-encryption.key"
165
+ )
166
+ is not None
167
+ and credential.get(
168
+ "spark.hadoop.fs.s3a.server-side-encryption.key"
169
+ ).strip()
170
+ != ""
171
+ and credential.get("spark.hadoop.fs.s3a.assumed.role.arn")
144
172
  is not None
173
+ and credential.get(
174
+ "spark.hadoop.fs.s3a.assumed.role.arn"
175
+ ).strip()
176
+ != ""
145
177
  ):
146
- aws_keys += f" AWS_TOKEN = '{credential.get('spark.hadoop.fs.s3a.session.token')}'"
147
- sql_query += f" CREDENTIALS = ({aws_keys})"
148
- sql_query += " ENCRYPTION = ( TYPE = 'AWS_SSE_S3' )"
178
+ aws_role = f" AWS_ROLE = '{credential.get('spark.hadoop.fs.s3a.assumed.role.arn')}'"
179
+ sql_query += f" CREDENTIALS = ({aws_role})"
180
+ sql_query += f" ENCRYPTION = ( TYPE='AWS_SSE_KMS' KMS_KEY_ID = '{credential.get('spark.hadoop.fs.s3a.server-side-encryption.key')}' )"
149
181
 
150
182
  logger.info(self.session.sql(sql_query).collect())
151
183
  self.stages_for_aws[bucket_name] = stage_name
@@ -36,6 +36,8 @@ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
36
36
  from snowflake.snowpark_connect.relation.io_utils import (
37
37
  convert_file_prefix_path,
38
38
  is_cloud_path,
39
+ is_supported_compression,
40
+ supported_compressions_for_format,
39
41
  )
40
42
  from snowflake.snowpark_connect.relation.map_relation import map_relation
41
43
  from snowflake.snowpark_connect.relation.read.reader_config import CsvWriterConfig
@@ -179,7 +181,7 @@ def map_write(request: proto_base.ExecutePlanRequest):
179
181
  f"Skipping REMOVE for root path {write_path} - too broad scope"
180
182
  )
181
183
  else:
182
- remove_command = f"REMOVE {write_path}/"
184
+ remove_command = f"REMOVE '{write_path}/'"
183
185
  session.sql(remove_command).collect()
184
186
  logger.info(f"Successfully cleared directory: {write_path}")
185
187
  except Exception as e:
@@ -208,6 +210,20 @@ def map_write(request: proto_base.ExecutePlanRequest):
208
210
  compression = write_op.options.get(
209
211
  "compression", default_compression
210
212
  ).upper()
213
+
214
+ if not is_supported_compression(write_op.source, compression):
215
+ supported_compressions = supported_compressions_for_format(
216
+ write_op.source
217
+ )
218
+ raise AnalysisException(
219
+ f"Compression {compression} is not supported for {write_op.source} format. "
220
+ + (
221
+ f"Supported compressions: {sorted(supported_compressions)}"
222
+ if supported_compressions
223
+ else "No compression supported for this format."
224
+ )
225
+ )
226
+
211
227
  parameters = {
212
228
  "location": temp_file_prefix_on_stage,
213
229
  "file_format_type": write_op.source
@@ -218,8 +234,9 @@ def map_write(request: proto_base.ExecutePlanRequest):
218
234
  },
219
235
  "overwrite": overwrite,
220
236
  }
221
- # By default, download from the same prefix we wrote to.
222
- download_stage_path = temp_file_prefix_on_stage
237
+ # Download from the base write path to ensure we fetch whatever Snowflake produced.
238
+ # Using the base avoids coupling to exact filenames/prefixes.
239
+ download_stage_path = write_path
223
240
 
224
241
  # Check for partition hint early to determine precedence over single option
225
242
  partition_hint = result.partition_hint
@@ -238,13 +255,19 @@ def map_write(request: proto_base.ExecutePlanRequest):
238
255
  raise SnowparkConnectNotImplementedError(
239
256
  "Partitioning is only supported for parquet format"
240
257
  )
241
- partitioning_columns = [f'"{c}"' for c in write_op.partitioning_columns]
242
- if len(partitioning_columns) > 1:
243
- raise SnowparkConnectNotImplementedError(
244
- "Multiple partitioning columns are not yet supported"
245
- )
246
- else:
247
- parameters["partition_by"] = partitioning_columns[0]
258
+ # Build Spark-style directory structure: col1=value1/col2=value2/...
259
+ # Example produced expression (Snowflake SQL):
260
+ # 'department=' || TO_VARCHAR("department") || '/' || 'region=' || TO_VARCHAR("region")
261
+ partitioning_column_names = list(write_op.partitioning_columns)
262
+ partition_expr_parts: list[str] = []
263
+ for col_name in partitioning_column_names:
264
+ quoted = f'"{col_name}"'
265
+ segment = f"'{col_name}=' || COALESCE(TO_VARCHAR({quoted}), '__HIVE_DEFAULT_PARTITION__')"
266
+ partition_expr_parts.append(segment)
267
+ parameters["partition_by"] = " || '/' || ".join(partition_expr_parts)
268
+ # When using PARTITION BY, Snowflake writes into subdirectories under the base path.
269
+ # Download from the base write path to preserve partition directories locally.
270
+ download_stage_path = write_path
248
271
 
249
272
  # If a partition hint is present (from DataFrame.repartition(n)), optionally split the
250
273
  # write into n COPY INTO calls by assigning a synthetic partition id. Controlled by config.
@@ -410,9 +433,27 @@ def map_write(request: proto_base.ExecutePlanRequest):
410
433
  )
411
434
  case _:
412
435
  snowpark_table_name = _spark_to_snowflake(write_op.table.table_name)
436
+ save_method = write_op.table.save_method
437
+
438
+ if (
439
+ write_op.source == "snowflake"
440
+ and write_op.table.save_method
441
+ == commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_UNSPECIFIED
442
+ ):
443
+ save_method = (
444
+ commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE
445
+ )
446
+ if len(write_op.table.table_name) == 0:
447
+ dbtable_name = write_op.options.get("dbtable", "")
448
+ if len(dbtable_name) == 0:
449
+ raise SnowparkConnectNotImplementedError(
450
+ "Save command is not supported without a table name"
451
+ )
452
+ else:
453
+ snowpark_table_name = _spark_to_snowflake(dbtable_name)
413
454
 
414
455
  if (
415
- write_op.table.save_method
456
+ save_method
416
457
  == commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE
417
458
  ):
418
459
  match write_mode:
@@ -474,7 +515,7 @@ def map_write(request: proto_base.ExecutePlanRequest):
474
515
  column_order=_column_order_for_write,
475
516
  )
476
517
  elif (
477
- write_op.table.save_method
518
+ save_method
478
519
  == commands_proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO
479
520
  ):
480
521
  _validate_schema_and_get_writer(
@@ -486,7 +527,7 @@ def map_write(request: proto_base.ExecutePlanRequest):
486
527
  )
487
528
  else:
488
529
  raise SnowparkConnectNotImplementedError(
489
- f"Save command not supported: {write_op.table.save_method}"
530
+ f"Save command not supported: {save_method}"
490
531
  )
491
532
 
492
533
 
@@ -978,7 +1019,47 @@ def store_files_locally(
978
1019
  )
979
1020
  if overwrite and os.path.isdir(target_path):
980
1021
  _truncate_directory(real_path)
981
- snowpark.file_operation.FileOperation(session).get(stage_path, str(real_path))
1022
+ # Per Snowflake docs: "The command does not preserve stage directory structure when transferring files to your client machine"
1023
+ # https://docs.snowflake.com/en/sql-reference/sql/get
1024
+ # Preserve directory structure under stage_path by listing files and
1025
+ # downloading each into its corresponding local subdirectory when partition subdirs exist.
1026
+ # Otherwise, fall back to a direct GET which flattens.
1027
+
1028
+ # TODO(SNOW-2326973): This can be parallelized further. Its not done here because it only affects
1029
+ # write to local storage.
1030
+
1031
+ ls_dataframe = session.sql(f"LS {stage_path}")
1032
+ ls_iterator = ls_dataframe.toLocalIterator()
1033
+
1034
+ # Build a normalized base prefix from stage_path to compute relatives
1035
+ # Example: stage_path='@MY_STAGE/prefix' -> base_prefix='my_stage/prefix/'
1036
+ base_prefix = stage_path.lstrip("@").rstrip("/") + "/"
1037
+ base_prefix_lower = base_prefix.lower()
1038
+
1039
+ # Group by parent directory under the base prefix, then issue a GET per directory.
1040
+ # This gives a small parallelism advantage if we have many files per partition directory.
1041
+ parent_dirs: set[str] = set()
1042
+ for row in ls_iterator:
1043
+ name: str = row[0]
1044
+ name_lower = name.lower()
1045
+ rel_start = name_lower.find(base_prefix_lower)
1046
+ relative = name[rel_start + len(base_prefix) :] if rel_start != -1 else name
1047
+ parent_dir = os.path.dirname(relative)
1048
+ if parent_dir and parent_dir != ".":
1049
+ parent_dirs.add(parent_dir)
1050
+
1051
+ # If no parent directories were discovered (non-partitioned unload prefix), use direct GET.
1052
+ if not parent_dirs:
1053
+ snowpark.file_operation.FileOperation(session).get(stage_path, str(real_path))
1054
+ return
1055
+
1056
+ file_op = snowpark.file_operation.FileOperation(session)
1057
+ for parent_dir in sorted(parent_dirs):
1058
+ local_dir = real_path / parent_dir
1059
+ os.makedirs(local_dir, exist_ok=True)
1060
+
1061
+ src_dir = f"@{base_prefix}{parent_dir}"
1062
+ file_op.get(src_dir, str(local_dir))
982
1063
 
983
1064
 
984
1065
  def _truncate_directory(directory_path: Path) -> None:
@@ -1161,23 +1161,28 @@ def get_session(url: Optional[str] = None, conf: SparkConf = None) -> SparkSessi
1161
1161
 
1162
1162
 
1163
1163
  def init_spark_session(conf: SparkConf = None) -> SparkSession:
1164
- try:
1165
- # For Notebooks on SPCS
1166
- from jdk4py import JAVA_HOME
1167
-
1168
- os.environ["JAVA_HOME"] = str(JAVA_HOME)
1169
- except ModuleNotFoundError:
1170
- # For notebooks on Warehouse
1171
- os.environ["JAVA_HOME"] = os.environ["CONDA_PREFIX"]
1172
- os.environ["JAVA_LD_LIBRARY_PATH"] = os.path.join(
1173
- os.environ["CONDA_PREFIX"], "lib", "server"
1174
- )
1175
- logger.info("JAVA_HOME=%s", os.environ["JAVA_HOME"])
1164
+ if os.environ.get("JAVA_HOME") is None:
1165
+ try:
1166
+ # For Notebooks on SPCS
1167
+ from jdk4py import JAVA_HOME
1168
+
1169
+ os.environ["JAVA_HOME"] = str(JAVA_HOME)
1170
+ except ModuleNotFoundError:
1171
+ # For notebooks on Warehouse
1172
+ conda_prefix = os.environ.get("CONDA_PREFIX")
1173
+ if conda_prefix is not None:
1174
+ os.environ["JAVA_HOME"] = conda_prefix
1175
+ os.environ["JAVA_LD_LIBRARY_PATH"] = os.path.join(
1176
+ conda_prefix, "lib", "server"
1177
+ )
1178
+ logger.info("JAVA_HOME=%s", os.environ.get("JAVA_HOME", "Not defined"))
1176
1179
 
1177
1180
  os.environ["SPARK_LOCAL_HOSTNAME"] = "127.0.0.1"
1178
1181
  os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
1179
1182
 
1180
- snowpark_session = snowpark.context.get_active_session()
1183
+ from snowflake.snowpark_connect.utils.session import _get_current_snowpark_session
1184
+
1185
+ snowpark_session = _get_current_snowpark_session()
1181
1186
  start_session(snowpark_session=snowpark_session)
1182
1187
  return get_session(conf=conf)
1183
1188
 
@@ -30,9 +30,6 @@ _sql_aggregate_function_count = ContextVar[int](
30
30
  "_contains_aggregate_function", default=0
31
31
  )
32
32
 
33
- # Context for parsing map_partitions
34
- _map_partitions_stack = ContextVar[int]("_map_partitions_stack", default=0)
35
-
36
33
  # We have to generate our own plan IDs that are different from Spark's.
37
34
  # Spark plan IDs start at 0, so pick a "big enough" number to avoid overlaps.
38
35
  _STARTING_SQL_PLAN_ID = 0x80000000
@@ -70,6 +67,26 @@ _lca_alias_map: ContextVar[dict[str, TypedColumn]] = ContextVar(
70
67
  default={},
71
68
  )
72
69
 
70
+ _view_process_context = ContextVar("_view_process_context", default=[])
71
+
72
+
73
+ @contextmanager
74
+ def push_processed_view(name: str):
75
+ _view_process_context.set(_view_process_context.get() + [name])
76
+ yield
77
+ _view_process_context.set(_view_process_context.get()[:-1])
78
+
79
+
80
+ def get_processed_views() -> list[str]:
81
+ return _view_process_context.get()
82
+
83
+
84
+ def register_processed_view(name: str) -> None:
85
+ context = _view_process_context.get()
86
+ context.append(name)
87
+ _view_process_context.set(context)
88
+
89
+
73
90
  # Context variable to track current grouping columns for grouping_id() function
74
91
  _current_grouping_columns: ContextVar[list[str]] = ContextVar(
75
92
  "_current_grouping_columns",
@@ -210,16 +227,6 @@ def push_evaluating_join_condition(join_type, left_keys, right_keys):
210
227
  _is_evaluating_join_condition.set(prev)
211
228
 
212
229
 
213
- @contextmanager
214
- def push_map_partitions():
215
- _map_partitions_stack.set(_map_partitions_stack.get() + 1)
216
- yield
217
-
218
-
219
- def map_partitions_depth() -> int:
220
- return _map_partitions_stack.get()
221
-
222
-
223
230
  @contextmanager
224
231
  def push_sql_scope():
225
232
  """
@@ -387,9 +394,9 @@ def clear_context_data() -> None:
387
394
  _plan_id_map.set({})
388
395
  _alias_map.set({})
389
396
 
397
+ _view_process_context.set([])
390
398
  _next_sql_plan_id.set(_STARTING_SQL_PLAN_ID)
391
399
  _sql_plan_name_map.set({})
392
- _map_partitions_stack.set(0)
393
400
  _sql_aggregate_function_count.set(0)
394
401
  _sql_named_args.set({})
395
402
  _sql_pos_args.set({})
@@ -28,12 +28,18 @@ def unquote_spark_identifier_if_quoted(spark_name: str) -> str:
28
28
  raise AnalysisException(f"Invalid name: {spark_name}")
29
29
 
30
30
 
31
- def spark_to_sf_single_id_with_unquoting(name: str) -> str:
31
+ def spark_to_sf_single_id_with_unquoting(
32
+ name: str, use_auto_upper_case: bool = False
33
+ ) -> str:
32
34
  """
33
35
  Transforms a spark name to a valid snowflake name by quoting and potentially uppercasing it.
34
36
  Unquotes the spark name if necessary. Will raise an AnalysisException if given name is not valid.
35
37
  """
36
- return spark_to_sf_single_id(unquote_spark_identifier_if_quoted(name))
38
+ return (
39
+ spark_to_sf_single_id(unquote_spark_identifier_if_quoted(name))
40
+ if use_auto_upper_case
41
+ else quote_name_without_upper_casing(unquote_spark_identifier_if_quoted(name))
42
+ )
37
43
 
38
44
 
39
45
  def spark_to_sf_single_id(name: str, is_column: bool = False) -> str:
@@ -3,10 +3,46 @@
3
3
  #
4
4
  import contextlib
5
5
  import functools
6
+ import re
6
7
 
7
8
  from snowflake.snowpark import Session
9
+ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
10
+ create_file_format_statement,
11
+ )
8
12
  from snowflake.snowpark_connect.utils.identifiers import FQN
9
13
 
14
+ _MINUS_AT_THE_BEGINNING_REGEX = re.compile(r"^-")
15
+
16
+
17
+ def cached_file_format(
18
+ session: Session, file_format: str, format_type_options: dict[str, str]
19
+ ) -> str:
20
+ """
21
+ Cache and return a file format name based on the given options.
22
+ """
23
+
24
+ function_name = _MINUS_AT_THE_BEGINNING_REGEX.sub(
25
+ "1", str(hash(frozenset(format_type_options.items())))
26
+ )
27
+ file_format_name = f"__SNOWPARK_CONNECT_FILE_FORMAT__{file_format}_{function_name}"
28
+ if file_format_name in session._file_formats:
29
+ return file_format_name
30
+
31
+ session.sql(
32
+ create_file_format_statement(
33
+ file_format_name,
34
+ file_format,
35
+ format_type_options,
36
+ temp=True,
37
+ if_not_exist=True,
38
+ use_scoped_temp_objects=False,
39
+ is_generated=True,
40
+ )
41
+ ).collect()
42
+
43
+ session._file_formats.add(file_format_name)
44
+ return file_format_name
45
+
10
46
 
11
47
  @functools.cache
12
48
  def file_format(
@@ -71,6 +71,9 @@ def configure_snowpark_session(session: snowpark.Session):
71
71
  init_builtin_udf_cache(session)
72
72
  init_external_udxf_cache(session)
73
73
 
74
+ # file format cache
75
+ session._file_formats = set()
76
+
74
77
  # Set experimental parameters (warnings globally suppressed)
75
78
  session.ast_enabled = False
76
79
  session.eliminate_numeric_sql_value_cast_enabled = False
@@ -0,0 +1,61 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ from pyspark.errors import AnalysisException
8
+
9
+ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
10
+ from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
11
+ from snowflake.snowpark_connect.utils.context import get_session_id
12
+
13
+ _temp_views = SynchronizedDict[Tuple[str, str], DataFrameContainer]()
14
+
15
+
16
+ def register_temp_view(name: str, df: DataFrameContainer, replace: bool) -> None:
17
+ normalized_name = _normalize(name)
18
+ current_session_id = get_session_id()
19
+ for key in list(_temp_views.keys()):
20
+ if _normalize(key[0]) == normalized_name and key[1] == current_session_id:
21
+ if replace:
22
+ _temp_views.remove(key)
23
+ break
24
+ else:
25
+ raise AnalysisException(
26
+ f"[TEMP_TABLE_OR_VIEW_ALREADY_EXISTS] Cannot create the temporary view `{name}` because it already exists."
27
+ )
28
+
29
+ _temp_views[(name, current_session_id)] = df
30
+
31
+
32
+ def unregister_temp_view(name: str) -> bool:
33
+ normalized_name = _normalize(name)
34
+
35
+ for key in _temp_views.keys():
36
+ normalized_key = _normalize(key[0])
37
+ if normalized_name == normalized_key and key[1] == get_session_id():
38
+ pop_result = _temp_views.remove(key)
39
+ return pop_result is not None
40
+ return False
41
+
42
+
43
+ def get_temp_view(name: str) -> Optional[DataFrameContainer]:
44
+ normalized_name = _normalize(name)
45
+ for key in _temp_views.keys():
46
+ normalized_key = _normalize(key[0])
47
+ if normalized_name == normalized_key and key[1] == get_session_id():
48
+ return _temp_views.get(key)
49
+ return None
50
+
51
+
52
+ def get_temp_view_normalized_names() -> list[str]:
53
+ return [
54
+ _normalize(key[0]) for key in _temp_views.keys() if key[1] == get_session_id()
55
+ ]
56
+
57
+
58
+ def _normalize(name: str) -> str:
59
+ from snowflake.snowpark_connect.config import global_config
60
+
61
+ return name if global_config.spark_sql_caseSensitive else name.lower()