snowpark-connect 0.23.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (25) hide show
  1. snowflake/snowpark_connect/expression/function_defaults.py +207 -0
  2. snowflake/snowpark_connect/expression/literal.py +5 -0
  3. snowflake/snowpark_connect/expression/map_expression.py +10 -1
  4. snowflake/snowpark_connect/expression/map_extension.py +12 -2
  5. snowflake/snowpark_connect/expression/map_unresolved_function.py +11 -12
  6. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/__init__.py +16 -0
  7. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/accessors.py +1281 -0
  8. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/functions.py +203 -0
  9. snowflake/snowpark_connect/includes/python/pyspark/pandas/spark/utils.py +202 -0
  10. snowflake/snowpark_connect/relation/map_column_ops.py +1 -32
  11. snowflake/snowpark_connect/relation/map_extension.py +7 -7
  12. snowflake/snowpark_connect/relation/map_row_ops.py +2 -29
  13. snowflake/snowpark_connect/relation/read/utils.py +6 -7
  14. snowflake/snowpark_connect/relation/utils.py +1 -170
  15. snowflake/snowpark_connect/version.py +1 -1
  16. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/METADATA +1 -1
  17. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/RECORD +25 -20
  18. {snowpark_connect-0.23.0.data → snowpark_connect-0.24.0.data}/scripts/snowpark-connect +0 -0
  19. {snowpark_connect-0.23.0.data → snowpark_connect-0.24.0.data}/scripts/snowpark-session +0 -0
  20. {snowpark_connect-0.23.0.data → snowpark_connect-0.24.0.data}/scripts/snowpark-submit +0 -0
  21. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/WHEEL +0 -0
  22. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE-binary +0 -0
  23. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/LICENSE.txt +0 -0
  24. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/licenses/NOTICE-binary +0 -0
  25. {snowpark_connect-0.23.0.dist-info → snowpark_connect-0.24.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,203 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ """
18
+ Additional Spark functions used in pandas-on-Spark.
19
+ """
20
+ from typing import Union
21
+
22
+ from pyspark import SparkContext
23
+ import pyspark.sql.functions as F
24
+ from pyspark.sql.column import Column
25
+
26
+ # For supporting Spark Connect
27
+ from pyspark.sql.utils import is_remote
28
+
29
+
30
+ def product(col: Column, dropna: bool) -> Column:
31
+ if is_remote():
32
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
33
+
34
+ return _invoke_function_over_columns( # type: ignore[return-value]
35
+ "pandas_product",
36
+ col, # type: ignore[arg-type]
37
+ lit(dropna),
38
+ )
39
+
40
+ else:
41
+ sc = SparkContext._active_spark_context
42
+ return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna))
43
+
44
+
45
+ def stddev(col: Column, ddof: int) -> Column:
46
+ if is_remote():
47
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
48
+
49
+ return _invoke_function_over_columns( # type: ignore[return-value]
50
+ "pandas_stddev",
51
+ col, # type: ignore[arg-type]
52
+ lit(ddof),
53
+ )
54
+
55
+ else:
56
+
57
+ sc = SparkContext._active_spark_context
58
+ return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))
59
+
60
+
61
+ def var(col: Column, ddof: int) -> Column:
62
+ if is_remote():
63
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
64
+
65
+ return _invoke_function_over_columns( # type: ignore[return-value]
66
+ "pandas_var",
67
+ col, # type: ignore[arg-type]
68
+ lit(ddof),
69
+ )
70
+
71
+ else:
72
+
73
+ sc = SparkContext._active_spark_context
74
+ return Column(sc._jvm.PythonSQLUtils.pandasVariance(col._jc, ddof))
75
+
76
+
77
+ def skew(col: Column) -> Column:
78
+ if is_remote():
79
+ from pyspark.sql.connect.functions import _invoke_function_over_columns
80
+
81
+ return _invoke_function_over_columns( # type: ignore[return-value]
82
+ "pandas_skew",
83
+ col, # type: ignore[arg-type]
84
+ )
85
+
86
+ else:
87
+
88
+ sc = SparkContext._active_spark_context
89
+ return Column(sc._jvm.PythonSQLUtils.pandasSkewness(col._jc))
90
+
91
+
92
+ def kurt(col: Column) -> Column:
93
+ if is_remote():
94
+ from pyspark.sql.connect.functions import _invoke_function_over_columns
95
+
96
+ return _invoke_function_over_columns( # type: ignore[return-value]
97
+ "pandas_kurt",
98
+ col, # type: ignore[arg-type]
99
+ )
100
+
101
+ else:
102
+
103
+ sc = SparkContext._active_spark_context
104
+ return Column(sc._jvm.PythonSQLUtils.pandasKurtosis(col._jc))
105
+
106
+
107
+ def mode(col: Column, dropna: bool) -> Column:
108
+ if is_remote():
109
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
110
+
111
+ return _invoke_function_over_columns( # type: ignore[return-value]
112
+ "pandas_mode",
113
+ col, # type: ignore[arg-type]
114
+ lit(dropna),
115
+ )
116
+
117
+ else:
118
+ sc = SparkContext._active_spark_context
119
+ return Column(sc._jvm.PythonSQLUtils.pandasMode(col._jc, dropna))
120
+
121
+
122
+ def covar(col1: Column, col2: Column, ddof: int) -> Column:
123
+ if is_remote():
124
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
125
+
126
+ return _invoke_function_over_columns( # type: ignore[return-value]
127
+ "pandas_covar",
128
+ col1, # type: ignore[arg-type]
129
+ col2, # type: ignore[arg-type]
130
+ lit(ddof),
131
+ )
132
+
133
+ else:
134
+ sc = SparkContext._active_spark_context
135
+ return Column(sc._jvm.PythonSQLUtils.pandasCovar(col1._jc, col2._jc, ddof))
136
+
137
+
138
+ def repeat(col: Column, n: Union[int, Column]) -> Column:
139
+ """
140
+ Repeats a string column n times, and returns it as a new string column.
141
+ """
142
+ _n = F.lit(n) if isinstance(n, int) else n
143
+ return F.call_udf("repeat", col, _n)
144
+
145
+
146
+ def ewm(col: Column, alpha: float, ignore_na: bool) -> Column:
147
+ if is_remote():
148
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
149
+
150
+ return _invoke_function_over_columns( # type: ignore[return-value]
151
+ "ewm",
152
+ col, # type: ignore[arg-type]
153
+ lit(alpha),
154
+ lit(ignore_na),
155
+ )
156
+
157
+ else:
158
+ sc = SparkContext._active_spark_context
159
+ return Column(sc._jvm.PythonSQLUtils.ewm(col._jc, alpha, ignore_na))
160
+
161
+
162
+ def last_non_null(col: Column) -> Column:
163
+ if is_remote():
164
+ from pyspark.sql.connect.functions import _invoke_function_over_columns
165
+
166
+ return _invoke_function_over_columns( # type: ignore[return-value]
167
+ "last_non_null",
168
+ col, # type: ignore[arg-type]
169
+ )
170
+
171
+ else:
172
+ sc = SparkContext._active_spark_context
173
+ return Column(sc._jvm.PythonSQLUtils.lastNonNull(col._jc))
174
+
175
+
176
+ def null_index(col: Column) -> Column:
177
+ if is_remote():
178
+ from pyspark.sql.connect.functions import _invoke_function_over_columns
179
+
180
+ return _invoke_function_over_columns( # type: ignore[return-value]
181
+ "null_index",
182
+ col, # type: ignore[arg-type]
183
+ )
184
+
185
+ else:
186
+ sc = SparkContext._active_spark_context
187
+ return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
188
+
189
+
190
+ def timestampdiff(unit: str, start: Column, end: Column) -> Column:
191
+ if is_remote():
192
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
193
+
194
+ return _invoke_function_over_columns( # type: ignore[return-value]
195
+ "timestampdiff",
196
+ lit(unit),
197
+ start, # type: ignore[arg-type]
198
+ end, # type: ignore[arg-type]
199
+ )
200
+
201
+ else:
202
+ sc = SparkContext._active_spark_context
203
+ return Column(sc._jvm.PythonSQLUtils.timestampDiff(unit, start._jc, end._jc))
@@ -0,0 +1,202 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ """
18
+ Helpers and utilities to deal with PySpark instances
19
+ """
20
+ from typing import overload
21
+
22
+ from pyspark.sql.types import DecimalType, StructType, MapType, ArrayType, StructField, DataType
23
+
24
+
25
+ @overload
26
+ def as_nullable_spark_type(dt: StructType) -> StructType:
27
+ ...
28
+
29
+
30
+ @overload
31
+ def as_nullable_spark_type(dt: ArrayType) -> ArrayType:
32
+ ...
33
+
34
+
35
+ @overload
36
+ def as_nullable_spark_type(dt: MapType) -> MapType:
37
+ ...
38
+
39
+
40
+ @overload
41
+ def as_nullable_spark_type(dt: DataType) -> DataType:
42
+ ...
43
+
44
+
45
+ def as_nullable_spark_type(dt: DataType) -> DataType:
46
+ """
47
+ Returns a nullable schema or data types.
48
+
49
+ Examples
50
+ --------
51
+ >>> from pyspark.sql.types import *
52
+ >>> as_nullable_spark_type(StructType([
53
+ ... StructField("A", IntegerType(), True),
54
+ ... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
55
+ StructType([StructField('A', IntegerType(), True), StructField('B', FloatType(), True)])
56
+
57
+ >>> as_nullable_spark_type(StructType([
58
+ ... StructField("A",
59
+ ... StructType([
60
+ ... StructField('a',
61
+ ... MapType(IntegerType(),
62
+ ... ArrayType(IntegerType(), False), False), False),
63
+ ... StructField('b', StringType(), True)])),
64
+ ... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
65
+ StructType([StructField('A',
66
+ StructType([StructField('a',
67
+ MapType(IntegerType(),
68
+ ArrayType(IntegerType(), True), True), True),
69
+ StructField('b', StringType(), True)]), True),
70
+ StructField('B', FloatType(), True)])
71
+ """
72
+ if isinstance(dt, StructType):
73
+ new_fields = []
74
+ for field in dt.fields:
75
+ new_fields.append(
76
+ StructField(
77
+ field.name,
78
+ as_nullable_spark_type(field.dataType),
79
+ nullable=True,
80
+ metadata=field.metadata,
81
+ )
82
+ )
83
+ return StructType(new_fields)
84
+ elif isinstance(dt, ArrayType):
85
+ return ArrayType(as_nullable_spark_type(dt.elementType), containsNull=True)
86
+ elif isinstance(dt, MapType):
87
+ return MapType(
88
+ as_nullable_spark_type(dt.keyType),
89
+ as_nullable_spark_type(dt.valueType),
90
+ valueContainsNull=True,
91
+ )
92
+ else:
93
+ return dt
94
+
95
+
96
+ @overload
97
+ def force_decimal_precision_scale(
98
+ dt: StructType, *, precision: int = ..., scale: int = ...
99
+ ) -> StructType:
100
+ ...
101
+
102
+
103
+ @overload
104
+ def force_decimal_precision_scale(
105
+ dt: ArrayType, *, precision: int = ..., scale: int = ...
106
+ ) -> ArrayType:
107
+ ...
108
+
109
+
110
+ @overload
111
+ def force_decimal_precision_scale(
112
+ dt: MapType, *, precision: int = ..., scale: int = ...
113
+ ) -> MapType:
114
+ ...
115
+
116
+
117
+ @overload
118
+ def force_decimal_precision_scale(
119
+ dt: DataType, *, precision: int = ..., scale: int = ...
120
+ ) -> DataType:
121
+ ...
122
+
123
+
124
+ def force_decimal_precision_scale(
125
+ dt: DataType, *, precision: int = 38, scale: int = 18
126
+ ) -> DataType:
127
+ """
128
+ Returns a data type with a fixed decimal type.
129
+
130
+ The precision and scale of the decimal type are fixed with the given values.
131
+
132
+ Examples
133
+ --------
134
+ >>> from pyspark.sql.types import *
135
+ >>> force_decimal_precision_scale(StructType([
136
+ ... StructField("A", DecimalType(10, 0), True),
137
+ ... StructField("B", DecimalType(14, 7), False)])) # doctest: +NORMALIZE_WHITESPACE
138
+ StructType([StructField('A', DecimalType(38,18), True),
139
+ StructField('B', DecimalType(38,18), False)])
140
+
141
+ >>> force_decimal_precision_scale(StructType([
142
+ ... StructField("A",
143
+ ... StructType([
144
+ ... StructField('a',
145
+ ... MapType(DecimalType(5, 0),
146
+ ... ArrayType(DecimalType(20, 0), False), False), False),
147
+ ... StructField('b', StringType(), True)])),
148
+ ... StructField("B", DecimalType(30, 15), False)]),
149
+ ... precision=30, scale=15) # doctest: +NORMALIZE_WHITESPACE
150
+ StructType([StructField('A',
151
+ StructType([StructField('a',
152
+ MapType(DecimalType(30,15),
153
+ ArrayType(DecimalType(30,15), False), False), False),
154
+ StructField('b', StringType(), True)]), True),
155
+ StructField('B', DecimalType(30,15), False)])
156
+ """
157
+ if isinstance(dt, StructType):
158
+ new_fields = []
159
+ for field in dt.fields:
160
+ new_fields.append(
161
+ StructField(
162
+ field.name,
163
+ force_decimal_precision_scale(field.dataType, precision=precision, scale=scale),
164
+ nullable=field.nullable,
165
+ metadata=field.metadata,
166
+ )
167
+ )
168
+ return StructType(new_fields)
169
+ elif isinstance(dt, ArrayType):
170
+ return ArrayType(
171
+ force_decimal_precision_scale(dt.elementType, precision=precision, scale=scale),
172
+ containsNull=dt.containsNull,
173
+ )
174
+ elif isinstance(dt, MapType):
175
+ return MapType(
176
+ force_decimal_precision_scale(dt.keyType, precision=precision, scale=scale),
177
+ force_decimal_precision_scale(dt.valueType, precision=precision, scale=scale),
178
+ valueContainsNull=dt.valueContainsNull,
179
+ )
180
+ elif isinstance(dt, DecimalType):
181
+ return DecimalType(precision=precision, scale=scale)
182
+ else:
183
+ return dt
184
+
185
+
186
+ def _test() -> None:
187
+ import doctest
188
+ import sys
189
+ import pyspark.pandas.spark.utils
190
+
191
+ globs = pyspark.pandas.spark.utils.__dict__.copy()
192
+ (failure_count, test_count) = doctest.testmod(
193
+ pyspark.pandas.spark.utils,
194
+ globs=globs,
195
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
196
+ )
197
+ if failure_count:
198
+ sys.exit(-1)
199
+
200
+
201
+ if __name__ == "__main__":
202
+ _test()
@@ -6,12 +6,10 @@ import ast
6
6
  import json
7
7
  import sys
8
8
  from collections import defaultdict
9
- from copy import copy
10
9
 
11
10
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
12
11
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
13
12
  import pyspark.sql.connect.proto.types_pb2 as types_proto
14
- from pyspark.errors import PySparkValueError
15
13
  from pyspark.errors.exceptions.base import AnalysisException
16
14
  from pyspark.serializers import CloudPickleSerializer
17
15
 
@@ -46,7 +44,6 @@ from snowflake.snowpark_connect.expression.typer import ExpressionTyper
46
44
  from snowflake.snowpark_connect.relation.map_relation import map_relation
47
45
  from snowflake.snowpark_connect.relation.utils import (
48
46
  TYPE_MAP_FOR_TO_SCHEMA,
49
- can_sort_be_flattened,
50
47
  snowpark_functions_col,
51
48
  )
52
49
  from snowflake.snowpark_connect.type_mapping import (
@@ -346,12 +343,6 @@ def map_sort(
346
343
 
347
344
  sort_order = sort.order
348
345
 
349
- if not sort_order:
350
- raise PySparkValueError(
351
- error_class="CANNOT_BE_EMPTY",
352
- message="At least one column must be specified.",
353
- )
354
-
355
346
  if len(sort_order) == 1:
356
347
  parsed_col_name = split_fully_qualified_spark_name(
357
348
  sort_order[0].child.unresolved_attribute.unparsed_identifier
@@ -433,29 +424,7 @@ def map_sort(
433
424
  if not order_specified:
434
425
  ascending = None
435
426
 
436
- select_statement = getattr(input_df, "_select_statement", None)
437
- sort_expressions = [c._expression for c in cols]
438
- if (
439
- can_sort_be_flattened(select_statement, *sort_expressions)
440
- and input_df._ops_after_agg is None
441
- ):
442
- # "flattened" order by that will allow using dropped columns
443
- new = copy(select_statement)
444
- new.from_ = select_statement.from_.to_subqueryable()
445
- new.pre_actions = new.from_.pre_actions
446
- new.post_actions = new.from_.post_actions
447
- new.order_by = sort_expressions + (select_statement.order_by or [])
448
- new.column_states = select_statement.column_states
449
- new._merge_projection_complexity_with_subquery = False
450
- new.df_ast_ids = (
451
- select_statement.df_ast_ids.copy()
452
- if select_statement.df_ast_ids is not None
453
- else None
454
- )
455
- new.attributes = select_statement.attributes
456
- result = input_df._with_plan(new)
457
- else:
458
- result = input_df.sort(cols, ascending=ascending)
427
+ result = input_df.sort(cols, ascending=ascending)
459
428
 
460
429
  return DataFrameContainer(
461
430
  result,
@@ -347,6 +347,13 @@ def map_aggregate(
347
347
  raw_groupings: list[tuple[str, TypedColumn]] = []
348
348
  raw_aggregations: list[tuple[str, TypedColumn]] = []
349
349
 
350
+ if not is_group_by_all:
351
+ raw_groupings = [_map_column(exp) for exp in aggregate.grouping_expressions]
352
+
353
+ # Set the current grouping columns in context for grouping_id() function
354
+ grouping_spark_columns = [spark_name for spark_name, _ in raw_groupings]
355
+ set_current_grouping_columns(grouping_spark_columns)
356
+
350
357
  agg_count = get_sql_aggregate_function_count()
351
358
  for exp in aggregate.aggregate_expressions:
352
359
  col = _map_column(exp)
@@ -359,13 +366,6 @@ def map_aggregate(
359
366
  else:
360
367
  agg_count = new_agg_count
361
368
 
362
- if not is_group_by_all:
363
- raw_groupings = [_map_column(exp) for exp in aggregate.grouping_expressions]
364
-
365
- # Set the current grouping columns in context for grouping_id() function
366
- grouping_spark_columns = [spark_name for spark_name, _ in raw_groupings]
367
- set_current_grouping_columns(grouping_spark_columns)
368
-
369
369
  # Now create column name lists and assign aliases.
370
370
  # In case of GROUP BY ALL, even though groupings are a subset of aggregations,
371
371
  # they will have their own aliases so we can drop them later.
@@ -1,7 +1,7 @@
1
1
  #
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
- from copy import copy
4
+
5
5
 
6
6
  import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
7
7
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
@@ -9,7 +9,6 @@ from pyspark.errors.exceptions.base import AnalysisException, IllegalArgumentExc
9
9
 
10
10
  import snowflake.snowpark_connect.relation.utils as utils
11
11
  from snowflake import snowpark
12
- from snowflake.snowpark._internal.analyzer.binary_expression import And
13
12
  from snowflake.snowpark.functions import col, expr as snowpark_expr
14
13
  from snowflake.snowpark.types import (
15
14
  BooleanType,
@@ -31,7 +30,6 @@ from snowflake.snowpark_connect.expression.map_expression import (
31
30
  )
32
31
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
33
32
  from snowflake.snowpark_connect.relation.map_relation import map_relation
34
- from snowflake.snowpark_connect.relation.utils import can_filter_be_flattened
35
33
  from snowflake.snowpark_connect.utils.telemetry import (
36
34
  SnowparkConnectNotImplementedError,
37
35
  )
@@ -555,32 +553,7 @@ def map_filter(
555
553
  rel.filter.condition, input_container.column_map, typer
556
554
  )
557
555
 
558
- select_statement = getattr(input_df, "_select_statement", None)
559
- condition_exp = condition.col._expression
560
- if (
561
- can_filter_be_flattened(select_statement, condition_exp)
562
- and input_df._ops_after_agg is None
563
- ):
564
- new = copy(select_statement)
565
- new.from_ = select_statement.from_.to_subqueryable()
566
- new.pre_actions = new.from_.pre_actions
567
- new.post_actions = new.from_.post_actions
568
- new.column_states = select_statement.column_states
569
- new.where = (
570
- And(select_statement.where, condition_exp)
571
- if select_statement.where is not None
572
- else condition_exp
573
- )
574
- new._merge_projection_complexity_with_subquery = False
575
- new.df_ast_ids = (
576
- select_statement.df_ast_ids.copy()
577
- if select_statement.df_ast_ids is not None
578
- else None
579
- )
580
- new.attributes = select_statement.attributes
581
- result = input_df._with_plan(new)
582
- else:
583
- result = input_df.filter(condition.col)
556
+ result = input_df.filter(condition.col)
584
557
 
585
558
  return DataFrameContainer(
586
559
  result,
@@ -73,13 +73,12 @@ def rename_columns_as_snowflake_standard(
73
73
  return df, []
74
74
 
75
75
  new_columns = make_column_names_snowpark_compatible(df.columns, plan_id)
76
- result = df.toDF(*new_columns)
77
- if result._select_statement is not None:
78
- # do not allow snowpark to flatten the to_df result
79
- # TODO: remove after SNOW-2203706 is fixed
80
- result._select_statement.flatten_disabled = True
81
-
82
- return (result, new_columns)
76
+ return (
77
+ df.select(
78
+ *(df.col(orig).alias(alias) for orig, alias in zip(df.columns, new_columns))
79
+ ),
80
+ new_columns,
81
+ )
83
82
 
84
83
 
85
84
  class Connection(Protocol):