snowpark-connect 0.28.1__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (47) hide show
  1. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  2. snowflake/snowpark_connect/client.py +65 -0
  3. snowflake/snowpark_connect/column_name_handler.py +6 -0
  4. snowflake/snowpark_connect/config.py +33 -5
  5. snowflake/snowpark_connect/execute_plan/map_execution_root.py +21 -19
  6. snowflake/snowpark_connect/expression/map_extension.py +277 -1
  7. snowflake/snowpark_connect/expression/map_sql_expression.py +107 -2
  8. snowflake/snowpark_connect/expression/map_unresolved_function.py +425 -269
  9. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  10. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  11. snowflake/snowpark_connect/relation/io_utils.py +21 -1
  12. snowflake/snowpark_connect/relation/map_column_ops.py +9 -4
  13. snowflake/snowpark_connect/relation/map_extension.py +21 -4
  14. snowflake/snowpark_connect/relation/map_join.py +8 -0
  15. snowflake/snowpark_connect/relation/map_map_partitions.py +7 -8
  16. snowflake/snowpark_connect/relation/map_relation.py +1 -3
  17. snowflake/snowpark_connect/relation/map_row_ops.py +116 -15
  18. snowflake/snowpark_connect/relation/map_show_string.py +14 -6
  19. snowflake/snowpark_connect/relation/map_sql.py +39 -5
  20. snowflake/snowpark_connect/relation/map_stats.py +1 -1
  21. snowflake/snowpark_connect/relation/read/map_read.py +22 -3
  22. snowflake/snowpark_connect/relation/read/map_read_csv.py +119 -29
  23. snowflake/snowpark_connect/relation/read/map_read_json.py +57 -36
  24. snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -1
  25. snowflake/snowpark_connect/relation/read/map_read_text.py +6 -1
  26. snowflake/snowpark_connect/relation/read/metadata_utils.py +159 -0
  27. snowflake/snowpark_connect/relation/stage_locator.py +85 -53
  28. snowflake/snowpark_connect/relation/write/map_write.py +67 -4
  29. snowflake/snowpark_connect/server.py +29 -16
  30. snowflake/snowpark_connect/type_mapping.py +75 -3
  31. snowflake/snowpark_connect/utils/context.py +0 -14
  32. snowflake/snowpark_connect/utils/describe_query_cache.py +6 -3
  33. snowflake/snowpark_connect/utils/io_utils.py +36 -0
  34. snowflake/snowpark_connect/utils/session.py +4 -0
  35. snowflake/snowpark_connect/utils/telemetry.py +30 -5
  36. snowflake/snowpark_connect/utils/udf_cache.py +37 -7
  37. snowflake/snowpark_connect/version.py +1 -1
  38. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/METADATA +3 -2
  39. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/RECORD +47 -45
  40. {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-connect +0 -0
  41. {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-session +0 -0
  42. {snowpark_connect-0.28.1.data → snowpark_connect-0.30.0.data}/scripts/snowpark-submit +0 -0
  43. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/WHEEL +0 -0
  44. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/LICENSE-binary +0 -0
  45. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/LICENSE.txt +0 -0
  46. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/licenses/NOTICE-binary +0 -0
  47. {snowpark_connect-0.28.1.dist-info → snowpark_connect-0.30.0.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e
16
16
  from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_relations__pb2
17
17
 
18
18
 
19
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1esnowflake_expression_ext.proto\x12\rsnowflake.ext\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\"\x98\x01\n\x0c\x45xpExtension\x12@\n\x0enamed_argument\x18\x01 \x01(\x0b\x32&.snowflake.ext.NamedArgumentExpressionH\x00\x12@\n\x13subquery_expression\x18\x02 \x01(\x0b\x32!.snowflake.ext.SubqueryExpressionH\x00\x42\x04\n\x02op\"P\n\x17NamedArgumentExpression\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.Expression\"\xf4\x04\n\x12SubqueryExpression\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x45\n\rsubquery_type\x18\x02 \x01(\x0e\x32..snowflake.ext.SubqueryExpression.SubqueryType\x12Q\n\x11table_arg_options\x18\x03 \x01(\x0b\x32\x31.snowflake.ext.SubqueryExpression.TableArgOptionsH\x00\x88\x01\x01\x12\x35\n\x12in_subquery_values\x18\x04 \x03(\x0b\x32\x19.spark.connect.Expression\x1a\xbb\x01\n\x0fTableArgOptions\x12\x31\n\x0epartition_spec\x18\x01 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x37\n\norder_spec\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrder\x12\"\n\x15with_single_partition\x18\x03 \x01(\x08H\x00\x88\x01\x01\x42\x18\n\x16_with_single_partition\"\x90\x01\n\x0cSubqueryType\x12\x19\n\x15SUBQUERY_TYPE_UNKNOWN\x10\x00\x12\x18\n\x14SUBQUERY_TYPE_SCALAR\x10\x01\x12\x18\n\x14SUBQUERY_TYPE_EXISTS\x10\x02\x12\x1b\n\x17SUBQUERY_TYPE_TABLE_ARG\x10\x03\x12\x14\n\x10SUBQUERY_TYPE_IN\x10\x04\x42\x14\n\x12_table_arg_optionsb\x06proto3')
19
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1esnowflake_expression_ext.proto\x12\rsnowflake.ext\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\"\xde\x01\n\x0c\x45xpExtension\x12@\n\x0enamed_argument\x18\x01 \x01(\x0b\x32&.snowflake.ext.NamedArgumentExpressionH\x00\x12@\n\x13subquery_expression\x18\x02 \x01(\x0b\x32!.snowflake.ext.SubqueryExpressionH\x00\x12\x44\n\x10interval_literal\x18\x03 \x01(\x0b\x32(.snowflake.ext.IntervalLiteralExpressionH\x00\x42\x04\n\x02op\"P\n\x17NamedArgumentExpression\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.Expression\"\xf4\x04\n\x12SubqueryExpression\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x45\n\rsubquery_type\x18\x02 \x01(\x0e\x32..snowflake.ext.SubqueryExpression.SubqueryType\x12Q\n\x11table_arg_options\x18\x03 \x01(\x0b\x32\x31.snowflake.ext.SubqueryExpression.TableArgOptionsH\x00\x88\x01\x01\x12\x35\n\x12in_subquery_values\x18\x04 \x03(\x0b\x32\x19.spark.connect.Expression\x1a\xbb\x01\n\x0fTableArgOptions\x12\x31\n\x0epartition_spec\x18\x01 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x37\n\norder_spec\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrder\x12\"\n\x15with_single_partition\x18\x03 \x01(\x08H\x00\x88\x01\x01\x42\x18\n\x16_with_single_partition\"\x90\x01\n\x0cSubqueryType\x12\x19\n\x15SUBQUERY_TYPE_UNKNOWN\x10\x00\x12\x18\n\x14SUBQUERY_TYPE_SCALAR\x10\x01\x12\x18\n\x14SUBQUERY_TYPE_EXISTS\x10\x02\x12\x1b\n\x17SUBQUERY_TYPE_TABLE_ARG\x10\x03\x12\x14\n\x10SUBQUERY_TYPE_IN\x10\x04\x42\x14\n\x12_table_arg_options\"\x9f\x01\n\x19IntervalLiteralExpression\x12\x32\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.Literal\x12\x18\n\x0bstart_field\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12\x16\n\tend_field\x18\x03 \x01(\x05H\x01\x88\x01\x01\x42\x0e\n\x0c_start_fieldB\x0c\n\n_end_fieldb\x06proto3')
20
20
 
21
21
  _globals = globals()
22
22
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -24,13 +24,15 @@ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'snowflake_expression_ext_pb
24
24
  if _descriptor._USE_C_DESCRIPTORS == False:
25
25
  DESCRIPTOR._options = None
26
26
  _globals['_EXPEXTENSION']._serialized_start=114
27
- _globals['_EXPEXTENSION']._serialized_end=266
28
- _globals['_NAMEDARGUMENTEXPRESSION']._serialized_start=268
29
- _globals['_NAMEDARGUMENTEXPRESSION']._serialized_end=348
30
- _globals['_SUBQUERYEXPRESSION']._serialized_start=351
31
- _globals['_SUBQUERYEXPRESSION']._serialized_end=979
32
- _globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_start=623
33
- _globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_end=810
34
- _globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_start=813
35
- _globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_end=957
27
+ _globals['_EXPEXTENSION']._serialized_end=336
28
+ _globals['_NAMEDARGUMENTEXPRESSION']._serialized_start=338
29
+ _globals['_NAMEDARGUMENTEXPRESSION']._serialized_end=418
30
+ _globals['_SUBQUERYEXPRESSION']._serialized_start=421
31
+ _globals['_SUBQUERYEXPRESSION']._serialized_end=1049
32
+ _globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_start=693
33
+ _globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_end=880
34
+ _globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_start=883
35
+ _globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_end=1027
36
+ _globals['_INTERVALLITERALEXPRESSION']._serialized_start=1052
37
+ _globals['_INTERVALLITERALEXPRESSION']._serialized_end=1211
36
38
  # @@protoc_insertion_point(module_scope)
@@ -9,12 +9,14 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map
9
9
  DESCRIPTOR: _descriptor.FileDescriptor
10
10
 
11
11
  class ExpExtension(_message.Message):
12
- __slots__ = ("named_argument", "subquery_expression")
12
+ __slots__ = ("named_argument", "subquery_expression", "interval_literal")
13
13
  NAMED_ARGUMENT_FIELD_NUMBER: _ClassVar[int]
14
14
  SUBQUERY_EXPRESSION_FIELD_NUMBER: _ClassVar[int]
15
+ INTERVAL_LITERAL_FIELD_NUMBER: _ClassVar[int]
15
16
  named_argument: NamedArgumentExpression
16
17
  subquery_expression: SubqueryExpression
17
- def __init__(self, named_argument: _Optional[_Union[NamedArgumentExpression, _Mapping]] = ..., subquery_expression: _Optional[_Union[SubqueryExpression, _Mapping]] = ...) -> None: ...
18
+ interval_literal: IntervalLiteralExpression
19
+ def __init__(self, named_argument: _Optional[_Union[NamedArgumentExpression, _Mapping]] = ..., subquery_expression: _Optional[_Union[SubqueryExpression, _Mapping]] = ..., interval_literal: _Optional[_Union[IntervalLiteralExpression, _Mapping]] = ...) -> None: ...
18
20
 
19
21
  class NamedArgumentExpression(_message.Message):
20
22
  __slots__ = ("key", "value")
@@ -56,3 +58,13 @@ class SubqueryExpression(_message.Message):
56
58
  table_arg_options: SubqueryExpression.TableArgOptions
57
59
  in_subquery_values: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
58
60
  def __init__(self, input: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., subquery_type: _Optional[_Union[SubqueryExpression.SubqueryType, str]] = ..., table_arg_options: _Optional[_Union[SubqueryExpression.TableArgOptions, _Mapping]] = ..., in_subquery_values: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ...) -> None: ...
61
+
62
+ class IntervalLiteralExpression(_message.Message):
63
+ __slots__ = ("literal", "start_field", "end_field")
64
+ LITERAL_FIELD_NUMBER: _ClassVar[int]
65
+ START_FIELD_FIELD_NUMBER: _ClassVar[int]
66
+ END_FIELD_FIELD_NUMBER: _ClassVar[int]
67
+ literal: _expressions_pb2.Expression.Literal
68
+ start_field: int
69
+ end_field: int
70
+ def __init__(self, literal: _Optional[_Union[_expressions_pb2.Expression.Literal, _Mapping]] = ..., start_field: _Optional[int] = ..., end_field: _Optional[int] = ...) -> None: ...
@@ -7,8 +7,27 @@ from urllib.parse import urlparse
7
7
  CLOUD_PREFIX_TO_CLOUD = {
8
8
  "abfss": "azure",
9
9
  "wasbs": "azure",
10
+ "gcs": "gcp",
11
+ "gs": "gcp",
10
12
  }
11
13
 
14
+ SUPPORTED_COMPRESSION_PER_FORMAT = {
15
+ "csv": {"AUTO", "GZIP", "BZ2", "BROTLI", "ZSTD", "DEFLATE", "RAW_DEFLATE", "NONE"},
16
+ "json": {"AUTO", "GZIP", "BZ2", "BROTLI", "ZSTD", "DEFLATE", "RAW_DEFLATE", "NONE"},
17
+ "parquet": {"AUTO", "LZO", "SNAPPY", "NONE"},
18
+ "text": {"NONE"},
19
+ }
20
+
21
+
22
+ def supported_compressions_for_format(format: str) -> set[str]:
23
+ return SUPPORTED_COMPRESSION_PER_FORMAT.get(format, set())
24
+
25
+
26
+ def is_supported_compression(format: str, compression: str | None) -> bool:
27
+ if compression is None:
28
+ return True
29
+ return compression in supported_compressions_for_format(format)
30
+
12
31
 
13
32
  def get_cloud_from_url(
14
33
  url: str,
@@ -66,7 +85,8 @@ def is_cloud_path(path: str) -> bool:
66
85
  or path.startswith("azure://")
67
86
  or path.startswith("abfss://")
68
87
  or path.startswith("wasbs://") # Azure
69
- or path.startswith("gcs://") # GCP
88
+ or path.startswith("gcs://")
89
+ or path.startswith("gs://") # GCP
70
90
  )
71
91
 
72
92
 
@@ -1124,10 +1124,15 @@ def map_group_map(
1124
1124
  group_by_df = input_df.group_by(*snowpark_grouping_expressions)
1125
1125
  inner_df = group_by_df._dataframe
1126
1126
 
1127
- result = inner_df.select(
1128
- snowpark_fn.call_table_function(
1129
- apply_udtf_temp_name, *inner_df.columns
1130
- ).over(partition_by=snowpark_grouping_expressions)
1127
+ renamed_columns = [f"snowflake_jtf_{column}" for column in input_df.columns]
1128
+ tfc = snowpark_fn.call_table_function(
1129
+ apply_udtf_temp_name, *renamed_columns
1130
+ ).over(partition_by=snowpark_grouping_expressions)
1131
+
1132
+ result = (
1133
+ inner_df.to_df(renamed_columns)
1134
+ .join_table_function(tfc)
1135
+ .drop(*renamed_columns)
1131
1136
  )
1132
1137
  else:
1133
1138
  (
@@ -345,7 +345,7 @@ def map_aggregate(
345
345
  return new_names[0], snowpark_column
346
346
 
347
347
  raw_groupings: list[tuple[str, TypedColumn]] = []
348
- raw_aggregations: list[tuple[str, TypedColumn]] = []
348
+ raw_aggregations: list[tuple[str, TypedColumn, list[str]]] = []
349
349
 
350
350
  if not is_group_by_all:
351
351
  raw_groupings = [_map_column(exp) for exp in aggregate.grouping_expressions]
@@ -375,10 +375,21 @@ def map_aggregate(
375
375
  # Note: We don't clear the map here to preserve any parent context aliases
376
376
  from snowflake.snowpark_connect.utils.context import register_lca_alias
377
377
 
378
+ # If it's an unresolved attribute when its in aggregate.aggregate_expressions, we know it came from the parent map straight away
379
+ # in this case, we should see if the parent map has a qualifier for it and propagate that here, in case the order by references it in
380
+ # a qualified way later.
378
381
  agg_count = get_sql_aggregate_function_count()
379
382
  for exp in aggregate.aggregate_expressions:
380
383
  col = _map_column(exp)
381
- raw_aggregations.append(col)
384
+ if exp.WhichOneof("expr_type") == "unresolved_attribute":
385
+ spark_name = col[0]
386
+ qualifiers = input_container.column_map.get_qualifier_for_spark_column(
387
+ spark_name
388
+ )
389
+ else:
390
+ qualifiers = []
391
+
392
+ raw_aggregations.append((col[0], col[1], qualifiers))
382
393
 
383
394
  # If this is an alias, register it in the LCA map for subsequent expressions
384
395
  if (
@@ -409,18 +420,20 @@ def map_aggregate(
409
420
  spark_columns: list[str] = []
410
421
  snowpark_columns: list[str] = []
411
422
  snowpark_column_types: list[snowpark_types.DataType] = []
423
+ all_qualifiers: list[list[str]] = []
412
424
 
413
425
  # Use grouping columns directly without aliases
414
426
  groupings = [col.col for _, col in raw_groupings]
415
427
 
416
428
  # Create aliases only for aggregation columns
417
429
  aggregations = []
418
- for i, (spark_name, snowpark_column) in enumerate(raw_aggregations):
430
+ for i, (spark_name, snowpark_column, qualifiers) in enumerate(raw_aggregations):
419
431
  alias = make_column_names_snowpark_compatible([spark_name], plan_id, i)[0]
420
432
 
421
433
  spark_columns.append(spark_name)
422
434
  snowpark_columns.append(alias)
423
435
  snowpark_column_types.append(snowpark_column.typ)
436
+ all_qualifiers.append(qualifiers)
424
437
 
425
438
  aggregations.append(snowpark_column.col.alias(alias))
426
439
 
@@ -483,6 +496,7 @@ def map_aggregate(
483
496
  spark_column_names=spark_columns,
484
497
  snowpark_column_names=snowpark_columns,
485
498
  snowpark_column_types=snowpark_column_types,
499
+ column_qualifiers=all_qualifiers,
486
500
  ).column_map
487
501
 
488
502
  # Create hybrid column map that can resolve both input and aggregate contexts
@@ -494,7 +508,9 @@ def map_aggregate(
494
508
  aggregate_expressions=list(aggregate.aggregate_expressions),
495
509
  grouping_expressions=list(aggregate.grouping_expressions),
496
510
  spark_columns=spark_columns,
497
- raw_aggregations=raw_aggregations,
511
+ raw_aggregations=[
512
+ (spark_name, col) for spark_name, col, _ in raw_aggregations
513
+ ],
498
514
  )
499
515
 
500
516
  # Map the HAVING condition using hybrid resolution
@@ -515,4 +531,5 @@ def map_aggregate(
515
531
  snowpark_column_names=snowpark_columns,
516
532
  snowpark_column_types=snowpark_column_types,
517
533
  parent_column_name_map=input_df._column_map,
534
+ column_qualifiers=all_qualifiers,
518
535
  )
@@ -21,6 +21,9 @@ from snowflake.snowpark_connect.relation.map_relation import (
21
21
  NATURAL_JOIN_TYPE_BASE,
22
22
  map_relation,
23
23
  )
24
+ from snowflake.snowpark_connect.relation.read.metadata_utils import (
25
+ filter_metadata_columns,
26
+ )
24
27
  from snowflake.snowpark_connect.utils.context import (
25
28
  push_evaluating_join_condition,
26
29
  push_sql_scope,
@@ -38,6 +41,11 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
38
41
  left_container: DataFrameContainer = map_relation(rel.join.left)
39
42
  right_container: DataFrameContainer = map_relation(rel.join.right)
40
43
 
44
+ # Remove any metadata columns(like metada$filename) present in the dataframes.
45
+ # We cannot support inputfilename for multisources as each dataframe has it's own source.
46
+ left_container = filter_metadata_columns(left_container)
47
+ right_container = filter_metadata_columns(right_container)
48
+
41
49
  left_input: snowpark.DataFrame = left_container.dataframe
42
50
  right_input: snowpark.DataFrame = right_container.dataframe
43
51
  is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
@@ -12,7 +12,6 @@ from snowflake.snowpark_connect.constants import MAP_IN_ARROW_EVAL_TYPE
12
12
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
13
13
  from snowflake.snowpark_connect.relation.map_relation import map_relation
14
14
  from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
15
- from snowflake.snowpark_connect.utils.context import map_partitions_depth
16
15
  from snowflake.snowpark_connect.utils.pandas_udtf_utils import (
17
16
  create_pandas_udtf,
18
17
  create_pandas_udtf_with_arrow,
@@ -53,18 +52,18 @@ def _call_udtf(
53
52
  ).cast("int"),
54
53
  )
55
54
 
56
- udtf_columns = input_df.columns + [snowpark_fn.col("_DUMMY_PARTITION_KEY")]
55
+ udtf_columns = [f"snowflake_jtf_{column}" for column in input_df.columns] + [
56
+ "_DUMMY_PARTITION_KEY"
57
+ ]
57
58
 
58
59
  tfc = snowpark_fn.call_table_function(udtf_name, *udtf_columns).over(
59
60
  partition_by=[snowpark_fn.col("_DUMMY_PARTITION_KEY")]
60
61
  )
61
62
 
62
- # Use map_partitions_depth only when mapping non nested map_partitions
63
- # When mapping chained functions additional column casting is necessary
64
- if map_partitions_depth() == 1:
65
- result_df_with_dummy = input_df_with_dummy.join_table_function(tfc)
66
- else:
67
- result_df_with_dummy = input_df_with_dummy.select(tfc)
63
+ # Overwrite the input_df columns to prevent name conflicts with UDTF output columns
64
+ result_df_with_dummy = input_df_with_dummy.to_df(udtf_columns).join_table_function(
65
+ tfc
66
+ )
68
67
 
69
68
  output_cols = [field.name for field in return_type.fields]
70
69
 
@@ -16,7 +16,6 @@ from snowflake.snowpark_connect.utils.context import (
16
16
  get_plan_id_map,
17
17
  get_session_id,
18
18
  not_resolving_fun_args,
19
- push_map_partitions,
20
19
  push_operation_scope,
21
20
  set_is_aggregate_function,
22
21
  set_plan_id_map,
@@ -185,8 +184,7 @@ def map_relation(
185
184
  )
186
185
  return cached_df
187
186
  case "map_partitions":
188
- with push_map_partitions():
189
- result = map_map_partitions.map_map_partitions(rel)
187
+ result = map_map_partitions.map_map_partitions(rel)
190
188
  case "offset":
191
189
  result = map_row_ops.map_offset(rel)
192
190
  case "project":
@@ -1,15 +1,14 @@
1
1
  #
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
-
5
-
6
4
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
7
5
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
8
6
  from pyspark.errors.exceptions.base import AnalysisException, IllegalArgumentException
9
7
 
10
8
  import snowflake.snowpark_connect.relation.utils as utils
11
9
  from snowflake import snowpark
12
- from snowflake.snowpark.functions import col, expr as snowpark_expr
10
+ from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
11
+ from snowflake.snowpark.functions import col, expr as snowpark_expr, lit
13
12
  from snowflake.snowpark.types import (
14
13
  BooleanType,
15
14
  ByteType,
@@ -20,8 +19,14 @@ from snowflake.snowpark.types import (
20
19
  LongType,
21
20
  NullType,
22
21
  ShortType,
22
+ StructField,
23
+ StructType,
24
+ )
25
+ from snowflake.snowpark_connect.column_name_handler import (
26
+ ColumnNameMap,
27
+ schema_getter,
28
+ set_schema_getter,
23
29
  )
24
- from snowflake.snowpark_connect.column_name_handler import ColumnNameMap, schema_getter
25
30
  from snowflake.snowpark_connect.config import global_config
26
31
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
27
32
  from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
@@ -318,23 +323,37 @@ def map_union(
318
323
  right_column_map = right_result.column_map
319
324
  columns_to_restore: dict[str, tuple[str, str]] = {}
320
325
 
321
- for column in right_df.columns:
326
+ original_right_schema = right_df.schema
327
+ right_renamed_fields = []
328
+ for field in original_right_schema.fields:
322
329
  spark_name = (
323
- right_column_map.get_spark_column_name_from_snowpark_column_name(column)
330
+ right_column_map.get_spark_column_name_from_snowpark_column_name(
331
+ field.name
332
+ )
324
333
  )
325
- right_df = right_df.withColumnRenamed(column, spark_name)
326
- columns_to_restore[spark_name.upper()] = (spark_name, column)
334
+ right_df = right_df.withColumnRenamed(field.name, spark_name)
335
+ columns_to_restore[spark_name.upper()] = (spark_name, field.name)
336
+ right_renamed_fields.append(
337
+ StructField(spark_name, field.datatype, field.nullable)
338
+ )
339
+ set_schema_getter(right_df, lambda: StructType(right_renamed_fields))
327
340
 
328
- for column in left_df.columns:
341
+ original_left_schema = left_df.schema
342
+ left_renamed_fields = []
343
+ for field in original_left_schema.fields:
329
344
  spark_name = (
330
- left_column_map.get_spark_column_name_from_snowpark_column_name(column)
345
+ left_column_map.get_spark_column_name_from_snowpark_column_name(
346
+ field.name
347
+ )
331
348
  )
332
- left_df = left_df.withColumnRenamed(column, spark_name)
333
- columns_to_restore[spark_name.upper()] = (spark_name, column)
349
+ left_df = left_df.withColumnRenamed(field.name, spark_name)
350
+ columns_to_restore[spark_name.upper()] = (spark_name, field.name)
351
+ left_renamed_fields.append(
352
+ StructField(spark_name, field.datatype, field.nullable)
353
+ )
354
+ set_schema_getter(left_df, lambda: StructType(left_renamed_fields))
334
355
 
335
- result = left_df.unionAllByName(
336
- right_df, allow_missing_columns=allow_missing_columns
337
- )
356
+ result = _union_by_name_optimized(left_df, right_df, allow_missing_columns)
338
357
 
339
358
  if allow_missing_columns:
340
359
  spark_columns = []
@@ -809,3 +828,85 @@ def map_tail(
809
828
  alias=input_container.alias,
810
829
  cached_schema_getter=lambda: input_df.schema,
811
830
  )
831
+
832
+
833
+ def _union_by_name_optimized(
834
+ left_df: snowpark.DataFrame,
835
+ right_df: snowpark.DataFrame,
836
+ allow_missing_columns: bool = False,
837
+ ) -> snowpark.DataFrame:
838
+ """
839
+ This implementation is an optimized version of Snowpark's Dataframe::_union_by_name_internal.
840
+ The only change is, that it avoids redundant schema queries that occur in the standard Snowpark,
841
+ by reusing already-fetched/calculated schemas.
842
+ """
843
+
844
+ left_schema = left_df.schema
845
+ right_schema = right_df.schema
846
+
847
+ left_cols = {field.name for field in left_schema.fields}
848
+ right_cols = {field.name for field in right_schema.fields}
849
+ right_field_map = {field.name: field for field in right_schema.fields}
850
+
851
+ missing_left = right_cols - left_cols
852
+ missing_right = left_cols - right_cols
853
+
854
+ def add_nulls(
855
+ missing_cols: set[str], to_df: snowpark.DataFrame, from_df: snowpark.DataFrame
856
+ ) -> snowpark.DataFrame:
857
+ dt_map = {field.name: field.datatype for field in from_df.schema.fields}
858
+ result = to_df.select(
859
+ "*",
860
+ *[lit(None).cast(dt_map[col]).alias(col) for col in missing_cols],
861
+ )
862
+
863
+ result_fields = []
864
+ for field in to_df.schema.fields:
865
+ result_fields.append(
866
+ StructField(field.name, field.datatype, field.nullable)
867
+ )
868
+ for col_name in missing_cols:
869
+ from_field = next(
870
+ field for field in from_df.schema.fields if field.name == col_name
871
+ )
872
+ result_fields.append(
873
+ StructField(col_name, from_field.datatype, from_field.nullable)
874
+ )
875
+
876
+ set_schema_getter(result, lambda: StructType(result_fields))
877
+
878
+ return result
879
+
880
+ if missing_left or missing_right:
881
+ if allow_missing_columns:
882
+ left = left_df
883
+ right = right_df
884
+ if missing_left:
885
+ left = add_nulls(missing_left, left, right)
886
+ if missing_right:
887
+ right = add_nulls(missing_right, right, left)
888
+ result = left._union_by_name_internal(right, is_all=True)
889
+
890
+ result_fields = []
891
+ for field in left_schema.fields:
892
+ result_fields.append(
893
+ StructField(field.name, field.datatype, field.nullable)
894
+ )
895
+ for col_name in missing_left:
896
+ right_field = right_field_map[col_name]
897
+ result_fields.append(
898
+ StructField(col_name, right_field.datatype, right_field.nullable)
899
+ )
900
+
901
+ set_schema_getter(result, lambda: StructType(result_fields))
902
+ return result
903
+ else:
904
+ raise SnowparkClientExceptionMessages.DF_CANNOT_RESOLVE_COLUMN_NAME_AMONG(
905
+ missing_left, missing_right
906
+ )
907
+
908
+ result = left_df.unionAllByName(
909
+ right_df, allow_missing_columns=allow_missing_columns
910
+ )
911
+ set_schema_getter(result, lambda: left_df.schema)
912
+ return result
@@ -15,6 +15,9 @@ from snowflake.snowpark_connect.column_name_handler import set_schema_getter
15
15
  from snowflake.snowpark_connect.config import global_config
16
16
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
17
17
  from snowflake.snowpark_connect.relation.map_relation import map_relation
18
+ from snowflake.snowpark_connect.relation.read.metadata_utils import (
19
+ filter_metadata_columns,
20
+ )
18
21
 
19
22
 
20
23
  def map_show_string(rel: relation_proto.Relation) -> pandas.DataFrame:
@@ -26,14 +29,17 @@ def map_show_string(rel: relation_proto.Relation) -> pandas.DataFrame:
26
29
  Buffer object as a single cell.
27
30
  """
28
31
  input_df_container: DataFrameContainer = map_relation(rel.show_string.input)
29
- raw_input_df = input_df_container.dataframe
30
- input_df = _handle_datetype_columns(raw_input_df)
32
+ filtered_container = filter_metadata_columns(input_df_container)
33
+ display_df = filtered_container.dataframe
34
+ display_spark_columns = filtered_container.column_map.get_spark_columns()
35
+
36
+ input_df = _handle_datetype_columns(display_df)
31
37
 
32
38
  show_string = input_df._show_string_spark(
33
39
  num_rows=rel.show_string.num_rows,
34
40
  truncate=rel.show_string.truncate,
35
41
  vertical=rel.show_string.vertical,
36
- _spark_column_names=input_df_container.column_map.get_spark_columns(),
42
+ _spark_column_names=display_spark_columns,
37
43
  _spark_session_tz=global_config.spark_sql_session_timeZone,
38
44
  )
39
45
  return pandas.DataFrame({"show_string": [show_string]})
@@ -44,14 +50,16 @@ def map_repr_html(rel: relation_proto.Relation) -> pandas.DataFrame:
44
50
  Generate the html string representation of the input dataframe.
45
51
  """
46
52
  input_df_container: DataFrameContainer = map_relation(rel.html_string.input)
47
- input_df = input_df_container.dataframe
53
+
54
+ filtered_container = filter_metadata_columns(input_df_container)
55
+ input_df = filtered_container.dataframe
48
56
 
49
57
  input_panda = input_df.toPandas()
50
58
  input_panda.rename(
51
59
  columns={
52
60
  analyzer_utils.unquote_if_quoted(
53
- input_df_container.column_map.get_snowpark_columns()[i]
54
- ): input_df_container.column_map.get_spark_columns()[i]
61
+ filtered_container.column_map.get_snowpark_columns()[i]
62
+ ): filtered_container.column_map.get_spark_columns()[i]
55
63
  for i in range(len(input_panda.columns))
56
64
  },
57
65
  inplace=True,
@@ -7,6 +7,7 @@ from collections.abc import MutableMapping, MutableSequence
7
7
  from contextlib import contextmanager, suppress
8
8
  from contextvars import ContextVar
9
9
  from functools import reduce
10
+ from typing import Tuple
10
11
 
11
12
  import jpype
12
13
  import pandas
@@ -31,6 +32,10 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
31
32
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
32
33
  from snowflake.snowpark._internal.utils import is_sql_select_statement, quote_name
33
34
  from snowflake.snowpark.functions import when_matched, when_not_matched
35
+ from snowflake.snowpark_connect.client import (
36
+ SQL_PASS_THROUGH_MARKER,
37
+ calculate_checksum,
38
+ )
34
39
  from snowflake.snowpark_connect.config import (
35
40
  auto_uppercase_non_column_identifiers,
36
41
  check_table_supports_operation,
@@ -397,7 +402,7 @@ def map_sql_to_pandas_df(
397
402
  returns a tuple of None for SELECT queries to enable lazy evaluation
398
403
  """
399
404
 
400
- snowpark_connect_sql_passthrough = get_sql_passthrough()
405
+ snowpark_connect_sql_passthrough, sql_string = is_valid_passthrough_sql(sql_string)
401
406
 
402
407
  if not snowpark_connect_sql_passthrough:
403
408
  logical_plan = sql_parser().parsePlan(sql_string)
@@ -1047,7 +1052,7 @@ def map_sql_to_pandas_df(
1047
1052
  raise AnalysisException(
1048
1053
  f"ALTER TABLE RENAME COLUMN is not supported for table '{full_table_identifier}'. "
1049
1054
  f"This table was created as a v1 table with a data source that doesn't support column renaming. "
1050
- f"To enable this operation, set 'enable_snowflake_extension_behavior' to 'true'."
1055
+ f"To enable this operation, set 'snowpark.connect.enable_snowflake_extension_behavior' to 'true'."
1051
1056
  )
1052
1057
 
1053
1058
  column_obj = logical_plan.column()
@@ -1282,6 +1287,14 @@ def map_sql_to_pandas_df(
1282
1287
  return pandas.DataFrame({"": [""]}), ""
1283
1288
 
1284
1289
  rows = session.sql(snowflake_sql).collect()
1290
+ case "RefreshTable":
1291
+ table_name_unquoted = ".".join(
1292
+ str(part)
1293
+ for part in as_java_list(logical_plan.child().multipartIdentifier())
1294
+ )
1295
+ SNOWFLAKE_CATALOG.refreshTable(table_name_unquoted)
1296
+
1297
+ return pandas.DataFrame({"": [""]}), ""
1285
1298
  case _:
1286
1299
  execute_logical_plan(logical_plan)
1287
1300
  return None, None
@@ -1302,6 +1315,27 @@ def get_sql_passthrough() -> bool:
1302
1315
  return get_boolean_session_config_param("snowpark.connect.sql.passthrough")
1303
1316
 
1304
1317
 
1318
+ def is_valid_passthrough_sql(sql_stmt: str) -> Tuple[bool, str]:
1319
+ """
1320
+ Checks if :param sql_stmt: should be executed as SQL pass-through. SQL pass-through can be detected in 1 of 2 ways:
1321
+ 1) Either Spark config parameter "snowpark.connect.sql.passthrough" is set (legacy mode, to be deprecated)
1322
+ 2) If :param sql_stmt: is created through SnowflakeSession and has correct marker + checksum
1323
+ """
1324
+ if get_sql_passthrough():
1325
+ # legacy style pass-through, sql_stmt should be a whole, valid SQL statement
1326
+ return True, sql_stmt
1327
+
1328
+ # check for new style, SnowflakeSession based SQL pass-through
1329
+ sql_parts = sql_stmt.split(" ", 2)
1330
+ if len(sql_parts) == 3:
1331
+ marker, checksum, sql = sql_parts
1332
+ if marker == SQL_PASS_THROUGH_MARKER and checksum == calculate_checksum(sql):
1333
+ return True, sql
1334
+
1335
+ # Not a SQL pass-through
1336
+ return False, sql_stmt
1337
+
1338
+
1305
1339
  def change_default_to_public(name: str) -> str:
1306
1340
  """
1307
1341
  Change the namespace to PUBLIC when given name is DEFAULT
@@ -1397,10 +1431,10 @@ def map_sql(
1397
1431
  In passthough mode as True, SAS calls session.sql() and not calling Spark Parser.
1398
1432
  This is to mitigate any issue not covered by spark logical plan to protobuf conversion.
1399
1433
  """
1400
- snowpark_connect_sql_passthrough = get_sql_passthrough()
1434
+ snowpark_connect_sql_passthrough, sql_stmt = is_valid_passthrough_sql(rel.sql.query)
1401
1435
 
1402
1436
  if not snowpark_connect_sql_passthrough:
1403
- logical_plan = sql_parser().parseQuery(rel.sql.query)
1437
+ logical_plan = sql_parser().parseQuery(sql_stmt)
1404
1438
 
1405
1439
  parsed_pos_args = parse_pos_args(logical_plan, rel.sql.pos_args)
1406
1440
  set_sql_args(rel.sql.args, parsed_pos_args)
@@ -1408,7 +1442,7 @@ def map_sql(
1408
1442
  return execute_logical_plan(logical_plan)
1409
1443
  else:
1410
1444
  session = snowpark.Session.get_active_session()
1411
- sql_df = session.sql(rel.sql.query)
1445
+ sql_df = session.sql(sql_stmt)
1412
1446
  columns = sql_df.columns
1413
1447
  return DataFrameContainer.create_with_column_mapping(
1414
1448
  dataframe=sql_df,
@@ -81,7 +81,7 @@ def map_approx_quantile(
81
81
  input_df = input_container.dataframe
82
82
 
83
83
  snowflake_compatible = get_boolean_session_config_param(
84
- "enable_snowflake_extension_behavior"
84
+ "snowpark.connect.enable_snowflake_extension_behavior"
85
85
  )
86
86
 
87
87
  if not snowflake_compatible: