snowpark-connect 0.32.0__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/column_name_handler.py +92 -27
- snowflake/snowpark_connect/column_qualifier.py +0 -4
- snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
- snowflake/snowpark_connect/expression/map_sql_expression.py +12 -4
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +58 -21
- snowflake/snowpark_connect/expression/map_unresolved_function.py +62 -27
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
- snowflake/snowpark_connect/relation/map_aggregate.py +2 -4
- snowflake/snowpark_connect/relation/map_column_ops.py +5 -0
- snowflake/snowpark_connect/relation/map_join.py +218 -146
- snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
- snowflake/snowpark_connect/relation/map_sql.py +102 -16
- snowflake/snowpark_connect/relation/read/map_read_json.py +87 -2
- snowflake/snowpark_connect/relation/utils.py +46 -0
- snowflake/snowpark_connect/relation/write/map_write.py +186 -275
- snowflake/snowpark_connect/resources_initializer.py +25 -13
- snowflake/snowpark_connect/server.py +9 -24
- snowflake/snowpark_connect/type_mapping.py +2 -0
- snowflake/snowpark_connect/typed_column.py +2 -2
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +8 -1
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/METADATA +3 -1
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/RECORD +35 -93
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-0.33.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-0.33.0.dist-info}/top_level.txt +0 -0
|
@@ -61,6 +61,9 @@ from snowflake.snowpark_connect.relation.map_relation import (
|
|
|
61
61
|
NATURAL_JOIN_TYPE_BASE,
|
|
62
62
|
map_relation,
|
|
63
63
|
)
|
|
64
|
+
|
|
65
|
+
# Import from utils for consistency
|
|
66
|
+
from snowflake.snowpark_connect.relation.utils import is_aggregate_function
|
|
64
67
|
from snowflake.snowpark_connect.type_mapping import map_snowpark_to_pyspark_types
|
|
65
68
|
from snowflake.snowpark_connect.utils.context import (
|
|
66
69
|
_accessing_temp_object,
|
|
@@ -588,25 +591,28 @@ def map_sql_to_pandas_df(
|
|
|
588
591
|
f"CREATE TABLE {if_not_exists}{name} LIKE {source}"
|
|
589
592
|
).collect()
|
|
590
593
|
case "CreateTempViewUsing":
|
|
594
|
+
parsed_sql = sqlglot.parse_one(sql_string, dialect="spark")
|
|
595
|
+
|
|
596
|
+
spark_view_name = next(parsed_sql.find_all(sqlglot.exp.Table)).name
|
|
597
|
+
|
|
598
|
+
num_columns = len(list(parsed_sql.find_all(sqlglot.exp.ColumnDef)))
|
|
599
|
+
null_list = (
|
|
600
|
+
", ".join(["NULL"] * num_columns) if num_columns > 0 else "*"
|
|
601
|
+
)
|
|
591
602
|
empty_select = (
|
|
592
|
-
" AS SELECT
|
|
603
|
+
f" AS SELECT {null_list} WHERE 1 = 0"
|
|
593
604
|
if logical_plan.options().isEmpty()
|
|
594
605
|
and logical_plan.children().isEmpty()
|
|
595
606
|
else ""
|
|
596
607
|
)
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
.transform(_normalize_identifiers)
|
|
608
|
+
|
|
609
|
+
transformed_sql = (
|
|
610
|
+
parsed_sql.transform(_normalize_identifiers)
|
|
600
611
|
.transform(_remove_column_data_type)
|
|
601
612
|
.transform(_remove_file_format_property)
|
|
602
613
|
)
|
|
603
|
-
snowflake_sql =
|
|
614
|
+
snowflake_sql = transformed_sql.sql(dialect="snowflake")
|
|
604
615
|
session.sql(f"{snowflake_sql}{empty_select}").collect()
|
|
605
|
-
spark_view_name = next(
|
|
606
|
-
sqlglot.parse_one(sql_string, dialect="spark").find_all(
|
|
607
|
-
sqlglot.exp.Table
|
|
608
|
-
)
|
|
609
|
-
).name
|
|
610
616
|
snowflake_view_name = spark_to_sf_single_id_with_unquoting(
|
|
611
617
|
spark_view_name
|
|
612
618
|
)
|
|
@@ -877,16 +883,71 @@ def map_sql_to_pandas_df(
|
|
|
877
883
|
overwrite_str = "OVERWRITE" if logical_plan.overwrite() else ""
|
|
878
884
|
cols_str = "(" + ", ".join(user_columns) + ")" if user_columns else ""
|
|
879
885
|
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
886
|
+
# Extract partition spec if any
|
|
887
|
+
partition_spec = logical_plan.partitionSpec()
|
|
888
|
+
partition_map = as_java_map(partition_spec)
|
|
889
|
+
|
|
890
|
+
partition_columns = {}
|
|
891
|
+
for entry in partition_map.entrySet():
|
|
892
|
+
col_name = str(entry.getKey())
|
|
893
|
+
value_option = entry.getValue()
|
|
894
|
+
if value_option.isDefined():
|
|
895
|
+
partition_columns[col_name] = value_option.get()
|
|
896
|
+
|
|
897
|
+
# Add partition columns to the dataframe
|
|
898
|
+
if partition_columns:
|
|
899
|
+
"""
|
|
900
|
+
Spark sends them in the partition spec and the values won't be present in the values array.
|
|
901
|
+
As snowflake does not support static partitions in INSERT INTO statements,
|
|
902
|
+
we need to add the partition columns to the dataframe as literal columns.
|
|
903
|
+
|
|
904
|
+
ex: INSERT INTO TABLE test_table PARTITION (ds='2021-01-01', hr=10) VALUES ('k1', 100), ('k2', 200), ('k3', 300)
|
|
905
|
+
|
|
906
|
+
Spark sends: VALUES ('k1', 100), ('k2', 200), ('k3', 300) with partition spec (ds='2021-01-01', hr=10)
|
|
907
|
+
Snowflake expects: VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
|
|
908
|
+
|
|
909
|
+
We need to add the partition columns to the dataframe as literal columns.
|
|
910
|
+
|
|
911
|
+
ex: df = df.withColumn('ds', snowpark_fn.lit('2021-01-01'))
|
|
912
|
+
df = df.withColumn('hr', snowpark_fn.lit(10))
|
|
913
|
+
|
|
914
|
+
Then the final query will be:
|
|
915
|
+
INSERT INTO TABLE test_table VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
|
|
916
|
+
"""
|
|
917
|
+
for partition_col, partition_value in partition_columns.items():
|
|
918
|
+
df = df.withColumn(
|
|
919
|
+
partition_col, snowpark_fn.lit(partition_value)
|
|
920
|
+
)
|
|
921
|
+
|
|
922
|
+
target_table = session.table(name)
|
|
923
|
+
target_schema = target_table.schema
|
|
883
924
|
|
|
884
|
-
|
|
925
|
+
expected_number_of_columns = (
|
|
926
|
+
len(user_columns) if user_columns else len(target_schema.fields)
|
|
927
|
+
)
|
|
928
|
+
if expected_number_of_columns != len(df.schema.fields):
|
|
929
|
+
reason = (
|
|
930
|
+
"too many data columns"
|
|
931
|
+
if len(df.schema.fields) > expected_number_of_columns
|
|
932
|
+
else "not enough data columns"
|
|
933
|
+
)
|
|
934
|
+
exception = AnalysisException(
|
|
935
|
+
f'[INSERT_COLUMN_ARITY_MISMATCH.{reason.replace(" ", "_").upper()}] Cannot write to {name}, the reason is {reason}:\n'
|
|
936
|
+
f'Table columns: {", ".join(target_schema.names)}.\n'
|
|
937
|
+
f'Data columns: {", ".join(df.schema.names)}.'
|
|
938
|
+
)
|
|
939
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
|
|
940
|
+
raise exception
|
|
941
|
+
|
|
942
|
+
try:
|
|
943
|
+
# Modify df with type conversions and struct field name mapping
|
|
885
944
|
modified_columns = []
|
|
886
945
|
for source_field, target_field in zip(
|
|
887
946
|
df.schema.fields, target_schema.fields
|
|
888
947
|
):
|
|
889
948
|
col_name = source_field.name
|
|
949
|
+
|
|
950
|
+
# Handle different type conversions
|
|
890
951
|
if isinstance(
|
|
891
952
|
target_field.datatype, snowpark.types.DecimalType
|
|
892
953
|
) and isinstance(
|
|
@@ -904,12 +965,25 @@ def map_sql_to_pandas_df(
|
|
|
904
965
|
.alias(col_name)
|
|
905
966
|
)
|
|
906
967
|
modified_columns.append(modified_col)
|
|
968
|
+
elif (
|
|
969
|
+
isinstance(target_field.datatype, snowpark.types.StructType)
|
|
970
|
+
and source_field.datatype != target_field.datatype
|
|
971
|
+
):
|
|
972
|
+
# Cast struct with field name mapping (e.g., col1,col2 -> i1,i2)
|
|
973
|
+
# This fixes INSERT INTO table with struct literals like (2, 3)
|
|
974
|
+
modified_col = (
|
|
975
|
+
snowpark_fn.col(col_name)
|
|
976
|
+
.cast(target_field.datatype, rename_fields=True)
|
|
977
|
+
.alias(col_name)
|
|
978
|
+
)
|
|
979
|
+
modified_columns.append(modified_col)
|
|
907
980
|
else:
|
|
908
981
|
modified_columns.append(snowpark_fn.col(col_name))
|
|
909
982
|
|
|
910
983
|
df = df.select(modified_columns)
|
|
911
984
|
except Exception:
|
|
912
985
|
pass
|
|
986
|
+
|
|
913
987
|
queries = df.queries["queries"]
|
|
914
988
|
final_query = queries[0]
|
|
915
989
|
session.sql(
|
|
@@ -1345,7 +1419,7 @@ def map_sql_to_pandas_df(
|
|
|
1345
1419
|
|
|
1346
1420
|
return pandas.DataFrame({"": [""]}), ""
|
|
1347
1421
|
case "RepairTable":
|
|
1348
|
-
# No-Op
|
|
1422
|
+
# No-Op: Snowflake doesn't have explicit partitions to repair.
|
|
1349
1423
|
table_relation = logical_plan.child()
|
|
1350
1424
|
db_and_table_name = as_java_list(table_relation.multipartIdentifier())
|
|
1351
1425
|
multi_part_len = len(db_and_table_name)
|
|
@@ -1598,7 +1672,19 @@ def map_logical_plan_relation(
|
|
|
1598
1672
|
attr_parts = as_java_list(expr.nameParts())
|
|
1599
1673
|
if len(attr_parts) == 1:
|
|
1600
1674
|
attr_name = str(attr_parts[0])
|
|
1601
|
-
|
|
1675
|
+
if attr_name in alias_map:
|
|
1676
|
+
# Check if the alias references an aggregate function
|
|
1677
|
+
# If so, don't substitute because you can't GROUP BY an aggregate
|
|
1678
|
+
aliased_expr = alias_map[attr_name]
|
|
1679
|
+
aliased_expr_class = str(
|
|
1680
|
+
aliased_expr.getClass().getSimpleName()
|
|
1681
|
+
)
|
|
1682
|
+
if aliased_expr_class == "UnresolvedFunction":
|
|
1683
|
+
func_name = str(aliased_expr.nameParts().head())
|
|
1684
|
+
if is_aggregate_function(func_name):
|
|
1685
|
+
return expr
|
|
1686
|
+
return aliased_expr
|
|
1687
|
+
return expr
|
|
1602
1688
|
|
|
1603
1689
|
return expr
|
|
1604
1690
|
|
|
@@ -117,6 +117,10 @@ def map_read_json(
|
|
|
117
117
|
if unquote_if_quoted(sf.name) in columns_with_valid_contents
|
|
118
118
|
]
|
|
119
119
|
|
|
120
|
+
new_schema, fields_changed = validate_and_update_schema(schema)
|
|
121
|
+
if fields_changed:
|
|
122
|
+
schema = new_schema
|
|
123
|
+
|
|
120
124
|
df = construct_dataframe_by_schema(
|
|
121
125
|
schema, df.to_local_iterator(), session, snowpark_options, batch_size
|
|
122
126
|
)
|
|
@@ -134,6 +138,84 @@ def map_read_json(
|
|
|
134
138
|
)
|
|
135
139
|
|
|
136
140
|
|
|
141
|
+
def should_drop_field(field: StructField) -> bool:
|
|
142
|
+
if isinstance(field.datatype, StructType):
|
|
143
|
+
# "a" : {} => drop the field
|
|
144
|
+
if len(field.datatype.fields) == 0:
|
|
145
|
+
return True
|
|
146
|
+
elif (
|
|
147
|
+
isinstance(field.datatype, ArrayType)
|
|
148
|
+
and field.datatype.element_type is not None
|
|
149
|
+
and isinstance(field.datatype.element_type, StructType)
|
|
150
|
+
):
|
|
151
|
+
if len(field.datatype.element_type.fields) == 0:
|
|
152
|
+
# "a" : [{}] => drop the field
|
|
153
|
+
return True
|
|
154
|
+
return False
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# Validate the schema to ensure it is valid for Snowflake
|
|
158
|
+
# Handles these cases:
|
|
159
|
+
# 1. Drops StructField([])
|
|
160
|
+
# 2. Drops ArrayType(StructType([]))
|
|
161
|
+
# 3. ArrayType() -> ArrayType(StringType())
|
|
162
|
+
def validate_and_update_schema(schema: StructType | None) -> (StructType | None, bool):
|
|
163
|
+
if not isinstance(schema, StructType):
|
|
164
|
+
return schema, False
|
|
165
|
+
new_fields = []
|
|
166
|
+
fields_changed = False
|
|
167
|
+
for sf in schema.fields:
|
|
168
|
+
if should_drop_field(sf):
|
|
169
|
+
fields_changed = True
|
|
170
|
+
continue
|
|
171
|
+
if isinstance(sf.datatype, StructType):
|
|
172
|
+
# If the schema is a struct, validate the child schema
|
|
173
|
+
if len(sf.datatype.fields) == 0:
|
|
174
|
+
# No fields in the struct, drop the field
|
|
175
|
+
fields_changed = True
|
|
176
|
+
continue
|
|
177
|
+
child_field = StructField(sf.name, sf.datatype, sf.nullable)
|
|
178
|
+
# Recursively validate the child schema
|
|
179
|
+
child_field.datatype, child_field_changes = validate_and_update_schema(
|
|
180
|
+
sf.datatype
|
|
181
|
+
)
|
|
182
|
+
if should_drop_field(child_field):
|
|
183
|
+
fields_changed = True
|
|
184
|
+
continue
|
|
185
|
+
new_fields.append(child_field)
|
|
186
|
+
fields_changed = fields_changed or child_field_changes
|
|
187
|
+
elif isinstance(sf.datatype, ArrayType):
|
|
188
|
+
# If the schema is an array, validate the element schema
|
|
189
|
+
if sf.datatype.element_type is not None and isinstance(
|
|
190
|
+
sf.datatype.element_type, StructType
|
|
191
|
+
):
|
|
192
|
+
# If the element schema is a struct, validate the element schema
|
|
193
|
+
if len(sf.datatype.element_type.fields) == 0:
|
|
194
|
+
# No fields in the struct, drop the field
|
|
195
|
+
fields_changed = True
|
|
196
|
+
continue
|
|
197
|
+
else:
|
|
198
|
+
# Recursively validate the element schema
|
|
199
|
+
element_schema, element_field_changes = validate_and_update_schema(
|
|
200
|
+
sf.datatype.element_type
|
|
201
|
+
)
|
|
202
|
+
if element_field_changes:
|
|
203
|
+
sf.datatype.element_type = element_schema
|
|
204
|
+
fields_changed = True
|
|
205
|
+
if should_drop_field(sf):
|
|
206
|
+
fields_changed = True
|
|
207
|
+
continue
|
|
208
|
+
elif sf.datatype.element_type is None:
|
|
209
|
+
fields_changed = True
|
|
210
|
+
sf.datatype.element_type = StringType()
|
|
211
|
+
new_fields.append(sf)
|
|
212
|
+
else:
|
|
213
|
+
new_fields.append(sf)
|
|
214
|
+
if fields_changed:
|
|
215
|
+
schema.fields = new_fields
|
|
216
|
+
return schema, fields_changed
|
|
217
|
+
|
|
218
|
+
|
|
137
219
|
def merge_json_schema(
|
|
138
220
|
content: typing.Any,
|
|
139
221
|
schema: StructType | None,
|
|
@@ -378,8 +460,11 @@ def construct_row_by_schema(
|
|
|
378
460
|
inner_schema = schema.element_type
|
|
379
461
|
if isinstance(content, str):
|
|
380
462
|
content = json.loads(content)
|
|
381
|
-
|
|
382
|
-
|
|
463
|
+
if inner_schema is not None:
|
|
464
|
+
for ele in content:
|
|
465
|
+
result.append(
|
|
466
|
+
construct_row_by_schema(ele, inner_schema, snowpark_options)
|
|
467
|
+
)
|
|
383
468
|
return result
|
|
384
469
|
elif isinstance(schema, DateType):
|
|
385
470
|
return cast_to_match_snowpark_type(
|
|
@@ -284,3 +284,49 @@ def snowpark_functions_col(name: str, column_map: ColumnNameMap) -> snowpark.Col
|
|
|
284
284
|
"""
|
|
285
285
|
is_qualified_name = name not in column_map.get_snowpark_columns()
|
|
286
286
|
return snowpark_fn.col(name, _is_qualified_name=is_qualified_name)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def is_aggregate_function(func_name: str) -> bool:
|
|
290
|
+
"""
|
|
291
|
+
Check if a function name is an aggregate function.
|
|
292
|
+
|
|
293
|
+
Uses a hybrid approach:
|
|
294
|
+
1. First checks PySpark's docstring convention (docstrings starting with "Aggregate function:")
|
|
295
|
+
2. Falls back to a hardcoded list for functions with missing/incorrect docstrings
|
|
296
|
+
|
|
297
|
+
This ensures comprehensive coverage while automatically supporting new PySpark aggregate functions.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
func_name: The function name to check (case-insensitive)
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
True if the function is an aggregate function, False otherwise
|
|
304
|
+
"""
|
|
305
|
+
try:
|
|
306
|
+
import pyspark.sql.functions as pyspark_functions
|
|
307
|
+
|
|
308
|
+
# TODO:
|
|
309
|
+
"""
|
|
310
|
+
Check we can leverage scala classes to determine agg functions:
|
|
311
|
+
https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala#L207
|
|
312
|
+
"""
|
|
313
|
+
|
|
314
|
+
# Try PySpark docstring approach first (covers most aggregate functions)
|
|
315
|
+
pyspark_func = getattr(pyspark_functions, func_name.lower(), None)
|
|
316
|
+
if pyspark_func and pyspark_func.__doc__:
|
|
317
|
+
if pyspark_func.__doc__.lstrip().startswith("Aggregate function:"):
|
|
318
|
+
return True
|
|
319
|
+
|
|
320
|
+
# Fallback list for aggregate functions with missing/incorrect docstrings
|
|
321
|
+
# These are known aggregate functions that don't have proper docstring markers
|
|
322
|
+
fallback_aggregates = {
|
|
323
|
+
"percentile_cont",
|
|
324
|
+
"percentile_disc",
|
|
325
|
+
"any_value",
|
|
326
|
+
"grouping",
|
|
327
|
+
"grouping_id",
|
|
328
|
+
}
|
|
329
|
+
return func_name.lower() in fallback_aggregates
|
|
330
|
+
|
|
331
|
+
except Exception:
|
|
332
|
+
return False
|