snowpark-connect 0.33.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/column_name_handler.py +42 -56
- snowflake/snowpark_connect/config.py +9 -0
- snowflake/snowpark_connect/expression/literal.py +12 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +6 -0
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +147 -63
- snowflake/snowpark_connect/expression/map_unresolved_function.py +31 -28
- snowflake/snowpark_connect/relation/map_aggregate.py +156 -255
- snowflake/snowpark_connect/relation/map_column_ops.py +14 -0
- snowflake/snowpark_connect/relation/map_join.py +364 -234
- snowflake/snowpark_connect/relation/map_sql.py +309 -150
- snowflake/snowpark_connect/relation/read/map_read.py +9 -1
- snowflake/snowpark_connect/relation/read/map_read_csv.py +19 -2
- snowflake/snowpark_connect/relation/read/map_read_json.py +3 -0
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +3 -0
- snowflake/snowpark_connect/relation/read/map_read_text.py +4 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
- snowflake/snowpark_connect/relation/read/utils.py +41 -0
- snowflake/snowpark_connect/relation/utils.py +4 -2
- snowflake/snowpark_connect/relation/write/map_write.py +65 -17
- snowflake/snowpark_connect/utils/context.py +0 -14
- snowflake/snowpark_connect/utils/expression_transformer.py +163 -0
- snowflake/snowpark_connect/utils/session.py +0 -4
- snowflake/snowpark_connect/utils/udf_helper.py +1 -0
- snowflake/snowpark_connect/utils/udtf_helper.py +3 -0
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/METADATA +2 -2
- {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/RECORD +35 -38
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +0 -16
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +0 -1281
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +0 -203
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +0 -202
- {snowpark_connect-0.33.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.33.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.33.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -25,6 +25,7 @@ import snowflake.snowpark.functions as snowpark_fn
|
|
|
25
25
|
import snowflake.snowpark_connect.proto.snowflake_expression_ext_pb2 as snowflake_exp_proto
|
|
26
26
|
import snowflake.snowpark_connect.proto.snowflake_relation_ext_pb2 as snowflake_proto
|
|
27
27
|
from snowflake import snowpark
|
|
28
|
+
from snowflake.snowpark import Session
|
|
28
29
|
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
29
30
|
quote_name_without_upper_casing,
|
|
30
31
|
unquote_if_quoted,
|
|
@@ -155,6 +156,48 @@ def _push_cte_scope():
|
|
|
155
156
|
_cte_definitions.reset(def_token)
|
|
156
157
|
|
|
157
158
|
|
|
159
|
+
def _process_cte_relations(cte_relations):
|
|
160
|
+
"""
|
|
161
|
+
Process CTE relations and register them in the current CTE scope.
|
|
162
|
+
|
|
163
|
+
This function extracts CTE definitions from CTE relations,
|
|
164
|
+
maps them to protobuf representations, and stores them for later reference.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
cte_relations: Java list of CTE relations (tuples of name and SubqueryAlias)
|
|
168
|
+
"""
|
|
169
|
+
for cte in as_java_list(cte_relations):
|
|
170
|
+
name = str(cte._1())
|
|
171
|
+
# Store the original CTE definition for re-evaluation
|
|
172
|
+
_cte_definitions.get()[name] = cte._2()
|
|
173
|
+
# Process CTE definition with a unique plan_id to ensure proper column naming
|
|
174
|
+
# Clear HAVING condition before processing each CTE to prevent leakage between CTEs
|
|
175
|
+
saved_having = _having_condition.get()
|
|
176
|
+
_having_condition.set(None)
|
|
177
|
+
try:
|
|
178
|
+
cte_plan_id = gen_sql_plan_id()
|
|
179
|
+
cte_proto = map_logical_plan_relation(cte._2(), cte_plan_id)
|
|
180
|
+
_ctes.get()[name] = cte_proto
|
|
181
|
+
finally:
|
|
182
|
+
_having_condition.set(saved_having)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@contextmanager
|
|
186
|
+
def _with_cte_scope(cte_relations):
|
|
187
|
+
"""
|
|
188
|
+
Context manager that creates a CTE scope and processes CTE relations.
|
|
189
|
+
|
|
190
|
+
This combines _push_cte_scope() and _process_cte_relations() to handle
|
|
191
|
+
the common pattern of processing CTEs within a new scope.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
cte_relations: Java list of CTE relations (tuples of name and SubqueryAlias)
|
|
195
|
+
"""
|
|
196
|
+
with (_push_cte_scope()):
|
|
197
|
+
_process_cte_relations(cte_relations)
|
|
198
|
+
yield
|
|
199
|
+
|
|
200
|
+
|
|
158
201
|
@contextmanager
|
|
159
202
|
def _push_window_specs_scope():
|
|
160
203
|
"""
|
|
@@ -261,6 +304,130 @@ def _create_table_as_select(logical_plan, mode: str) -> None:
|
|
|
261
304
|
)
|
|
262
305
|
|
|
263
306
|
|
|
307
|
+
def _insert_into_table(logical_plan, session: Session) -> None:
|
|
308
|
+
df_container = execute_logical_plan(logical_plan.query())
|
|
309
|
+
df = df_container.dataframe
|
|
310
|
+
queries = df.queries["queries"]
|
|
311
|
+
if len(queries) != 1:
|
|
312
|
+
exception = SnowparkConnectNotImplementedError(
|
|
313
|
+
f"Unexpected number of queries: {len(queries)}"
|
|
314
|
+
)
|
|
315
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
316
|
+
raise exception
|
|
317
|
+
|
|
318
|
+
name = get_relation_identifier_name(logical_plan.table(), True)
|
|
319
|
+
|
|
320
|
+
user_columns = [
|
|
321
|
+
spark_to_sf_single_id(str(col), is_column=True)
|
|
322
|
+
for col in as_java_list(logical_plan.userSpecifiedCols())
|
|
323
|
+
]
|
|
324
|
+
overwrite_str = "OVERWRITE" if logical_plan.overwrite() else ""
|
|
325
|
+
cols_str = "(" + ", ".join(user_columns) + ")" if user_columns else ""
|
|
326
|
+
|
|
327
|
+
# Extract partition spec if any
|
|
328
|
+
partition_spec = logical_plan.partitionSpec()
|
|
329
|
+
partition_map = as_java_map(partition_spec)
|
|
330
|
+
|
|
331
|
+
partition_columns = {}
|
|
332
|
+
for entry in partition_map.entrySet():
|
|
333
|
+
col_name = str(entry.getKey())
|
|
334
|
+
value_option = entry.getValue()
|
|
335
|
+
if value_option.isDefined():
|
|
336
|
+
partition_columns[col_name] = value_option.get()
|
|
337
|
+
|
|
338
|
+
# Add partition columns to the dataframe
|
|
339
|
+
if partition_columns:
|
|
340
|
+
"""
|
|
341
|
+
Spark sends them in the partition spec and the values won't be present in the values array.
|
|
342
|
+
As snowflake does not support static partitions in INSERT INTO statements,
|
|
343
|
+
we need to add the partition columns to the dataframe as literal columns.
|
|
344
|
+
|
|
345
|
+
ex: INSERT INTO TABLE test_table PARTITION (ds='2021-01-01', hr=10) VALUES ('k1', 100), ('k2', 200), ('k3', 300)
|
|
346
|
+
|
|
347
|
+
Spark sends: VALUES ('k1', 100), ('k2', 200), ('k3', 300) with partition spec (ds='2021-01-01', hr=10)
|
|
348
|
+
Snowflake expects: VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
|
|
349
|
+
|
|
350
|
+
We need to add the partition columns to the dataframe as literal columns.
|
|
351
|
+
|
|
352
|
+
ex: df = df.withColumn('ds', snowpark_fn.lit('2021-01-01'))
|
|
353
|
+
df = df.withColumn('hr', snowpark_fn.lit(10))
|
|
354
|
+
|
|
355
|
+
Then the final query will be:
|
|
356
|
+
INSERT INTO TABLE test_table VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
|
|
357
|
+
"""
|
|
358
|
+
for partition_col, partition_value in partition_columns.items():
|
|
359
|
+
df = df.withColumn(partition_col, snowpark_fn.lit(partition_value))
|
|
360
|
+
|
|
361
|
+
target_table = session.table(name)
|
|
362
|
+
target_schema = target_table.schema
|
|
363
|
+
|
|
364
|
+
expected_number_of_columns = (
|
|
365
|
+
len(user_columns) if user_columns else len(target_schema.fields)
|
|
366
|
+
)
|
|
367
|
+
if expected_number_of_columns != len(df.schema.fields):
|
|
368
|
+
reason = (
|
|
369
|
+
"too many data columns"
|
|
370
|
+
if len(df.schema.fields) > expected_number_of_columns
|
|
371
|
+
else "not enough data columns"
|
|
372
|
+
)
|
|
373
|
+
exception = AnalysisException(
|
|
374
|
+
f'[INSERT_COLUMN_ARITY_MISMATCH.{reason.replace(" ", "_").upper()}] Cannot write to {name}, the reason is {reason}:\n'
|
|
375
|
+
f'Table columns: {", ".join(target_schema.names)}.\n'
|
|
376
|
+
f'Data columns: {", ".join(df.schema.names)}.'
|
|
377
|
+
)
|
|
378
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
|
|
379
|
+
raise exception
|
|
380
|
+
|
|
381
|
+
try:
|
|
382
|
+
# Modify df with type conversions and struct field name mapping
|
|
383
|
+
modified_columns = []
|
|
384
|
+
for source_field, target_field in zip(df.schema.fields, target_schema.fields):
|
|
385
|
+
col_name = source_field.name
|
|
386
|
+
|
|
387
|
+
# Handle different type conversions
|
|
388
|
+
if isinstance(
|
|
389
|
+
target_field.datatype, snowpark.types.DecimalType
|
|
390
|
+
) and isinstance(
|
|
391
|
+
source_field.datatype,
|
|
392
|
+
(snowpark.types.FloatType, snowpark.types.DoubleType),
|
|
393
|
+
):
|
|
394
|
+
# Add CASE WHEN to convert NaN to NULL for DECIMAL targets
|
|
395
|
+
# Only apply this to floating-point source columns
|
|
396
|
+
modified_col = (
|
|
397
|
+
snowpark_fn.when(
|
|
398
|
+
snowpark_fn.equal_nan(snowpark_fn.col(col_name)),
|
|
399
|
+
snowpark_fn.lit(None),
|
|
400
|
+
)
|
|
401
|
+
.otherwise(snowpark_fn.col(col_name))
|
|
402
|
+
.alias(col_name)
|
|
403
|
+
)
|
|
404
|
+
modified_columns.append(modified_col)
|
|
405
|
+
elif (
|
|
406
|
+
isinstance(target_field.datatype, snowpark.types.StructType)
|
|
407
|
+
and source_field.datatype != target_field.datatype
|
|
408
|
+
):
|
|
409
|
+
# Cast struct with field name mapping (e.g., col1,col2 -> i1,i2)
|
|
410
|
+
# This fixes INSERT INTO table with struct literals like (2, 3)
|
|
411
|
+
modified_col = (
|
|
412
|
+
snowpark_fn.col(col_name)
|
|
413
|
+
.cast(target_field.datatype, rename_fields=True)
|
|
414
|
+
.alias(col_name)
|
|
415
|
+
)
|
|
416
|
+
modified_columns.append(modified_col)
|
|
417
|
+
else:
|
|
418
|
+
modified_columns.append(snowpark_fn.col(col_name))
|
|
419
|
+
|
|
420
|
+
df = df.select(modified_columns)
|
|
421
|
+
except Exception:
|
|
422
|
+
pass
|
|
423
|
+
|
|
424
|
+
queries = df.queries["queries"]
|
|
425
|
+
final_query = queries[0]
|
|
426
|
+
session.sql(
|
|
427
|
+
f"INSERT {overwrite_str} INTO {name} {cols_str} {final_query}",
|
|
428
|
+
).collect()
|
|
429
|
+
|
|
430
|
+
|
|
264
431
|
def _spark_field_to_sql(field: jpype.JObject, is_column: bool) -> str:
|
|
265
432
|
# Column names will be uppercased according to "snowpark.connect.sql.identifiers.auto-uppercase"
|
|
266
433
|
# if present, or to "spark.sql.caseSensitive".
|
|
@@ -595,10 +762,30 @@ def map_sql_to_pandas_df(
|
|
|
595
762
|
|
|
596
763
|
spark_view_name = next(parsed_sql.find_all(sqlglot.exp.Table)).name
|
|
597
764
|
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
765
|
+
# extract ONLY top-level column definitions (not nested struct fields)
|
|
766
|
+
column_defs = []
|
|
767
|
+
schema_node = next(parsed_sql.find_all(sqlglot.exp.Schema), None)
|
|
768
|
+
if schema_node:
|
|
769
|
+
for expr in schema_node.expressions:
|
|
770
|
+
if isinstance(expr, sqlglot.exp.ColumnDef):
|
|
771
|
+
column_defs.append(expr)
|
|
772
|
+
|
|
773
|
+
num_columns = len(column_defs)
|
|
774
|
+
if num_columns > 0:
|
|
775
|
+
null_list_parts = []
|
|
776
|
+
for col_def in column_defs:
|
|
777
|
+
col_name = spark_to_sf_single_id(col_def.name, is_column=True)
|
|
778
|
+
col_type = col_def.kind
|
|
779
|
+
if col_type:
|
|
780
|
+
null_list_parts.append(
|
|
781
|
+
f"CAST(NULL AS {col_type.sql(dialect='snowflake')}) AS {col_name}"
|
|
782
|
+
)
|
|
783
|
+
else:
|
|
784
|
+
null_list_parts.append(f"NULL AS {col_name}")
|
|
785
|
+
null_list = ", ".join(null_list_parts)
|
|
786
|
+
else:
|
|
787
|
+
null_list = "*"
|
|
788
|
+
|
|
602
789
|
empty_select = (
|
|
603
790
|
f" AS SELECT {null_list} WHERE 1 = 0"
|
|
604
791
|
if logical_plan.options().isEmpty()
|
|
@@ -862,133 +1049,7 @@ def map_sql_to_pandas_df(
|
|
|
862
1049
|
)
|
|
863
1050
|
raise exception
|
|
864
1051
|
case "InsertIntoStatement":
|
|
865
|
-
|
|
866
|
-
df = df_container.dataframe
|
|
867
|
-
queries = df.queries["queries"]
|
|
868
|
-
if len(queries) != 1:
|
|
869
|
-
exception = SnowparkConnectNotImplementedError(
|
|
870
|
-
f"Unexpected number of queries: {len(queries)}"
|
|
871
|
-
)
|
|
872
|
-
attach_custom_error_code(
|
|
873
|
-
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
874
|
-
)
|
|
875
|
-
raise exception
|
|
876
|
-
|
|
877
|
-
name = get_relation_identifier_name(logical_plan.table(), True)
|
|
878
|
-
|
|
879
|
-
user_columns = [
|
|
880
|
-
spark_to_sf_single_id(str(col), is_column=True)
|
|
881
|
-
for col in as_java_list(logical_plan.userSpecifiedCols())
|
|
882
|
-
]
|
|
883
|
-
overwrite_str = "OVERWRITE" if logical_plan.overwrite() else ""
|
|
884
|
-
cols_str = "(" + ", ".join(user_columns) + ")" if user_columns else ""
|
|
885
|
-
|
|
886
|
-
# Extract partition spec if any
|
|
887
|
-
partition_spec = logical_plan.partitionSpec()
|
|
888
|
-
partition_map = as_java_map(partition_spec)
|
|
889
|
-
|
|
890
|
-
partition_columns = {}
|
|
891
|
-
for entry in partition_map.entrySet():
|
|
892
|
-
col_name = str(entry.getKey())
|
|
893
|
-
value_option = entry.getValue()
|
|
894
|
-
if value_option.isDefined():
|
|
895
|
-
partition_columns[col_name] = value_option.get()
|
|
896
|
-
|
|
897
|
-
# Add partition columns to the dataframe
|
|
898
|
-
if partition_columns:
|
|
899
|
-
"""
|
|
900
|
-
Spark sends them in the partition spec and the values won't be present in the values array.
|
|
901
|
-
As snowflake does not support static partitions in INSERT INTO statements,
|
|
902
|
-
we need to add the partition columns to the dataframe as literal columns.
|
|
903
|
-
|
|
904
|
-
ex: INSERT INTO TABLE test_table PARTITION (ds='2021-01-01', hr=10) VALUES ('k1', 100), ('k2', 200), ('k3', 300)
|
|
905
|
-
|
|
906
|
-
Spark sends: VALUES ('k1', 100), ('k2', 200), ('k3', 300) with partition spec (ds='2021-01-01', hr=10)
|
|
907
|
-
Snowflake expects: VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
|
|
908
|
-
|
|
909
|
-
We need to add the partition columns to the dataframe as literal columns.
|
|
910
|
-
|
|
911
|
-
ex: df = df.withColumn('ds', snowpark_fn.lit('2021-01-01'))
|
|
912
|
-
df = df.withColumn('hr', snowpark_fn.lit(10))
|
|
913
|
-
|
|
914
|
-
Then the final query will be:
|
|
915
|
-
INSERT INTO TABLE test_table VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
|
|
916
|
-
"""
|
|
917
|
-
for partition_col, partition_value in partition_columns.items():
|
|
918
|
-
df = df.withColumn(
|
|
919
|
-
partition_col, snowpark_fn.lit(partition_value)
|
|
920
|
-
)
|
|
921
|
-
|
|
922
|
-
target_table = session.table(name)
|
|
923
|
-
target_schema = target_table.schema
|
|
924
|
-
|
|
925
|
-
expected_number_of_columns = (
|
|
926
|
-
len(user_columns) if user_columns else len(target_schema.fields)
|
|
927
|
-
)
|
|
928
|
-
if expected_number_of_columns != len(df.schema.fields):
|
|
929
|
-
reason = (
|
|
930
|
-
"too many data columns"
|
|
931
|
-
if len(df.schema.fields) > expected_number_of_columns
|
|
932
|
-
else "not enough data columns"
|
|
933
|
-
)
|
|
934
|
-
exception = AnalysisException(
|
|
935
|
-
f'[INSERT_COLUMN_ARITY_MISMATCH.{reason.replace(" ", "_").upper()}] Cannot write to {name}, the reason is {reason}:\n'
|
|
936
|
-
f'Table columns: {", ".join(target_schema.names)}.\n'
|
|
937
|
-
f'Data columns: {", ".join(df.schema.names)}.'
|
|
938
|
-
)
|
|
939
|
-
attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
|
|
940
|
-
raise exception
|
|
941
|
-
|
|
942
|
-
try:
|
|
943
|
-
# Modify df with type conversions and struct field name mapping
|
|
944
|
-
modified_columns = []
|
|
945
|
-
for source_field, target_field in zip(
|
|
946
|
-
df.schema.fields, target_schema.fields
|
|
947
|
-
):
|
|
948
|
-
col_name = source_field.name
|
|
949
|
-
|
|
950
|
-
# Handle different type conversions
|
|
951
|
-
if isinstance(
|
|
952
|
-
target_field.datatype, snowpark.types.DecimalType
|
|
953
|
-
) and isinstance(
|
|
954
|
-
source_field.datatype,
|
|
955
|
-
(snowpark.types.FloatType, snowpark.types.DoubleType),
|
|
956
|
-
):
|
|
957
|
-
# Add CASE WHEN to convert NaN to NULL for DECIMAL targets
|
|
958
|
-
# Only apply this to floating-point source columns
|
|
959
|
-
modified_col = (
|
|
960
|
-
snowpark_fn.when(
|
|
961
|
-
snowpark_fn.equal_nan(snowpark_fn.col(col_name)),
|
|
962
|
-
snowpark_fn.lit(None),
|
|
963
|
-
)
|
|
964
|
-
.otherwise(snowpark_fn.col(col_name))
|
|
965
|
-
.alias(col_name)
|
|
966
|
-
)
|
|
967
|
-
modified_columns.append(modified_col)
|
|
968
|
-
elif (
|
|
969
|
-
isinstance(target_field.datatype, snowpark.types.StructType)
|
|
970
|
-
and source_field.datatype != target_field.datatype
|
|
971
|
-
):
|
|
972
|
-
# Cast struct with field name mapping (e.g., col1,col2 -> i1,i2)
|
|
973
|
-
# This fixes INSERT INTO table with struct literals like (2, 3)
|
|
974
|
-
modified_col = (
|
|
975
|
-
snowpark_fn.col(col_name)
|
|
976
|
-
.cast(target_field.datatype, rename_fields=True)
|
|
977
|
-
.alias(col_name)
|
|
978
|
-
)
|
|
979
|
-
modified_columns.append(modified_col)
|
|
980
|
-
else:
|
|
981
|
-
modified_columns.append(snowpark_fn.col(col_name))
|
|
982
|
-
|
|
983
|
-
df = df.select(modified_columns)
|
|
984
|
-
except Exception:
|
|
985
|
-
pass
|
|
986
|
-
|
|
987
|
-
queries = df.queries["queries"]
|
|
988
|
-
final_query = queries[0]
|
|
989
|
-
session.sql(
|
|
990
|
-
f"INSERT {overwrite_str} INTO {name} {cols_str} {final_query}",
|
|
991
|
-
).collect()
|
|
1052
|
+
_insert_into_table(logical_plan, session)
|
|
992
1053
|
case "MergeIntoTable":
|
|
993
1054
|
source_df_container = map_relation(
|
|
994
1055
|
map_logical_plan_relation(logical_plan.sourceTable())
|
|
@@ -1445,6 +1506,16 @@ def map_sql_to_pandas_df(
|
|
|
1445
1506
|
raise exception
|
|
1446
1507
|
|
|
1447
1508
|
return pandas.DataFrame({"": [""]}), ""
|
|
1509
|
+
case "UnresolvedWith":
|
|
1510
|
+
child = logical_plan.child()
|
|
1511
|
+
child_class = str(child.getClass().getSimpleName())
|
|
1512
|
+
match child_class:
|
|
1513
|
+
case "InsertIntoStatement":
|
|
1514
|
+
with _with_cte_scope(logical_plan.cteRelations()):
|
|
1515
|
+
_insert_into_table(child, get_or_create_snowpark_session())
|
|
1516
|
+
case _:
|
|
1517
|
+
execute_logical_plan(logical_plan)
|
|
1518
|
+
return None, None
|
|
1448
1519
|
case _:
|
|
1449
1520
|
execute_logical_plan(logical_plan)
|
|
1450
1521
|
return None, None
|
|
@@ -1923,13 +1994,104 @@ def map_logical_plan_relation(
|
|
|
1923
1994
|
)
|
|
1924
1995
|
)
|
|
1925
1996
|
case "Sort":
|
|
1997
|
+
# Process the input first
|
|
1998
|
+
input_proto = map_logical_plan_relation(rel.child())
|
|
1999
|
+
|
|
2000
|
+
# Check if child is a Project - if so, build an alias map for ORDER BY resolution
|
|
2001
|
+
# This handles: SELECT o.date AS order_date ... ORDER BY o.date
|
|
2002
|
+
child_class = str(rel.child().getClass().getSimpleName())
|
|
2003
|
+
alias_map = {}
|
|
2004
|
+
|
|
2005
|
+
if child_class == "Project":
|
|
2006
|
+
# Extract aliases from SELECT clause
|
|
2007
|
+
for proj_expr in list(as_java_list(rel.child().projectList())):
|
|
2008
|
+
if str(proj_expr.getClass().getSimpleName()) == "Alias":
|
|
2009
|
+
alias_name = str(proj_expr.name())
|
|
2010
|
+
child_expr = proj_expr.child()
|
|
2011
|
+
|
|
2012
|
+
# Store mapping from original expression to alias name
|
|
2013
|
+
# Use string representation for matching
|
|
2014
|
+
expr_str = str(child_expr)
|
|
2015
|
+
alias_map[expr_str] = alias_name
|
|
2016
|
+
|
|
2017
|
+
# Also handle UnresolvedAttribute specifically to get the qualified name
|
|
2018
|
+
if (
|
|
2019
|
+
str(child_expr.getClass().getSimpleName())
|
|
2020
|
+
== "UnresolvedAttribute"
|
|
2021
|
+
):
|
|
2022
|
+
# Get the qualified name like "o.date"
|
|
2023
|
+
name_parts = list(as_java_list(child_expr.nameParts()))
|
|
2024
|
+
qualified_name = ".".join(str(part) for part in name_parts)
|
|
2025
|
+
if qualified_name not in alias_map:
|
|
2026
|
+
alias_map[qualified_name] = alias_name
|
|
2027
|
+
|
|
2028
|
+
# Process ORDER BY expressions, substituting aliases where needed
|
|
2029
|
+
order_list = []
|
|
2030
|
+
for order_expr in as_java_list(rel.order()):
|
|
2031
|
+
# Get the child expression from the SortOrder
|
|
2032
|
+
child_expr = order_expr.child()
|
|
2033
|
+
expr_class = str(child_expr.getClass().getSimpleName())
|
|
2034
|
+
|
|
2035
|
+
# Check if this expression matches any aliased expression
|
|
2036
|
+
expr_str = str(child_expr)
|
|
2037
|
+
substituted = False
|
|
2038
|
+
|
|
2039
|
+
if expr_str in alias_map:
|
|
2040
|
+
# Found a match - substitute with alias reference
|
|
2041
|
+
alias_name = alias_map[expr_str]
|
|
2042
|
+
# Create new UnresolvedAttribute for the alias
|
|
2043
|
+
UnresolvedAttribute = jpype.JClass(
|
|
2044
|
+
"org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute"
|
|
2045
|
+
)
|
|
2046
|
+
new_attr = UnresolvedAttribute.quoted(alias_name)
|
|
2047
|
+
|
|
2048
|
+
# Create new SortOrder with substituted expression
|
|
2049
|
+
SortOrder = jpype.JClass(
|
|
2050
|
+
"org.apache.spark.sql.catalyst.expressions.SortOrder"
|
|
2051
|
+
)
|
|
2052
|
+
new_order = SortOrder(
|
|
2053
|
+
new_attr,
|
|
2054
|
+
order_expr.direction(),
|
|
2055
|
+
order_expr.nullOrdering(),
|
|
2056
|
+
order_expr.sameOrderExpressions(),
|
|
2057
|
+
)
|
|
2058
|
+
order_list.append(map_logical_plan_expression(new_order).sort_order)
|
|
2059
|
+
substituted = True
|
|
2060
|
+
elif expr_class == "UnresolvedAttribute":
|
|
2061
|
+
# Try matching on qualified name
|
|
2062
|
+
name_parts = list(as_java_list(child_expr.nameParts()))
|
|
2063
|
+
qualified_name = ".".join(str(part) for part in name_parts)
|
|
2064
|
+
if qualified_name in alias_map:
|
|
2065
|
+
alias_name = alias_map[qualified_name]
|
|
2066
|
+
UnresolvedAttribute = jpype.JClass(
|
|
2067
|
+
"org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute"
|
|
2068
|
+
)
|
|
2069
|
+
new_attr = UnresolvedAttribute.quoted(alias_name)
|
|
2070
|
+
|
|
2071
|
+
SortOrder = jpype.JClass(
|
|
2072
|
+
"org.apache.spark.sql.catalyst.expressions.SortOrder"
|
|
2073
|
+
)
|
|
2074
|
+
new_order = SortOrder(
|
|
2075
|
+
new_attr,
|
|
2076
|
+
order_expr.direction(),
|
|
2077
|
+
order_expr.nullOrdering(),
|
|
2078
|
+
order_expr.sameOrderExpressions(),
|
|
2079
|
+
)
|
|
2080
|
+
order_list.append(
|
|
2081
|
+
map_logical_plan_expression(new_order).sort_order
|
|
2082
|
+
)
|
|
2083
|
+
substituted = True
|
|
2084
|
+
|
|
2085
|
+
if not substituted:
|
|
2086
|
+
# No substitution needed - use original
|
|
2087
|
+
order_list.append(
|
|
2088
|
+
map_logical_plan_expression(order_expr).sort_order
|
|
2089
|
+
)
|
|
2090
|
+
|
|
1926
2091
|
proto = relation_proto.Relation(
|
|
1927
2092
|
sort=relation_proto.Sort(
|
|
1928
|
-
input=
|
|
1929
|
-
order=
|
|
1930
|
-
map_logical_plan_expression(e).sort_order
|
|
1931
|
-
for e in as_java_list(rel.order())
|
|
1932
|
-
],
|
|
2093
|
+
input=input_proto,
|
|
2094
|
+
order=order_list,
|
|
1933
2095
|
)
|
|
1934
2096
|
)
|
|
1935
2097
|
case "SubqueryAlias":
|
|
@@ -2116,10 +2278,16 @@ def map_logical_plan_relation(
|
|
|
2116
2278
|
)
|
|
2117
2279
|
|
|
2118
2280
|
# Re-evaluate the CTE definition with a fresh plan_id
|
|
2119
|
-
|
|
2120
|
-
|
|
2121
|
-
|
|
2122
|
-
|
|
2281
|
+
# Clear HAVING condition to prevent leakage from outer CTEs
|
|
2282
|
+
saved_having = _having_condition.get()
|
|
2283
|
+
_having_condition.set(None)
|
|
2284
|
+
try:
|
|
2285
|
+
fresh_plan_id = gen_sql_plan_id()
|
|
2286
|
+
fresh_cte_proto = map_logical_plan_relation(
|
|
2287
|
+
cte_definition, fresh_plan_id
|
|
2288
|
+
)
|
|
2289
|
+
finally:
|
|
2290
|
+
_having_condition.set(saved_having)
|
|
2123
2291
|
|
|
2124
2292
|
# Use SubqueryColumnAliases to ensure consistent column names across CTE references
|
|
2125
2293
|
# This is crucial for CTEs that reference other CTEs
|
|
@@ -2274,16 +2442,7 @@ def map_logical_plan_relation(
|
|
|
2274
2442
|
),
|
|
2275
2443
|
)
|
|
2276
2444
|
case "UnresolvedWith":
|
|
2277
|
-
with
|
|
2278
|
-
for cte in as_java_list(rel.cteRelations()):
|
|
2279
|
-
name = str(cte._1())
|
|
2280
|
-
# Store the original CTE definition for re-evaluation
|
|
2281
|
-
_cte_definitions.get()[name] = cte._2()
|
|
2282
|
-
# Process CTE definition with a unique plan_id to ensure proper column naming
|
|
2283
|
-
cte_plan_id = gen_sql_plan_id()
|
|
2284
|
-
cte_proto = map_logical_plan_relation(cte._2(), cte_plan_id)
|
|
2285
|
-
_ctes.get()[name] = cte_proto
|
|
2286
|
-
|
|
2445
|
+
with _with_cte_scope(rel.cteRelations()):
|
|
2287
2446
|
proto = map_logical_plan_relation(rel.child())
|
|
2288
2447
|
case "LateralJoin":
|
|
2289
2448
|
left = map_logical_plan_relation(rel.left())
|
|
@@ -225,12 +225,20 @@ def _get_supported_read_file_format(unparsed_identifier: str) -> str | None:
|
|
|
225
225
|
return None
|
|
226
226
|
|
|
227
227
|
|
|
228
|
+
# TODO: [SNOW-2465948] Remove this once Snowpark fixes the issue with stage paths.
|
|
229
|
+
class StagePathStr(str):
|
|
230
|
+
def partition(self, __sep):
|
|
231
|
+
if str(self)[0] == "'":
|
|
232
|
+
return str(self)[1:].partition(__sep)
|
|
233
|
+
return str(self).partition(__sep)
|
|
234
|
+
|
|
235
|
+
|
|
228
236
|
def _quote_stage_path(stage_path: str) -> str:
|
|
229
237
|
"""
|
|
230
238
|
Quote stage paths to escape any special characters.
|
|
231
239
|
"""
|
|
232
240
|
if stage_path.startswith("@"):
|
|
233
|
-
return f"'{stage_path}'"
|
|
241
|
+
return StagePathStr(f"'{stage_path}'")
|
|
234
242
|
return stage_path
|
|
235
243
|
|
|
236
244
|
|
|
@@ -6,6 +6,7 @@ import copy
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
8
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
9
|
+
from pyspark.errors.exceptions.base import AnalysisException
|
|
9
10
|
|
|
10
11
|
import snowflake.snowpark.functions as snowpark_fn
|
|
11
12
|
from snowflake import snowpark
|
|
@@ -20,6 +21,7 @@ from snowflake.snowpark_connect.relation.read.metadata_utils import (
|
|
|
20
21
|
get_non_metadata_fields,
|
|
21
22
|
)
|
|
22
23
|
from snowflake.snowpark_connect.relation.read.utils import (
|
|
24
|
+
apply_metadata_exclusion_pattern,
|
|
23
25
|
get_spark_column_names_from_snowpark_columns,
|
|
24
26
|
rename_columns_as_snowflake_standard,
|
|
25
27
|
)
|
|
@@ -62,6 +64,8 @@ def map_read_csv(
|
|
|
62
64
|
snowpark_read_options["INFER_SCHEMA"] = snowpark_options.get(
|
|
63
65
|
"INFER_SCHEMA", False
|
|
64
66
|
)
|
|
67
|
+
|
|
68
|
+
apply_metadata_exclusion_pattern(snowpark_options)
|
|
65
69
|
snowpark_read_options["PATTERN"] = snowpark_options.get("PATTERN", None)
|
|
66
70
|
|
|
67
71
|
raw_options = rel.read.data_source.options
|
|
@@ -157,6 +161,7 @@ def get_header_names(
|
|
|
157
161
|
path: list[str],
|
|
158
162
|
file_format_options: dict,
|
|
159
163
|
snowpark_read_options: dict,
|
|
164
|
+
raw_options: dict,
|
|
160
165
|
) -> list[str]:
|
|
161
166
|
no_header_file_format_options = copy.copy(file_format_options)
|
|
162
167
|
no_header_file_format_options["PARSE_HEADER"] = False
|
|
@@ -168,7 +173,19 @@ def get_header_names(
|
|
|
168
173
|
no_header_snowpark_read_options.pop("INFER_SCHEMA", None)
|
|
169
174
|
|
|
170
175
|
header_df = session.read.options(no_header_snowpark_read_options).csv(path).limit(1)
|
|
171
|
-
|
|
176
|
+
collected_data = header_df.collect()
|
|
177
|
+
|
|
178
|
+
if len(collected_data) == 0:
|
|
179
|
+
error_msg = f"Path does not exist or contains no data: {path}"
|
|
180
|
+
user_pattern = raw_options.get("pathGlobFilter", None)
|
|
181
|
+
if user_pattern:
|
|
182
|
+
error_msg += f" (with pathGlobFilter: {user_pattern})"
|
|
183
|
+
|
|
184
|
+
exception = AnalysisException(error_msg)
|
|
185
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
|
|
186
|
+
raise exception
|
|
187
|
+
|
|
188
|
+
header_data = collected_data[0]
|
|
172
189
|
return [
|
|
173
190
|
f'"{header_data[i]}"'
|
|
174
191
|
for i in range(len(header_df.schema.fields))
|
|
@@ -207,7 +224,7 @@ def read_data(
|
|
|
207
224
|
return df
|
|
208
225
|
|
|
209
226
|
headers = get_header_names(
|
|
210
|
-
session, path, file_format_options, snowpark_read_options
|
|
227
|
+
session, path, file_format_options, snowpark_read_options, raw_options
|
|
211
228
|
)
|
|
212
229
|
|
|
213
230
|
df_schema_fields = non_metadata_fields
|
|
@@ -35,6 +35,7 @@ from snowflake.snowpark_connect.relation.read.metadata_utils import (
|
|
|
35
35
|
add_filename_metadata_to_reader,
|
|
36
36
|
)
|
|
37
37
|
from snowflake.snowpark_connect.relation.read.utils import (
|
|
38
|
+
apply_metadata_exclusion_pattern,
|
|
38
39
|
get_spark_column_names_from_snowpark_columns,
|
|
39
40
|
rename_columns_as_snowflake_standard,
|
|
40
41
|
)
|
|
@@ -80,6 +81,8 @@ def map_read_json(
|
|
|
80
81
|
dropFieldIfAllNull = snowpark_options.pop("dropfieldifallnull", False)
|
|
81
82
|
batch_size = snowpark_options.pop("batchsize", 1000)
|
|
82
83
|
|
|
84
|
+
apply_metadata_exclusion_pattern(snowpark_options)
|
|
85
|
+
|
|
83
86
|
reader = add_filename_metadata_to_reader(
|
|
84
87
|
session.read.options(snowpark_options), raw_options
|
|
85
88
|
)
|
|
@@ -29,6 +29,7 @@ from snowflake.snowpark_connect.relation.read.metadata_utils import (
|
|
|
29
29
|
)
|
|
30
30
|
from snowflake.snowpark_connect.relation.read.reader_config import ReaderWriterConfig
|
|
31
31
|
from snowflake.snowpark_connect.relation.read.utils import (
|
|
32
|
+
apply_metadata_exclusion_pattern,
|
|
32
33
|
rename_columns_as_snowflake_standard,
|
|
33
34
|
)
|
|
34
35
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
@@ -57,6 +58,8 @@ def map_read_parquet(
|
|
|
57
58
|
assert schema is None, "Read PARQUET does not support user schema"
|
|
58
59
|
assert len(paths) > 0, "Read PARQUET expects at least one path"
|
|
59
60
|
|
|
61
|
+
apply_metadata_exclusion_pattern(snowpark_options)
|
|
62
|
+
|
|
60
63
|
reader = add_filename_metadata_to_reader(
|
|
61
64
|
session.read.options(snowpark_options), raw_options
|
|
62
65
|
)
|
|
@@ -26,6 +26,10 @@ def get_file_paths_from_stage(
|
|
|
26
26
|
) -> typing.List[str]:
|
|
27
27
|
files_paths = []
|
|
28
28
|
for listed_path_row in session.sql(f"LIST {path}").collect():
|
|
29
|
+
# Skip _SUCCESS marker files
|
|
30
|
+
if listed_path_row[0].endswith("_SUCCESS"):
|
|
31
|
+
continue
|
|
32
|
+
|
|
29
33
|
listed_path = listed_path_row[0].split("/")
|
|
30
34
|
if listed_path_row[0].startswith("s3://") or listed_path_row[0].startswith(
|
|
31
35
|
"s3a://"
|
|
@@ -126,6 +126,7 @@ CSV_READ_SUPPORTED_OPTIONS = lowercase_set(
|
|
|
126
126
|
"compression",
|
|
127
127
|
# "escapeQuotes",
|
|
128
128
|
# "quoteAll",
|
|
129
|
+
"rowsToInferSchema", # Snowflake specific option, number of rows to infer schema
|
|
129
130
|
}
|
|
130
131
|
)
|
|
131
132
|
|
|
@@ -201,6 +202,15 @@ def csv_convert_to_snowpark_args(snowpark_config: dict[str, Any]) -> dict[str, A
|
|
|
201
202
|
if snowpark_config["escape"] and snowpark_config["escape"] == "\\":
|
|
202
203
|
snowpark_config["escape"] = "\\\\"
|
|
203
204
|
|
|
205
|
+
# Snowflake specific option, number of rows to infer schema for CSV files
|
|
206
|
+
if "rowstoinferschema" in snowpark_config:
|
|
207
|
+
rows_to_infer_schema = snowpark_config["rowstoinferschema"]
|
|
208
|
+
del snowpark_config["rowstoinferschema"]
|
|
209
|
+
snowpark_config["INFER_SCHEMA_OPTIONS"] = {
|
|
210
|
+
"MAX_RECORDS_PER_FILE": int(rows_to_infer_schema),
|
|
211
|
+
"USE_RELAXED_TYPES": True,
|
|
212
|
+
}
|
|
213
|
+
|
|
204
214
|
# Rename the keys to match the Snowpark configuration.
|
|
205
215
|
for spark_arg, snowpark_arg in renamed_args.items():
|
|
206
216
|
if spark_arg not in snowpark_config:
|