snowpark-connect 0.28.0__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (36) hide show
  1. snowflake/snowpark_connect/config.py +12 -3
  2. snowflake/snowpark_connect/execute_plan/map_execution_command.py +31 -68
  3. snowflake/snowpark_connect/expression/map_unresolved_function.py +172 -210
  4. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +207 -20
  5. snowflake/snowpark_connect/relation/io_utils.py +21 -1
  6. snowflake/snowpark_connect/relation/map_extension.py +21 -4
  7. snowflake/snowpark_connect/relation/map_map_partitions.py +7 -8
  8. snowflake/snowpark_connect/relation/map_relation.py +1 -3
  9. snowflake/snowpark_connect/relation/map_sql.py +112 -53
  10. snowflake/snowpark_connect/relation/read/map_read.py +22 -3
  11. snowflake/snowpark_connect/relation/read/map_read_csv.py +105 -26
  12. snowflake/snowpark_connect/relation/read/map_read_json.py +45 -34
  13. snowflake/snowpark_connect/relation/read/map_read_table.py +58 -0
  14. snowflake/snowpark_connect/relation/read/map_read_text.py +6 -1
  15. snowflake/snowpark_connect/relation/stage_locator.py +85 -53
  16. snowflake/snowpark_connect/relation/write/map_write.py +95 -14
  17. snowflake/snowpark_connect/server.py +18 -13
  18. snowflake/snowpark_connect/utils/context.py +21 -14
  19. snowflake/snowpark_connect/utils/identifiers.py +8 -2
  20. snowflake/snowpark_connect/utils/io_utils.py +36 -0
  21. snowflake/snowpark_connect/utils/session.py +3 -0
  22. snowflake/snowpark_connect/utils/temporary_view_cache.py +61 -0
  23. snowflake/snowpark_connect/utils/udf_cache.py +37 -7
  24. snowflake/snowpark_connect/utils/udf_utils.py +9 -8
  25. snowflake/snowpark_connect/utils/udtf_utils.py +3 -2
  26. snowflake/snowpark_connect/version.py +1 -1
  27. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/METADATA +3 -2
  28. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/RECORD +36 -35
  29. {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-connect +0 -0
  30. {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-session +0 -0
  31. {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-submit +0 -0
  32. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/WHEEL +0 -0
  33. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE-binary +0 -0
  34. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE.txt +0 -0
  35. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/NOTICE-binary +0 -0
  36. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/top_level.txt +0 -0
@@ -61,6 +61,7 @@ from snowflake.snowpark_connect.utils.context import (
61
61
  get_session_id,
62
62
  get_sql_plan,
63
63
  push_evaluating_sql_scope,
64
+ push_processed_view,
64
65
  push_sql_scope,
65
66
  set_plan_id_map,
66
67
  set_sql_args,
@@ -80,7 +81,16 @@ from ..expression.map_sql_expression import (
80
81
  map_logical_plan_expression,
81
82
  sql_parser,
82
83
  )
83
- from ..utils.identifiers import spark_to_sf_single_id
84
+ from ..utils.identifiers import (
85
+ spark_to_sf_single_id,
86
+ spark_to_sf_single_id_with_unquoting,
87
+ )
88
+ from ..utils.temporary_view_cache import (
89
+ get_temp_view,
90
+ register_temp_view,
91
+ unregister_temp_view,
92
+ )
93
+ from .catalogs import SNOWFLAKE_CATALOG
84
94
 
85
95
  _ctes = ContextVar[dict[str, relation_proto.Relation]]("_ctes", default={})
86
96
  _cte_definitions = ContextVar[dict[str, any]]("_cte_definitions", default={})
@@ -403,6 +413,7 @@ def map_sql_to_pandas_df(
403
413
  ) == "UnresolvedHint":
404
414
  logical_plan = logical_plan.child()
405
415
 
416
+ # TODO: Add support for temporary views for SQL cases such as ShowViews, ShowColumns ect. (Currently the cases are not compatible with Spark, returning raw Snowflake rows)
406
417
  match class_name:
407
418
  case "AddColumns":
408
419
  # Handle ALTER TABLE ... ADD COLUMNS (col_name data_type) -> ADD COLUMN col_name data_type
@@ -577,6 +588,23 @@ def map_sql_to_pandas_df(
577
588
  )
578
589
  snowflake_sql = parsed_sql.sql(dialect="snowflake")
579
590
  session.sql(f"{snowflake_sql}{empty_select}").collect()
591
+ spark_view_name = next(
592
+ sqlglot.parse_one(sql_string, dialect="spark").find_all(
593
+ sqlglot.exp.Table
594
+ )
595
+ ).name
596
+ snowflake_view_name = spark_to_sf_single_id_with_unquoting(
597
+ spark_view_name
598
+ )
599
+ temp_view = get_temp_view(snowflake_view_name)
600
+ if temp_view is not None and not logical_plan.replace():
601
+ raise AnalysisException(
602
+ f"[TEMP_TABLE_OR_VIEW_ALREADY_EXISTS] Cannot create the temporary view `{spark_view_name}` because it already exists."
603
+ )
604
+ else:
605
+ unregister_temp_view(
606
+ spark_to_sf_single_id_with_unquoting(spark_view_name)
607
+ )
580
608
  case "CreateView":
581
609
  current_schema = session.connection.schema
582
610
  if (
@@ -613,50 +641,60 @@ def map_sql_to_pandas_df(
613
641
  else None,
614
642
  )
615
643
  case "CreateViewCommand":
616
- df_container = execute_logical_plan(logical_plan.plan())
617
- df = df_container.dataframe
618
- tmp_views = _get_current_temp_objects()
619
- tmp_views.add(
620
- (
621
- CURRENT_CATALOG_NAME,
622
- session.connection.schema,
623
- str(logical_plan.name().identifier()),
644
+ with push_processed_view(logical_plan.name().identifier()):
645
+ df_container = execute_logical_plan(logical_plan.plan())
646
+ df = df_container.dataframe
647
+ user_specified_spark_column_names = [
648
+ str(col._1())
649
+ for col in as_java_list(logical_plan.userSpecifiedColumns())
650
+ ]
651
+ df_container = DataFrameContainer.create_with_column_mapping(
652
+ dataframe=df,
653
+ spark_column_names=user_specified_spark_column_names
654
+ if user_specified_spark_column_names
655
+ else df_container.column_map.get_spark_columns(),
656
+ snowpark_column_names=df_container.column_map.get_snowpark_columns(),
657
+ parent_column_name_map=df_container.column_map,
624
658
  )
625
- )
626
-
627
- name = str(logical_plan.name().identifier())
628
- name = spark_to_sf_single_id(name)
629
- if isinstance(
630
- logical_plan.viewType(),
631
- jpype.JClass(
632
- "org.apache.spark.sql.catalyst.analysis.GlobalTempView$"
633
- ),
634
- ):
635
- name = f"{global_config.spark_sql_globalTempDatabase}.{name}"
636
- comment = logical_plan.comment()
637
- maybe_comment = (
638
- _escape_sql_comment(str(comment.get()))
639
- if comment.isDefined()
640
- else None
641
- )
642
-
643
- df = _rename_columns(
644
- df, logical_plan.userSpecifiedColumns(), df_container.column_map
645
- )
646
-
647
- if logical_plan.replace():
648
- df.create_or_replace_temp_view(
649
- name,
650
- comment=maybe_comment,
659
+ is_global = isinstance(
660
+ logical_plan.viewType(),
661
+ jpype.JClass(
662
+ "org.apache.spark.sql.catalyst.analysis.GlobalTempView$"
663
+ ),
651
664
  )
652
- else:
653
- df.create_temp_view(
654
- name,
655
- comment=maybe_comment,
665
+ if is_global:
666
+ view_name = [
667
+ global_config.spark_sql_globalTempDatabase,
668
+ logical_plan.name().quotedString(),
669
+ ]
670
+ else:
671
+ view_name = [logical_plan.name().quotedString()]
672
+ view_name = [
673
+ spark_to_sf_single_id_with_unquoting(part) for part in view_name
674
+ ]
675
+ joined_view_name = ".".join(view_name)
676
+
677
+ register_temp_view(
678
+ joined_view_name,
679
+ df_container,
680
+ logical_plan.replace(),
681
+ )
682
+ tmp_views = _get_current_temp_objects()
683
+ tmp_views.add(
684
+ (
685
+ CURRENT_CATALOG_NAME,
686
+ session.connection.schema,
687
+ str(logical_plan.name().identifier()),
688
+ )
656
689
  )
657
690
  case "DescribeColumn":
658
- name = get_relation_identifier_name(logical_plan.column())
691
+ name = get_relation_identifier_name_without_uppercasing(
692
+ logical_plan.column()
693
+ )
694
+ if get_temp_view(name):
695
+ return SNOWFLAKE_CATALOG.listColumns(unquote_if_quoted(name)), ""
659
696
  # todo double check if this is correct
697
+ name = get_relation_identifier_name(logical_plan.column())
660
698
  rows = session.sql(f"DESCRIBE TABLE {name}").collect()
661
699
  case "DescribeNamespace":
662
700
  name = get_relation_identifier_name(logical_plan.namespace(), True)
@@ -731,9 +769,13 @@ def map_sql_to_pandas_df(
731
769
  if_exists = "IF EXISTS " if logical_plan.ifExists() else ""
732
770
  session.sql(f"DROP TABLE {if_exists}{name}").collect()
733
771
  case "DropView":
734
- name = get_relation_identifier_name(logical_plan.child())
735
- if_exists = "IF EXISTS " if logical_plan.ifExists() else ""
736
- session.sql(f"DROP VIEW {if_exists}{name}").collect()
772
+ temporary_view_name = get_relation_identifier_name_without_uppercasing(
773
+ logical_plan.child()
774
+ )
775
+ if not unregister_temp_view(temporary_view_name):
776
+ name = get_relation_identifier_name(logical_plan.child())
777
+ if_exists = "IF EXISTS " if logical_plan.ifExists() else ""
778
+ session.sql(f"DROP VIEW {if_exists}{name}").collect()
737
779
  case "ExplainCommand":
738
780
  inner_plan = logical_plan.logicalPlan()
739
781
  logical_plan_name = inner_plan.nodeName()
@@ -2173,21 +2215,38 @@ def map_logical_plan_relation(
2173
2215
  return proto
2174
2216
 
2175
2217
 
2176
- def get_relation_identifier_name(name_obj, is_multi_part: bool = False) -> str:
2218
+ def _get_relation_identifier(name_obj) -> str:
2219
+ # IDENTIFIER(<table_name>), or IDENTIFIER(<method name>)
2220
+ expr_proto = map_logical_plan_expression(name_obj.identifierExpr())
2221
+ session = snowpark.Session.get_active_session()
2222
+ m = ColumnNameMap([], [], None)
2223
+ expr = map_single_column_expression(
2224
+ expr_proto, m, ExpressionTyper.dummy_typer(session)
2225
+ )
2226
+ return spark_to_sf_single_id(session.range(1).select(expr[1].col).collect()[0][0])
2227
+
2228
+
2229
+ def get_relation_identifier_name_without_uppercasing(name_obj) -> str:
2177
2230
  if name_obj.getClass().getSimpleName() in (
2178
2231
  "PlanWithUnresolvedIdentifier",
2179
2232
  "ExpressionWithUnresolvedIdentifier",
2180
2233
  ):
2181
- # IDENTIFIER(<table_name>), or IDENTIFIER(<method name>)
2182
- expr_proto = map_logical_plan_expression(name_obj.identifierExpr())
2183
- session = snowpark.Session.get_active_session()
2184
- m = ColumnNameMap([], [], None)
2185
- expr = map_single_column_expression(
2186
- expr_proto, m, ExpressionTyper.dummy_typer(session)
2187
- )
2188
- name = spark_to_sf_single_id(
2189
- session.range(1).select(expr[1].col).collect()[0][0]
2234
+ return _get_relation_identifier(name_obj)
2235
+ else:
2236
+ name = ".".join(
2237
+ quote_name_without_upper_casing(str(part))
2238
+ for part in as_java_list(name_obj.nameParts())
2190
2239
  )
2240
+
2241
+ return name
2242
+
2243
+
2244
+ def get_relation_identifier_name(name_obj, is_multi_part: bool = False) -> str:
2245
+ if name_obj.getClass().getSimpleName() in (
2246
+ "PlanWithUnresolvedIdentifier",
2247
+ "ExpressionWithUnresolvedIdentifier",
2248
+ ):
2249
+ return _get_relation_identifier(name_obj)
2191
2250
  else:
2192
2251
  if is_multi_part:
2193
2252
  try:
@@ -46,6 +46,9 @@ def map_read(
46
46
 
47
47
  Currently, the supported read formats are `csv`, `json` and `parquet`.
48
48
  """
49
+
50
+ materialize_df = True
51
+
49
52
  match rel.read.WhichOneof("read_type"):
50
53
  case "named_table":
51
54
  return map_read_table_or_file(rel)
@@ -99,6 +102,10 @@ def map_read(
99
102
  for path in rel.read.data_source.paths
100
103
  ]
101
104
 
105
+ # JSON already materializes the table internally
106
+ if read_format == "json":
107
+ materialize_df = False
108
+
102
109
  result = _read_file(
103
110
  clean_source_paths, options, read_format, rel, schema, session
104
111
  )
@@ -159,7 +166,9 @@ def map_read(
159
166
  raise SnowparkConnectNotImplementedError(f"Unsupported read type: {other}")
160
167
 
161
168
  return df_cache_map_put_if_absent(
162
- (get_session_id(), rel.common.plan_id), lambda: result, materialize=True
169
+ (get_session_id(), rel.common.plan_id),
170
+ lambda: result,
171
+ materialize=materialize_df,
163
172
  )
164
173
 
165
174
 
@@ -205,6 +214,15 @@ def _get_supported_read_file_format(unparsed_identifier: str) -> str | None:
205
214
  return None
206
215
 
207
216
 
217
+ def _quote_stage_path(stage_path: str) -> str:
218
+ """
219
+ Quote stage paths to escape any special characters.
220
+ """
221
+ if stage_path.startswith("@"):
222
+ return f"'{stage_path}'"
223
+ return stage_path
224
+
225
+
208
226
  def _read_file(
209
227
  clean_source_paths: list[str],
210
228
  options: dict,
@@ -218,6 +236,7 @@ def _read_file(
218
236
  session,
219
237
  )
220
238
  upload_files_if_needed(paths, clean_source_paths, session, read_format)
239
+ paths = [_quote_stage_path(path) for path in paths]
221
240
  match read_format:
222
241
  case "csv":
223
242
  from snowflake.snowpark_connect.relation.read.map_read_csv import (
@@ -285,8 +304,8 @@ def upload_files_if_needed(
285
304
 
286
305
  def _upload_dir(target: str, source: str) -> None:
287
306
  # overwrite=True will not remove all stale files in the target prefix
288
-
289
- remove_command = f"REMOVE {target}/"
307
+ # Quote the target path to allow special characters.
308
+ remove_command = f"REMOVE '{target}/'"
290
309
  assert (
291
310
  "//" not in remove_command
292
311
  ), f"Remove command {remove_command} contains double slash"
@@ -3,6 +3,7 @@
3
3
  #
4
4
 
5
5
  import copy
6
+ from typing import Any
6
7
 
7
8
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
8
9
 
@@ -16,6 +17,7 @@ from snowflake.snowpark_connect.relation.read.utils import (
16
17
  get_spark_column_names_from_snowpark_columns,
17
18
  rename_columns_as_snowflake_standard,
18
19
  )
20
+ from snowflake.snowpark_connect.utils.io_utils import cached_file_format
19
21
  from snowflake.snowpark_connect.utils.telemetry import (
20
22
  SnowparkConnectNotImplementedError,
21
23
  )
@@ -42,21 +44,34 @@ def map_read_csv(
42
44
  )
43
45
  else:
44
46
  snowpark_options = options.convert_to_snowpark_args()
47
+ parse_header = snowpark_options.get("PARSE_HEADER", False)
48
+ file_format_options = _parse_csv_snowpark_options(snowpark_options)
49
+ file_format = cached_file_format(session, "csv", file_format_options)
50
+
51
+ snowpark_read_options = dict()
52
+ snowpark_read_options["FORMAT_NAME"] = file_format
53
+ snowpark_read_options["ENFORCE_EXISTING_FILE_FORMAT"] = True
54
+ snowpark_read_options["INFER_SCHEMA"] = snowpark_options.get(
55
+ "INFER_SCHEMA", False
56
+ )
57
+ snowpark_read_options["PATTERN"] = snowpark_options.get("PATTERN", None)
58
+
45
59
  raw_options = rel.read.data_source.options
46
60
  if schema is None or (
47
- snowpark_options.get("PARSE_HEADER", False)
48
- and raw_options.get("enforceSchema", "True").lower() == "false"
61
+ parse_header and raw_options.get("enforceSchema", "True").lower() == "false"
49
62
  ): # Schema has to equals to header's format
50
- reader = session.read.options(snowpark_options)
63
+ reader = session.read.options(snowpark_read_options)
51
64
  else:
52
- reader = session.read.options(snowpark_options).schema(schema)
65
+ reader = session.read.options(snowpark_read_options).schema(schema)
53
66
  df = read_data(
54
67
  reader,
55
68
  schema,
56
69
  session,
57
70
  paths[0],
58
- snowpark_options,
71
+ file_format_options,
72
+ snowpark_read_options,
59
73
  raw_options,
74
+ parse_header,
60
75
  )
61
76
  if len(paths) > 1:
62
77
  # TODO: figure out if this is what Spark does.
@@ -81,15 +96,65 @@ def map_read_csv(
81
96
  )
82
97
 
83
98
 
99
+ _csv_file_format_allowed_options = {
100
+ "COMPRESSION",
101
+ "RECORD_DELIMITER",
102
+ "FIELD_DELIMITER",
103
+ "MULTI_LINE",
104
+ "FILE_EXTENSION",
105
+ "PARSE_HEADER",
106
+ "SKIP_HEADER",
107
+ "SKIP_BLANK_LINES",
108
+ "DATE_FORMAT",
109
+ "TIME_FORMAT",
110
+ "TIMESTAMP_FORMAT",
111
+ "BINARY_FORMAT",
112
+ "ESCAPE",
113
+ "ESCAPE_UNENCLOSED_FIELD",
114
+ "TRIM_SPACE",
115
+ "FIELD_OPTIONALLY_ENCLOSED_BY",
116
+ "NULL_IF",
117
+ "ERROR_ON_COLUMN_COUNT_MISMATCH",
118
+ "REPLACE_INVALID_CHARACTERS",
119
+ "EMPTY_FIELD_AS_NULL",
120
+ "SKIP_BYTE_ORDER_MARK",
121
+ "ENCODING",
122
+ }
123
+
124
+
125
+ def _parse_csv_snowpark_options(snowpark_options: dict[str, Any]) -> dict[str, Any]:
126
+ file_format_options = dict()
127
+ for key, value in snowpark_options.items():
128
+ upper_key = key.upper()
129
+ if upper_key in _csv_file_format_allowed_options:
130
+ file_format_options[upper_key] = value
131
+
132
+ # This option has to be removed, because we cannot use at the same time predefined file format and parse_header option
133
+ # Such combination causes snowpark to raise SQL compilation error: Invalid file format "PARSE_HEADER" is only allowed for CSV INFER_SCHEMA and MATCH_BY_COLUMN_NAME
134
+ parse_header = file_format_options.get("PARSE_HEADER", False)
135
+ if parse_header:
136
+ file_format_options["SKIP_HEADER"] = 1
137
+ del file_format_options["PARSE_HEADER"]
138
+
139
+ return file_format_options
140
+
141
+
84
142
  def get_header_names(
85
143
  session: snowpark.Session,
86
144
  path: list[str],
87
- snowpark_options: dict,
145
+ file_format_options: dict,
146
+ snowpark_read_options: dict,
88
147
  ) -> list[str]:
89
- snowpark_options_no_header = copy.copy(snowpark_options)
90
- snowpark_options_no_header["PARSE_HEADER"] = False
148
+ no_header_file_format_options = copy.copy(file_format_options)
149
+ no_header_file_format_options["PARSE_HEADER"] = False
150
+ no_header_file_format_options.pop("SKIP_HEADER", None)
151
+
152
+ file_format = cached_file_format(session, "csv", no_header_file_format_options)
153
+ no_header_snowpark_read_options = copy.copy(snowpark_read_options)
154
+ no_header_snowpark_read_options["FORMAT_NAME"] = file_format
155
+ no_header_snowpark_read_options.pop("INFER_SCHEMA", None)
91
156
 
92
- header_df = session.read.options(snowpark_options_no_header).csv(path).limit(1)
157
+ header_df = session.read.options(no_header_snowpark_read_options).csv(path).limit(1)
93
158
  header_data = header_df.collect()[0]
94
159
  return [
95
160
  f'"{header_data[i]}"'
@@ -103,8 +168,10 @@ def read_data(
103
168
  schema: snowpark.types.StructType | None,
104
169
  session: snowpark.Session,
105
170
  path: list[str],
106
- snowpark_options: dict,
171
+ file_format_options: dict,
172
+ snowpark_read_options: dict,
107
173
  raw_options: dict,
174
+ parse_header: bool,
108
175
  ) -> snowpark.DataFrame:
109
176
  df = reader.csv(path)
110
177
  filename = path.strip("/").split("/")[-1]
@@ -120,23 +187,35 @@ def read_data(
120
187
  raise Exception("CSV header does not conform to the schema")
121
188
  return df
122
189
 
123
- headers = get_header_names(session, path, snowpark_options)
124
-
190
+ headers = get_header_names(
191
+ session, path, file_format_options, snowpark_read_options
192
+ )
193
+
194
+ df_schema_fields = df.schema.fields
195
+ if len(headers) == len(df_schema_fields) and parse_header:
196
+ return df.select(
197
+ [
198
+ snowpark_fn.col(df_schema_fields[i].name).alias(headers[i])
199
+ for i in range(len(headers))
200
+ ]
201
+ )
125
202
  # Handle mismatch in column count between header and data
126
- if (
127
- len(df.schema.fields) == 1
128
- and df.schema.fields[0].name.upper() == "C1"
129
- and snowpark_options.get("PARSE_HEADER") is True
130
- and len(headers) != len(df.schema.fields)
203
+ elif (
204
+ len(df_schema_fields) == 1
205
+ and df_schema_fields[0].name.upper() == "C1"
206
+ and parse_header
207
+ and len(headers) != len(df_schema_fields)
131
208
  ):
132
- df = (
133
- session.read.options(snowpark_options)
134
- .schema(StructType([StructField(h, StringType(), True) for h in headers]))
135
- .csv(path)
209
+ df = reader.schema(
210
+ StructType([StructField(h, StringType(), True) for h in headers])
211
+ ).csv(path)
212
+ elif not parse_header and len(headers) != len(df_schema_fields):
213
+ return df.select([df_schema_fields[i].name for i in range(len(headers))])
214
+ elif parse_header and len(headers) != len(df_schema_fields):
215
+ return df.select(
216
+ [
217
+ snowpark_fn.col(df_schema_fields[i].name).alias(headers[i])
218
+ for i in range(len(headers))
219
+ ]
136
220
  )
137
- elif snowpark_options.get("PARSE_HEADER") is False and len(headers) != len(
138
- df.schema.fields
139
- ):
140
- return df.select([df.schema.fields[i].name for i in range(len(headers))])
141
-
142
221
  return df
@@ -2,9 +2,12 @@
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
4
 
5
+ import concurrent.futures
5
6
  import copy
6
7
  import json
8
+ import os
7
9
  import typing
10
+ import uuid
8
11
  from contextlib import suppress
9
12
  from datetime import datetime
10
13
 
@@ -253,20 +256,20 @@ def merge_row_schema(
253
256
  return schema
254
257
 
255
258
 
256
- def union_data_into_df(
257
- result_df: snowpark.DataFrame,
258
- data: typing.List[Row],
259
- schema: StructType,
259
+ def insert_data_chunk(
260
260
  session: snowpark.Session,
261
- ) -> snowpark.DataFrame:
262
- current_df = session.create_dataframe(
261
+ data: list[Row],
262
+ schema: StructType,
263
+ table_name: str,
264
+ ) -> None:
265
+ df = session.create_dataframe(
263
266
  data=data,
264
267
  schema=schema,
265
268
  )
266
- if result_df is None:
267
- return current_df
268
269
 
269
- return result_df.union(current_df)
270
+ df.write.mode("append").save_as_table(
271
+ table_name, table_type="temp", table_exists=True
272
+ )
270
273
 
271
274
 
272
275
  def construct_dataframe_by_schema(
@@ -276,39 +279,47 @@ def construct_dataframe_by_schema(
276
279
  snowpark_options: dict,
277
280
  batch_size: int = 1000,
278
281
  ) -> snowpark.DataFrame:
279
- result = None
282
+ table_name = "__sas_json_read_temp_" + uuid.uuid4().hex
283
+
284
+ # We can have more workers than CPU count, this is an IO-intensive task
285
+ max_workers = min(16, os.cpu_count() * 2)
280
286
 
281
287
  current_data = []
282
288
  progress = 0
283
- for row in rows:
284
- current_data.append(construct_row_by_schema(row, schema, snowpark_options))
285
- if len(current_data) >= batch_size:
289
+
290
+ # Initialize the temp table
291
+ session.create_dataframe([], schema=schema).write.mode("append").save_as_table(
292
+ table_name, table_type="temp", table_exists=False
293
+ )
294
+
295
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as exc:
296
+ for row in rows:
297
+ current_data.append(construct_row_by_schema(row, schema, snowpark_options))
298
+ if len(current_data) >= batch_size:
299
+ progress += len(current_data)
300
+ exc.submit(
301
+ insert_data_chunk,
302
+ session,
303
+ copy.deepcopy(current_data),
304
+ schema,
305
+ table_name,
306
+ )
307
+
308
+ logger.info(f"JSON reader: finished processing {progress} rows")
309
+ current_data.clear()
310
+
311
+ if len(current_data) > 0:
286
312
  progress += len(current_data)
287
- result = union_data_into_df(
288
- result,
289
- current_data,
290
- schema,
313
+ exc.submit(
314
+ insert_data_chunk,
291
315
  session,
316
+ copy.deepcopy(current_data),
317
+ schema,
318
+ table_name,
292
319
  )
293
-
294
320
  logger.info(f"JSON reader: finished processing {progress} rows")
295
- current_data = []
296
-
297
- if len(current_data) > 0:
298
- progress += len(current_data)
299
- result = union_data_into_df(
300
- result,
301
- current_data,
302
- schema,
303
- session,
304
- )
305
-
306
- logger.info(f"JSON reader: finished processing {progress} rows")
307
- current_data = []
308
321
 
309
- if result is None:
310
- raise ValueError("Dataframe cannot be empty")
311
- return result
322
+ return session.table(table_name)
312
323
 
313
324
 
314
325
  def construct_row_by_schema(
@@ -11,11 +11,17 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
11
11
  unquote_if_quoted,
12
12
  )
13
13
  from snowflake.snowpark.exceptions import SnowparkSQLException
14
+ from snowflake.snowpark.types import StructField, StructType
15
+ from snowflake.snowpark_connect.column_name_handler import (
16
+ ColumnNameMap,
17
+ make_column_names_snowpark_compatible,
18
+ )
14
19
  from snowflake.snowpark_connect.config import auto_uppercase_non_column_identifiers
15
20
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
16
21
  from snowflake.snowpark_connect.relation.read.utils import (
17
22
  rename_columns_as_snowflake_standard,
18
23
  )
24
+ from snowflake.snowpark_connect.utils.context import get_processed_views
19
25
  from snowflake.snowpark_connect.utils.identifiers import (
20
26
  split_fully_qualified_spark_name,
21
27
  )
@@ -23,6 +29,7 @@ from snowflake.snowpark_connect.utils.session import _get_current_snowpark_sessi
23
29
  from snowflake.snowpark_connect.utils.telemetry import (
24
30
  SnowparkConnectNotImplementedError,
25
31
  )
32
+ from snowflake.snowpark_connect.utils.temporary_view_cache import get_temp_view
26
33
 
27
34
 
28
35
  def post_process_df(
@@ -64,15 +71,66 @@ def post_process_df(
64
71
  raise
65
72
 
66
73
 
74
+ def _get_temporary_view(
75
+ temp_view: DataFrameContainer, table_name: str, plan_id: int
76
+ ) -> DataFrameContainer:
77
+ fields_names = [field.name for field in temp_view.dataframe.schema.fields]
78
+ fields_types = [field.datatype for field in temp_view.dataframe.schema.fields]
79
+
80
+ snowpark_column_names = make_column_names_snowpark_compatible(fields_names, plan_id)
81
+ # Rename columns in dataframe to prevent conflicting names during joins
82
+ renamed_df = temp_view.dataframe.select(
83
+ *(
84
+ temp_view.dataframe.col(orig).alias(alias)
85
+ for orig, alias in zip(fields_names, snowpark_column_names)
86
+ )
87
+ )
88
+
89
+ new_column_map = ColumnNameMap(
90
+ spark_column_names=temp_view.column_map.get_spark_columns(),
91
+ snowpark_column_names=snowpark_column_names,
92
+ column_metadata=temp_view.column_map.column_metadata,
93
+ column_qualifiers=[split_fully_qualified_spark_name(table_name)]
94
+ * len(temp_view.column_map.get_spark_columns()),
95
+ parent_column_name_map=temp_view.column_map.get_parent_column_name_map(),
96
+ )
97
+
98
+ schema = StructType(
99
+ [
100
+ StructField(name, type, _is_column=False)
101
+ for name, type in zip(snowpark_column_names, fields_types)
102
+ ]
103
+ )
104
+ return DataFrameContainer(
105
+ dataframe=renamed_df,
106
+ column_map=new_column_map,
107
+ table_name=temp_view.table_name,
108
+ alias=temp_view.alias,
109
+ partition_hint=temp_view.partition_hint,
110
+ cached_schema_getter=lambda: schema,
111
+ )
112
+
113
+
67
114
  def get_table_from_name(
68
115
  table_name: str, session: snowpark.Session, plan_id: int
69
116
  ) -> DataFrameContainer:
70
117
  """Get table from name returning a container."""
118
+
119
+ # Verify if recursive view read is not attempted
120
+ if table_name in get_processed_views():
121
+ raise AnalysisException(
122
+ f"[RECURSIVE_VIEW] Recursive view `{table_name}` detected (cycle: `{table_name}` -> `{table_name}`)"
123
+ )
124
+
71
125
  snowpark_name = ".".join(
72
126
  quote_name_without_upper_casing(part)
73
127
  for part in split_fully_qualified_spark_name(table_name)
74
128
  )
75
129
 
130
+ temp_view = get_temp_view(snowpark_name)
131
+ if temp_view:
132
+ return _get_temporary_view(temp_view, table_name, plan_id)
133
+
76
134
  if auto_uppercase_non_column_identifiers():
77
135
  snowpark_name = snowpark_name.upper()
78
136