snowpark-connect 0.21.0__py3-none-any.whl → 0.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +19 -3
- snowflake/snowpark_connect/error/error_utils.py +25 -0
- snowflake/snowpark_connect/expression/map_udf.py +4 -4
- snowflake/snowpark_connect/expression/map_unresolved_function.py +203 -128
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/relation/map_aggregate.py +102 -18
- snowflake/snowpark_connect/relation/map_column_ops.py +21 -2
- snowflake/snowpark_connect/relation/map_map_partitions.py +3 -1
- snowflake/snowpark_connect/relation/map_sql.py +18 -191
- snowflake/snowpark_connect/relation/map_udtf.py +4 -4
- snowflake/snowpark_connect/relation/read/map_read_json.py +12 -1
- snowflake/snowpark_connect/relation/read/reader_config.py +1 -0
- snowflake/snowpark_connect/relation/write/map_write.py +68 -24
- snowflake/snowpark_connect/server.py +9 -0
- snowflake/snowpark_connect/type_mapping.py +4 -0
- snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
- snowflake/snowpark_connect/utils/session.py +0 -4
- snowflake/snowpark_connect/utils/telemetry.py +213 -61
- snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/__init__.py +0 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
- snowflake/snowpark_decoder/dp_session.py +111 -0
- snowflake/snowpark_decoder/spark_decoder.py +76 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/METADATA +2 -2
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/RECORD +40 -29
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/top_level.txt +1 -0
- spark/__init__.py +0 -0
- spark/connect/__init__.py +0 -0
- spark/connect/envelope_pb2.py +31 -0
- spark/connect/envelope_pb2.pyi +46 -0
- snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.22.1.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.22.1.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.22.1.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/NOTICE-binary +0 -0
|
@@ -4,10 +4,14 @@
|
|
|
4
4
|
|
|
5
5
|
import re
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
+
from typing import Optional
|
|
7
8
|
|
|
8
9
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
9
10
|
|
|
11
|
+
import snowflake.snowpark.functions as snowpark_fn
|
|
10
12
|
from snowflake import snowpark
|
|
13
|
+
from snowflake.snowpark import Column
|
|
14
|
+
from snowflake.snowpark._internal.analyzer.unary_expression import Alias
|
|
11
15
|
from snowflake.snowpark.types import DataType
|
|
12
16
|
from snowflake.snowpark_connect.column_name_handler import (
|
|
13
17
|
make_column_names_snowpark_compatible,
|
|
@@ -21,6 +25,7 @@ from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
|
21
25
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
22
26
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
23
27
|
from snowflake.snowpark_connect.utils.context import (
|
|
28
|
+
get_is_evaluating_sql,
|
|
24
29
|
set_current_grouping_columns,
|
|
25
30
|
temporary_pivot_expression,
|
|
26
31
|
)
|
|
@@ -131,19 +136,57 @@ def map_pivot_aggregate(
|
|
|
131
136
|
get_literal_field_and_name(lit)[0] for lit in rel.aggregate.pivot.values
|
|
132
137
|
]
|
|
133
138
|
|
|
139
|
+
used_columns = {pivot_column[1].col._expression.name}
|
|
140
|
+
if get_is_evaluating_sql():
|
|
141
|
+
# When evaluating SQL spark doesn't trim columns from the result
|
|
142
|
+
used_columns = {"*"}
|
|
143
|
+
else:
|
|
144
|
+
for expression in rel.aggregate.aggregate_expressions:
|
|
145
|
+
matched_identifiers = re.findall(
|
|
146
|
+
r'unparsed_identifier: "(.*)"', expression.__str__()
|
|
147
|
+
)
|
|
148
|
+
for identifier in matched_identifiers:
|
|
149
|
+
mapped_col = input_container.column_map.spark_to_col.get(
|
|
150
|
+
identifier, None
|
|
151
|
+
)
|
|
152
|
+
if mapped_col:
|
|
153
|
+
used_columns.add(mapped_col[0].snowpark_name)
|
|
154
|
+
|
|
134
155
|
if len(columns.grouping_expressions()) == 0:
|
|
135
|
-
result =
|
|
136
|
-
|
|
137
|
-
|
|
156
|
+
result = (
|
|
157
|
+
input_df_actual.select(*used_columns)
|
|
158
|
+
.pivot(pivot_column[1].col, pivot_values if pivot_values else None)
|
|
159
|
+
.agg(*columns.aggregation_expressions(unalias=True))
|
|
160
|
+
)
|
|
138
161
|
else:
|
|
139
162
|
result = (
|
|
140
163
|
input_df_actual.group_by(*columns.grouping_expressions())
|
|
141
164
|
.pivot(pivot_column[1].col, pivot_values if pivot_values else None)
|
|
142
|
-
.agg(*columns.aggregation_expressions())
|
|
165
|
+
.agg(*columns.aggregation_expressions(unalias=True))
|
|
143
166
|
)
|
|
144
167
|
|
|
168
|
+
agg_name_list = [c.spark_name for c in columns.grouping_columns]
|
|
169
|
+
|
|
170
|
+
# Calculate number of pivot values for proper Spark-compatible indexing
|
|
171
|
+
total_pivot_columns = len(result.columns) - len(agg_name_list)
|
|
172
|
+
num_pivot_values = (
|
|
173
|
+
total_pivot_columns // len(columns.aggregation_columns)
|
|
174
|
+
if len(columns.aggregation_columns) > 0
|
|
175
|
+
else 1
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def _get_agg_exp_alias_for_col(col_index: int) -> Optional[str]:
|
|
179
|
+
if col_index < len(agg_name_list) or len(columns.aggregation_columns) <= 1:
|
|
180
|
+
return None
|
|
181
|
+
else:
|
|
182
|
+
index = (col_index - len(agg_name_list)) // num_pivot_values
|
|
183
|
+
return columns.aggregation_columns[index].spark_name
|
|
184
|
+
|
|
145
185
|
spark_columns = []
|
|
146
|
-
for col in [
|
|
186
|
+
for col in [
|
|
187
|
+
pivot_column_name(c, _get_agg_exp_alias_for_col(i))
|
|
188
|
+
for i, c in enumerate(result.columns)
|
|
189
|
+
]:
|
|
147
190
|
spark_col = (
|
|
148
191
|
input_container.column_map.get_spark_column_name_from_snowpark_column_name(
|
|
149
192
|
col, allow_non_exists=True
|
|
@@ -153,22 +196,57 @@ def map_pivot_aggregate(
|
|
|
153
196
|
if spark_col is not None:
|
|
154
197
|
spark_columns.append(spark_col)
|
|
155
198
|
else:
|
|
156
|
-
|
|
199
|
+
# Handle NULL column names to match Spark behavior (lowercase 'null')
|
|
200
|
+
if col == "NULL":
|
|
201
|
+
spark_columns.append(col.lower())
|
|
202
|
+
else:
|
|
203
|
+
spark_columns.append(col)
|
|
204
|
+
|
|
205
|
+
grouping_cols_count = len(agg_name_list)
|
|
206
|
+
pivot_cols = result.columns[grouping_cols_count:]
|
|
207
|
+
spark_pivot_cols = spark_columns[grouping_cols_count:]
|
|
208
|
+
|
|
209
|
+
num_agg_functions = len(columns.aggregation_columns)
|
|
210
|
+
num_pivot_values = len(pivot_cols) // num_agg_functions
|
|
211
|
+
|
|
212
|
+
reordered_snowpark_cols = []
|
|
213
|
+
reordered_spark_cols = []
|
|
214
|
+
column_indices = [] # 1-based indexing
|
|
215
|
+
|
|
216
|
+
for i in range(grouping_cols_count):
|
|
217
|
+
reordered_snowpark_cols.append(result.columns[i])
|
|
218
|
+
reordered_spark_cols.append(spark_columns[i])
|
|
219
|
+
column_indices.append(i + 1)
|
|
220
|
+
|
|
221
|
+
for pivot_idx in range(num_pivot_values):
|
|
222
|
+
for agg_idx in range(num_agg_functions):
|
|
223
|
+
current_pos = agg_idx * num_pivot_values + pivot_idx
|
|
224
|
+
if current_pos < len(pivot_cols):
|
|
225
|
+
reordered_snowpark_cols.append(pivot_cols[current_pos])
|
|
226
|
+
reordered_spark_cols.append(spark_pivot_cols[current_pos])
|
|
227
|
+
original_index = grouping_cols_count + current_pos
|
|
228
|
+
column_indices.append(original_index + 1)
|
|
229
|
+
|
|
230
|
+
reordered_result = result.select(
|
|
231
|
+
*[snowpark_fn.col(f"${idx}") for idx in column_indices]
|
|
232
|
+
)
|
|
157
233
|
|
|
158
|
-
agg_name_list = [c.spark_name for c in columns.grouping_columns]
|
|
159
234
|
return DataFrameContainer.create_with_column_mapping(
|
|
160
|
-
dataframe=
|
|
161
|
-
spark_column_names=
|
|
162
|
-
snowpark_column_names=
|
|
235
|
+
dataframe=reordered_result,
|
|
236
|
+
spark_column_names=reordered_spark_cols,
|
|
237
|
+
snowpark_column_names=[f"${idx}" for idx in column_indices],
|
|
163
238
|
column_qualifiers=(
|
|
164
239
|
columns.get_qualifiers()[: len(agg_name_list)]
|
|
165
|
-
+ [[]] * (len(
|
|
240
|
+
+ [[]] * (len(reordered_spark_cols) - len(agg_name_list))
|
|
166
241
|
),
|
|
167
242
|
parent_column_name_map=input_container.column_map,
|
|
243
|
+
snowpark_column_types=[
|
|
244
|
+
result.schema.fields[idx - 1].datatype for idx in column_indices
|
|
245
|
+
],
|
|
168
246
|
)
|
|
169
247
|
|
|
170
248
|
|
|
171
|
-
def
|
|
249
|
+
def pivot_column_name(snowpark_cname, opt_alias: Optional[str] = None) -> Optional[str]:
|
|
172
250
|
# For values that are used as pivoted columns, the input and output are in the following format (outermost double quotes are part of the input):
|
|
173
251
|
|
|
174
252
|
# 1. "'Java'" -> Java
|
|
@@ -183,7 +261,7 @@ def string_parser(s):
|
|
|
183
261
|
|
|
184
262
|
try:
|
|
185
263
|
# handling values that are used as pivoted columns
|
|
186
|
-
match = re.match(r'^"\'(.*)\'"$',
|
|
264
|
+
match = re.match(r'^"\'(.*)\'"$', snowpark_cname)
|
|
187
265
|
# extract the content between the outermost double quote followed by a single quote "'
|
|
188
266
|
content = match.group(1)
|
|
189
267
|
# convert the escaped double quote to the actual double quote
|
|
@@ -195,10 +273,10 @@ def string_parser(s):
|
|
|
195
273
|
content = re.sub(r"'", "", content)
|
|
196
274
|
# replace the placeholder with the single quote which we want to preserve
|
|
197
275
|
result = content.replace(escape_single_quote_placeholder, "'")
|
|
198
|
-
return result
|
|
276
|
+
return f"{result}_{opt_alias}" if opt_alias else result
|
|
199
277
|
except Exception:
|
|
200
278
|
# fallback to the original logic, handling aliased column names
|
|
201
|
-
double_quote_list = re.findall(r'"(.*?)"',
|
|
279
|
+
double_quote_list = re.findall(r'"(.*?)"', snowpark_cname)
|
|
202
280
|
spark_string = ""
|
|
203
281
|
for entry in list(filter(None, double_quote_list)):
|
|
204
282
|
if "'" in entry:
|
|
@@ -210,7 +288,7 @@ def string_parser(s):
|
|
|
210
288
|
spark_string += entry
|
|
211
289
|
else:
|
|
212
290
|
spark_string += '"' + entry + '"'
|
|
213
|
-
return
|
|
291
|
+
return snowpark_cname if spark_string == "" else spark_string
|
|
214
292
|
|
|
215
293
|
|
|
216
294
|
@dataclass(frozen=True)
|
|
@@ -231,8 +309,14 @@ class _Columns:
|
|
|
231
309
|
def grouping_expressions(self) -> list[snowpark.Column]:
|
|
232
310
|
return [col.expression for col in self.grouping_columns]
|
|
233
311
|
|
|
234
|
-
def aggregation_expressions(self) -> list[snowpark.Column]:
|
|
235
|
-
|
|
312
|
+
def aggregation_expressions(self, unalias: bool = False) -> list[snowpark.Column]:
|
|
313
|
+
def _unalias(col: snowpark.Column) -> snowpark.Column:
|
|
314
|
+
if unalias and hasattr(col, "_expr1") and isinstance(col._expr1, Alias):
|
|
315
|
+
return _unalias(Column(col._expr1.child))
|
|
316
|
+
else:
|
|
317
|
+
return col
|
|
318
|
+
|
|
319
|
+
return [_unalias(col.expression) for col in self.aggregation_columns]
|
|
236
320
|
|
|
237
321
|
def expressions(self) -> list[snowpark.Column]:
|
|
238
322
|
return self.grouping_expressions() + self.aggregation_expressions()
|
|
@@ -666,10 +666,29 @@ def map_with_columns_renamed(
|
|
|
666
666
|
)
|
|
667
667
|
|
|
668
668
|
# Validate for naming conflicts
|
|
669
|
-
|
|
669
|
+
rename_map = dict(rel.with_columns_renamed.rename_columns_map)
|
|
670
|
+
new_names_list = list(rename_map.values())
|
|
670
671
|
seen = set()
|
|
671
672
|
for new_name in new_names_list:
|
|
672
|
-
if
|
|
673
|
+
# Check if this new name conflicts with existing columns
|
|
674
|
+
# But allow renaming a column to a different case version of itself
|
|
675
|
+
is_case_insensitive_self_rename = False
|
|
676
|
+
if not global_config.spark_sql_caseSensitive:
|
|
677
|
+
# Find the source column(s) that map to this new name
|
|
678
|
+
source_columns = [
|
|
679
|
+
old_name
|
|
680
|
+
for old_name, new_name_candidate in rename_map.items()
|
|
681
|
+
if new_name_candidate == new_name
|
|
682
|
+
]
|
|
683
|
+
# Check if any source column is the same as new name when case-insensitive
|
|
684
|
+
is_case_insensitive_self_rename = any(
|
|
685
|
+
source_col.lower() == new_name.lower() for source_col in source_columns
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
if (
|
|
689
|
+
column_map.has_spark_column(new_name)
|
|
690
|
+
and not is_case_insensitive_self_rename
|
|
691
|
+
):
|
|
673
692
|
# Spark doesn't allow reusing existing names, even if the result df will not contain duplicate columns
|
|
674
693
|
raise _column_exists_error(new_name)
|
|
675
694
|
if (global_config.spark_sql_caseSensitive and new_name in seen) or (
|
|
@@ -14,7 +14,6 @@ from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
|
14
14
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
15
15
|
from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
|
|
16
16
|
from snowflake.snowpark_connect.utils.pandas_udtf_utils import create_pandas_udtf
|
|
17
|
-
from snowflake.snowpark_connect.utils.session import get_python_udxf_import_files
|
|
18
17
|
from snowflake.snowpark_connect.utils.udf_helper import (
|
|
19
18
|
SnowparkUDF,
|
|
20
19
|
process_udf_in_sproc,
|
|
@@ -28,6 +27,9 @@ from snowflake.snowpark_connect.utils.udtf_helper import (
|
|
|
28
27
|
create_pandas_udtf_in_sproc,
|
|
29
28
|
require_creating_udtf_in_sproc,
|
|
30
29
|
)
|
|
30
|
+
from snowflake.snowpark_connect.utils.udxf_import_utils import (
|
|
31
|
+
get_python_udxf_import_files,
|
|
32
|
+
)
|
|
31
33
|
|
|
32
34
|
|
|
33
35
|
def map_map_partitions(
|
|
@@ -14,7 +14,10 @@ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
|
14
14
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
15
15
|
import sqlglot
|
|
16
16
|
from google.protobuf.any_pb2 import Any
|
|
17
|
-
from pyspark.errors.exceptions.base import
|
|
17
|
+
from pyspark.errors.exceptions.base import (
|
|
18
|
+
AnalysisException,
|
|
19
|
+
UnsupportedOperationException,
|
|
20
|
+
)
|
|
18
21
|
from sqlglot.expressions import ColumnDef, DataType, FileFormatProperty, Identifier
|
|
19
22
|
|
|
20
23
|
import snowflake.snowpark.functions as snowpark_fn
|
|
@@ -27,7 +30,6 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
|
27
30
|
)
|
|
28
31
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
|
29
32
|
from snowflake.snowpark._internal.utils import is_sql_select_statement, quote_name
|
|
30
|
-
from snowflake.snowpark.functions import when_matched, when_not_matched
|
|
31
33
|
from snowflake.snowpark_connect.config import (
|
|
32
34
|
auto_uppercase_non_column_identifiers,
|
|
33
35
|
get_boolean_session_config_param,
|
|
@@ -56,16 +58,15 @@ from snowflake.snowpark_connect.utils.context import (
|
|
|
56
58
|
get_session_id,
|
|
57
59
|
push_evaluating_sql_scope,
|
|
58
60
|
push_sql_scope,
|
|
59
|
-
set_plan_id_map,
|
|
60
61
|
set_sql_args,
|
|
61
62
|
set_sql_plan_name,
|
|
62
63
|
)
|
|
63
64
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
64
65
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
65
66
|
SnowparkConnectNotImplementedError,
|
|
67
|
+
telemetry,
|
|
66
68
|
)
|
|
67
69
|
|
|
68
|
-
from .. import column_name_handler
|
|
69
70
|
from ..expression.map_sql_expression import (
|
|
70
71
|
_window_specs,
|
|
71
72
|
as_java_list,
|
|
@@ -163,6 +164,7 @@ def parse_pos_args(
|
|
|
163
164
|
|
|
164
165
|
def execute_logical_plan(logical_plan) -> DataFrameContainer:
|
|
165
166
|
proto = map_logical_plan_relation(logical_plan)
|
|
167
|
+
telemetry.report_parsed_sql_plan(proto)
|
|
166
168
|
with push_evaluating_sql_scope():
|
|
167
169
|
return map_relation(proto)
|
|
168
170
|
|
|
@@ -712,197 +714,22 @@ def map_sql_to_pandas_df(
|
|
|
712
714
|
f"INSERT {overwrite_str} INTO {name} {cols_str} {final_query}",
|
|
713
715
|
).collect()
|
|
714
716
|
case "MergeIntoTable":
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
if action.condition().isDefined():
|
|
719
|
-
(_, condition_typed_col,) = map_single_column_expression(
|
|
720
|
-
map_logical_plan_expression(
|
|
721
|
-
matched_action.condition().get()
|
|
722
|
-
),
|
|
723
|
-
column_mapping,
|
|
724
|
-
typer,
|
|
725
|
-
)
|
|
726
|
-
condition = condition_typed_col.col
|
|
727
|
-
return condition
|
|
728
|
-
|
|
729
|
-
def _get_assignments_from_action(
|
|
730
|
-
action,
|
|
731
|
-
column_mapping_source,
|
|
732
|
-
column_mapping_target,
|
|
733
|
-
typer_source,
|
|
734
|
-
typer_target,
|
|
735
|
-
):
|
|
736
|
-
assignments = dict()
|
|
737
|
-
if (
|
|
738
|
-
action.getClass().getSimpleName() == "InsertAction"
|
|
739
|
-
or action.getClass().getSimpleName() == "UpdateAction"
|
|
740
|
-
):
|
|
741
|
-
incoming_assignments = as_java_list(action.assignments())
|
|
742
|
-
for assignment in incoming_assignments:
|
|
743
|
-
(key_name, _) = map_single_column_expression(
|
|
744
|
-
map_logical_plan_expression(assignment.key()),
|
|
745
|
-
column_mapping=column_mapping_target,
|
|
746
|
-
typer=typer_source,
|
|
747
|
-
)
|
|
748
|
-
|
|
749
|
-
(_, val_typ_col) = map_single_column_expression(
|
|
750
|
-
map_logical_plan_expression(assignment.value()),
|
|
751
|
-
column_mapping=column_mapping_source,
|
|
752
|
-
typer=typer_target,
|
|
753
|
-
)
|
|
754
|
-
|
|
755
|
-
assignments[key_name] = val_typ_col.col
|
|
756
|
-
elif (
|
|
757
|
-
action.getClass().getSimpleName() == "InsertStarAction"
|
|
758
|
-
or action.getClass().getSimpleName() == "UpdateStarAction"
|
|
759
|
-
):
|
|
760
|
-
if len(column_mapping_source.columns) != len(
|
|
761
|
-
column_mapping_target.columns
|
|
762
|
-
):
|
|
763
|
-
raise ValueError(
|
|
764
|
-
"source and target must have the same number of columns for InsertStarAction or UpdateStarAction"
|
|
765
|
-
)
|
|
766
|
-
for i, col in enumerate(column_mapping_target.columns):
|
|
767
|
-
if assignments.get(col.snowpark_name) is not None:
|
|
768
|
-
raise SnowparkConnectNotImplementedError(
|
|
769
|
-
"UpdateStarAction or InsertStarAction is not supported with duplicate columns."
|
|
770
|
-
)
|
|
771
|
-
assignments[col.snowpark_name] = snowpark_fn.col(
|
|
772
|
-
column_mapping_source.columns[i].snowpark_name
|
|
773
|
-
)
|
|
774
|
-
return assignments
|
|
775
|
-
|
|
776
|
-
source_df_container = map_relation(
|
|
777
|
-
map_logical_plan_relation(logical_plan.sourceTable())
|
|
778
|
-
)
|
|
779
|
-
source_df = source_df_container.dataframe
|
|
780
|
-
plan_id = gen_sql_plan_id()
|
|
781
|
-
target_df_container = map_relation(
|
|
782
|
-
map_logical_plan_relation(logical_plan.targetTable(), plan_id)
|
|
783
|
-
)
|
|
784
|
-
target_df = target_df_container.dataframe
|
|
785
|
-
|
|
786
|
-
for col in target_df_container.column_map.columns:
|
|
787
|
-
target_df = target_df.with_column_renamed(
|
|
788
|
-
col.snowpark_name,
|
|
789
|
-
spark_to_sf_single_id(col.spark_name, is_column=True),
|
|
790
|
-
)
|
|
791
|
-
target_df_container = DataFrameContainer.create_with_column_mapping(
|
|
792
|
-
dataframe=target_df,
|
|
793
|
-
spark_column_names=target_df.columns,
|
|
794
|
-
snowpark_column_names=target_df.columns,
|
|
795
|
-
)
|
|
796
|
-
|
|
797
|
-
set_plan_id_map(plan_id, target_df_container)
|
|
798
|
-
|
|
799
|
-
joined_df_before_condition: snowpark.DataFrame = source_df.join(
|
|
800
|
-
target_df
|
|
801
|
-
)
|
|
802
|
-
|
|
803
|
-
column_mapping_for_conditions = column_name_handler.JoinColumnNameMap(
|
|
804
|
-
source_df_container.column_map,
|
|
805
|
-
target_df_container.column_map,
|
|
806
|
-
)
|
|
807
|
-
typer_for_expressions = ExpressionTyper(joined_df_before_condition)
|
|
808
|
-
|
|
809
|
-
(_, merge_condition_typed_col,) = map_single_column_expression(
|
|
810
|
-
map_logical_plan_expression(logical_plan.mergeCondition()),
|
|
811
|
-
column_mapping=column_mapping_for_conditions,
|
|
812
|
-
typer=typer_for_expressions,
|
|
813
|
-
)
|
|
814
|
-
|
|
815
|
-
clauses = []
|
|
816
|
-
|
|
817
|
-
for matched_action in as_java_list(logical_plan.matchedActions()):
|
|
818
|
-
condition = _get_condition_from_action(
|
|
819
|
-
matched_action,
|
|
820
|
-
column_mapping_for_conditions,
|
|
821
|
-
typer_for_expressions,
|
|
822
|
-
)
|
|
823
|
-
if matched_action.getClass().getSimpleName() == "DeleteAction":
|
|
824
|
-
clauses.append(when_matched(condition).delete())
|
|
825
|
-
elif (
|
|
826
|
-
matched_action.getClass().getSimpleName() == "UpdateAction"
|
|
827
|
-
or matched_action.getClass().getSimpleName()
|
|
828
|
-
== "UpdateStarAction"
|
|
829
|
-
):
|
|
830
|
-
assignments = _get_assignments_from_action(
|
|
831
|
-
matched_action,
|
|
832
|
-
source_df_container.column_map,
|
|
833
|
-
target_df_container.column_map,
|
|
834
|
-
ExpressionTyper(source_df),
|
|
835
|
-
ExpressionTyper(target_df),
|
|
836
|
-
)
|
|
837
|
-
clauses.append(when_matched(condition).update(assignments))
|
|
838
|
-
|
|
839
|
-
for not_matched_action in as_java_list(
|
|
840
|
-
logical_plan.notMatchedActions()
|
|
841
|
-
):
|
|
842
|
-
condition = _get_condition_from_action(
|
|
843
|
-
not_matched_action,
|
|
844
|
-
column_mapping_for_conditions,
|
|
845
|
-
typer_for_expressions,
|
|
846
|
-
)
|
|
847
|
-
if (
|
|
848
|
-
not_matched_action.getClass().getSimpleName() == "InsertAction"
|
|
849
|
-
or not_matched_action.getClass().getSimpleName()
|
|
850
|
-
== "InsertStarAction"
|
|
851
|
-
):
|
|
852
|
-
assignments = _get_assignments_from_action(
|
|
853
|
-
not_matched_action,
|
|
854
|
-
source_df_container.column_map,
|
|
855
|
-
target_df_container.column_map,
|
|
856
|
-
ExpressionTyper(source_df),
|
|
857
|
-
ExpressionTyper(target_df),
|
|
858
|
-
)
|
|
859
|
-
clauses.append(when_not_matched(condition).insert(assignments))
|
|
860
|
-
|
|
861
|
-
if not as_java_list(logical_plan.notMatchedBySourceActions()).isEmpty():
|
|
862
|
-
raise SnowparkConnectNotImplementedError(
|
|
863
|
-
"Snowflake does not support 'not matched by source' actions in MERGE statements."
|
|
864
|
-
)
|
|
865
|
-
|
|
866
|
-
if (
|
|
867
|
-
logical_plan.targetTable().getClass().getSimpleName()
|
|
868
|
-
== "UnresolvedRelation"
|
|
869
|
-
):
|
|
870
|
-
target_table_name = _spark_to_snowflake(
|
|
871
|
-
logical_plan.targetTable().multipartIdentifier()
|
|
872
|
-
)
|
|
873
|
-
else:
|
|
874
|
-
target_table_name = _spark_to_snowflake(
|
|
875
|
-
logical_plan.targetTable().child().multipartIdentifier()
|
|
876
|
-
)
|
|
877
|
-
session.table(target_table_name).merge(
|
|
878
|
-
source_df, merge_condition_typed_col.col, clauses
|
|
717
|
+
raise UnsupportedOperationException(
|
|
718
|
+
"[UNSUPPORTED_SQL_EXTENSION] The MERGE INTO command failed.\n"
|
|
719
|
+
+ "Reason: This command is a platform-specific SQL extension and is not part of the standard Apache Spark specification that this interface uses."
|
|
879
720
|
)
|
|
880
721
|
case "DeleteFromTable":
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
df = df_container.dataframe
|
|
885
|
-
for col in df_container.column_map.columns:
|
|
886
|
-
df = df.with_column_renamed(
|
|
887
|
-
col.snowpark_name,
|
|
888
|
-
spark_to_sf_single_id(col.spark_name, is_column=True),
|
|
889
|
-
)
|
|
890
|
-
df_container = column_name_handler.create_with_column_mapping(
|
|
891
|
-
dataframe=df,
|
|
892
|
-
spark_column_names=df.columns,
|
|
893
|
-
snowpark_column_names=df.columns,
|
|
722
|
+
raise UnsupportedOperationException(
|
|
723
|
+
"[UNSUPPORTED_SQL_EXTENSION] The DELETE FROM command failed.\n"
|
|
724
|
+
+ "Reason: This command is a platform-specific SQL extension and is not part of the standard Apache Spark specification that this interface uses."
|
|
894
725
|
)
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
map_logical_plan_expression(logical_plan.condition()),
|
|
902
|
-
df_container.column_map,
|
|
903
|
-
ExpressionTyper(df),
|
|
726
|
+
case "UpdateTable":
|
|
727
|
+
# Databricks/Delta-specific extension not supported by SAS.
|
|
728
|
+
# Provide an actionable, clear error.
|
|
729
|
+
raise UnsupportedOperationException(
|
|
730
|
+
"[UNSUPPORTED_SQL_EXTENSION] The UPDATE TABLE command failed.\n"
|
|
731
|
+
+ "Reason: This command is a platform-specific SQL extension and is not part of the standard Apache Spark specification that this interface uses."
|
|
904
732
|
)
|
|
905
|
-
session.table(name).delete(condition_typed_col.col)
|
|
906
733
|
case "RenameColumn":
|
|
907
734
|
table_name = get_relation_identifier_name(logical_plan.table(), True)
|
|
908
735
|
column_obj = logical_plan.column()
|
|
@@ -31,10 +31,7 @@ from snowflake.snowpark_connect.type_mapping import (
|
|
|
31
31
|
proto_to_snowpark_type,
|
|
32
32
|
)
|
|
33
33
|
from snowflake.snowpark_connect.utils.context import push_udtf_context
|
|
34
|
-
from snowflake.snowpark_connect.utils.session import
|
|
35
|
-
get_or_create_snowpark_session,
|
|
36
|
-
get_python_udxf_import_files,
|
|
37
|
-
)
|
|
34
|
+
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
38
35
|
from snowflake.snowpark_connect.utils.udtf_helper import (
|
|
39
36
|
SnowparkUDTF,
|
|
40
37
|
create_udtf_in_sproc,
|
|
@@ -42,6 +39,9 @@ from snowflake.snowpark_connect.utils.udtf_helper import (
|
|
|
42
39
|
udtf_check,
|
|
43
40
|
)
|
|
44
41
|
from snowflake.snowpark_connect.utils.udtf_utils import create_udtf
|
|
42
|
+
from snowflake.snowpark_connect.utils.udxf_import_utils import (
|
|
43
|
+
get_python_udxf_import_files,
|
|
44
|
+
)
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
def build_expected_types_from_parsed(
|
|
@@ -6,6 +6,7 @@ import copy
|
|
|
6
6
|
import json
|
|
7
7
|
import typing
|
|
8
8
|
from contextlib import suppress
|
|
9
|
+
from datetime import datetime
|
|
9
10
|
|
|
10
11
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
11
12
|
|
|
@@ -21,6 +22,7 @@ from snowflake.snowpark.types import (
|
|
|
21
22
|
StringType,
|
|
22
23
|
StructField,
|
|
23
24
|
StructType,
|
|
25
|
+
TimestampType,
|
|
24
26
|
)
|
|
25
27
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
26
28
|
from snowflake.snowpark_connect.relation.read.map_read import JsonReaderConfig
|
|
@@ -204,6 +206,8 @@ def merge_row_schema(
|
|
|
204
206
|
next_level_content = row[col_name]
|
|
205
207
|
if next_level_content is not None:
|
|
206
208
|
with suppress(json.JSONDecodeError):
|
|
209
|
+
if isinstance(next_level_content, datetime):
|
|
210
|
+
next_level_content = str(next_level_content)
|
|
207
211
|
next_level_content = json.loads(next_level_content)
|
|
208
212
|
if isinstance(next_level_content, dict):
|
|
209
213
|
sf.datatype = merge_json_schema(
|
|
@@ -235,6 +239,9 @@ def merge_row_schema(
|
|
|
235
239
|
sf.datatype.element_type,
|
|
236
240
|
dropFieldIfAllNull,
|
|
237
241
|
)
|
|
242
|
+
elif isinstance(sf.datatype, TimestampType):
|
|
243
|
+
sf.datatype = StringType()
|
|
244
|
+
columns_with_valid_contents.add(col_name)
|
|
238
245
|
elif row[col_name] is not None:
|
|
239
246
|
columns_with_valid_contents.add(col_name)
|
|
240
247
|
|
|
@@ -265,7 +272,7 @@ def construct_dataframe_by_schema(
|
|
|
265
272
|
rows: typing.Iterator[Row],
|
|
266
273
|
session: snowpark.Session,
|
|
267
274
|
snowpark_options: dict,
|
|
268
|
-
batch_size: int =
|
|
275
|
+
batch_size: int = 1000,
|
|
269
276
|
) -> snowpark.DataFrame:
|
|
270
277
|
result = None
|
|
271
278
|
|
|
@@ -280,6 +287,8 @@ def construct_dataframe_by_schema(
|
|
|
280
287
|
session,
|
|
281
288
|
)
|
|
282
289
|
|
|
290
|
+
current_data = []
|
|
291
|
+
|
|
283
292
|
if len(current_data) > 0:
|
|
284
293
|
result = union_data_into_df(
|
|
285
294
|
result,
|
|
@@ -288,6 +297,8 @@ def construct_dataframe_by_schema(
|
|
|
288
297
|
session,
|
|
289
298
|
)
|
|
290
299
|
|
|
300
|
+
current_data = []
|
|
301
|
+
|
|
291
302
|
if result is None:
|
|
292
303
|
raise ValueError("Dataframe cannot be empty")
|
|
293
304
|
return result
|