snowpark-connect 0.33.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (39) hide show
  1. snowflake/snowpark_connect/column_name_handler.py +42 -56
  2. snowflake/snowpark_connect/config.py +9 -0
  3. snowflake/snowpark_connect/expression/literal.py +12 -12
  4. snowflake/snowpark_connect/expression/map_sql_expression.py +6 -0
  5. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +147 -63
  6. snowflake/snowpark_connect/expression/map_unresolved_function.py +31 -28
  7. snowflake/snowpark_connect/relation/map_aggregate.py +156 -255
  8. snowflake/snowpark_connect/relation/map_column_ops.py +14 -0
  9. snowflake/snowpark_connect/relation/map_join.py +364 -234
  10. snowflake/snowpark_connect/relation/map_sql.py +309 -150
  11. snowflake/snowpark_connect/relation/read/map_read.py +9 -1
  12. snowflake/snowpark_connect/relation/read/map_read_csv.py +19 -2
  13. snowflake/snowpark_connect/relation/read/map_read_json.py +3 -0
  14. snowflake/snowpark_connect/relation/read/map_read_parquet.py +3 -0
  15. snowflake/snowpark_connect/relation/read/map_read_text.py +4 -0
  16. snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
  17. snowflake/snowpark_connect/relation/read/utils.py +41 -0
  18. snowflake/snowpark_connect/relation/utils.py +4 -2
  19. snowflake/snowpark_connect/relation/write/map_write.py +65 -17
  20. snowflake/snowpark_connect/utils/context.py +0 -14
  21. snowflake/snowpark_connect/utils/expression_transformer.py +163 -0
  22. snowflake/snowpark_connect/utils/session.py +0 -4
  23. snowflake/snowpark_connect/utils/udf_helper.py +1 -0
  24. snowflake/snowpark_connect/utils/udtf_helper.py +3 -0
  25. snowflake/snowpark_connect/version.py +1 -1
  26. {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/METADATA +2 -2
  27. {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/RECORD +35 -38
  28. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +0 -16
  29. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +0 -1281
  30. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +0 -203
  31. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +0 -202
  32. {snowpark_connect-0.33.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-connect +0 -0
  33. {snowpark_connect-0.33.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-session +0 -0
  34. {snowpark_connect-0.33.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-submit +0 -0
  35. {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/WHEEL +0 -0
  36. {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE-binary +0 -0
  37. {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE.txt +0 -0
  38. {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/NOTICE-binary +0 -0
  39. {snowpark_connect-0.33.0.dist-info → snowpark_connect-1.0.0.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ import snowflake.snowpark.functions as snowpark_fn
25
25
  import snowflake.snowpark_connect.proto.snowflake_expression_ext_pb2 as snowflake_exp_proto
26
26
  import snowflake.snowpark_connect.proto.snowflake_relation_ext_pb2 as snowflake_proto
27
27
  from snowflake import snowpark
28
+ from snowflake.snowpark import Session
28
29
  from snowflake.snowpark._internal.analyzer.analyzer_utils import (
29
30
  quote_name_without_upper_casing,
30
31
  unquote_if_quoted,
@@ -155,6 +156,48 @@ def _push_cte_scope():
155
156
  _cte_definitions.reset(def_token)
156
157
 
157
158
 
159
+ def _process_cte_relations(cte_relations):
160
+ """
161
+ Process CTE relations and register them in the current CTE scope.
162
+
163
+ This function extracts CTE definitions from CTE relations,
164
+ maps them to protobuf representations, and stores them for later reference.
165
+
166
+ Args:
167
+ cte_relations: Java list of CTE relations (tuples of name and SubqueryAlias)
168
+ """
169
+ for cte in as_java_list(cte_relations):
170
+ name = str(cte._1())
171
+ # Store the original CTE definition for re-evaluation
172
+ _cte_definitions.get()[name] = cte._2()
173
+ # Process CTE definition with a unique plan_id to ensure proper column naming
174
+ # Clear HAVING condition before processing each CTE to prevent leakage between CTEs
175
+ saved_having = _having_condition.get()
176
+ _having_condition.set(None)
177
+ try:
178
+ cte_plan_id = gen_sql_plan_id()
179
+ cte_proto = map_logical_plan_relation(cte._2(), cte_plan_id)
180
+ _ctes.get()[name] = cte_proto
181
+ finally:
182
+ _having_condition.set(saved_having)
183
+
184
+
185
+ @contextmanager
186
+ def _with_cte_scope(cte_relations):
187
+ """
188
+ Context manager that creates a CTE scope and processes CTE relations.
189
+
190
+ This combines _push_cte_scope() and _process_cte_relations() to handle
191
+ the common pattern of processing CTEs within a new scope.
192
+
193
+ Args:
194
+ cte_relations: Java list of CTE relations (tuples of name and SubqueryAlias)
195
+ """
196
+ with (_push_cte_scope()):
197
+ _process_cte_relations(cte_relations)
198
+ yield
199
+
200
+
158
201
  @contextmanager
159
202
  def _push_window_specs_scope():
160
203
  """
@@ -261,6 +304,130 @@ def _create_table_as_select(logical_plan, mode: str) -> None:
261
304
  )
262
305
 
263
306
 
307
+ def _insert_into_table(logical_plan, session: Session) -> None:
308
+ df_container = execute_logical_plan(logical_plan.query())
309
+ df = df_container.dataframe
310
+ queries = df.queries["queries"]
311
+ if len(queries) != 1:
312
+ exception = SnowparkConnectNotImplementedError(
313
+ f"Unexpected number of queries: {len(queries)}"
314
+ )
315
+ attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
316
+ raise exception
317
+
318
+ name = get_relation_identifier_name(logical_plan.table(), True)
319
+
320
+ user_columns = [
321
+ spark_to_sf_single_id(str(col), is_column=True)
322
+ for col in as_java_list(logical_plan.userSpecifiedCols())
323
+ ]
324
+ overwrite_str = "OVERWRITE" if logical_plan.overwrite() else ""
325
+ cols_str = "(" + ", ".join(user_columns) + ")" if user_columns else ""
326
+
327
+ # Extract partition spec if any
328
+ partition_spec = logical_plan.partitionSpec()
329
+ partition_map = as_java_map(partition_spec)
330
+
331
+ partition_columns = {}
332
+ for entry in partition_map.entrySet():
333
+ col_name = str(entry.getKey())
334
+ value_option = entry.getValue()
335
+ if value_option.isDefined():
336
+ partition_columns[col_name] = value_option.get()
337
+
338
+ # Add partition columns to the dataframe
339
+ if partition_columns:
340
+ """
341
+ Spark sends them in the partition spec and the values won't be present in the values array.
342
+ As snowflake does not support static partitions in INSERT INTO statements,
343
+ we need to add the partition columns to the dataframe as literal columns.
344
+
345
+ ex: INSERT INTO TABLE test_table PARTITION (ds='2021-01-01', hr=10) VALUES ('k1', 100), ('k2', 200), ('k3', 300)
346
+
347
+ Spark sends: VALUES ('k1', 100), ('k2', 200), ('k3', 300) with partition spec (ds='2021-01-01', hr=10)
348
+ Snowflake expects: VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
349
+
350
+ We need to add the partition columns to the dataframe as literal columns.
351
+
352
+ ex: df = df.withColumn('ds', snowpark_fn.lit('2021-01-01'))
353
+ df = df.withColumn('hr', snowpark_fn.lit(10))
354
+
355
+ Then the final query will be:
356
+ INSERT INTO TABLE test_table VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
357
+ """
358
+ for partition_col, partition_value in partition_columns.items():
359
+ df = df.withColumn(partition_col, snowpark_fn.lit(partition_value))
360
+
361
+ target_table = session.table(name)
362
+ target_schema = target_table.schema
363
+
364
+ expected_number_of_columns = (
365
+ len(user_columns) if user_columns else len(target_schema.fields)
366
+ )
367
+ if expected_number_of_columns != len(df.schema.fields):
368
+ reason = (
369
+ "too many data columns"
370
+ if len(df.schema.fields) > expected_number_of_columns
371
+ else "not enough data columns"
372
+ )
373
+ exception = AnalysisException(
374
+ f'[INSERT_COLUMN_ARITY_MISMATCH.{reason.replace(" ", "_").upper()}] Cannot write to {name}, the reason is {reason}:\n'
375
+ f'Table columns: {", ".join(target_schema.names)}.\n'
376
+ f'Data columns: {", ".join(df.schema.names)}.'
377
+ )
378
+ attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
379
+ raise exception
380
+
381
+ try:
382
+ # Modify df with type conversions and struct field name mapping
383
+ modified_columns = []
384
+ for source_field, target_field in zip(df.schema.fields, target_schema.fields):
385
+ col_name = source_field.name
386
+
387
+ # Handle different type conversions
388
+ if isinstance(
389
+ target_field.datatype, snowpark.types.DecimalType
390
+ ) and isinstance(
391
+ source_field.datatype,
392
+ (snowpark.types.FloatType, snowpark.types.DoubleType),
393
+ ):
394
+ # Add CASE WHEN to convert NaN to NULL for DECIMAL targets
395
+ # Only apply this to floating-point source columns
396
+ modified_col = (
397
+ snowpark_fn.when(
398
+ snowpark_fn.equal_nan(snowpark_fn.col(col_name)),
399
+ snowpark_fn.lit(None),
400
+ )
401
+ .otherwise(snowpark_fn.col(col_name))
402
+ .alias(col_name)
403
+ )
404
+ modified_columns.append(modified_col)
405
+ elif (
406
+ isinstance(target_field.datatype, snowpark.types.StructType)
407
+ and source_field.datatype != target_field.datatype
408
+ ):
409
+ # Cast struct with field name mapping (e.g., col1,col2 -> i1,i2)
410
+ # This fixes INSERT INTO table with struct literals like (2, 3)
411
+ modified_col = (
412
+ snowpark_fn.col(col_name)
413
+ .cast(target_field.datatype, rename_fields=True)
414
+ .alias(col_name)
415
+ )
416
+ modified_columns.append(modified_col)
417
+ else:
418
+ modified_columns.append(snowpark_fn.col(col_name))
419
+
420
+ df = df.select(modified_columns)
421
+ except Exception:
422
+ pass
423
+
424
+ queries = df.queries["queries"]
425
+ final_query = queries[0]
426
+ session.sql(
427
+ f"INSERT {overwrite_str} INTO {name} {cols_str} {final_query}",
428
+ ).collect()
429
+
430
+
264
431
  def _spark_field_to_sql(field: jpype.JObject, is_column: bool) -> str:
265
432
  # Column names will be uppercased according to "snowpark.connect.sql.identifiers.auto-uppercase"
266
433
  # if present, or to "spark.sql.caseSensitive".
@@ -595,10 +762,30 @@ def map_sql_to_pandas_df(
595
762
 
596
763
  spark_view_name = next(parsed_sql.find_all(sqlglot.exp.Table)).name
597
764
 
598
- num_columns = len(list(parsed_sql.find_all(sqlglot.exp.ColumnDef)))
599
- null_list = (
600
- ", ".join(["NULL"] * num_columns) if num_columns > 0 else "*"
601
- )
765
+ # extract ONLY top-level column definitions (not nested struct fields)
766
+ column_defs = []
767
+ schema_node = next(parsed_sql.find_all(sqlglot.exp.Schema), None)
768
+ if schema_node:
769
+ for expr in schema_node.expressions:
770
+ if isinstance(expr, sqlglot.exp.ColumnDef):
771
+ column_defs.append(expr)
772
+
773
+ num_columns = len(column_defs)
774
+ if num_columns > 0:
775
+ null_list_parts = []
776
+ for col_def in column_defs:
777
+ col_name = spark_to_sf_single_id(col_def.name, is_column=True)
778
+ col_type = col_def.kind
779
+ if col_type:
780
+ null_list_parts.append(
781
+ f"CAST(NULL AS {col_type.sql(dialect='snowflake')}) AS {col_name}"
782
+ )
783
+ else:
784
+ null_list_parts.append(f"NULL AS {col_name}")
785
+ null_list = ", ".join(null_list_parts)
786
+ else:
787
+ null_list = "*"
788
+
602
789
  empty_select = (
603
790
  f" AS SELECT {null_list} WHERE 1 = 0"
604
791
  if logical_plan.options().isEmpty()
@@ -862,133 +1049,7 @@ def map_sql_to_pandas_df(
862
1049
  )
863
1050
  raise exception
864
1051
  case "InsertIntoStatement":
865
- df_container = execute_logical_plan(logical_plan.query())
866
- df = df_container.dataframe
867
- queries = df.queries["queries"]
868
- if len(queries) != 1:
869
- exception = SnowparkConnectNotImplementedError(
870
- f"Unexpected number of queries: {len(queries)}"
871
- )
872
- attach_custom_error_code(
873
- exception, ErrorCodes.UNSUPPORTED_OPERATION
874
- )
875
- raise exception
876
-
877
- name = get_relation_identifier_name(logical_plan.table(), True)
878
-
879
- user_columns = [
880
- spark_to_sf_single_id(str(col), is_column=True)
881
- for col in as_java_list(logical_plan.userSpecifiedCols())
882
- ]
883
- overwrite_str = "OVERWRITE" if logical_plan.overwrite() else ""
884
- cols_str = "(" + ", ".join(user_columns) + ")" if user_columns else ""
885
-
886
- # Extract partition spec if any
887
- partition_spec = logical_plan.partitionSpec()
888
- partition_map = as_java_map(partition_spec)
889
-
890
- partition_columns = {}
891
- for entry in partition_map.entrySet():
892
- col_name = str(entry.getKey())
893
- value_option = entry.getValue()
894
- if value_option.isDefined():
895
- partition_columns[col_name] = value_option.get()
896
-
897
- # Add partition columns to the dataframe
898
- if partition_columns:
899
- """
900
- Spark sends them in the partition spec and the values won't be present in the values array.
901
- As snowflake does not support static partitions in INSERT INTO statements,
902
- we need to add the partition columns to the dataframe as literal columns.
903
-
904
- ex: INSERT INTO TABLE test_table PARTITION (ds='2021-01-01', hr=10) VALUES ('k1', 100), ('k2', 200), ('k3', 300)
905
-
906
- Spark sends: VALUES ('k1', 100), ('k2', 200), ('k3', 300) with partition spec (ds='2021-01-01', hr=10)
907
- Snowflake expects: VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
908
-
909
- We need to add the partition columns to the dataframe as literal columns.
910
-
911
- ex: df = df.withColumn('ds', snowpark_fn.lit('2021-01-01'))
912
- df = df.withColumn('hr', snowpark_fn.lit(10))
913
-
914
- Then the final query will be:
915
- INSERT INTO TABLE test_table VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
916
- """
917
- for partition_col, partition_value in partition_columns.items():
918
- df = df.withColumn(
919
- partition_col, snowpark_fn.lit(partition_value)
920
- )
921
-
922
- target_table = session.table(name)
923
- target_schema = target_table.schema
924
-
925
- expected_number_of_columns = (
926
- len(user_columns) if user_columns else len(target_schema.fields)
927
- )
928
- if expected_number_of_columns != len(df.schema.fields):
929
- reason = (
930
- "too many data columns"
931
- if len(df.schema.fields) > expected_number_of_columns
932
- else "not enough data columns"
933
- )
934
- exception = AnalysisException(
935
- f'[INSERT_COLUMN_ARITY_MISMATCH.{reason.replace(" ", "_").upper()}] Cannot write to {name}, the reason is {reason}:\n'
936
- f'Table columns: {", ".join(target_schema.names)}.\n'
937
- f'Data columns: {", ".join(df.schema.names)}.'
938
- )
939
- attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
940
- raise exception
941
-
942
- try:
943
- # Modify df with type conversions and struct field name mapping
944
- modified_columns = []
945
- for source_field, target_field in zip(
946
- df.schema.fields, target_schema.fields
947
- ):
948
- col_name = source_field.name
949
-
950
- # Handle different type conversions
951
- if isinstance(
952
- target_field.datatype, snowpark.types.DecimalType
953
- ) and isinstance(
954
- source_field.datatype,
955
- (snowpark.types.FloatType, snowpark.types.DoubleType),
956
- ):
957
- # Add CASE WHEN to convert NaN to NULL for DECIMAL targets
958
- # Only apply this to floating-point source columns
959
- modified_col = (
960
- snowpark_fn.when(
961
- snowpark_fn.equal_nan(snowpark_fn.col(col_name)),
962
- snowpark_fn.lit(None),
963
- )
964
- .otherwise(snowpark_fn.col(col_name))
965
- .alias(col_name)
966
- )
967
- modified_columns.append(modified_col)
968
- elif (
969
- isinstance(target_field.datatype, snowpark.types.StructType)
970
- and source_field.datatype != target_field.datatype
971
- ):
972
- # Cast struct with field name mapping (e.g., col1,col2 -> i1,i2)
973
- # This fixes INSERT INTO table with struct literals like (2, 3)
974
- modified_col = (
975
- snowpark_fn.col(col_name)
976
- .cast(target_field.datatype, rename_fields=True)
977
- .alias(col_name)
978
- )
979
- modified_columns.append(modified_col)
980
- else:
981
- modified_columns.append(snowpark_fn.col(col_name))
982
-
983
- df = df.select(modified_columns)
984
- except Exception:
985
- pass
986
-
987
- queries = df.queries["queries"]
988
- final_query = queries[0]
989
- session.sql(
990
- f"INSERT {overwrite_str} INTO {name} {cols_str} {final_query}",
991
- ).collect()
1052
+ _insert_into_table(logical_plan, session)
992
1053
  case "MergeIntoTable":
993
1054
  source_df_container = map_relation(
994
1055
  map_logical_plan_relation(logical_plan.sourceTable())
@@ -1445,6 +1506,16 @@ def map_sql_to_pandas_df(
1445
1506
  raise exception
1446
1507
 
1447
1508
  return pandas.DataFrame({"": [""]}), ""
1509
+ case "UnresolvedWith":
1510
+ child = logical_plan.child()
1511
+ child_class = str(child.getClass().getSimpleName())
1512
+ match child_class:
1513
+ case "InsertIntoStatement":
1514
+ with _with_cte_scope(logical_plan.cteRelations()):
1515
+ _insert_into_table(child, get_or_create_snowpark_session())
1516
+ case _:
1517
+ execute_logical_plan(logical_plan)
1518
+ return None, None
1448
1519
  case _:
1449
1520
  execute_logical_plan(logical_plan)
1450
1521
  return None, None
@@ -1923,13 +1994,104 @@ def map_logical_plan_relation(
1923
1994
  )
1924
1995
  )
1925
1996
  case "Sort":
1997
+ # Process the input first
1998
+ input_proto = map_logical_plan_relation(rel.child())
1999
+
2000
+ # Check if child is a Project - if so, build an alias map for ORDER BY resolution
2001
+ # This handles: SELECT o.date AS order_date ... ORDER BY o.date
2002
+ child_class = str(rel.child().getClass().getSimpleName())
2003
+ alias_map = {}
2004
+
2005
+ if child_class == "Project":
2006
+ # Extract aliases from SELECT clause
2007
+ for proj_expr in list(as_java_list(rel.child().projectList())):
2008
+ if str(proj_expr.getClass().getSimpleName()) == "Alias":
2009
+ alias_name = str(proj_expr.name())
2010
+ child_expr = proj_expr.child()
2011
+
2012
+ # Store mapping from original expression to alias name
2013
+ # Use string representation for matching
2014
+ expr_str = str(child_expr)
2015
+ alias_map[expr_str] = alias_name
2016
+
2017
+ # Also handle UnresolvedAttribute specifically to get the qualified name
2018
+ if (
2019
+ str(child_expr.getClass().getSimpleName())
2020
+ == "UnresolvedAttribute"
2021
+ ):
2022
+ # Get the qualified name like "o.date"
2023
+ name_parts = list(as_java_list(child_expr.nameParts()))
2024
+ qualified_name = ".".join(str(part) for part in name_parts)
2025
+ if qualified_name not in alias_map:
2026
+ alias_map[qualified_name] = alias_name
2027
+
2028
+ # Process ORDER BY expressions, substituting aliases where needed
2029
+ order_list = []
2030
+ for order_expr in as_java_list(rel.order()):
2031
+ # Get the child expression from the SortOrder
2032
+ child_expr = order_expr.child()
2033
+ expr_class = str(child_expr.getClass().getSimpleName())
2034
+
2035
+ # Check if this expression matches any aliased expression
2036
+ expr_str = str(child_expr)
2037
+ substituted = False
2038
+
2039
+ if expr_str in alias_map:
2040
+ # Found a match - substitute with alias reference
2041
+ alias_name = alias_map[expr_str]
2042
+ # Create new UnresolvedAttribute for the alias
2043
+ UnresolvedAttribute = jpype.JClass(
2044
+ "org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute"
2045
+ )
2046
+ new_attr = UnresolvedAttribute.quoted(alias_name)
2047
+
2048
+ # Create new SortOrder with substituted expression
2049
+ SortOrder = jpype.JClass(
2050
+ "org.apache.spark.sql.catalyst.expressions.SortOrder"
2051
+ )
2052
+ new_order = SortOrder(
2053
+ new_attr,
2054
+ order_expr.direction(),
2055
+ order_expr.nullOrdering(),
2056
+ order_expr.sameOrderExpressions(),
2057
+ )
2058
+ order_list.append(map_logical_plan_expression(new_order).sort_order)
2059
+ substituted = True
2060
+ elif expr_class == "UnresolvedAttribute":
2061
+ # Try matching on qualified name
2062
+ name_parts = list(as_java_list(child_expr.nameParts()))
2063
+ qualified_name = ".".join(str(part) for part in name_parts)
2064
+ if qualified_name in alias_map:
2065
+ alias_name = alias_map[qualified_name]
2066
+ UnresolvedAttribute = jpype.JClass(
2067
+ "org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute"
2068
+ )
2069
+ new_attr = UnresolvedAttribute.quoted(alias_name)
2070
+
2071
+ SortOrder = jpype.JClass(
2072
+ "org.apache.spark.sql.catalyst.expressions.SortOrder"
2073
+ )
2074
+ new_order = SortOrder(
2075
+ new_attr,
2076
+ order_expr.direction(),
2077
+ order_expr.nullOrdering(),
2078
+ order_expr.sameOrderExpressions(),
2079
+ )
2080
+ order_list.append(
2081
+ map_logical_plan_expression(new_order).sort_order
2082
+ )
2083
+ substituted = True
2084
+
2085
+ if not substituted:
2086
+ # No substitution needed - use original
2087
+ order_list.append(
2088
+ map_logical_plan_expression(order_expr).sort_order
2089
+ )
2090
+
1926
2091
  proto = relation_proto.Relation(
1927
2092
  sort=relation_proto.Sort(
1928
- input=map_logical_plan_relation(rel.child()),
1929
- order=[
1930
- map_logical_plan_expression(e).sort_order
1931
- for e in as_java_list(rel.order())
1932
- ],
2093
+ input=input_proto,
2094
+ order=order_list,
1933
2095
  )
1934
2096
  )
1935
2097
  case "SubqueryAlias":
@@ -2116,10 +2278,16 @@ def map_logical_plan_relation(
2116
2278
  )
2117
2279
 
2118
2280
  # Re-evaluate the CTE definition with a fresh plan_id
2119
- fresh_plan_id = gen_sql_plan_id()
2120
- fresh_cte_proto = map_logical_plan_relation(
2121
- cte_definition, fresh_plan_id
2122
- )
2281
+ # Clear HAVING condition to prevent leakage from outer CTEs
2282
+ saved_having = _having_condition.get()
2283
+ _having_condition.set(None)
2284
+ try:
2285
+ fresh_plan_id = gen_sql_plan_id()
2286
+ fresh_cte_proto = map_logical_plan_relation(
2287
+ cte_definition, fresh_plan_id
2288
+ )
2289
+ finally:
2290
+ _having_condition.set(saved_having)
2123
2291
 
2124
2292
  # Use SubqueryColumnAliases to ensure consistent column names across CTE references
2125
2293
  # This is crucial for CTEs that reference other CTEs
@@ -2274,16 +2442,7 @@ def map_logical_plan_relation(
2274
2442
  ),
2275
2443
  )
2276
2444
  case "UnresolvedWith":
2277
- with _push_cte_scope():
2278
- for cte in as_java_list(rel.cteRelations()):
2279
- name = str(cte._1())
2280
- # Store the original CTE definition for re-evaluation
2281
- _cte_definitions.get()[name] = cte._2()
2282
- # Process CTE definition with a unique plan_id to ensure proper column naming
2283
- cte_plan_id = gen_sql_plan_id()
2284
- cte_proto = map_logical_plan_relation(cte._2(), cte_plan_id)
2285
- _ctes.get()[name] = cte_proto
2286
-
2445
+ with _with_cte_scope(rel.cteRelations()):
2287
2446
  proto = map_logical_plan_relation(rel.child())
2288
2447
  case "LateralJoin":
2289
2448
  left = map_logical_plan_relation(rel.left())
@@ -225,12 +225,20 @@ def _get_supported_read_file_format(unparsed_identifier: str) -> str | None:
225
225
  return None
226
226
 
227
227
 
228
+ # TODO: [SNOW-2465948] Remove this once Snowpark fixes the issue with stage paths.
229
+ class StagePathStr(str):
230
+ def partition(self, __sep):
231
+ if str(self)[0] == "'":
232
+ return str(self)[1:].partition(__sep)
233
+ return str(self).partition(__sep)
234
+
235
+
228
236
  def _quote_stage_path(stage_path: str) -> str:
229
237
  """
230
238
  Quote stage paths to escape any special characters.
231
239
  """
232
240
  if stage_path.startswith("@"):
233
- return f"'{stage_path}'"
241
+ return StagePathStr(f"'{stage_path}'")
234
242
  return stage_path
235
243
 
236
244
 
@@ -6,6 +6,7 @@ import copy
6
6
  from typing import Any
7
7
 
8
8
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
9
+ from pyspark.errors.exceptions.base import AnalysisException
9
10
 
10
11
  import snowflake.snowpark.functions as snowpark_fn
11
12
  from snowflake import snowpark
@@ -20,6 +21,7 @@ from snowflake.snowpark_connect.relation.read.metadata_utils import (
20
21
  get_non_metadata_fields,
21
22
  )
22
23
  from snowflake.snowpark_connect.relation.read.utils import (
24
+ apply_metadata_exclusion_pattern,
23
25
  get_spark_column_names_from_snowpark_columns,
24
26
  rename_columns_as_snowflake_standard,
25
27
  )
@@ -62,6 +64,8 @@ def map_read_csv(
62
64
  snowpark_read_options["INFER_SCHEMA"] = snowpark_options.get(
63
65
  "INFER_SCHEMA", False
64
66
  )
67
+
68
+ apply_metadata_exclusion_pattern(snowpark_options)
65
69
  snowpark_read_options["PATTERN"] = snowpark_options.get("PATTERN", None)
66
70
 
67
71
  raw_options = rel.read.data_source.options
@@ -157,6 +161,7 @@ def get_header_names(
157
161
  path: list[str],
158
162
  file_format_options: dict,
159
163
  snowpark_read_options: dict,
164
+ raw_options: dict,
160
165
  ) -> list[str]:
161
166
  no_header_file_format_options = copy.copy(file_format_options)
162
167
  no_header_file_format_options["PARSE_HEADER"] = False
@@ -168,7 +173,19 @@ def get_header_names(
168
173
  no_header_snowpark_read_options.pop("INFER_SCHEMA", None)
169
174
 
170
175
  header_df = session.read.options(no_header_snowpark_read_options).csv(path).limit(1)
171
- header_data = header_df.collect()[0]
176
+ collected_data = header_df.collect()
177
+
178
+ if len(collected_data) == 0:
179
+ error_msg = f"Path does not exist or contains no data: {path}"
180
+ user_pattern = raw_options.get("pathGlobFilter", None)
181
+ if user_pattern:
182
+ error_msg += f" (with pathGlobFilter: {user_pattern})"
183
+
184
+ exception = AnalysisException(error_msg)
185
+ attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
186
+ raise exception
187
+
188
+ header_data = collected_data[0]
172
189
  return [
173
190
  f'"{header_data[i]}"'
174
191
  for i in range(len(header_df.schema.fields))
@@ -207,7 +224,7 @@ def read_data(
207
224
  return df
208
225
 
209
226
  headers = get_header_names(
210
- session, path, file_format_options, snowpark_read_options
227
+ session, path, file_format_options, snowpark_read_options, raw_options
211
228
  )
212
229
 
213
230
  df_schema_fields = non_metadata_fields
@@ -35,6 +35,7 @@ from snowflake.snowpark_connect.relation.read.metadata_utils import (
35
35
  add_filename_metadata_to_reader,
36
36
  )
37
37
  from snowflake.snowpark_connect.relation.read.utils import (
38
+ apply_metadata_exclusion_pattern,
38
39
  get_spark_column_names_from_snowpark_columns,
39
40
  rename_columns_as_snowflake_standard,
40
41
  )
@@ -80,6 +81,8 @@ def map_read_json(
80
81
  dropFieldIfAllNull = snowpark_options.pop("dropfieldifallnull", False)
81
82
  batch_size = snowpark_options.pop("batchsize", 1000)
82
83
 
84
+ apply_metadata_exclusion_pattern(snowpark_options)
85
+
83
86
  reader = add_filename_metadata_to_reader(
84
87
  session.read.options(snowpark_options), raw_options
85
88
  )
@@ -29,6 +29,7 @@ from snowflake.snowpark_connect.relation.read.metadata_utils import (
29
29
  )
30
30
  from snowflake.snowpark_connect.relation.read.reader_config import ReaderWriterConfig
31
31
  from snowflake.snowpark_connect.relation.read.utils import (
32
+ apply_metadata_exclusion_pattern,
32
33
  rename_columns_as_snowflake_standard,
33
34
  )
34
35
  from snowflake.snowpark_connect.utils.telemetry import (
@@ -57,6 +58,8 @@ def map_read_parquet(
57
58
  assert schema is None, "Read PARQUET does not support user schema"
58
59
  assert len(paths) > 0, "Read PARQUET expects at least one path"
59
60
 
61
+ apply_metadata_exclusion_pattern(snowpark_options)
62
+
60
63
  reader = add_filename_metadata_to_reader(
61
64
  session.read.options(snowpark_options), raw_options
62
65
  )
@@ -26,6 +26,10 @@ def get_file_paths_from_stage(
26
26
  ) -> typing.List[str]:
27
27
  files_paths = []
28
28
  for listed_path_row in session.sql(f"LIST {path}").collect():
29
+ # Skip _SUCCESS marker files
30
+ if listed_path_row[0].endswith("_SUCCESS"):
31
+ continue
32
+
29
33
  listed_path = listed_path_row[0].split("/")
30
34
  if listed_path_row[0].startswith("s3://") or listed_path_row[0].startswith(
31
35
  "s3a://"
@@ -126,6 +126,7 @@ CSV_READ_SUPPORTED_OPTIONS = lowercase_set(
126
126
  "compression",
127
127
  # "escapeQuotes",
128
128
  # "quoteAll",
129
+ "rowsToInferSchema", # Snowflake specific option, number of rows to infer schema
129
130
  }
130
131
  )
131
132
 
@@ -201,6 +202,15 @@ def csv_convert_to_snowpark_args(snowpark_config: dict[str, Any]) -> dict[str, A
201
202
  if snowpark_config["escape"] and snowpark_config["escape"] == "\\":
202
203
  snowpark_config["escape"] = "\\\\"
203
204
 
205
+ # Snowflake specific option, number of rows to infer schema for CSV files
206
+ if "rowstoinferschema" in snowpark_config:
207
+ rows_to_infer_schema = snowpark_config["rowstoinferschema"]
208
+ del snowpark_config["rowstoinferschema"]
209
+ snowpark_config["INFER_SCHEMA_OPTIONS"] = {
210
+ "MAX_RECORDS_PER_FILE": int(rows_to_infer_schema),
211
+ "USE_RELAXED_TYPES": True,
212
+ }
213
+
204
214
  # Rename the keys to match the Snowpark configuration.
205
215
  for spark_arg, snowpark_arg in renamed_args.items():
206
216
  if spark_arg not in snowpark_config: