snowpark-connect 0.29.0__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (41) hide show
  1. snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
  2. snowflake/snowpark_connect/client.py +65 -0
  3. snowflake/snowpark_connect/column_name_handler.py +6 -0
  4. snowflake/snowpark_connect/config.py +25 -3
  5. snowflake/snowpark_connect/execute_plan/map_execution_root.py +21 -19
  6. snowflake/snowpark_connect/expression/map_extension.py +277 -1
  7. snowflake/snowpark_connect/expression/map_sql_expression.py +107 -2
  8. snowflake/snowpark_connect/expression/map_unresolved_function.py +253 -59
  9. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
  10. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
  11. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
  12. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
  13. snowflake/snowpark_connect/relation/io_utils.py +61 -4
  14. snowflake/snowpark_connect/relation/map_column_ops.py +9 -4
  15. snowflake/snowpark_connect/relation/map_join.py +8 -0
  16. snowflake/snowpark_connect/relation/map_row_ops.py +129 -17
  17. snowflake/snowpark_connect/relation/map_show_string.py +14 -6
  18. snowflake/snowpark_connect/relation/map_sql.py +39 -5
  19. snowflake/snowpark_connect/relation/map_stats.py +21 -6
  20. snowflake/snowpark_connect/relation/read/map_read.py +9 -0
  21. snowflake/snowpark_connect/relation/read/map_read_csv.py +17 -6
  22. snowflake/snowpark_connect/relation/read/map_read_json.py +12 -2
  23. snowflake/snowpark_connect/relation/read/map_read_parquet.py +7 -1
  24. snowflake/snowpark_connect/relation/read/metadata_utils.py +159 -0
  25. snowflake/snowpark_connect/relation/utils.py +19 -2
  26. snowflake/snowpark_connect/relation/write/map_write.py +44 -29
  27. snowflake/snowpark_connect/server.py +11 -3
  28. snowflake/snowpark_connect/type_mapping.py +75 -3
  29. snowflake/snowpark_connect/utils/describe_query_cache.py +6 -3
  30. snowflake/snowpark_connect/utils/telemetry.py +105 -23
  31. snowflake/snowpark_connect/version.py +1 -1
  32. {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/METADATA +1 -1
  33. {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/RECORD +41 -37
  34. {snowpark_connect-0.29.0.data → snowpark_connect-0.30.1.data}/scripts/snowpark-connect +0 -0
  35. {snowpark_connect-0.29.0.data → snowpark_connect-0.30.1.data}/scripts/snowpark-session +0 -0
  36. {snowpark_connect-0.29.0.data → snowpark_connect-0.30.1.data}/scripts/snowpark-submit +0 -0
  37. {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/WHEEL +0 -0
  38. {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/licenses/LICENSE-binary +0 -0
  39. {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/licenses/LICENSE.txt +0 -0
  40. {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/licenses/NOTICE-binary +0 -0
  41. {snowpark_connect-0.29.0.dist-info → snowpark_connect-0.30.1.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e
16
16
  from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_relations__pb2
17
17
 
18
18
 
19
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1esnowflake_expression_ext.proto\x12\rsnowflake.ext\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\"\x98\x01\n\x0c\x45xpExtension\x12@\n\x0enamed_argument\x18\x01 \x01(\x0b\x32&.snowflake.ext.NamedArgumentExpressionH\x00\x12@\n\x13subquery_expression\x18\x02 \x01(\x0b\x32!.snowflake.ext.SubqueryExpressionH\x00\x42\x04\n\x02op\"P\n\x17NamedArgumentExpression\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.Expression\"\xf4\x04\n\x12SubqueryExpression\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x45\n\rsubquery_type\x18\x02 \x01(\x0e\x32..snowflake.ext.SubqueryExpression.SubqueryType\x12Q\n\x11table_arg_options\x18\x03 \x01(\x0b\x32\x31.snowflake.ext.SubqueryExpression.TableArgOptionsH\x00\x88\x01\x01\x12\x35\n\x12in_subquery_values\x18\x04 \x03(\x0b\x32\x19.spark.connect.Expression\x1a\xbb\x01\n\x0fTableArgOptions\x12\x31\n\x0epartition_spec\x18\x01 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x37\n\norder_spec\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrder\x12\"\n\x15with_single_partition\x18\x03 \x01(\x08H\x00\x88\x01\x01\x42\x18\n\x16_with_single_partition\"\x90\x01\n\x0cSubqueryType\x12\x19\n\x15SUBQUERY_TYPE_UNKNOWN\x10\x00\x12\x18\n\x14SUBQUERY_TYPE_SCALAR\x10\x01\x12\x18\n\x14SUBQUERY_TYPE_EXISTS\x10\x02\x12\x1b\n\x17SUBQUERY_TYPE_TABLE_ARG\x10\x03\x12\x14\n\x10SUBQUERY_TYPE_IN\x10\x04\x42\x14\n\x12_table_arg_optionsb\x06proto3')
19
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1esnowflake_expression_ext.proto\x12\rsnowflake.ext\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\"\xde\x01\n\x0c\x45xpExtension\x12@\n\x0enamed_argument\x18\x01 \x01(\x0b\x32&.snowflake.ext.NamedArgumentExpressionH\x00\x12@\n\x13subquery_expression\x18\x02 \x01(\x0b\x32!.snowflake.ext.SubqueryExpressionH\x00\x12\x44\n\x10interval_literal\x18\x03 \x01(\x0b\x32(.snowflake.ext.IntervalLiteralExpressionH\x00\x42\x04\n\x02op\"P\n\x17NamedArgumentExpression\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.Expression\"\xf4\x04\n\x12SubqueryExpression\x12&\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.Relation\x12\x45\n\rsubquery_type\x18\x02 \x01(\x0e\x32..snowflake.ext.SubqueryExpression.SubqueryType\x12Q\n\x11table_arg_options\x18\x03 \x01(\x0b\x32\x31.snowflake.ext.SubqueryExpression.TableArgOptionsH\x00\x88\x01\x01\x12\x35\n\x12in_subquery_values\x18\x04 \x03(\x0b\x32\x19.spark.connect.Expression\x1a\xbb\x01\n\x0fTableArgOptions\x12\x31\n\x0epartition_spec\x18\x01 \x03(\x0b\x32\x19.spark.connect.Expression\x12\x37\n\norder_spec\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrder\x12\"\n\x15with_single_partition\x18\x03 \x01(\x08H\x00\x88\x01\x01\x42\x18\n\x16_with_single_partition\"\x90\x01\n\x0cSubqueryType\x12\x19\n\x15SUBQUERY_TYPE_UNKNOWN\x10\x00\x12\x18\n\x14SUBQUERY_TYPE_SCALAR\x10\x01\x12\x18\n\x14SUBQUERY_TYPE_EXISTS\x10\x02\x12\x1b\n\x17SUBQUERY_TYPE_TABLE_ARG\x10\x03\x12\x14\n\x10SUBQUERY_TYPE_IN\x10\x04\x42\x14\n\x12_table_arg_options\"\x9f\x01\n\x19IntervalLiteralExpression\x12\x32\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.Literal\x12\x18\n\x0bstart_field\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12\x16\n\tend_field\x18\x03 \x01(\x05H\x01\x88\x01\x01\x42\x0e\n\x0c_start_fieldB\x0c\n\n_end_fieldb\x06proto3')
20
20
 
21
21
  _globals = globals()
22
22
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -24,13 +24,15 @@ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'snowflake_expression_ext_pb
24
24
  if _descriptor._USE_C_DESCRIPTORS == False:
25
25
  DESCRIPTOR._options = None
26
26
  _globals['_EXPEXTENSION']._serialized_start=114
27
- _globals['_EXPEXTENSION']._serialized_end=266
28
- _globals['_NAMEDARGUMENTEXPRESSION']._serialized_start=268
29
- _globals['_NAMEDARGUMENTEXPRESSION']._serialized_end=348
30
- _globals['_SUBQUERYEXPRESSION']._serialized_start=351
31
- _globals['_SUBQUERYEXPRESSION']._serialized_end=979
32
- _globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_start=623
33
- _globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_end=810
34
- _globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_start=813
35
- _globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_end=957
27
+ _globals['_EXPEXTENSION']._serialized_end=336
28
+ _globals['_NAMEDARGUMENTEXPRESSION']._serialized_start=338
29
+ _globals['_NAMEDARGUMENTEXPRESSION']._serialized_end=418
30
+ _globals['_SUBQUERYEXPRESSION']._serialized_start=421
31
+ _globals['_SUBQUERYEXPRESSION']._serialized_end=1049
32
+ _globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_start=693
33
+ _globals['_SUBQUERYEXPRESSION_TABLEARGOPTIONS']._serialized_end=880
34
+ _globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_start=883
35
+ _globals['_SUBQUERYEXPRESSION_SUBQUERYTYPE']._serialized_end=1027
36
+ _globals['_INTERVALLITERALEXPRESSION']._serialized_start=1052
37
+ _globals['_INTERVALLITERALEXPRESSION']._serialized_end=1211
36
38
  # @@protoc_insertion_point(module_scope)
@@ -9,12 +9,14 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map
9
9
  DESCRIPTOR: _descriptor.FileDescriptor
10
10
 
11
11
  class ExpExtension(_message.Message):
12
- __slots__ = ("named_argument", "subquery_expression")
12
+ __slots__ = ("named_argument", "subquery_expression", "interval_literal")
13
13
  NAMED_ARGUMENT_FIELD_NUMBER: _ClassVar[int]
14
14
  SUBQUERY_EXPRESSION_FIELD_NUMBER: _ClassVar[int]
15
+ INTERVAL_LITERAL_FIELD_NUMBER: _ClassVar[int]
15
16
  named_argument: NamedArgumentExpression
16
17
  subquery_expression: SubqueryExpression
17
- def __init__(self, named_argument: _Optional[_Union[NamedArgumentExpression, _Mapping]] = ..., subquery_expression: _Optional[_Union[SubqueryExpression, _Mapping]] = ...) -> None: ...
18
+ interval_literal: IntervalLiteralExpression
19
+ def __init__(self, named_argument: _Optional[_Union[NamedArgumentExpression, _Mapping]] = ..., subquery_expression: _Optional[_Union[SubqueryExpression, _Mapping]] = ..., interval_literal: _Optional[_Union[IntervalLiteralExpression, _Mapping]] = ...) -> None: ...
18
20
 
19
21
  class NamedArgumentExpression(_message.Message):
20
22
  __slots__ = ("key", "value")
@@ -56,3 +58,13 @@ class SubqueryExpression(_message.Message):
56
58
  table_arg_options: SubqueryExpression.TableArgOptions
57
59
  in_subquery_values: _containers.RepeatedCompositeFieldContainer[_expressions_pb2.Expression]
58
60
  def __init__(self, input: _Optional[_Union[_relations_pb2.Relation, _Mapping]] = ..., subquery_type: _Optional[_Union[SubqueryExpression.SubqueryType, str]] = ..., table_arg_options: _Optional[_Union[SubqueryExpression.TableArgOptions, _Mapping]] = ..., in_subquery_values: _Optional[_Iterable[_Union[_expressions_pb2.Expression, _Mapping]]] = ...) -> None: ...
61
+
62
+ class IntervalLiteralExpression(_message.Message):
63
+ __slots__ = ("literal", "start_field", "end_field")
64
+ LITERAL_FIELD_NUMBER: _ClassVar[int]
65
+ START_FIELD_FIELD_NUMBER: _ClassVar[int]
66
+ END_FIELD_FIELD_NUMBER: _ClassVar[int]
67
+ literal: _expressions_pb2.Expression.Literal
68
+ start_field: int
69
+ end_field: int
70
+ def __init__(self, literal: _Optional[_Union[_expressions_pb2.Expression.Literal, _Mapping]] = ..., start_field: _Optional[int] = ..., end_field: _Optional[int] = ...) -> None: ...
@@ -0,0 +1,4 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+
@@ -0,0 +1,4 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+
@@ -4,6 +4,8 @@
4
4
 
5
5
  from urllib.parse import urlparse
6
6
 
7
+ from pyspark.errors.exceptions.base import AnalysisException
8
+
7
9
  CLOUD_PREFIX_TO_CLOUD = {
8
10
  "abfss": "azure",
9
11
  "wasbs": "azure",
@@ -12,10 +14,28 @@ CLOUD_PREFIX_TO_CLOUD = {
12
14
  }
13
15
 
14
16
  SUPPORTED_COMPRESSION_PER_FORMAT = {
15
- "csv": {"AUTO", "GZIP", "BZ2", "BROTLI", "ZSTD", "DEFLATE", "RAW_DEFLATE", "NONE"},
16
- "json": {"AUTO", "GZIP", "BZ2", "BROTLI", "ZSTD", "DEFLATE", "RAW_DEFLATE", "NONE"},
17
- "parquet": {"AUTO", "LZO", "SNAPPY", "NONE"},
18
- "text": {"NONE"},
17
+ "csv": {
18
+ "GZIP",
19
+ "BZ2",
20
+ "BROTLI",
21
+ "ZSTD",
22
+ "DEFLATE",
23
+ "RAW_DEFLATE",
24
+ "NONE",
25
+ "UNCOMPRESSED",
26
+ },
27
+ "json": {
28
+ "GZIP",
29
+ "BZ2",
30
+ "BROTLI",
31
+ "ZSTD",
32
+ "DEFLATE",
33
+ "RAW_DEFLATE",
34
+ "NONE",
35
+ "UNCOMPRESSED",
36
+ },
37
+ "parquet": {"LZO", "SNAPPY", "NONE", "UNCOMPRESSED"},
38
+ "text": {"NONE", "UNCOMPRESSED"},
19
39
  }
20
40
 
21
41
 
@@ -29,6 +49,43 @@ def is_supported_compression(format: str, compression: str | None) -> bool:
29
49
  return compression in supported_compressions_for_format(format)
30
50
 
31
51
 
52
+ def get_compression_for_source_and_options(
53
+ source: str, options: dict[str, str], from_read: bool = False
54
+ ) -> str | None:
55
+ """
56
+ Determines the compression type to use for a given data source and options.
57
+ Args:
58
+ source (str): The data source format (e.g., "csv", "json", "parquet", "text").
59
+ options (dict[str, str]): A dictionary of options that may include a "compression" key.
60
+ Returns:
61
+ str: The compression type to use (e.g., "GZIP", "SNAPPY", "NONE").
62
+ Raises:
63
+ AnalysisException: If the specified compression is not supported for the given source format.
64
+ """
65
+ # From read, we don't have a default compression
66
+ if from_read and "compression" not in options:
67
+ return None
68
+
69
+ # Get compression from options for proper filename generation
70
+ default_compression = "NONE" if source != "parquet" else "snappy"
71
+ compression = options.get("compression", default_compression).upper()
72
+ if compression == "UNCOMPRESSED":
73
+ compression = "NONE"
74
+
75
+ if not is_supported_compression(source, compression):
76
+ supported_compressions = supported_compressions_for_format(source)
77
+ raise AnalysisException(
78
+ f"Compression {compression} is not supported for {source} format. "
79
+ + (
80
+ f"Supported compressions: {sorted(supported_compressions)}"
81
+ if supported_compressions
82
+ else "None compression supported for this format."
83
+ )
84
+ )
85
+
86
+ return compression
87
+
88
+
32
89
  def get_cloud_from_url(
33
90
  url: str,
34
91
  ):
@@ -1124,10 +1124,15 @@ def map_group_map(
1124
1124
  group_by_df = input_df.group_by(*snowpark_grouping_expressions)
1125
1125
  inner_df = group_by_df._dataframe
1126
1126
 
1127
- result = inner_df.select(
1128
- snowpark_fn.call_table_function(
1129
- apply_udtf_temp_name, *inner_df.columns
1130
- ).over(partition_by=snowpark_grouping_expressions)
1127
+ renamed_columns = [f"snowflake_jtf_{column}" for column in input_df.columns]
1128
+ tfc = snowpark_fn.call_table_function(
1129
+ apply_udtf_temp_name, *renamed_columns
1130
+ ).over(partition_by=snowpark_grouping_expressions)
1131
+
1132
+ result = (
1133
+ inner_df.to_df(renamed_columns)
1134
+ .join_table_function(tfc)
1135
+ .drop(*renamed_columns)
1131
1136
  )
1132
1137
  else:
1133
1138
  (
@@ -21,6 +21,9 @@ from snowflake.snowpark_connect.relation.map_relation import (
21
21
  NATURAL_JOIN_TYPE_BASE,
22
22
  map_relation,
23
23
  )
24
+ from snowflake.snowpark_connect.relation.read.metadata_utils import (
25
+ filter_metadata_columns,
26
+ )
24
27
  from snowflake.snowpark_connect.utils.context import (
25
28
  push_evaluating_join_condition,
26
29
  push_sql_scope,
@@ -38,6 +41,11 @@ def map_join(rel: relation_proto.Relation) -> DataFrameContainer:
38
41
  left_container: DataFrameContainer = map_relation(rel.join.left)
39
42
  right_container: DataFrameContainer = map_relation(rel.join.right)
40
43
 
44
+ # Remove any metadata columns(like metada$filename) present in the dataframes.
45
+ # We cannot support inputfilename for multisources as each dataframe has it's own source.
46
+ left_container = filter_metadata_columns(left_container)
47
+ right_container = filter_metadata_columns(right_container)
48
+
41
49
  left_input: snowpark.DataFrame = left_container.dataframe
42
50
  right_input: snowpark.DataFrame = right_container.dataframe
43
51
  is_natural_join = rel.join.join_type >= NATURAL_JOIN_TYPE_BASE
@@ -1,15 +1,14 @@
1
1
  #
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
-
5
-
6
4
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
7
5
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
8
6
  from pyspark.errors.exceptions.base import AnalysisException, IllegalArgumentException
9
7
 
10
8
  import snowflake.snowpark_connect.relation.utils as utils
11
9
  from snowflake import snowpark
12
- from snowflake.snowpark.functions import col, expr as snowpark_expr
10
+ from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
11
+ from snowflake.snowpark.functions import col, expr as snowpark_expr, lit
13
12
  from snowflake.snowpark.types import (
14
13
  BooleanType,
15
14
  ByteType,
@@ -20,8 +19,14 @@ from snowflake.snowpark.types import (
20
19
  LongType,
21
20
  NullType,
22
21
  ShortType,
22
+ StructField,
23
+ StructType,
24
+ )
25
+ from snowflake.snowpark_connect.column_name_handler import (
26
+ ColumnNameMap,
27
+ schema_getter,
28
+ set_schema_getter,
23
29
  )
24
- from snowflake.snowpark_connect.column_name_handler import ColumnNameMap, schema_getter
25
30
  from snowflake.snowpark_connect.config import global_config
26
31
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
27
32
  from snowflake.snowpark_connect.expression.literal import get_literal_field_and_name
@@ -30,6 +35,9 @@ from snowflake.snowpark_connect.expression.map_expression import (
30
35
  )
31
36
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
32
37
  from snowflake.snowpark_connect.relation.map_relation import map_relation
38
+ from snowflake.snowpark_connect.utils.identifiers import (
39
+ split_fully_qualified_spark_name,
40
+ )
33
41
  from snowflake.snowpark_connect.utils.telemetry import (
34
42
  SnowparkConnectNotImplementedError,
35
43
  )
@@ -126,11 +134,19 @@ def map_fillna(
126
134
  input_df = input_container.dataframe
127
135
 
128
136
  if len(rel.fill_na.cols) > 0:
137
+ if rel.fill_na.cols == ["*"]:
138
+ # Expand "*" to all columns
139
+ spark_col_names = input_container.column_map.get_spark_columns()
140
+ else:
141
+ spark_col_names = list(rel.fill_na.cols)
142
+
143
+ # We don't validate the fully qualified spark name here as fillNa is no-op for structured type colums.
144
+ # It only works for scalar type columns like float, int, string or bool.
129
145
  columns: list[str] = [
130
146
  input_container.column_map.get_snowpark_column_name_from_spark_column_name(
131
- c
147
+ split_fully_qualified_spark_name(c)[0]
132
148
  )
133
- for c in rel.fill_na.cols
149
+ for c in spark_col_names
134
150
  ]
135
151
  values = [get_literal_field_and_name(v)[0] for v in rel.fill_na.values]
136
152
  if len(values) == 1:
@@ -318,23 +334,37 @@ def map_union(
318
334
  right_column_map = right_result.column_map
319
335
  columns_to_restore: dict[str, tuple[str, str]] = {}
320
336
 
321
- for column in right_df.columns:
337
+ original_right_schema = right_df.schema
338
+ right_renamed_fields = []
339
+ for field in original_right_schema.fields:
322
340
  spark_name = (
323
- right_column_map.get_spark_column_name_from_snowpark_column_name(column)
341
+ right_column_map.get_spark_column_name_from_snowpark_column_name(
342
+ field.name
343
+ )
344
+ )
345
+ right_df = right_df.withColumnRenamed(field.name, spark_name)
346
+ columns_to_restore[spark_name.upper()] = (spark_name, field.name)
347
+ right_renamed_fields.append(
348
+ StructField(spark_name, field.datatype, field.nullable)
324
349
  )
325
- right_df = right_df.withColumnRenamed(column, spark_name)
326
- columns_to_restore[spark_name.upper()] = (spark_name, column)
350
+ set_schema_getter(right_df, lambda: StructType(right_renamed_fields))
327
351
 
328
- for column in left_df.columns:
352
+ original_left_schema = left_df.schema
353
+ left_renamed_fields = []
354
+ for field in original_left_schema.fields:
329
355
  spark_name = (
330
- left_column_map.get_spark_column_name_from_snowpark_column_name(column)
356
+ left_column_map.get_spark_column_name_from_snowpark_column_name(
357
+ field.name
358
+ )
359
+ )
360
+ left_df = left_df.withColumnRenamed(field.name, spark_name)
361
+ columns_to_restore[spark_name.upper()] = (spark_name, field.name)
362
+ left_renamed_fields.append(
363
+ StructField(spark_name, field.datatype, field.nullable)
331
364
  )
332
- left_df = left_df.withColumnRenamed(column, spark_name)
333
- columns_to_restore[spark_name.upper()] = (spark_name, column)
365
+ set_schema_getter(left_df, lambda: StructType(left_renamed_fields))
334
366
 
335
- result = left_df.unionAllByName(
336
- right_df, allow_missing_columns=allow_missing_columns
337
- )
367
+ result = _union_by_name_optimized(left_df, right_df, allow_missing_columns)
338
368
 
339
369
  if allow_missing_columns:
340
370
  spark_columns = []
@@ -809,3 +839,85 @@ def map_tail(
809
839
  alias=input_container.alias,
810
840
  cached_schema_getter=lambda: input_df.schema,
811
841
  )
842
+
843
+
844
+ def _union_by_name_optimized(
845
+ left_df: snowpark.DataFrame,
846
+ right_df: snowpark.DataFrame,
847
+ allow_missing_columns: bool = False,
848
+ ) -> snowpark.DataFrame:
849
+ """
850
+ This implementation is an optimized version of Snowpark's Dataframe::_union_by_name_internal.
851
+ The only change is, that it avoids redundant schema queries that occur in the standard Snowpark,
852
+ by reusing already-fetched/calculated schemas.
853
+ """
854
+
855
+ left_schema = left_df.schema
856
+ right_schema = right_df.schema
857
+
858
+ left_cols = {field.name for field in left_schema.fields}
859
+ right_cols = {field.name for field in right_schema.fields}
860
+ right_field_map = {field.name: field for field in right_schema.fields}
861
+
862
+ missing_left = right_cols - left_cols
863
+ missing_right = left_cols - right_cols
864
+
865
+ def add_nulls(
866
+ missing_cols: set[str], to_df: snowpark.DataFrame, from_df: snowpark.DataFrame
867
+ ) -> snowpark.DataFrame:
868
+ dt_map = {field.name: field.datatype for field in from_df.schema.fields}
869
+ result = to_df.select(
870
+ "*",
871
+ *[lit(None).cast(dt_map[col]).alias(col) for col in missing_cols],
872
+ )
873
+
874
+ result_fields = []
875
+ for field in to_df.schema.fields:
876
+ result_fields.append(
877
+ StructField(field.name, field.datatype, field.nullable)
878
+ )
879
+ for col_name in missing_cols:
880
+ from_field = next(
881
+ field for field in from_df.schema.fields if field.name == col_name
882
+ )
883
+ result_fields.append(
884
+ StructField(col_name, from_field.datatype, from_field.nullable)
885
+ )
886
+
887
+ set_schema_getter(result, lambda: StructType(result_fields))
888
+
889
+ return result
890
+
891
+ if missing_left or missing_right:
892
+ if allow_missing_columns:
893
+ left = left_df
894
+ right = right_df
895
+ if missing_left:
896
+ left = add_nulls(missing_left, left, right)
897
+ if missing_right:
898
+ right = add_nulls(missing_right, right, left)
899
+ result = left._union_by_name_internal(right, is_all=True)
900
+
901
+ result_fields = []
902
+ for field in left_schema.fields:
903
+ result_fields.append(
904
+ StructField(field.name, field.datatype, field.nullable)
905
+ )
906
+ for col_name in missing_left:
907
+ right_field = right_field_map[col_name]
908
+ result_fields.append(
909
+ StructField(col_name, right_field.datatype, right_field.nullable)
910
+ )
911
+
912
+ set_schema_getter(result, lambda: StructType(result_fields))
913
+ return result
914
+ else:
915
+ raise SnowparkClientExceptionMessages.DF_CANNOT_RESOLVE_COLUMN_NAME_AMONG(
916
+ missing_left, missing_right
917
+ )
918
+
919
+ result = left_df.unionAllByName(
920
+ right_df, allow_missing_columns=allow_missing_columns
921
+ )
922
+ set_schema_getter(result, lambda: left_df.schema)
923
+ return result
@@ -15,6 +15,9 @@ from snowflake.snowpark_connect.column_name_handler import set_schema_getter
15
15
  from snowflake.snowpark_connect.config import global_config
16
16
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
17
17
  from snowflake.snowpark_connect.relation.map_relation import map_relation
18
+ from snowflake.snowpark_connect.relation.read.metadata_utils import (
19
+ filter_metadata_columns,
20
+ )
18
21
 
19
22
 
20
23
  def map_show_string(rel: relation_proto.Relation) -> pandas.DataFrame:
@@ -26,14 +29,17 @@ def map_show_string(rel: relation_proto.Relation) -> pandas.DataFrame:
26
29
  Buffer object as a single cell.
27
30
  """
28
31
  input_df_container: DataFrameContainer = map_relation(rel.show_string.input)
29
- raw_input_df = input_df_container.dataframe
30
- input_df = _handle_datetype_columns(raw_input_df)
32
+ filtered_container = filter_metadata_columns(input_df_container)
33
+ display_df = filtered_container.dataframe
34
+ display_spark_columns = filtered_container.column_map.get_spark_columns()
35
+
36
+ input_df = _handle_datetype_columns(display_df)
31
37
 
32
38
  show_string = input_df._show_string_spark(
33
39
  num_rows=rel.show_string.num_rows,
34
40
  truncate=rel.show_string.truncate,
35
41
  vertical=rel.show_string.vertical,
36
- _spark_column_names=input_df_container.column_map.get_spark_columns(),
42
+ _spark_column_names=display_spark_columns,
37
43
  _spark_session_tz=global_config.spark_sql_session_timeZone,
38
44
  )
39
45
  return pandas.DataFrame({"show_string": [show_string]})
@@ -44,14 +50,16 @@ def map_repr_html(rel: relation_proto.Relation) -> pandas.DataFrame:
44
50
  Generate the html string representation of the input dataframe.
45
51
  """
46
52
  input_df_container: DataFrameContainer = map_relation(rel.html_string.input)
47
- input_df = input_df_container.dataframe
53
+
54
+ filtered_container = filter_metadata_columns(input_df_container)
55
+ input_df = filtered_container.dataframe
48
56
 
49
57
  input_panda = input_df.toPandas()
50
58
  input_panda.rename(
51
59
  columns={
52
60
  analyzer_utils.unquote_if_quoted(
53
- input_df_container.column_map.get_snowpark_columns()[i]
54
- ): input_df_container.column_map.get_spark_columns()[i]
61
+ filtered_container.column_map.get_snowpark_columns()[i]
62
+ ): filtered_container.column_map.get_spark_columns()[i]
55
63
  for i in range(len(input_panda.columns))
56
64
  },
57
65
  inplace=True,
@@ -7,6 +7,7 @@ from collections.abc import MutableMapping, MutableSequence
7
7
  from contextlib import contextmanager, suppress
8
8
  from contextvars import ContextVar
9
9
  from functools import reduce
10
+ from typing import Tuple
10
11
 
11
12
  import jpype
12
13
  import pandas
@@ -31,6 +32,10 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import (
31
32
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
32
33
  from snowflake.snowpark._internal.utils import is_sql_select_statement, quote_name
33
34
  from snowflake.snowpark.functions import when_matched, when_not_matched
35
+ from snowflake.snowpark_connect.client import (
36
+ SQL_PASS_THROUGH_MARKER,
37
+ calculate_checksum,
38
+ )
34
39
  from snowflake.snowpark_connect.config import (
35
40
  auto_uppercase_non_column_identifiers,
36
41
  check_table_supports_operation,
@@ -397,7 +402,7 @@ def map_sql_to_pandas_df(
397
402
  returns a tuple of None for SELECT queries to enable lazy evaluation
398
403
  """
399
404
 
400
- snowpark_connect_sql_passthrough = get_sql_passthrough()
405
+ snowpark_connect_sql_passthrough, sql_string = is_valid_passthrough_sql(sql_string)
401
406
 
402
407
  if not snowpark_connect_sql_passthrough:
403
408
  logical_plan = sql_parser().parsePlan(sql_string)
@@ -1047,7 +1052,7 @@ def map_sql_to_pandas_df(
1047
1052
  raise AnalysisException(
1048
1053
  f"ALTER TABLE RENAME COLUMN is not supported for table '{full_table_identifier}'. "
1049
1054
  f"This table was created as a v1 table with a data source that doesn't support column renaming. "
1050
- f"To enable this operation, set 'enable_snowflake_extension_behavior' to 'true'."
1055
+ f"To enable this operation, set 'snowpark.connect.enable_snowflake_extension_behavior' to 'true'."
1051
1056
  )
1052
1057
 
1053
1058
  column_obj = logical_plan.column()
@@ -1282,6 +1287,14 @@ def map_sql_to_pandas_df(
1282
1287
  return pandas.DataFrame({"": [""]}), ""
1283
1288
 
1284
1289
  rows = session.sql(snowflake_sql).collect()
1290
+ case "RefreshTable":
1291
+ table_name_unquoted = ".".join(
1292
+ str(part)
1293
+ for part in as_java_list(logical_plan.child().multipartIdentifier())
1294
+ )
1295
+ SNOWFLAKE_CATALOG.refreshTable(table_name_unquoted)
1296
+
1297
+ return pandas.DataFrame({"": [""]}), ""
1285
1298
  case _:
1286
1299
  execute_logical_plan(logical_plan)
1287
1300
  return None, None
@@ -1302,6 +1315,27 @@ def get_sql_passthrough() -> bool:
1302
1315
  return get_boolean_session_config_param("snowpark.connect.sql.passthrough")
1303
1316
 
1304
1317
 
1318
+ def is_valid_passthrough_sql(sql_stmt: str) -> Tuple[bool, str]:
1319
+ """
1320
+ Checks if :param sql_stmt: should be executed as SQL pass-through. SQL pass-through can be detected in 1 of 2 ways:
1321
+ 1) Either Spark config parameter "snowpark.connect.sql.passthrough" is set (legacy mode, to be deprecated)
1322
+ 2) If :param sql_stmt: is created through SnowflakeSession and has correct marker + checksum
1323
+ """
1324
+ if get_sql_passthrough():
1325
+ # legacy style pass-through, sql_stmt should be a whole, valid SQL statement
1326
+ return True, sql_stmt
1327
+
1328
+ # check for new style, SnowflakeSession based SQL pass-through
1329
+ sql_parts = sql_stmt.split(" ", 2)
1330
+ if len(sql_parts) == 3:
1331
+ marker, checksum, sql = sql_parts
1332
+ if marker == SQL_PASS_THROUGH_MARKER and checksum == calculate_checksum(sql):
1333
+ return True, sql
1334
+
1335
+ # Not a SQL pass-through
1336
+ return False, sql_stmt
1337
+
1338
+
1305
1339
  def change_default_to_public(name: str) -> str:
1306
1340
  """
1307
1341
  Change the namespace to PUBLIC when given name is DEFAULT
@@ -1397,10 +1431,10 @@ def map_sql(
1397
1431
  In passthough mode as True, SAS calls session.sql() and not calling Spark Parser.
1398
1432
  This is to mitigate any issue not covered by spark logical plan to protobuf conversion.
1399
1433
  """
1400
- snowpark_connect_sql_passthrough = get_sql_passthrough()
1434
+ snowpark_connect_sql_passthrough, sql_stmt = is_valid_passthrough_sql(rel.sql.query)
1401
1435
 
1402
1436
  if not snowpark_connect_sql_passthrough:
1403
- logical_plan = sql_parser().parseQuery(rel.sql.query)
1437
+ logical_plan = sql_parser().parseQuery(sql_stmt)
1404
1438
 
1405
1439
  parsed_pos_args = parse_pos_args(logical_plan, rel.sql.pos_args)
1406
1440
  set_sql_args(rel.sql.args, parsed_pos_args)
@@ -1408,7 +1442,7 @@ def map_sql(
1408
1442
  return execute_logical_plan(logical_plan)
1409
1443
  else:
1410
1444
  session = snowpark.Session.get_active_session()
1411
- sql_df = session.sql(rel.sql.query)
1445
+ sql_df = session.sql(sql_stmt)
1412
1446
  columns = sql_df.columns
1413
1447
  return DataFrameContainer.create_with_column_mapping(
1414
1448
  dataframe=sql_df,
@@ -81,7 +81,7 @@ def map_approx_quantile(
81
81
  input_df = input_container.dataframe
82
82
 
83
83
  snowflake_compatible = get_boolean_session_config_param(
84
- "enable_snowflake_extension_behavior"
84
+ "snowpark.connect.enable_snowflake_extension_behavior"
85
85
  )
86
86
 
87
87
  if not snowflake_compatible:
@@ -309,9 +309,28 @@ def map_freq_items(rel: relation_proto.Relation) -> DataFrameContainer:
309
309
  cols = input_container.column_map.get_snowpark_column_names_from_spark_column_names(
310
310
  list(rel.freq_items.cols)
311
311
  )
312
+
313
+ # handle empty DataFrame case
314
+ row_count = input_df.count()
315
+
316
+ for sp_col_name in cols:
317
+ spark_col_names.append(
318
+ f"{input_container.column_map.get_spark_column_name_from_snowpark_column_name(sp_col_name)}_freqItems"
319
+ )
320
+
321
+ if row_count == 0:
322
+ # If DataFrame is empty, return empty arrays for each column
323
+ empty_values = [[] for _ in cols]
324
+ approx_top_k_df = session.createDataFrame([empty_values], spark_col_names)
325
+ return DataFrameContainer.create_with_column_mapping(
326
+ dataframe=approx_top_k_df,
327
+ spark_column_names=spark_col_names,
328
+ snowpark_column_names=spark_col_names,
329
+ )
330
+
312
331
  approx_top_k_df = input_df.select(
313
332
  *[
314
- fn.function("approx_top_k")(fn.col(col), round(input_df.count() / support))
333
+ fn.function("approx_top_k")(fn.col(col), round(row_count / support))
315
334
  for col in cols
316
335
  ]
317
336
  )
@@ -330,10 +349,6 @@ def map_freq_items(rel: relation_proto.Relation) -> DataFrameContainer:
330
349
  for value in approx_top_k_values
331
350
  ]
332
351
 
333
- for sp_col_name in cols:
334
- spark_col_names.append(
335
- f"{input_container.column_map.get_spark_column_name_from_snowpark_column_name(sp_col_name)}_freqItems"
336
- )
337
352
  approx_top_k_df = session.createDataFrame([filtered_values], spark_col_names)
338
353
 
339
354
  return DataFrameContainer.create_with_column_mapping(
@@ -17,6 +17,7 @@ from snowflake.snowpark_connect.config import global_config
17
17
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
18
18
  from snowflake.snowpark_connect.relation.io_utils import (
19
19
  convert_file_prefix_path,
20
+ get_compression_for_source_and_options,
20
21
  is_cloud_path,
21
22
  )
22
23
  from snowflake.snowpark_connect.relation.read.map_read_table import map_read_table
@@ -237,6 +238,14 @@ def _read_file(
237
238
  )
238
239
  upload_files_if_needed(paths, clean_source_paths, session, read_format)
239
240
  paths = [_quote_stage_path(path) for path in paths]
241
+
242
+ if read_format in ("csv", "text", "json", "parquet"):
243
+ compression = get_compression_for_source_and_options(
244
+ read_format, options, from_read=True
245
+ )
246
+ if compression is not None:
247
+ options["compression"] = compression
248
+
240
249
  match read_format:
241
250
  case "csv":
242
251
  from snowflake.snowpark_connect.relation.read.map_read_csv import (