snowpark-connect 0.32.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/column_name_handler.py +91 -40
- snowflake/snowpark_connect/column_qualifier.py +0 -4
- snowflake/snowpark_connect/config.py +9 -0
- snowflake/snowpark_connect/expression/hybrid_column_map.py +5 -4
- snowflake/snowpark_connect/expression/literal.py +12 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +18 -4
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +150 -29
- snowflake/snowpark_connect/expression/map_unresolved_function.py +93 -55
- snowflake/snowpark_connect/relation/map_aggregate.py +156 -257
- snowflake/snowpark_connect/relation/map_column_ops.py +19 -0
- snowflake/snowpark_connect/relation/map_join.py +454 -252
- snowflake/snowpark_connect/relation/map_row_ops.py +136 -54
- snowflake/snowpark_connect/relation/map_sql.py +335 -90
- snowflake/snowpark_connect/relation/read/map_read.py +9 -1
- snowflake/snowpark_connect/relation/read/map_read_csv.py +19 -2
- snowflake/snowpark_connect/relation/read/map_read_json.py +90 -2
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +3 -0
- snowflake/snowpark_connect/relation/read/map_read_text.py +4 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +10 -0
- snowflake/snowpark_connect/relation/read/utils.py +41 -0
- snowflake/snowpark_connect/relation/utils.py +50 -2
- snowflake/snowpark_connect/relation/write/map_write.py +251 -292
- snowflake/snowpark_connect/resources_initializer.py +25 -13
- snowflake/snowpark_connect/server.py +9 -24
- snowflake/snowpark_connect/type_mapping.py +2 -0
- snowflake/snowpark_connect/typed_column.py +2 -2
- snowflake/snowpark_connect/utils/context.py +0 -14
- snowflake/snowpark_connect/utils/expression_transformer.py +163 -0
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +4 -1
- snowflake/snowpark_connect/utils/udf_helper.py +1 -0
- snowflake/snowpark_connect/utils/udtf_helper.py +3 -0
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/METADATA +4 -2
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/RECORD +43 -104
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.32.0.data → snowpark_connect-1.0.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.32.0.dist-info → snowpark_connect-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -25,6 +25,7 @@ import snowflake.snowpark.functions as snowpark_fn
|
|
|
25
25
|
import snowflake.snowpark_connect.proto.snowflake_expression_ext_pb2 as snowflake_exp_proto
|
|
26
26
|
import snowflake.snowpark_connect.proto.snowflake_relation_ext_pb2 as snowflake_proto
|
|
27
27
|
from snowflake import snowpark
|
|
28
|
+
from snowflake.snowpark import Session
|
|
28
29
|
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
29
30
|
quote_name_without_upper_casing,
|
|
30
31
|
unquote_if_quoted,
|
|
@@ -61,6 +62,9 @@ from snowflake.snowpark_connect.relation.map_relation import (
|
|
|
61
62
|
NATURAL_JOIN_TYPE_BASE,
|
|
62
63
|
map_relation,
|
|
63
64
|
)
|
|
65
|
+
|
|
66
|
+
# Import from utils for consistency
|
|
67
|
+
from snowflake.snowpark_connect.relation.utils import is_aggregate_function
|
|
64
68
|
from snowflake.snowpark_connect.type_mapping import map_snowpark_to_pyspark_types
|
|
65
69
|
from snowflake.snowpark_connect.utils.context import (
|
|
66
70
|
_accessing_temp_object,
|
|
@@ -152,6 +156,48 @@ def _push_cte_scope():
|
|
|
152
156
|
_cte_definitions.reset(def_token)
|
|
153
157
|
|
|
154
158
|
|
|
159
|
+
def _process_cte_relations(cte_relations):
|
|
160
|
+
"""
|
|
161
|
+
Process CTE relations and register them in the current CTE scope.
|
|
162
|
+
|
|
163
|
+
This function extracts CTE definitions from CTE relations,
|
|
164
|
+
maps them to protobuf representations, and stores them for later reference.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
cte_relations: Java list of CTE relations (tuples of name and SubqueryAlias)
|
|
168
|
+
"""
|
|
169
|
+
for cte in as_java_list(cte_relations):
|
|
170
|
+
name = str(cte._1())
|
|
171
|
+
# Store the original CTE definition for re-evaluation
|
|
172
|
+
_cte_definitions.get()[name] = cte._2()
|
|
173
|
+
# Process CTE definition with a unique plan_id to ensure proper column naming
|
|
174
|
+
# Clear HAVING condition before processing each CTE to prevent leakage between CTEs
|
|
175
|
+
saved_having = _having_condition.get()
|
|
176
|
+
_having_condition.set(None)
|
|
177
|
+
try:
|
|
178
|
+
cte_plan_id = gen_sql_plan_id()
|
|
179
|
+
cte_proto = map_logical_plan_relation(cte._2(), cte_plan_id)
|
|
180
|
+
_ctes.get()[name] = cte_proto
|
|
181
|
+
finally:
|
|
182
|
+
_having_condition.set(saved_having)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@contextmanager
|
|
186
|
+
def _with_cte_scope(cte_relations):
|
|
187
|
+
"""
|
|
188
|
+
Context manager that creates a CTE scope and processes CTE relations.
|
|
189
|
+
|
|
190
|
+
This combines _push_cte_scope() and _process_cte_relations() to handle
|
|
191
|
+
the common pattern of processing CTEs within a new scope.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
cte_relations: Java list of CTE relations (tuples of name and SubqueryAlias)
|
|
195
|
+
"""
|
|
196
|
+
with (_push_cte_scope()):
|
|
197
|
+
_process_cte_relations(cte_relations)
|
|
198
|
+
yield
|
|
199
|
+
|
|
200
|
+
|
|
155
201
|
@contextmanager
|
|
156
202
|
def _push_window_specs_scope():
|
|
157
203
|
"""
|
|
@@ -258,6 +304,130 @@ def _create_table_as_select(logical_plan, mode: str) -> None:
|
|
|
258
304
|
)
|
|
259
305
|
|
|
260
306
|
|
|
307
|
+
def _insert_into_table(logical_plan, session: Session) -> None:
|
|
308
|
+
df_container = execute_logical_plan(logical_plan.query())
|
|
309
|
+
df = df_container.dataframe
|
|
310
|
+
queries = df.queries["queries"]
|
|
311
|
+
if len(queries) != 1:
|
|
312
|
+
exception = SnowparkConnectNotImplementedError(
|
|
313
|
+
f"Unexpected number of queries: {len(queries)}"
|
|
314
|
+
)
|
|
315
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
316
|
+
raise exception
|
|
317
|
+
|
|
318
|
+
name = get_relation_identifier_name(logical_plan.table(), True)
|
|
319
|
+
|
|
320
|
+
user_columns = [
|
|
321
|
+
spark_to_sf_single_id(str(col), is_column=True)
|
|
322
|
+
for col in as_java_list(logical_plan.userSpecifiedCols())
|
|
323
|
+
]
|
|
324
|
+
overwrite_str = "OVERWRITE" if logical_plan.overwrite() else ""
|
|
325
|
+
cols_str = "(" + ", ".join(user_columns) + ")" if user_columns else ""
|
|
326
|
+
|
|
327
|
+
# Extract partition spec if any
|
|
328
|
+
partition_spec = logical_plan.partitionSpec()
|
|
329
|
+
partition_map = as_java_map(partition_spec)
|
|
330
|
+
|
|
331
|
+
partition_columns = {}
|
|
332
|
+
for entry in partition_map.entrySet():
|
|
333
|
+
col_name = str(entry.getKey())
|
|
334
|
+
value_option = entry.getValue()
|
|
335
|
+
if value_option.isDefined():
|
|
336
|
+
partition_columns[col_name] = value_option.get()
|
|
337
|
+
|
|
338
|
+
# Add partition columns to the dataframe
|
|
339
|
+
if partition_columns:
|
|
340
|
+
"""
|
|
341
|
+
Spark sends them in the partition spec and the values won't be present in the values array.
|
|
342
|
+
As snowflake does not support static partitions in INSERT INTO statements,
|
|
343
|
+
we need to add the partition columns to the dataframe as literal columns.
|
|
344
|
+
|
|
345
|
+
ex: INSERT INTO TABLE test_table PARTITION (ds='2021-01-01', hr=10) VALUES ('k1', 100), ('k2', 200), ('k3', 300)
|
|
346
|
+
|
|
347
|
+
Spark sends: VALUES ('k1', 100), ('k2', 200), ('k3', 300) with partition spec (ds='2021-01-01', hr=10)
|
|
348
|
+
Snowflake expects: VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
|
|
349
|
+
|
|
350
|
+
We need to add the partition columns to the dataframe as literal columns.
|
|
351
|
+
|
|
352
|
+
ex: df = df.withColumn('ds', snowpark_fn.lit('2021-01-01'))
|
|
353
|
+
df = df.withColumn('hr', snowpark_fn.lit(10))
|
|
354
|
+
|
|
355
|
+
Then the final query will be:
|
|
356
|
+
INSERT INTO TABLE test_table VALUES ('k1', 100, '2021-01-01', 10), ('k2', 200, '2021-01-01', 10), ('k3', 300, '2021-01-01', 10)
|
|
357
|
+
"""
|
|
358
|
+
for partition_col, partition_value in partition_columns.items():
|
|
359
|
+
df = df.withColumn(partition_col, snowpark_fn.lit(partition_value))
|
|
360
|
+
|
|
361
|
+
target_table = session.table(name)
|
|
362
|
+
target_schema = target_table.schema
|
|
363
|
+
|
|
364
|
+
expected_number_of_columns = (
|
|
365
|
+
len(user_columns) if user_columns else len(target_schema.fields)
|
|
366
|
+
)
|
|
367
|
+
if expected_number_of_columns != len(df.schema.fields):
|
|
368
|
+
reason = (
|
|
369
|
+
"too many data columns"
|
|
370
|
+
if len(df.schema.fields) > expected_number_of_columns
|
|
371
|
+
else "not enough data columns"
|
|
372
|
+
)
|
|
373
|
+
exception = AnalysisException(
|
|
374
|
+
f'[INSERT_COLUMN_ARITY_MISMATCH.{reason.replace(" ", "_").upper()}] Cannot write to {name}, the reason is {reason}:\n'
|
|
375
|
+
f'Table columns: {", ".join(target_schema.names)}.\n'
|
|
376
|
+
f'Data columns: {", ".join(df.schema.names)}.'
|
|
377
|
+
)
|
|
378
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
|
|
379
|
+
raise exception
|
|
380
|
+
|
|
381
|
+
try:
|
|
382
|
+
# Modify df with type conversions and struct field name mapping
|
|
383
|
+
modified_columns = []
|
|
384
|
+
for source_field, target_field in zip(df.schema.fields, target_schema.fields):
|
|
385
|
+
col_name = source_field.name
|
|
386
|
+
|
|
387
|
+
# Handle different type conversions
|
|
388
|
+
if isinstance(
|
|
389
|
+
target_field.datatype, snowpark.types.DecimalType
|
|
390
|
+
) and isinstance(
|
|
391
|
+
source_field.datatype,
|
|
392
|
+
(snowpark.types.FloatType, snowpark.types.DoubleType),
|
|
393
|
+
):
|
|
394
|
+
# Add CASE WHEN to convert NaN to NULL for DECIMAL targets
|
|
395
|
+
# Only apply this to floating-point source columns
|
|
396
|
+
modified_col = (
|
|
397
|
+
snowpark_fn.when(
|
|
398
|
+
snowpark_fn.equal_nan(snowpark_fn.col(col_name)),
|
|
399
|
+
snowpark_fn.lit(None),
|
|
400
|
+
)
|
|
401
|
+
.otherwise(snowpark_fn.col(col_name))
|
|
402
|
+
.alias(col_name)
|
|
403
|
+
)
|
|
404
|
+
modified_columns.append(modified_col)
|
|
405
|
+
elif (
|
|
406
|
+
isinstance(target_field.datatype, snowpark.types.StructType)
|
|
407
|
+
and source_field.datatype != target_field.datatype
|
|
408
|
+
):
|
|
409
|
+
# Cast struct with field name mapping (e.g., col1,col2 -> i1,i2)
|
|
410
|
+
# This fixes INSERT INTO table with struct literals like (2, 3)
|
|
411
|
+
modified_col = (
|
|
412
|
+
snowpark_fn.col(col_name)
|
|
413
|
+
.cast(target_field.datatype, rename_fields=True)
|
|
414
|
+
.alias(col_name)
|
|
415
|
+
)
|
|
416
|
+
modified_columns.append(modified_col)
|
|
417
|
+
else:
|
|
418
|
+
modified_columns.append(snowpark_fn.col(col_name))
|
|
419
|
+
|
|
420
|
+
df = df.select(modified_columns)
|
|
421
|
+
except Exception:
|
|
422
|
+
pass
|
|
423
|
+
|
|
424
|
+
queries = df.queries["queries"]
|
|
425
|
+
final_query = queries[0]
|
|
426
|
+
session.sql(
|
|
427
|
+
f"INSERT {overwrite_str} INTO {name} {cols_str} {final_query}",
|
|
428
|
+
).collect()
|
|
429
|
+
|
|
430
|
+
|
|
261
431
|
def _spark_field_to_sql(field: jpype.JObject, is_column: bool) -> str:
|
|
262
432
|
# Column names will be uppercased according to "snowpark.connect.sql.identifiers.auto-uppercase"
|
|
263
433
|
# if present, or to "spark.sql.caseSensitive".
|
|
@@ -588,25 +758,48 @@ def map_sql_to_pandas_df(
|
|
|
588
758
|
f"CREATE TABLE {if_not_exists}{name} LIKE {source}"
|
|
589
759
|
).collect()
|
|
590
760
|
case "CreateTempViewUsing":
|
|
761
|
+
parsed_sql = sqlglot.parse_one(sql_string, dialect="spark")
|
|
762
|
+
|
|
763
|
+
spark_view_name = next(parsed_sql.find_all(sqlglot.exp.Table)).name
|
|
764
|
+
|
|
765
|
+
# extract ONLY top-level column definitions (not nested struct fields)
|
|
766
|
+
column_defs = []
|
|
767
|
+
schema_node = next(parsed_sql.find_all(sqlglot.exp.Schema), None)
|
|
768
|
+
if schema_node:
|
|
769
|
+
for expr in schema_node.expressions:
|
|
770
|
+
if isinstance(expr, sqlglot.exp.ColumnDef):
|
|
771
|
+
column_defs.append(expr)
|
|
772
|
+
|
|
773
|
+
num_columns = len(column_defs)
|
|
774
|
+
if num_columns > 0:
|
|
775
|
+
null_list_parts = []
|
|
776
|
+
for col_def in column_defs:
|
|
777
|
+
col_name = spark_to_sf_single_id(col_def.name, is_column=True)
|
|
778
|
+
col_type = col_def.kind
|
|
779
|
+
if col_type:
|
|
780
|
+
null_list_parts.append(
|
|
781
|
+
f"CAST(NULL AS {col_type.sql(dialect='snowflake')}) AS {col_name}"
|
|
782
|
+
)
|
|
783
|
+
else:
|
|
784
|
+
null_list_parts.append(f"NULL AS {col_name}")
|
|
785
|
+
null_list = ", ".join(null_list_parts)
|
|
786
|
+
else:
|
|
787
|
+
null_list = "*"
|
|
788
|
+
|
|
591
789
|
empty_select = (
|
|
592
|
-
" AS SELECT
|
|
790
|
+
f" AS SELECT {null_list} WHERE 1 = 0"
|
|
593
791
|
if logical_plan.options().isEmpty()
|
|
594
792
|
and logical_plan.children().isEmpty()
|
|
595
793
|
else ""
|
|
596
794
|
)
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
.transform(_normalize_identifiers)
|
|
795
|
+
|
|
796
|
+
transformed_sql = (
|
|
797
|
+
parsed_sql.transform(_normalize_identifiers)
|
|
600
798
|
.transform(_remove_column_data_type)
|
|
601
799
|
.transform(_remove_file_format_property)
|
|
602
800
|
)
|
|
603
|
-
snowflake_sql =
|
|
801
|
+
snowflake_sql = transformed_sql.sql(dialect="snowflake")
|
|
604
802
|
session.sql(f"{snowflake_sql}{empty_select}").collect()
|
|
605
|
-
spark_view_name = next(
|
|
606
|
-
sqlglot.parse_one(sql_string, dialect="spark").find_all(
|
|
607
|
-
sqlglot.exp.Table
|
|
608
|
-
)
|
|
609
|
-
).name
|
|
610
803
|
snowflake_view_name = spark_to_sf_single_id_with_unquoting(
|
|
611
804
|
spark_view_name
|
|
612
805
|
)
|
|
@@ -856,65 +1049,7 @@ def map_sql_to_pandas_df(
|
|
|
856
1049
|
)
|
|
857
1050
|
raise exception
|
|
858
1051
|
case "InsertIntoStatement":
|
|
859
|
-
|
|
860
|
-
df = df_container.dataframe
|
|
861
|
-
queries = df.queries["queries"]
|
|
862
|
-
if len(queries) != 1:
|
|
863
|
-
exception = SnowparkConnectNotImplementedError(
|
|
864
|
-
f"Unexpected number of queries: {len(queries)}"
|
|
865
|
-
)
|
|
866
|
-
attach_custom_error_code(
|
|
867
|
-
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
868
|
-
)
|
|
869
|
-
raise exception
|
|
870
|
-
|
|
871
|
-
name = get_relation_identifier_name(logical_plan.table(), True)
|
|
872
|
-
|
|
873
|
-
user_columns = [
|
|
874
|
-
spark_to_sf_single_id(str(col), is_column=True)
|
|
875
|
-
for col in as_java_list(logical_plan.userSpecifiedCols())
|
|
876
|
-
]
|
|
877
|
-
overwrite_str = "OVERWRITE" if logical_plan.overwrite() else ""
|
|
878
|
-
cols_str = "(" + ", ".join(user_columns) + ")" if user_columns else ""
|
|
879
|
-
|
|
880
|
-
try:
|
|
881
|
-
target_table = session.table(name)
|
|
882
|
-
target_schema = target_table.schema
|
|
883
|
-
|
|
884
|
-
# Modify df with NaN → NULL conversion for DECIMAL columns
|
|
885
|
-
modified_columns = []
|
|
886
|
-
for source_field, target_field in zip(
|
|
887
|
-
df.schema.fields, target_schema.fields
|
|
888
|
-
):
|
|
889
|
-
col_name = source_field.name
|
|
890
|
-
if isinstance(
|
|
891
|
-
target_field.datatype, snowpark.types.DecimalType
|
|
892
|
-
) and isinstance(
|
|
893
|
-
source_field.datatype,
|
|
894
|
-
(snowpark.types.FloatType, snowpark.types.DoubleType),
|
|
895
|
-
):
|
|
896
|
-
# Add CASE WHEN to convert NaN to NULL for DECIMAL targets
|
|
897
|
-
# Only apply this to floating-point source columns
|
|
898
|
-
modified_col = (
|
|
899
|
-
snowpark_fn.when(
|
|
900
|
-
snowpark_fn.equal_nan(snowpark_fn.col(col_name)),
|
|
901
|
-
snowpark_fn.lit(None),
|
|
902
|
-
)
|
|
903
|
-
.otherwise(snowpark_fn.col(col_name))
|
|
904
|
-
.alias(col_name)
|
|
905
|
-
)
|
|
906
|
-
modified_columns.append(modified_col)
|
|
907
|
-
else:
|
|
908
|
-
modified_columns.append(snowpark_fn.col(col_name))
|
|
909
|
-
|
|
910
|
-
df = df.select(modified_columns)
|
|
911
|
-
except Exception:
|
|
912
|
-
pass
|
|
913
|
-
queries = df.queries["queries"]
|
|
914
|
-
final_query = queries[0]
|
|
915
|
-
session.sql(
|
|
916
|
-
f"INSERT {overwrite_str} INTO {name} {cols_str} {final_query}",
|
|
917
|
-
).collect()
|
|
1052
|
+
_insert_into_table(logical_plan, session)
|
|
918
1053
|
case "MergeIntoTable":
|
|
919
1054
|
source_df_container = map_relation(
|
|
920
1055
|
map_logical_plan_relation(logical_plan.sourceTable())
|
|
@@ -1345,7 +1480,7 @@ def map_sql_to_pandas_df(
|
|
|
1345
1480
|
|
|
1346
1481
|
return pandas.DataFrame({"": [""]}), ""
|
|
1347
1482
|
case "RepairTable":
|
|
1348
|
-
# No-Op
|
|
1483
|
+
# No-Op: Snowflake doesn't have explicit partitions to repair.
|
|
1349
1484
|
table_relation = logical_plan.child()
|
|
1350
1485
|
db_and_table_name = as_java_list(table_relation.multipartIdentifier())
|
|
1351
1486
|
multi_part_len = len(db_and_table_name)
|
|
@@ -1371,6 +1506,16 @@ def map_sql_to_pandas_df(
|
|
|
1371
1506
|
raise exception
|
|
1372
1507
|
|
|
1373
1508
|
return pandas.DataFrame({"": [""]}), ""
|
|
1509
|
+
case "UnresolvedWith":
|
|
1510
|
+
child = logical_plan.child()
|
|
1511
|
+
child_class = str(child.getClass().getSimpleName())
|
|
1512
|
+
match child_class:
|
|
1513
|
+
case "InsertIntoStatement":
|
|
1514
|
+
with _with_cte_scope(logical_plan.cteRelations()):
|
|
1515
|
+
_insert_into_table(child, get_or_create_snowpark_session())
|
|
1516
|
+
case _:
|
|
1517
|
+
execute_logical_plan(logical_plan)
|
|
1518
|
+
return None, None
|
|
1374
1519
|
case _:
|
|
1375
1520
|
execute_logical_plan(logical_plan)
|
|
1376
1521
|
return None, None
|
|
@@ -1598,7 +1743,19 @@ def map_logical_plan_relation(
|
|
|
1598
1743
|
attr_parts = as_java_list(expr.nameParts())
|
|
1599
1744
|
if len(attr_parts) == 1:
|
|
1600
1745
|
attr_name = str(attr_parts[0])
|
|
1601
|
-
|
|
1746
|
+
if attr_name in alias_map:
|
|
1747
|
+
# Check if the alias references an aggregate function
|
|
1748
|
+
# If so, don't substitute because you can't GROUP BY an aggregate
|
|
1749
|
+
aliased_expr = alias_map[attr_name]
|
|
1750
|
+
aliased_expr_class = str(
|
|
1751
|
+
aliased_expr.getClass().getSimpleName()
|
|
1752
|
+
)
|
|
1753
|
+
if aliased_expr_class == "UnresolvedFunction":
|
|
1754
|
+
func_name = str(aliased_expr.nameParts().head())
|
|
1755
|
+
if is_aggregate_function(func_name):
|
|
1756
|
+
return expr
|
|
1757
|
+
return aliased_expr
|
|
1758
|
+
return expr
|
|
1602
1759
|
|
|
1603
1760
|
return expr
|
|
1604
1761
|
|
|
@@ -1837,13 +1994,104 @@ def map_logical_plan_relation(
|
|
|
1837
1994
|
)
|
|
1838
1995
|
)
|
|
1839
1996
|
case "Sort":
|
|
1997
|
+
# Process the input first
|
|
1998
|
+
input_proto = map_logical_plan_relation(rel.child())
|
|
1999
|
+
|
|
2000
|
+
# Check if child is a Project - if so, build an alias map for ORDER BY resolution
|
|
2001
|
+
# This handles: SELECT o.date AS order_date ... ORDER BY o.date
|
|
2002
|
+
child_class = str(rel.child().getClass().getSimpleName())
|
|
2003
|
+
alias_map = {}
|
|
2004
|
+
|
|
2005
|
+
if child_class == "Project":
|
|
2006
|
+
# Extract aliases from SELECT clause
|
|
2007
|
+
for proj_expr in list(as_java_list(rel.child().projectList())):
|
|
2008
|
+
if str(proj_expr.getClass().getSimpleName()) == "Alias":
|
|
2009
|
+
alias_name = str(proj_expr.name())
|
|
2010
|
+
child_expr = proj_expr.child()
|
|
2011
|
+
|
|
2012
|
+
# Store mapping from original expression to alias name
|
|
2013
|
+
# Use string representation for matching
|
|
2014
|
+
expr_str = str(child_expr)
|
|
2015
|
+
alias_map[expr_str] = alias_name
|
|
2016
|
+
|
|
2017
|
+
# Also handle UnresolvedAttribute specifically to get the qualified name
|
|
2018
|
+
if (
|
|
2019
|
+
str(child_expr.getClass().getSimpleName())
|
|
2020
|
+
== "UnresolvedAttribute"
|
|
2021
|
+
):
|
|
2022
|
+
# Get the qualified name like "o.date"
|
|
2023
|
+
name_parts = list(as_java_list(child_expr.nameParts()))
|
|
2024
|
+
qualified_name = ".".join(str(part) for part in name_parts)
|
|
2025
|
+
if qualified_name not in alias_map:
|
|
2026
|
+
alias_map[qualified_name] = alias_name
|
|
2027
|
+
|
|
2028
|
+
# Process ORDER BY expressions, substituting aliases where needed
|
|
2029
|
+
order_list = []
|
|
2030
|
+
for order_expr in as_java_list(rel.order()):
|
|
2031
|
+
# Get the child expression from the SortOrder
|
|
2032
|
+
child_expr = order_expr.child()
|
|
2033
|
+
expr_class = str(child_expr.getClass().getSimpleName())
|
|
2034
|
+
|
|
2035
|
+
# Check if this expression matches any aliased expression
|
|
2036
|
+
expr_str = str(child_expr)
|
|
2037
|
+
substituted = False
|
|
2038
|
+
|
|
2039
|
+
if expr_str in alias_map:
|
|
2040
|
+
# Found a match - substitute with alias reference
|
|
2041
|
+
alias_name = alias_map[expr_str]
|
|
2042
|
+
# Create new UnresolvedAttribute for the alias
|
|
2043
|
+
UnresolvedAttribute = jpype.JClass(
|
|
2044
|
+
"org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute"
|
|
2045
|
+
)
|
|
2046
|
+
new_attr = UnresolvedAttribute.quoted(alias_name)
|
|
2047
|
+
|
|
2048
|
+
# Create new SortOrder with substituted expression
|
|
2049
|
+
SortOrder = jpype.JClass(
|
|
2050
|
+
"org.apache.spark.sql.catalyst.expressions.SortOrder"
|
|
2051
|
+
)
|
|
2052
|
+
new_order = SortOrder(
|
|
2053
|
+
new_attr,
|
|
2054
|
+
order_expr.direction(),
|
|
2055
|
+
order_expr.nullOrdering(),
|
|
2056
|
+
order_expr.sameOrderExpressions(),
|
|
2057
|
+
)
|
|
2058
|
+
order_list.append(map_logical_plan_expression(new_order).sort_order)
|
|
2059
|
+
substituted = True
|
|
2060
|
+
elif expr_class == "UnresolvedAttribute":
|
|
2061
|
+
# Try matching on qualified name
|
|
2062
|
+
name_parts = list(as_java_list(child_expr.nameParts()))
|
|
2063
|
+
qualified_name = ".".join(str(part) for part in name_parts)
|
|
2064
|
+
if qualified_name in alias_map:
|
|
2065
|
+
alias_name = alias_map[qualified_name]
|
|
2066
|
+
UnresolvedAttribute = jpype.JClass(
|
|
2067
|
+
"org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute"
|
|
2068
|
+
)
|
|
2069
|
+
new_attr = UnresolvedAttribute.quoted(alias_name)
|
|
2070
|
+
|
|
2071
|
+
SortOrder = jpype.JClass(
|
|
2072
|
+
"org.apache.spark.sql.catalyst.expressions.SortOrder"
|
|
2073
|
+
)
|
|
2074
|
+
new_order = SortOrder(
|
|
2075
|
+
new_attr,
|
|
2076
|
+
order_expr.direction(),
|
|
2077
|
+
order_expr.nullOrdering(),
|
|
2078
|
+
order_expr.sameOrderExpressions(),
|
|
2079
|
+
)
|
|
2080
|
+
order_list.append(
|
|
2081
|
+
map_logical_plan_expression(new_order).sort_order
|
|
2082
|
+
)
|
|
2083
|
+
substituted = True
|
|
2084
|
+
|
|
2085
|
+
if not substituted:
|
|
2086
|
+
# No substitution needed - use original
|
|
2087
|
+
order_list.append(
|
|
2088
|
+
map_logical_plan_expression(order_expr).sort_order
|
|
2089
|
+
)
|
|
2090
|
+
|
|
1840
2091
|
proto = relation_proto.Relation(
|
|
1841
2092
|
sort=relation_proto.Sort(
|
|
1842
|
-
input=
|
|
1843
|
-
order=
|
|
1844
|
-
map_logical_plan_expression(e).sort_order
|
|
1845
|
-
for e in as_java_list(rel.order())
|
|
1846
|
-
],
|
|
2093
|
+
input=input_proto,
|
|
2094
|
+
order=order_list,
|
|
1847
2095
|
)
|
|
1848
2096
|
)
|
|
1849
2097
|
case "SubqueryAlias":
|
|
@@ -2030,10 +2278,16 @@ def map_logical_plan_relation(
|
|
|
2030
2278
|
)
|
|
2031
2279
|
|
|
2032
2280
|
# Re-evaluate the CTE definition with a fresh plan_id
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
|
|
2036
|
-
|
|
2281
|
+
# Clear HAVING condition to prevent leakage from outer CTEs
|
|
2282
|
+
saved_having = _having_condition.get()
|
|
2283
|
+
_having_condition.set(None)
|
|
2284
|
+
try:
|
|
2285
|
+
fresh_plan_id = gen_sql_plan_id()
|
|
2286
|
+
fresh_cte_proto = map_logical_plan_relation(
|
|
2287
|
+
cte_definition, fresh_plan_id
|
|
2288
|
+
)
|
|
2289
|
+
finally:
|
|
2290
|
+
_having_condition.set(saved_having)
|
|
2037
2291
|
|
|
2038
2292
|
# Use SubqueryColumnAliases to ensure consistent column names across CTE references
|
|
2039
2293
|
# This is crucial for CTEs that reference other CTEs
|
|
@@ -2188,16 +2442,7 @@ def map_logical_plan_relation(
|
|
|
2188
2442
|
),
|
|
2189
2443
|
)
|
|
2190
2444
|
case "UnresolvedWith":
|
|
2191
|
-
with
|
|
2192
|
-
for cte in as_java_list(rel.cteRelations()):
|
|
2193
|
-
name = str(cte._1())
|
|
2194
|
-
# Store the original CTE definition for re-evaluation
|
|
2195
|
-
_cte_definitions.get()[name] = cte._2()
|
|
2196
|
-
# Process CTE definition with a unique plan_id to ensure proper column naming
|
|
2197
|
-
cte_plan_id = gen_sql_plan_id()
|
|
2198
|
-
cte_proto = map_logical_plan_relation(cte._2(), cte_plan_id)
|
|
2199
|
-
_ctes.get()[name] = cte_proto
|
|
2200
|
-
|
|
2445
|
+
with _with_cte_scope(rel.cteRelations()):
|
|
2201
2446
|
proto = map_logical_plan_relation(rel.child())
|
|
2202
2447
|
case "LateralJoin":
|
|
2203
2448
|
left = map_logical_plan_relation(rel.left())
|
|
@@ -225,12 +225,20 @@ def _get_supported_read_file_format(unparsed_identifier: str) -> str | None:
|
|
|
225
225
|
return None
|
|
226
226
|
|
|
227
227
|
|
|
228
|
+
# TODO: [SNOW-2465948] Remove this once Snowpark fixes the issue with stage paths.
|
|
229
|
+
class StagePathStr(str):
|
|
230
|
+
def partition(self, __sep):
|
|
231
|
+
if str(self)[0] == "'":
|
|
232
|
+
return str(self)[1:].partition(__sep)
|
|
233
|
+
return str(self).partition(__sep)
|
|
234
|
+
|
|
235
|
+
|
|
228
236
|
def _quote_stage_path(stage_path: str) -> str:
|
|
229
237
|
"""
|
|
230
238
|
Quote stage paths to escape any special characters.
|
|
231
239
|
"""
|
|
232
240
|
if stage_path.startswith("@"):
|
|
233
|
-
return f"'{stage_path}'"
|
|
241
|
+
return StagePathStr(f"'{stage_path}'")
|
|
234
242
|
return stage_path
|
|
235
243
|
|
|
236
244
|
|
|
@@ -6,6 +6,7 @@ import copy
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
8
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
9
|
+
from pyspark.errors.exceptions.base import AnalysisException
|
|
9
10
|
|
|
10
11
|
import snowflake.snowpark.functions as snowpark_fn
|
|
11
12
|
from snowflake import snowpark
|
|
@@ -20,6 +21,7 @@ from snowflake.snowpark_connect.relation.read.metadata_utils import (
|
|
|
20
21
|
get_non_metadata_fields,
|
|
21
22
|
)
|
|
22
23
|
from snowflake.snowpark_connect.relation.read.utils import (
|
|
24
|
+
apply_metadata_exclusion_pattern,
|
|
23
25
|
get_spark_column_names_from_snowpark_columns,
|
|
24
26
|
rename_columns_as_snowflake_standard,
|
|
25
27
|
)
|
|
@@ -62,6 +64,8 @@ def map_read_csv(
|
|
|
62
64
|
snowpark_read_options["INFER_SCHEMA"] = snowpark_options.get(
|
|
63
65
|
"INFER_SCHEMA", False
|
|
64
66
|
)
|
|
67
|
+
|
|
68
|
+
apply_metadata_exclusion_pattern(snowpark_options)
|
|
65
69
|
snowpark_read_options["PATTERN"] = snowpark_options.get("PATTERN", None)
|
|
66
70
|
|
|
67
71
|
raw_options = rel.read.data_source.options
|
|
@@ -157,6 +161,7 @@ def get_header_names(
|
|
|
157
161
|
path: list[str],
|
|
158
162
|
file_format_options: dict,
|
|
159
163
|
snowpark_read_options: dict,
|
|
164
|
+
raw_options: dict,
|
|
160
165
|
) -> list[str]:
|
|
161
166
|
no_header_file_format_options = copy.copy(file_format_options)
|
|
162
167
|
no_header_file_format_options["PARSE_HEADER"] = False
|
|
@@ -168,7 +173,19 @@ def get_header_names(
|
|
|
168
173
|
no_header_snowpark_read_options.pop("INFER_SCHEMA", None)
|
|
169
174
|
|
|
170
175
|
header_df = session.read.options(no_header_snowpark_read_options).csv(path).limit(1)
|
|
171
|
-
|
|
176
|
+
collected_data = header_df.collect()
|
|
177
|
+
|
|
178
|
+
if len(collected_data) == 0:
|
|
179
|
+
error_msg = f"Path does not exist or contains no data: {path}"
|
|
180
|
+
user_pattern = raw_options.get("pathGlobFilter", None)
|
|
181
|
+
if user_pattern:
|
|
182
|
+
error_msg += f" (with pathGlobFilter: {user_pattern})"
|
|
183
|
+
|
|
184
|
+
exception = AnalysisException(error_msg)
|
|
185
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_INPUT)
|
|
186
|
+
raise exception
|
|
187
|
+
|
|
188
|
+
header_data = collected_data[0]
|
|
172
189
|
return [
|
|
173
190
|
f'"{header_data[i]}"'
|
|
174
191
|
for i in range(len(header_df.schema.fields))
|
|
@@ -207,7 +224,7 @@ def read_data(
|
|
|
207
224
|
return df
|
|
208
225
|
|
|
209
226
|
headers = get_header_names(
|
|
210
|
-
session, path, file_format_options, snowpark_read_options
|
|
227
|
+
session, path, file_format_options, snowpark_read_options, raw_options
|
|
211
228
|
)
|
|
212
229
|
|
|
213
230
|
df_schema_fields = non_metadata_fields
|