snowpark-connect 0.31.0__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (111) hide show
  1. snowflake/snowpark_connect/__init__.py +1 -0
  2. snowflake/snowpark_connect/column_name_handler.py +143 -105
  3. snowflake/snowpark_connect/column_qualifier.py +43 -0
  4. snowflake/snowpark_connect/dataframe_container.py +3 -2
  5. snowflake/snowpark_connect/execute_plan/map_execution_command.py +4 -2
  6. snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
  7. snowflake/snowpark_connect/expression/map_expression.py +5 -4
  8. snowflake/snowpark_connect/expression/map_extension.py +12 -6
  9. snowflake/snowpark_connect/expression/map_sql_expression.py +50 -7
  10. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +62 -25
  11. snowflake/snowpark_connect/expression/map_unresolved_function.py +924 -127
  12. snowflake/snowpark_connect/expression/map_unresolved_star.py +9 -7
  13. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
  14. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
  15. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
  16. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
  17. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +4 -1
  18. snowflake/snowpark_connect/relation/map_aggregate.py +6 -5
  19. snowflake/snowpark_connect/relation/map_column_ops.py +9 -3
  20. snowflake/snowpark_connect/relation/map_extension.py +10 -9
  21. snowflake/snowpark_connect/relation/map_join.py +219 -144
  22. snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
  23. snowflake/snowpark_connect/relation/map_sql.py +134 -16
  24. snowflake/snowpark_connect/relation/map_subquery_alias.py +4 -1
  25. snowflake/snowpark_connect/relation/read/map_read_json.py +87 -2
  26. snowflake/snowpark_connect/relation/read/map_read_table.py +6 -3
  27. snowflake/snowpark_connect/relation/utils.py +46 -0
  28. snowflake/snowpark_connect/relation/write/map_write.py +215 -289
  29. snowflake/snowpark_connect/resources_initializer.py +25 -13
  30. snowflake/snowpark_connect/server.py +10 -26
  31. snowflake/snowpark_connect/type_mapping.py +38 -3
  32. snowflake/snowpark_connect/typed_column.py +8 -6
  33. snowflake/snowpark_connect/utils/sequence.py +21 -0
  34. snowflake/snowpark_connect/utils/session.py +27 -4
  35. snowflake/snowpark_connect/version.py +1 -1
  36. snowflake/snowpark_decoder/dp_session.py +1 -1
  37. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/METADATA +7 -2
  38. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/RECORD +46 -105
  39. snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
  40. snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
  41. snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
  42. snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
  43. snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
  44. snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
  45. snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
  46. snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
  47. snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
  48. snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
  49. snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
  50. snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
  51. snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
  52. snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
  53. snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
  54. snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
  55. snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
  56. snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
  57. snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
  58. snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
  59. snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
  60. snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
  61. snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
  62. snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
  63. snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
  64. snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
  65. snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
  66. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  67. snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
  68. snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
  69. snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
  70. snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
  71. snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
  72. snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
  73. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  74. snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
  75. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  76. snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
  77. snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
  78. snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
  79. snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
  80. snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
  81. snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
  82. snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
  83. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
  84. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
  85. snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
  86. snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
  87. snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
  88. snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
  89. snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
  90. snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
  91. snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
  92. snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
  93. snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
  94. snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
  95. snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
  96. snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
  97. snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
  98. snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
  99. snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
  100. snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
  101. snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
  102. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +0 -4
  103. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +0 -4
  104. {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-connect +0 -0
  105. {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-session +0 -0
  106. {snowpark_connect-0.31.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-submit +0 -0
  107. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/WHEEL +0 -0
  108. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE-binary +0 -0
  109. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE.txt +0 -0
  110. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/NOTICE-binary +0 -0
  111. {snowpark_connect-0.31.0.dist-info → snowpark_connect-0.33.0.dist-info}/top_level.txt +0 -0
@@ -61,6 +61,9 @@ from snowflake.snowpark_connect.relation.map_relation import (
61
61
  NATURAL_JOIN_TYPE_BASE,
62
62
  map_relation,
63
63
  )
64
+
65
+ # Import from utils for consistency
66
+ from snowflake.snowpark_connect.relation.utils import is_aggregate_function
64
67
  from snowflake.snowpark_connect.type_mapping import map_snowpark_to_pyspark_types
65
68
  from snowflake.snowpark_connect.utils.context import (
66
69
  _accessing_temp_object,
@@ -588,25 +591,28 @@ def map_sql_to_pandas_df(
588
591
  f"CREATE TABLE {if_not_exists}{name} LIKE {source}"
589
592
  ).collect()
590
593
  case "CreateTempViewUsing":
594
+ parsed_sql = sqlglot.parse_one(sql_string, dialect="spark")
595
+
596
+ spark_view_name = next(parsed_sql.find_all(sqlglot.exp.Table)).name
597
+
598
+ num_columns = len(list(parsed_sql.find_all(sqlglot.exp.ColumnDef)))
599
+ null_list = (
600
+ ", ".join(["NULL"] * num_columns) if num_columns > 0 else "*"
601
+ )
591
602
  empty_select = (
592
- " AS SELECT * WHERE 1 = 0"
603
+ f" AS SELECT {null_list} WHERE 1 = 0"
593
604
  if logical_plan.options().isEmpty()
594
605
  and logical_plan.children().isEmpty()
595
606
  else ""
596
607
  )
597
- parsed_sql = (
598
- sqlglot.parse_one(sql_string, dialect="spark")
599
- .transform(_normalize_identifiers)
608
+
609
+ transformed_sql = (
610
+ parsed_sql.transform(_normalize_identifiers)
600
611
  .transform(_remove_column_data_type)
601
612
  .transform(_remove_file_format_property)
602
613
  )
603
- snowflake_sql = parsed_sql.sql(dialect="snowflake")
614
+ snowflake_sql = transformed_sql.sql(dialect="snowflake")
604
615
  session.sql(f"{snowflake_sql}{empty_select}").collect()
605
- spark_view_name = next(
606
- sqlglot.parse_one(sql_string, dialect="spark").find_all(
607
- sqlglot.exp.Table
608
- )
609
- ).name
610
616
  snowflake_view_name = spark_to_sf_single_id_with_unquoting(
611
617
  spark_view_name
612
618
  )
@@ -877,16 +883,71 @@ def map_sql_to_pandas_df(
877
883
  overwrite_str = "OVERWRITE" if logical_plan.overwrite() else ""
878
884
  cols_str = "(" + ", ".join(user_columns) + ")" if user_columns else ""
879
885
 
880
- try:
881
- target_table = session.table(name)
882
- target_schema = target_table.schema
886
+ # Extract partition spec if any
887
+ partition_spec = logical_plan.partitionSpec()
888
+ partition_map = as_java_map(partition_spec)
889
+
890
+ partition_columns = {}
891
+ for entry in partition_map.entrySet():
892
+ col_name = str(entry.getKey())
893
+ value_option = entry.getValue()
894
+ if value_option.isDefined():
895
+ partition_columns[col_name] = value_option.get()
896
+
897
+ # Add partition columns to the dataframe
898
+ if partition_columns:
899
+ """
900
+ Spark sends them in the partition spec and the values won't be present in the values array.
901
+ As snowflake does not support static partitions in INSERT INTO statements,
902
+ we need to add the partition columns to the dataframe as literal columns.
903
+
904
+ ex: INSERT INTO TABLE test_table PARTITION (ds='2021-01-01', hr=10) VALUES ('k1', 100), ('k2', 200), ('k3', 300)
905
+
906
+ Spark sends: VALUES ('k1', 100), ('k2', 200), ('k3', 300) with partition spec (ds='2021-01-01', hr=10)
907
+ Snowflake expects: VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
883
908
 
884
- # Modify df with NaN NULL conversion for DECIMAL columns
909
+ We need to add the partition columns to the dataframe as literal columns.
910
+
911
+ ex: df = df.withColumn('ds', snowpark_fn.lit('2021-01-01'))
912
+ df = df.withColumn('hr', snowpark_fn.lit(10))
913
+
914
+ Then the final query will be:
915
+ INSERT INTO TABLE test_table VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
916
+ """
917
+ for partition_col, partition_value in partition_columns.items():
918
+ df = df.withColumn(
919
+ partition_col, snowpark_fn.lit(partition_value)
920
+ )
921
+
922
+ target_table = session.table(name)
923
+ target_schema = target_table.schema
924
+
925
+ expected_number_of_columns = (
926
+ len(user_columns) if user_columns else len(target_schema.fields)
927
+ )
928
+ if expected_number_of_columns != len(df.schema.fields):
929
+ reason = (
930
+ "too many data columns"
931
+ if len(df.schema.fields) > expected_number_of_columns
932
+ else "not enough data columns"
933
+ )
934
+ exception = AnalysisException(
935
+ f'[INSERT_COLUMN_ARITY_MISMATCH.{reason.replace(" ", "_").upper()}] Cannot write to {name}, the reason is {reason}:\n'
936
+ f'Table columns: {", ".join(target_schema.names)}.\n'
937
+ f'Data columns: {", ".join(df.schema.names)}.'
938
+ )
939
+ attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
940
+ raise exception
941
+
942
+ try:
943
+ # Modify df with type conversions and struct field name mapping
885
944
  modified_columns = []
886
945
  for source_field, target_field in zip(
887
946
  df.schema.fields, target_schema.fields
888
947
  ):
889
948
  col_name = source_field.name
949
+
950
+ # Handle different type conversions
890
951
  if isinstance(
891
952
  target_field.datatype, snowpark.types.DecimalType
892
953
  ) and isinstance(
@@ -904,12 +965,25 @@ def map_sql_to_pandas_df(
904
965
  .alias(col_name)
905
966
  )
906
967
  modified_columns.append(modified_col)
968
+ elif (
969
+ isinstance(target_field.datatype, snowpark.types.StructType)
970
+ and source_field.datatype != target_field.datatype
971
+ ):
972
+ # Cast struct with field name mapping (e.g., col1,col2 -> i1,i2)
973
+ # This fixes INSERT INTO table with struct literals like (2, 3)
974
+ modified_col = (
975
+ snowpark_fn.col(col_name)
976
+ .cast(target_field.datatype, rename_fields=True)
977
+ .alias(col_name)
978
+ )
979
+ modified_columns.append(modified_col)
907
980
  else:
908
981
  modified_columns.append(snowpark_fn.col(col_name))
909
982
 
910
983
  df = df.select(modified_columns)
911
984
  except Exception:
912
985
  pass
986
+
913
987
  queries = df.queries["queries"]
914
988
  final_query = queries[0]
915
989
  session.sql(
@@ -1343,6 +1417,33 @@ def map_sql_to_pandas_df(
1343
1417
  )
1344
1418
  SNOWFLAKE_CATALOG.refreshTable(table_name_unquoted)
1345
1419
 
1420
+ return pandas.DataFrame({"": [""]}), ""
1421
+ case "RepairTable":
1422
+ # No-Op: Snowflake doesn't have explicit partitions to repair.
1423
+ table_relation = logical_plan.child()
1424
+ db_and_table_name = as_java_list(table_relation.multipartIdentifier())
1425
+ multi_part_len = len(db_and_table_name)
1426
+
1427
+ if multi_part_len == 1:
1428
+ table_name = db_and_table_name[0]
1429
+ db_name = None
1430
+ full_table_name = table_name
1431
+ else:
1432
+ db_name = db_and_table_name[0]
1433
+ table_name = db_and_table_name[1]
1434
+ full_table_name = db_name + "." + table_name
1435
+
1436
+ df = SNOWFLAKE_CATALOG.tableExists(table_name, db_name)
1437
+
1438
+ table_exist = df.iloc[0, 0]
1439
+
1440
+ if not table_exist:
1441
+ exception = AnalysisException(
1442
+ f"[TABLE_OR_VIEW_NOT_FOUND] Table not found `{full_table_name}`."
1443
+ )
1444
+ attach_custom_error_code(exception, ErrorCodes.INVALID_OPERATION)
1445
+ raise exception
1446
+
1346
1447
  return pandas.DataFrame({"": [""]}), ""
1347
1448
  case _:
1348
1449
  execute_logical_plan(logical_plan)
@@ -1483,7 +1584,12 @@ def map_sql(
1483
1584
  snowpark_connect_sql_passthrough, sql_stmt = is_valid_passthrough_sql(rel.sql.query)
1484
1585
 
1485
1586
  if not snowpark_connect_sql_passthrough:
1486
- logical_plan = sql_parser().parseQuery(sql_stmt)
1587
+ # Changed from parseQuery to parsePlan as Spark parseQuery() call generating wrong logical plan for
1588
+ # query like this: SELECT cast('3.4' as decimal(38, 18)) UNION SELECT 'foo'
1589
+ # As such other place in this file we use parsePlan.
1590
+ # Main difference between parsePlan() and parseQuery() is, parsePlan() can be called for any SQL statement, while
1591
+ # parseQuery() can only be called for query statements.
1592
+ logical_plan = sql_parser().parsePlan(sql_stmt)
1487
1593
 
1488
1594
  parsed_pos_args = parse_pos_args(logical_plan, rel.sql.pos_args)
1489
1595
  set_sql_args(rel.sql.args, parsed_pos_args)
@@ -1566,7 +1672,19 @@ def map_logical_plan_relation(
1566
1672
  attr_parts = as_java_list(expr.nameParts())
1567
1673
  if len(attr_parts) == 1:
1568
1674
  attr_name = str(attr_parts[0])
1569
- return alias_map.get(attr_name, expr)
1675
+ if attr_name in alias_map:
1676
+ # Check if the alias references an aggregate function
1677
+ # If so, don't substitute because you can't GROUP BY an aggregate
1678
+ aliased_expr = alias_map[attr_name]
1679
+ aliased_expr_class = str(
1680
+ aliased_expr.getClass().getSimpleName()
1681
+ )
1682
+ if aliased_expr_class == "UnresolvedFunction":
1683
+ func_name = str(aliased_expr.nameParts().head())
1684
+ if is_aggregate_function(func_name):
1685
+ return expr
1686
+ return aliased_expr
1687
+ return expr
1570
1688
 
1571
1689
  return expr
1572
1690
 
@@ -4,6 +4,7 @@
4
4
 
5
5
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
6
6
 
7
+ from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
7
8
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
8
9
  from snowflake.snowpark_connect.relation.map_relation import map_relation
9
10
 
@@ -18,7 +19,9 @@ def map_alias(
18
19
  # we set reuse_parsed_plan=False because we need new expr_id for the attributes (output columns) in aliased snowpark dataframe
19
20
  # reuse_parsed_plan will lead to ambiguous column name for operations like joining two dataframes that are aliased from the same dataframe
20
21
  input_container = map_relation(rel.subquery_alias.input, reuse_parsed_plan=False)
21
- qualifiers = [[alias]] * len(input_container.column_map.columns)
22
+ qualifiers = [
23
+ {ColumnQualifier((alias,))} for _ in input_container.column_map.columns
24
+ ]
22
25
 
23
26
  return DataFrameContainer.create_with_column_mapping(
24
27
  dataframe=input_container.dataframe,
@@ -117,6 +117,10 @@ def map_read_json(
117
117
  if unquote_if_quoted(sf.name) in columns_with_valid_contents
118
118
  ]
119
119
 
120
+ new_schema, fields_changed = validate_and_update_schema(schema)
121
+ if fields_changed:
122
+ schema = new_schema
123
+
120
124
  df = construct_dataframe_by_schema(
121
125
  schema, df.to_local_iterator(), session, snowpark_options, batch_size
122
126
  )
@@ -134,6 +138,84 @@ def map_read_json(
134
138
  )
135
139
 
136
140
 
141
+ def should_drop_field(field: StructField) -> bool:
142
+ if isinstance(field.datatype, StructType):
143
+ # "a" : {} => drop the field
144
+ if len(field.datatype.fields) == 0:
145
+ return True
146
+ elif (
147
+ isinstance(field.datatype, ArrayType)
148
+ and field.datatype.element_type is not None
149
+ and isinstance(field.datatype.element_type, StructType)
150
+ ):
151
+ if len(field.datatype.element_type.fields) == 0:
152
+ # "a" : [{}] => drop the field
153
+ return True
154
+ return False
155
+
156
+
157
+ # Validate the schema to ensure it is valid for Snowflake
158
+ # Handles these cases:
159
+ # 1. Drops StructField([])
160
+ # 2. Drops ArrayType(StructType([]))
161
+ # 3. ArrayType() -> ArrayType(StringType())
162
+ def validate_and_update_schema(schema: StructType | None) -> (StructType | None, bool):
163
+ if not isinstance(schema, StructType):
164
+ return schema, False
165
+ new_fields = []
166
+ fields_changed = False
167
+ for sf in schema.fields:
168
+ if should_drop_field(sf):
169
+ fields_changed = True
170
+ continue
171
+ if isinstance(sf.datatype, StructType):
172
+ # If the schema is a struct, validate the child schema
173
+ if len(sf.datatype.fields) == 0:
174
+ # No fields in the struct, drop the field
175
+ fields_changed = True
176
+ continue
177
+ child_field = StructField(sf.name, sf.datatype, sf.nullable)
178
+ # Recursively validate the child schema
179
+ child_field.datatype, child_field_changes = validate_and_update_schema(
180
+ sf.datatype
181
+ )
182
+ if should_drop_field(child_field):
183
+ fields_changed = True
184
+ continue
185
+ new_fields.append(child_field)
186
+ fields_changed = fields_changed or child_field_changes
187
+ elif isinstance(sf.datatype, ArrayType):
188
+ # If the schema is an array, validate the element schema
189
+ if sf.datatype.element_type is not None and isinstance(
190
+ sf.datatype.element_type, StructType
191
+ ):
192
+ # If the element schema is a struct, validate the element schema
193
+ if len(sf.datatype.element_type.fields) == 0:
194
+ # No fields in the struct, drop the field
195
+ fields_changed = True
196
+ continue
197
+ else:
198
+ # Recursively validate the element schema
199
+ element_schema, element_field_changes = validate_and_update_schema(
200
+ sf.datatype.element_type
201
+ )
202
+ if element_field_changes:
203
+ sf.datatype.element_type = element_schema
204
+ fields_changed = True
205
+ if should_drop_field(sf):
206
+ fields_changed = True
207
+ continue
208
+ elif sf.datatype.element_type is None:
209
+ fields_changed = True
210
+ sf.datatype.element_type = StringType()
211
+ new_fields.append(sf)
212
+ else:
213
+ new_fields.append(sf)
214
+ if fields_changed:
215
+ schema.fields = new_fields
216
+ return schema, fields_changed
217
+
218
+
137
219
  def merge_json_schema(
138
220
  content: typing.Any,
139
221
  schema: StructType | None,
@@ -378,8 +460,11 @@ def construct_row_by_schema(
378
460
  inner_schema = schema.element_type
379
461
  if isinstance(content, str):
380
462
  content = json.loads(content)
381
- for ele in content:
382
- result.append(construct_row_by_schema(ele, inner_schema, snowpark_options))
463
+ if inner_schema is not None:
464
+ for ele in content:
465
+ result.append(
466
+ construct_row_by_schema(ele, inner_schema, snowpark_options)
467
+ )
383
468
  return result
384
469
  elif isinstance(schema, DateType):
385
470
  return cast_to_match_snowpark_type(
@@ -16,6 +16,7 @@ from snowflake.snowpark_connect.column_name_handler import (
16
16
  ColumnNameMap,
17
17
  make_column_names_snowpark_compatible,
18
18
  )
19
+ from snowflake.snowpark_connect.column_qualifier import ColumnQualifier
19
20
  from snowflake.snowpark_connect.config import auto_uppercase_non_column_identifiers
20
21
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
21
22
  from snowflake.snowpark_connect.error.error_codes import ErrorCodes
@@ -58,7 +59,7 @@ def post_process_df(
58
59
  spark_column_names=true_names,
59
60
  snowpark_column_names=snowpark_column_names,
60
61
  snowpark_column_types=[f.datatype for f in df.schema.fields],
61
- column_qualifiers=[name_parts] * len(true_names)
62
+ column_qualifiers=[{ColumnQualifier(tuple(name_parts))} for _ in true_names]
62
63
  if source_table_name
63
64
  else None,
64
65
  )
@@ -94,8 +95,10 @@ def _get_temporary_view(
94
95
  spark_column_names=temp_view.column_map.get_spark_columns(),
95
96
  snowpark_column_names=snowpark_column_names,
96
97
  column_metadata=temp_view.column_map.column_metadata,
97
- column_qualifiers=[split_fully_qualified_spark_name(table_name)]
98
- * len(temp_view.column_map.get_spark_columns()),
98
+ column_qualifiers=[
99
+ {ColumnQualifier(tuple(split_fully_qualified_spark_name(table_name)))}
100
+ for _ in range(len(temp_view.column_map.get_spark_columns()))
101
+ ],
99
102
  parent_column_name_map=temp_view.column_map.get_parent_column_name_map(),
100
103
  )
101
104
 
@@ -284,3 +284,49 @@ def snowpark_functions_col(name: str, column_map: ColumnNameMap) -> snowpark.Col
284
284
  """
285
285
  is_qualified_name = name not in column_map.get_snowpark_columns()
286
286
  return snowpark_fn.col(name, _is_qualified_name=is_qualified_name)
287
+
288
+
289
+ def is_aggregate_function(func_name: str) -> bool:
290
+ """
291
+ Check if a function name is an aggregate function.
292
+
293
+ Uses a hybrid approach:
294
+ 1. First checks PySpark's docstring convention (docstrings starting with "Aggregate function:")
295
+ 2. Falls back to a hardcoded list for functions with missing/incorrect docstrings
296
+
297
+ This ensures comprehensive coverage while automatically supporting new PySpark aggregate functions.
298
+
299
+ Args:
300
+ func_name: The function name to check (case-insensitive)
301
+
302
+ Returns:
303
+ True if the function is an aggregate function, False otherwise
304
+ """
305
+ try:
306
+ import pyspark.sql.functions as pyspark_functions
307
+
308
+ # TODO:
309
+ """
310
+ Check we can leverage scala classes to determine agg functions:
311
+ https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala#L207
312
+ """
313
+
314
+ # Try PySpark docstring approach first (covers most aggregate functions)
315
+ pyspark_func = getattr(pyspark_functions, func_name.lower(), None)
316
+ if pyspark_func and pyspark_func.__doc__:
317
+ if pyspark_func.__doc__.lstrip().startswith("Aggregate function:"):
318
+ return True
319
+
320
+ # Fallback list for aggregate functions with missing/incorrect docstrings
321
+ # These are known aggregate functions that don't have proper docstring markers
322
+ fallback_aggregates = {
323
+ "percentile_cont",
324
+ "percentile_disc",
325
+ "any_value",
326
+ "grouping",
327
+ "grouping_id",
328
+ }
329
+ return func_name.lower() in fallback_aggregates
330
+
331
+ except Exception:
332
+ return False