snowpark-connect 0.28.0__py3-none-any.whl → 0.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +12 -3
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +31 -68
- snowflake/snowpark_connect/expression/map_unresolved_function.py +172 -210
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +207 -20
- snowflake/snowpark_connect/relation/io_utils.py +21 -1
- snowflake/snowpark_connect/relation/map_extension.py +21 -4
- snowflake/snowpark_connect/relation/map_map_partitions.py +7 -8
- snowflake/snowpark_connect/relation/map_relation.py +1 -3
- snowflake/snowpark_connect/relation/map_sql.py +112 -53
- snowflake/snowpark_connect/relation/read/map_read.py +22 -3
- snowflake/snowpark_connect/relation/read/map_read_csv.py +105 -26
- snowflake/snowpark_connect/relation/read/map_read_json.py +45 -34
- snowflake/snowpark_connect/relation/read/map_read_table.py +58 -0
- snowflake/snowpark_connect/relation/read/map_read_text.py +6 -1
- snowflake/snowpark_connect/relation/stage_locator.py +85 -53
- snowflake/snowpark_connect/relation/write/map_write.py +95 -14
- snowflake/snowpark_connect/server.py +18 -13
- snowflake/snowpark_connect/utils/context.py +21 -14
- snowflake/snowpark_connect/utils/identifiers.py +8 -2
- snowflake/snowpark_connect/utils/io_utils.py +36 -0
- snowflake/snowpark_connect/utils/session.py +3 -0
- snowflake/snowpark_connect/utils/temporary_view_cache.py +61 -0
- snowflake/snowpark_connect/utils/udf_cache.py +37 -7
- snowflake/snowpark_connect/utils/udf_utils.py +9 -8
- snowflake/snowpark_connect/utils/udtf_utils.py +3 -2
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/METADATA +3 -2
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/RECORD +36 -35
- {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/top_level.txt +0 -0
|
@@ -61,6 +61,7 @@ from snowflake.snowpark_connect.utils.context import (
|
|
|
61
61
|
get_session_id,
|
|
62
62
|
get_sql_plan,
|
|
63
63
|
push_evaluating_sql_scope,
|
|
64
|
+
push_processed_view,
|
|
64
65
|
push_sql_scope,
|
|
65
66
|
set_plan_id_map,
|
|
66
67
|
set_sql_args,
|
|
@@ -80,7 +81,16 @@ from ..expression.map_sql_expression import (
|
|
|
80
81
|
map_logical_plan_expression,
|
|
81
82
|
sql_parser,
|
|
82
83
|
)
|
|
83
|
-
from ..utils.identifiers import
|
|
84
|
+
from ..utils.identifiers import (
|
|
85
|
+
spark_to_sf_single_id,
|
|
86
|
+
spark_to_sf_single_id_with_unquoting,
|
|
87
|
+
)
|
|
88
|
+
from ..utils.temporary_view_cache import (
|
|
89
|
+
get_temp_view,
|
|
90
|
+
register_temp_view,
|
|
91
|
+
unregister_temp_view,
|
|
92
|
+
)
|
|
93
|
+
from .catalogs import SNOWFLAKE_CATALOG
|
|
84
94
|
|
|
85
95
|
_ctes = ContextVar[dict[str, relation_proto.Relation]]("_ctes", default={})
|
|
86
96
|
_cte_definitions = ContextVar[dict[str, any]]("_cte_definitions", default={})
|
|
@@ -403,6 +413,7 @@ def map_sql_to_pandas_df(
|
|
|
403
413
|
) == "UnresolvedHint":
|
|
404
414
|
logical_plan = logical_plan.child()
|
|
405
415
|
|
|
416
|
+
# TODO: Add support for temporary views for SQL cases such as ShowViews, ShowColumns ect. (Currently the cases are not compatible with Spark, returning raw Snowflake rows)
|
|
406
417
|
match class_name:
|
|
407
418
|
case "AddColumns":
|
|
408
419
|
# Handle ALTER TABLE ... ADD COLUMNS (col_name data_type) -> ADD COLUMN col_name data_type
|
|
@@ -577,6 +588,23 @@ def map_sql_to_pandas_df(
|
|
|
577
588
|
)
|
|
578
589
|
snowflake_sql = parsed_sql.sql(dialect="snowflake")
|
|
579
590
|
session.sql(f"{snowflake_sql}{empty_select}").collect()
|
|
591
|
+
spark_view_name = next(
|
|
592
|
+
sqlglot.parse_one(sql_string, dialect="spark").find_all(
|
|
593
|
+
sqlglot.exp.Table
|
|
594
|
+
)
|
|
595
|
+
).name
|
|
596
|
+
snowflake_view_name = spark_to_sf_single_id_with_unquoting(
|
|
597
|
+
spark_view_name
|
|
598
|
+
)
|
|
599
|
+
temp_view = get_temp_view(snowflake_view_name)
|
|
600
|
+
if temp_view is not None and not logical_plan.replace():
|
|
601
|
+
raise AnalysisException(
|
|
602
|
+
f"[TEMP_TABLE_OR_VIEW_ALREADY_EXISTS] Cannot create the temporary view `{spark_view_name}` because it already exists."
|
|
603
|
+
)
|
|
604
|
+
else:
|
|
605
|
+
unregister_temp_view(
|
|
606
|
+
spark_to_sf_single_id_with_unquoting(spark_view_name)
|
|
607
|
+
)
|
|
580
608
|
case "CreateView":
|
|
581
609
|
current_schema = session.connection.schema
|
|
582
610
|
if (
|
|
@@ -613,50 +641,60 @@ def map_sql_to_pandas_df(
|
|
|
613
641
|
else None,
|
|
614
642
|
)
|
|
615
643
|
case "CreateViewCommand":
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
644
|
+
with push_processed_view(logical_plan.name().identifier()):
|
|
645
|
+
df_container = execute_logical_plan(logical_plan.plan())
|
|
646
|
+
df = df_container.dataframe
|
|
647
|
+
user_specified_spark_column_names = [
|
|
648
|
+
str(col._1())
|
|
649
|
+
for col in as_java_list(logical_plan.userSpecifiedColumns())
|
|
650
|
+
]
|
|
651
|
+
df_container = DataFrameContainer.create_with_column_mapping(
|
|
652
|
+
dataframe=df,
|
|
653
|
+
spark_column_names=user_specified_spark_column_names
|
|
654
|
+
if user_specified_spark_column_names
|
|
655
|
+
else df_container.column_map.get_spark_columns(),
|
|
656
|
+
snowpark_column_names=df_container.column_map.get_snowpark_columns(),
|
|
657
|
+
parent_column_name_map=df_container.column_map,
|
|
624
658
|
)
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
logical_plan.viewType(),
|
|
631
|
-
jpype.JClass(
|
|
632
|
-
"org.apache.spark.sql.catalyst.analysis.GlobalTempView$"
|
|
633
|
-
),
|
|
634
|
-
):
|
|
635
|
-
name = f"{global_config.spark_sql_globalTempDatabase}.{name}"
|
|
636
|
-
comment = logical_plan.comment()
|
|
637
|
-
maybe_comment = (
|
|
638
|
-
_escape_sql_comment(str(comment.get()))
|
|
639
|
-
if comment.isDefined()
|
|
640
|
-
else None
|
|
641
|
-
)
|
|
642
|
-
|
|
643
|
-
df = _rename_columns(
|
|
644
|
-
df, logical_plan.userSpecifiedColumns(), df_container.column_map
|
|
645
|
-
)
|
|
646
|
-
|
|
647
|
-
if logical_plan.replace():
|
|
648
|
-
df.create_or_replace_temp_view(
|
|
649
|
-
name,
|
|
650
|
-
comment=maybe_comment,
|
|
659
|
+
is_global = isinstance(
|
|
660
|
+
logical_plan.viewType(),
|
|
661
|
+
jpype.JClass(
|
|
662
|
+
"org.apache.spark.sql.catalyst.analysis.GlobalTempView$"
|
|
663
|
+
),
|
|
651
664
|
)
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
665
|
+
if is_global:
|
|
666
|
+
view_name = [
|
|
667
|
+
global_config.spark_sql_globalTempDatabase,
|
|
668
|
+
logical_plan.name().quotedString(),
|
|
669
|
+
]
|
|
670
|
+
else:
|
|
671
|
+
view_name = [logical_plan.name().quotedString()]
|
|
672
|
+
view_name = [
|
|
673
|
+
spark_to_sf_single_id_with_unquoting(part) for part in view_name
|
|
674
|
+
]
|
|
675
|
+
joined_view_name = ".".join(view_name)
|
|
676
|
+
|
|
677
|
+
register_temp_view(
|
|
678
|
+
joined_view_name,
|
|
679
|
+
df_container,
|
|
680
|
+
logical_plan.replace(),
|
|
681
|
+
)
|
|
682
|
+
tmp_views = _get_current_temp_objects()
|
|
683
|
+
tmp_views.add(
|
|
684
|
+
(
|
|
685
|
+
CURRENT_CATALOG_NAME,
|
|
686
|
+
session.connection.schema,
|
|
687
|
+
str(logical_plan.name().identifier()),
|
|
688
|
+
)
|
|
656
689
|
)
|
|
657
690
|
case "DescribeColumn":
|
|
658
|
-
name =
|
|
691
|
+
name = get_relation_identifier_name_without_uppercasing(
|
|
692
|
+
logical_plan.column()
|
|
693
|
+
)
|
|
694
|
+
if get_temp_view(name):
|
|
695
|
+
return SNOWFLAKE_CATALOG.listColumns(unquote_if_quoted(name)), ""
|
|
659
696
|
# todo double check if this is correct
|
|
697
|
+
name = get_relation_identifier_name(logical_plan.column())
|
|
660
698
|
rows = session.sql(f"DESCRIBE TABLE {name}").collect()
|
|
661
699
|
case "DescribeNamespace":
|
|
662
700
|
name = get_relation_identifier_name(logical_plan.namespace(), True)
|
|
@@ -731,9 +769,13 @@ def map_sql_to_pandas_df(
|
|
|
731
769
|
if_exists = "IF EXISTS " if logical_plan.ifExists() else ""
|
|
732
770
|
session.sql(f"DROP TABLE {if_exists}{name}").collect()
|
|
733
771
|
case "DropView":
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
772
|
+
temporary_view_name = get_relation_identifier_name_without_uppercasing(
|
|
773
|
+
logical_plan.child()
|
|
774
|
+
)
|
|
775
|
+
if not unregister_temp_view(temporary_view_name):
|
|
776
|
+
name = get_relation_identifier_name(logical_plan.child())
|
|
777
|
+
if_exists = "IF EXISTS " if logical_plan.ifExists() else ""
|
|
778
|
+
session.sql(f"DROP VIEW {if_exists}{name}").collect()
|
|
737
779
|
case "ExplainCommand":
|
|
738
780
|
inner_plan = logical_plan.logicalPlan()
|
|
739
781
|
logical_plan_name = inner_plan.nodeName()
|
|
@@ -2173,21 +2215,38 @@ def map_logical_plan_relation(
|
|
|
2173
2215
|
return proto
|
|
2174
2216
|
|
|
2175
2217
|
|
|
2176
|
-
def
|
|
2218
|
+
def _get_relation_identifier(name_obj) -> str:
|
|
2219
|
+
# IDENTIFIER(<table_name>), or IDENTIFIER(<method name>)
|
|
2220
|
+
expr_proto = map_logical_plan_expression(name_obj.identifierExpr())
|
|
2221
|
+
session = snowpark.Session.get_active_session()
|
|
2222
|
+
m = ColumnNameMap([], [], None)
|
|
2223
|
+
expr = map_single_column_expression(
|
|
2224
|
+
expr_proto, m, ExpressionTyper.dummy_typer(session)
|
|
2225
|
+
)
|
|
2226
|
+
return spark_to_sf_single_id(session.range(1).select(expr[1].col).collect()[0][0])
|
|
2227
|
+
|
|
2228
|
+
|
|
2229
|
+
def get_relation_identifier_name_without_uppercasing(name_obj) -> str:
|
|
2177
2230
|
if name_obj.getClass().getSimpleName() in (
|
|
2178
2231
|
"PlanWithUnresolvedIdentifier",
|
|
2179
2232
|
"ExpressionWithUnresolvedIdentifier",
|
|
2180
2233
|
):
|
|
2181
|
-
|
|
2182
|
-
|
|
2183
|
-
|
|
2184
|
-
|
|
2185
|
-
|
|
2186
|
-
expr_proto, m, ExpressionTyper.dummy_typer(session)
|
|
2187
|
-
)
|
|
2188
|
-
name = spark_to_sf_single_id(
|
|
2189
|
-
session.range(1).select(expr[1].col).collect()[0][0]
|
|
2234
|
+
return _get_relation_identifier(name_obj)
|
|
2235
|
+
else:
|
|
2236
|
+
name = ".".join(
|
|
2237
|
+
quote_name_without_upper_casing(str(part))
|
|
2238
|
+
for part in as_java_list(name_obj.nameParts())
|
|
2190
2239
|
)
|
|
2240
|
+
|
|
2241
|
+
return name
|
|
2242
|
+
|
|
2243
|
+
|
|
2244
|
+
def get_relation_identifier_name(name_obj, is_multi_part: bool = False) -> str:
|
|
2245
|
+
if name_obj.getClass().getSimpleName() in (
|
|
2246
|
+
"PlanWithUnresolvedIdentifier",
|
|
2247
|
+
"ExpressionWithUnresolvedIdentifier",
|
|
2248
|
+
):
|
|
2249
|
+
return _get_relation_identifier(name_obj)
|
|
2191
2250
|
else:
|
|
2192
2251
|
if is_multi_part:
|
|
2193
2252
|
try:
|
|
@@ -46,6 +46,9 @@ def map_read(
|
|
|
46
46
|
|
|
47
47
|
Currently, the supported read formats are `csv`, `json` and `parquet`.
|
|
48
48
|
"""
|
|
49
|
+
|
|
50
|
+
materialize_df = True
|
|
51
|
+
|
|
49
52
|
match rel.read.WhichOneof("read_type"):
|
|
50
53
|
case "named_table":
|
|
51
54
|
return map_read_table_or_file(rel)
|
|
@@ -99,6 +102,10 @@ def map_read(
|
|
|
99
102
|
for path in rel.read.data_source.paths
|
|
100
103
|
]
|
|
101
104
|
|
|
105
|
+
# JSON already materializes the table internally
|
|
106
|
+
if read_format == "json":
|
|
107
|
+
materialize_df = False
|
|
108
|
+
|
|
102
109
|
result = _read_file(
|
|
103
110
|
clean_source_paths, options, read_format, rel, schema, session
|
|
104
111
|
)
|
|
@@ -159,7 +166,9 @@ def map_read(
|
|
|
159
166
|
raise SnowparkConnectNotImplementedError(f"Unsupported read type: {other}")
|
|
160
167
|
|
|
161
168
|
return df_cache_map_put_if_absent(
|
|
162
|
-
(get_session_id(), rel.common.plan_id),
|
|
169
|
+
(get_session_id(), rel.common.plan_id),
|
|
170
|
+
lambda: result,
|
|
171
|
+
materialize=materialize_df,
|
|
163
172
|
)
|
|
164
173
|
|
|
165
174
|
|
|
@@ -205,6 +214,15 @@ def _get_supported_read_file_format(unparsed_identifier: str) -> str | None:
|
|
|
205
214
|
return None
|
|
206
215
|
|
|
207
216
|
|
|
217
|
+
def _quote_stage_path(stage_path: str) -> str:
|
|
218
|
+
"""
|
|
219
|
+
Quote stage paths to escape any special characters.
|
|
220
|
+
"""
|
|
221
|
+
if stage_path.startswith("@"):
|
|
222
|
+
return f"'{stage_path}'"
|
|
223
|
+
return stage_path
|
|
224
|
+
|
|
225
|
+
|
|
208
226
|
def _read_file(
|
|
209
227
|
clean_source_paths: list[str],
|
|
210
228
|
options: dict,
|
|
@@ -218,6 +236,7 @@ def _read_file(
|
|
|
218
236
|
session,
|
|
219
237
|
)
|
|
220
238
|
upload_files_if_needed(paths, clean_source_paths, session, read_format)
|
|
239
|
+
paths = [_quote_stage_path(path) for path in paths]
|
|
221
240
|
match read_format:
|
|
222
241
|
case "csv":
|
|
223
242
|
from snowflake.snowpark_connect.relation.read.map_read_csv import (
|
|
@@ -285,8 +304,8 @@ def upload_files_if_needed(
|
|
|
285
304
|
|
|
286
305
|
def _upload_dir(target: str, source: str) -> None:
|
|
287
306
|
# overwrite=True will not remove all stale files in the target prefix
|
|
288
|
-
|
|
289
|
-
remove_command = f"REMOVE {target}/"
|
|
307
|
+
# Quote the target path to allow special characters.
|
|
308
|
+
remove_command = f"REMOVE '{target}/'"
|
|
290
309
|
assert (
|
|
291
310
|
"//" not in remove_command
|
|
292
311
|
), f"Remove command {remove_command} contains double slash"
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
#
|
|
4
4
|
|
|
5
5
|
import copy
|
|
6
|
+
from typing import Any
|
|
6
7
|
|
|
7
8
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
8
9
|
|
|
@@ -16,6 +17,7 @@ from snowflake.snowpark_connect.relation.read.utils import (
|
|
|
16
17
|
get_spark_column_names_from_snowpark_columns,
|
|
17
18
|
rename_columns_as_snowflake_standard,
|
|
18
19
|
)
|
|
20
|
+
from snowflake.snowpark_connect.utils.io_utils import cached_file_format
|
|
19
21
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
20
22
|
SnowparkConnectNotImplementedError,
|
|
21
23
|
)
|
|
@@ -42,21 +44,34 @@ def map_read_csv(
|
|
|
42
44
|
)
|
|
43
45
|
else:
|
|
44
46
|
snowpark_options = options.convert_to_snowpark_args()
|
|
47
|
+
parse_header = snowpark_options.get("PARSE_HEADER", False)
|
|
48
|
+
file_format_options = _parse_csv_snowpark_options(snowpark_options)
|
|
49
|
+
file_format = cached_file_format(session, "csv", file_format_options)
|
|
50
|
+
|
|
51
|
+
snowpark_read_options = dict()
|
|
52
|
+
snowpark_read_options["FORMAT_NAME"] = file_format
|
|
53
|
+
snowpark_read_options["ENFORCE_EXISTING_FILE_FORMAT"] = True
|
|
54
|
+
snowpark_read_options["INFER_SCHEMA"] = snowpark_options.get(
|
|
55
|
+
"INFER_SCHEMA", False
|
|
56
|
+
)
|
|
57
|
+
snowpark_read_options["PATTERN"] = snowpark_options.get("PATTERN", None)
|
|
58
|
+
|
|
45
59
|
raw_options = rel.read.data_source.options
|
|
46
60
|
if schema is None or (
|
|
47
|
-
|
|
48
|
-
and raw_options.get("enforceSchema", "True").lower() == "false"
|
|
61
|
+
parse_header and raw_options.get("enforceSchema", "True").lower() == "false"
|
|
49
62
|
): # Schema has to equals to header's format
|
|
50
|
-
reader = session.read.options(
|
|
63
|
+
reader = session.read.options(snowpark_read_options)
|
|
51
64
|
else:
|
|
52
|
-
reader = session.read.options(
|
|
65
|
+
reader = session.read.options(snowpark_read_options).schema(schema)
|
|
53
66
|
df = read_data(
|
|
54
67
|
reader,
|
|
55
68
|
schema,
|
|
56
69
|
session,
|
|
57
70
|
paths[0],
|
|
58
|
-
|
|
71
|
+
file_format_options,
|
|
72
|
+
snowpark_read_options,
|
|
59
73
|
raw_options,
|
|
74
|
+
parse_header,
|
|
60
75
|
)
|
|
61
76
|
if len(paths) > 1:
|
|
62
77
|
# TODO: figure out if this is what Spark does.
|
|
@@ -81,15 +96,65 @@ def map_read_csv(
|
|
|
81
96
|
)
|
|
82
97
|
|
|
83
98
|
|
|
99
|
+
_csv_file_format_allowed_options = {
|
|
100
|
+
"COMPRESSION",
|
|
101
|
+
"RECORD_DELIMITER",
|
|
102
|
+
"FIELD_DELIMITER",
|
|
103
|
+
"MULTI_LINE",
|
|
104
|
+
"FILE_EXTENSION",
|
|
105
|
+
"PARSE_HEADER",
|
|
106
|
+
"SKIP_HEADER",
|
|
107
|
+
"SKIP_BLANK_LINES",
|
|
108
|
+
"DATE_FORMAT",
|
|
109
|
+
"TIME_FORMAT",
|
|
110
|
+
"TIMESTAMP_FORMAT",
|
|
111
|
+
"BINARY_FORMAT",
|
|
112
|
+
"ESCAPE",
|
|
113
|
+
"ESCAPE_UNENCLOSED_FIELD",
|
|
114
|
+
"TRIM_SPACE",
|
|
115
|
+
"FIELD_OPTIONALLY_ENCLOSED_BY",
|
|
116
|
+
"NULL_IF",
|
|
117
|
+
"ERROR_ON_COLUMN_COUNT_MISMATCH",
|
|
118
|
+
"REPLACE_INVALID_CHARACTERS",
|
|
119
|
+
"EMPTY_FIELD_AS_NULL",
|
|
120
|
+
"SKIP_BYTE_ORDER_MARK",
|
|
121
|
+
"ENCODING",
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _parse_csv_snowpark_options(snowpark_options: dict[str, Any]) -> dict[str, Any]:
|
|
126
|
+
file_format_options = dict()
|
|
127
|
+
for key, value in snowpark_options.items():
|
|
128
|
+
upper_key = key.upper()
|
|
129
|
+
if upper_key in _csv_file_format_allowed_options:
|
|
130
|
+
file_format_options[upper_key] = value
|
|
131
|
+
|
|
132
|
+
# This option has to be removed, because we cannot use at the same time predefined file format and parse_header option
|
|
133
|
+
# Such combination causes snowpark to raise SQL compilation error: Invalid file format "PARSE_HEADER" is only allowed for CSV INFER_SCHEMA and MATCH_BY_COLUMN_NAME
|
|
134
|
+
parse_header = file_format_options.get("PARSE_HEADER", False)
|
|
135
|
+
if parse_header:
|
|
136
|
+
file_format_options["SKIP_HEADER"] = 1
|
|
137
|
+
del file_format_options["PARSE_HEADER"]
|
|
138
|
+
|
|
139
|
+
return file_format_options
|
|
140
|
+
|
|
141
|
+
|
|
84
142
|
def get_header_names(
|
|
85
143
|
session: snowpark.Session,
|
|
86
144
|
path: list[str],
|
|
87
|
-
|
|
145
|
+
file_format_options: dict,
|
|
146
|
+
snowpark_read_options: dict,
|
|
88
147
|
) -> list[str]:
|
|
89
|
-
|
|
90
|
-
|
|
148
|
+
no_header_file_format_options = copy.copy(file_format_options)
|
|
149
|
+
no_header_file_format_options["PARSE_HEADER"] = False
|
|
150
|
+
no_header_file_format_options.pop("SKIP_HEADER", None)
|
|
151
|
+
|
|
152
|
+
file_format = cached_file_format(session, "csv", no_header_file_format_options)
|
|
153
|
+
no_header_snowpark_read_options = copy.copy(snowpark_read_options)
|
|
154
|
+
no_header_snowpark_read_options["FORMAT_NAME"] = file_format
|
|
155
|
+
no_header_snowpark_read_options.pop("INFER_SCHEMA", None)
|
|
91
156
|
|
|
92
|
-
header_df = session.read.options(
|
|
157
|
+
header_df = session.read.options(no_header_snowpark_read_options).csv(path).limit(1)
|
|
93
158
|
header_data = header_df.collect()[0]
|
|
94
159
|
return [
|
|
95
160
|
f'"{header_data[i]}"'
|
|
@@ -103,8 +168,10 @@ def read_data(
|
|
|
103
168
|
schema: snowpark.types.StructType | None,
|
|
104
169
|
session: snowpark.Session,
|
|
105
170
|
path: list[str],
|
|
106
|
-
|
|
171
|
+
file_format_options: dict,
|
|
172
|
+
snowpark_read_options: dict,
|
|
107
173
|
raw_options: dict,
|
|
174
|
+
parse_header: bool,
|
|
108
175
|
) -> snowpark.DataFrame:
|
|
109
176
|
df = reader.csv(path)
|
|
110
177
|
filename = path.strip("/").split("/")[-1]
|
|
@@ -120,23 +187,35 @@ def read_data(
|
|
|
120
187
|
raise Exception("CSV header does not conform to the schema")
|
|
121
188
|
return df
|
|
122
189
|
|
|
123
|
-
headers = get_header_names(
|
|
124
|
-
|
|
190
|
+
headers = get_header_names(
|
|
191
|
+
session, path, file_format_options, snowpark_read_options
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
df_schema_fields = df.schema.fields
|
|
195
|
+
if len(headers) == len(df_schema_fields) and parse_header:
|
|
196
|
+
return df.select(
|
|
197
|
+
[
|
|
198
|
+
snowpark_fn.col(df_schema_fields[i].name).alias(headers[i])
|
|
199
|
+
for i in range(len(headers))
|
|
200
|
+
]
|
|
201
|
+
)
|
|
125
202
|
# Handle mismatch in column count between header and data
|
|
126
|
-
|
|
127
|
-
len(
|
|
128
|
-
and
|
|
129
|
-
and
|
|
130
|
-
and len(headers) != len(
|
|
203
|
+
elif (
|
|
204
|
+
len(df_schema_fields) == 1
|
|
205
|
+
and df_schema_fields[0].name.upper() == "C1"
|
|
206
|
+
and parse_header
|
|
207
|
+
and len(headers) != len(df_schema_fields)
|
|
131
208
|
):
|
|
132
|
-
df = (
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
209
|
+
df = reader.schema(
|
|
210
|
+
StructType([StructField(h, StringType(), True) for h in headers])
|
|
211
|
+
).csv(path)
|
|
212
|
+
elif not parse_header and len(headers) != len(df_schema_fields):
|
|
213
|
+
return df.select([df_schema_fields[i].name for i in range(len(headers))])
|
|
214
|
+
elif parse_header and len(headers) != len(df_schema_fields):
|
|
215
|
+
return df.select(
|
|
216
|
+
[
|
|
217
|
+
snowpark_fn.col(df_schema_fields[i].name).alias(headers[i])
|
|
218
|
+
for i in range(len(headers))
|
|
219
|
+
]
|
|
136
220
|
)
|
|
137
|
-
elif snowpark_options.get("PARSE_HEADER") is False and len(headers) != len(
|
|
138
|
-
df.schema.fields
|
|
139
|
-
):
|
|
140
|
-
return df.select([df.schema.fields[i].name for i in range(len(headers))])
|
|
141
|
-
|
|
142
221
|
return df
|
|
@@ -2,9 +2,12 @@
|
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
|
|
5
|
+
import concurrent.futures
|
|
5
6
|
import copy
|
|
6
7
|
import json
|
|
8
|
+
import os
|
|
7
9
|
import typing
|
|
10
|
+
import uuid
|
|
8
11
|
from contextlib import suppress
|
|
9
12
|
from datetime import datetime
|
|
10
13
|
|
|
@@ -253,20 +256,20 @@ def merge_row_schema(
|
|
|
253
256
|
return schema
|
|
254
257
|
|
|
255
258
|
|
|
256
|
-
def
|
|
257
|
-
result_df: snowpark.DataFrame,
|
|
258
|
-
data: typing.List[Row],
|
|
259
|
-
schema: StructType,
|
|
259
|
+
def insert_data_chunk(
|
|
260
260
|
session: snowpark.Session,
|
|
261
|
-
|
|
262
|
-
|
|
261
|
+
data: list[Row],
|
|
262
|
+
schema: StructType,
|
|
263
|
+
table_name: str,
|
|
264
|
+
) -> None:
|
|
265
|
+
df = session.create_dataframe(
|
|
263
266
|
data=data,
|
|
264
267
|
schema=schema,
|
|
265
268
|
)
|
|
266
|
-
if result_df is None:
|
|
267
|
-
return current_df
|
|
268
269
|
|
|
269
|
-
|
|
270
|
+
df.write.mode("append").save_as_table(
|
|
271
|
+
table_name, table_type="temp", table_exists=True
|
|
272
|
+
)
|
|
270
273
|
|
|
271
274
|
|
|
272
275
|
def construct_dataframe_by_schema(
|
|
@@ -276,39 +279,47 @@ def construct_dataframe_by_schema(
|
|
|
276
279
|
snowpark_options: dict,
|
|
277
280
|
batch_size: int = 1000,
|
|
278
281
|
) -> snowpark.DataFrame:
|
|
279
|
-
|
|
282
|
+
table_name = "__sas_json_read_temp_" + uuid.uuid4().hex
|
|
283
|
+
|
|
284
|
+
# We can have more workers than CPU count, this is an IO-intensive task
|
|
285
|
+
max_workers = min(16, os.cpu_count() * 2)
|
|
280
286
|
|
|
281
287
|
current_data = []
|
|
282
288
|
progress = 0
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
289
|
+
|
|
290
|
+
# Initialize the temp table
|
|
291
|
+
session.create_dataframe([], schema=schema).write.mode("append").save_as_table(
|
|
292
|
+
table_name, table_type="temp", table_exists=False
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as exc:
|
|
296
|
+
for row in rows:
|
|
297
|
+
current_data.append(construct_row_by_schema(row, schema, snowpark_options))
|
|
298
|
+
if len(current_data) >= batch_size:
|
|
299
|
+
progress += len(current_data)
|
|
300
|
+
exc.submit(
|
|
301
|
+
insert_data_chunk,
|
|
302
|
+
session,
|
|
303
|
+
copy.deepcopy(current_data),
|
|
304
|
+
schema,
|
|
305
|
+
table_name,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
logger.info(f"JSON reader: finished processing {progress} rows")
|
|
309
|
+
current_data.clear()
|
|
310
|
+
|
|
311
|
+
if len(current_data) > 0:
|
|
286
312
|
progress += len(current_data)
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
current_data,
|
|
290
|
-
schema,
|
|
313
|
+
exc.submit(
|
|
314
|
+
insert_data_chunk,
|
|
291
315
|
session,
|
|
316
|
+
copy.deepcopy(current_data),
|
|
317
|
+
schema,
|
|
318
|
+
table_name,
|
|
292
319
|
)
|
|
293
|
-
|
|
294
320
|
logger.info(f"JSON reader: finished processing {progress} rows")
|
|
295
|
-
current_data = []
|
|
296
|
-
|
|
297
|
-
if len(current_data) > 0:
|
|
298
|
-
progress += len(current_data)
|
|
299
|
-
result = union_data_into_df(
|
|
300
|
-
result,
|
|
301
|
-
current_data,
|
|
302
|
-
schema,
|
|
303
|
-
session,
|
|
304
|
-
)
|
|
305
|
-
|
|
306
|
-
logger.info(f"JSON reader: finished processing {progress} rows")
|
|
307
|
-
current_data = []
|
|
308
321
|
|
|
309
|
-
|
|
310
|
-
raise ValueError("Dataframe cannot be empty")
|
|
311
|
-
return result
|
|
322
|
+
return session.table(table_name)
|
|
312
323
|
|
|
313
324
|
|
|
314
325
|
def construct_row_by_schema(
|
|
@@ -11,11 +11,17 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
|
11
11
|
unquote_if_quoted,
|
|
12
12
|
)
|
|
13
13
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
14
|
+
from snowflake.snowpark.types import StructField, StructType
|
|
15
|
+
from snowflake.snowpark_connect.column_name_handler import (
|
|
16
|
+
ColumnNameMap,
|
|
17
|
+
make_column_names_snowpark_compatible,
|
|
18
|
+
)
|
|
14
19
|
from snowflake.snowpark_connect.config import auto_uppercase_non_column_identifiers
|
|
15
20
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
16
21
|
from snowflake.snowpark_connect.relation.read.utils import (
|
|
17
22
|
rename_columns_as_snowflake_standard,
|
|
18
23
|
)
|
|
24
|
+
from snowflake.snowpark_connect.utils.context import get_processed_views
|
|
19
25
|
from snowflake.snowpark_connect.utils.identifiers import (
|
|
20
26
|
split_fully_qualified_spark_name,
|
|
21
27
|
)
|
|
@@ -23,6 +29,7 @@ from snowflake.snowpark_connect.utils.session import _get_current_snowpark_sessi
|
|
|
23
29
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
24
30
|
SnowparkConnectNotImplementedError,
|
|
25
31
|
)
|
|
32
|
+
from snowflake.snowpark_connect.utils.temporary_view_cache import get_temp_view
|
|
26
33
|
|
|
27
34
|
|
|
28
35
|
def post_process_df(
|
|
@@ -64,15 +71,66 @@ def post_process_df(
|
|
|
64
71
|
raise
|
|
65
72
|
|
|
66
73
|
|
|
74
|
+
def _get_temporary_view(
|
|
75
|
+
temp_view: DataFrameContainer, table_name: str, plan_id: int
|
|
76
|
+
) -> DataFrameContainer:
|
|
77
|
+
fields_names = [field.name for field in temp_view.dataframe.schema.fields]
|
|
78
|
+
fields_types = [field.datatype for field in temp_view.dataframe.schema.fields]
|
|
79
|
+
|
|
80
|
+
snowpark_column_names = make_column_names_snowpark_compatible(fields_names, plan_id)
|
|
81
|
+
# Rename columns in dataframe to prevent conflicting names during joins
|
|
82
|
+
renamed_df = temp_view.dataframe.select(
|
|
83
|
+
*(
|
|
84
|
+
temp_view.dataframe.col(orig).alias(alias)
|
|
85
|
+
for orig, alias in zip(fields_names, snowpark_column_names)
|
|
86
|
+
)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
new_column_map = ColumnNameMap(
|
|
90
|
+
spark_column_names=temp_view.column_map.get_spark_columns(),
|
|
91
|
+
snowpark_column_names=snowpark_column_names,
|
|
92
|
+
column_metadata=temp_view.column_map.column_metadata,
|
|
93
|
+
column_qualifiers=[split_fully_qualified_spark_name(table_name)]
|
|
94
|
+
* len(temp_view.column_map.get_spark_columns()),
|
|
95
|
+
parent_column_name_map=temp_view.column_map.get_parent_column_name_map(),
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
schema = StructType(
|
|
99
|
+
[
|
|
100
|
+
StructField(name, type, _is_column=False)
|
|
101
|
+
for name, type in zip(snowpark_column_names, fields_types)
|
|
102
|
+
]
|
|
103
|
+
)
|
|
104
|
+
return DataFrameContainer(
|
|
105
|
+
dataframe=renamed_df,
|
|
106
|
+
column_map=new_column_map,
|
|
107
|
+
table_name=temp_view.table_name,
|
|
108
|
+
alias=temp_view.alias,
|
|
109
|
+
partition_hint=temp_view.partition_hint,
|
|
110
|
+
cached_schema_getter=lambda: schema,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
67
114
|
def get_table_from_name(
|
|
68
115
|
table_name: str, session: snowpark.Session, plan_id: int
|
|
69
116
|
) -> DataFrameContainer:
|
|
70
117
|
"""Get table from name returning a container."""
|
|
118
|
+
|
|
119
|
+
# Verify if recursive view read is not attempted
|
|
120
|
+
if table_name in get_processed_views():
|
|
121
|
+
raise AnalysisException(
|
|
122
|
+
f"[RECURSIVE_VIEW] Recursive view `{table_name}` detected (cycle: `{table_name}` -> `{table_name}`)"
|
|
123
|
+
)
|
|
124
|
+
|
|
71
125
|
snowpark_name = ".".join(
|
|
72
126
|
quote_name_without_upper_casing(part)
|
|
73
127
|
for part in split_fully_qualified_spark_name(table_name)
|
|
74
128
|
)
|
|
75
129
|
|
|
130
|
+
temp_view = get_temp_view(snowpark_name)
|
|
131
|
+
if temp_view:
|
|
132
|
+
return _get_temporary_view(temp_view, table_name, plan_id)
|
|
133
|
+
|
|
76
134
|
if auto_uppercase_non_column_identifiers():
|
|
77
135
|
snowpark_name = snowpark_name.upper()
|
|
78
136
|
|