snowpark-connect 0.27.0__py3-none-any.whl → 0.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/column_name_handler.py +3 -93
- snowflake/snowpark_connect/config.py +99 -1
- snowflake/snowpark_connect/dataframe_container.py +0 -6
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +31 -68
- snowflake/snowpark_connect/expression/map_expression.py +22 -7
- snowflake/snowpark_connect/expression/map_sql_expression.py +22 -18
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +4 -26
- snowflake/snowpark_connect/expression/map_unresolved_function.py +12 -3
- snowflake/snowpark_connect/expression/map_unresolved_star.py +2 -3
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +207 -20
- snowflake/snowpark_connect/relation/map_extension.py +14 -10
- snowflake/snowpark_connect/relation/map_join.py +62 -258
- snowflake/snowpark_connect/relation/map_relation.py +5 -1
- snowflake/snowpark_connect/relation/map_sql.py +464 -68
- snowflake/snowpark_connect/relation/read/map_read_table.py +58 -0
- snowflake/snowpark_connect/relation/write/map_write.py +228 -120
- snowflake/snowpark_connect/resources_initializer.py +20 -5
- snowflake/snowpark_connect/server.py +16 -17
- snowflake/snowpark_connect/utils/concurrent.py +4 -0
- snowflake/snowpark_connect/utils/context.py +21 -0
- snowflake/snowpark_connect/utils/describe_query_cache.py +57 -51
- snowflake/snowpark_connect/utils/identifiers.py +128 -2
- snowflake/snowpark_connect/utils/io_utils.py +21 -1
- snowflake/snowpark_connect/utils/scala_udf_utils.py +34 -43
- snowflake/snowpark_connect/utils/session.py +16 -26
- snowflake/snowpark_connect/utils/telemetry.py +53 -0
- snowflake/snowpark_connect/utils/temporary_view_cache.py +61 -0
- snowflake/snowpark_connect/utils/udf_utils.py +9 -8
- snowflake/snowpark_connect/utils/udtf_utils.py +3 -2
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/METADATA +2 -2
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/RECORD +41 -41
- snowflake/snowpark_connect/hidden_column.py +0 -39
- {snowpark_connect-0.27.0.data → snowpark_connect-0.28.1.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-0.28.1.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-0.28.1.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/top_level.txt +0 -0
|
@@ -163,7 +163,6 @@ def map_unresolved_attribute(
|
|
|
163
163
|
attr_name = ".".join(name_parts)
|
|
164
164
|
|
|
165
165
|
has_plan_id = exp.unresolved_attribute.HasField("plan_id")
|
|
166
|
-
source_qualifiers = None
|
|
167
166
|
|
|
168
167
|
if has_plan_id:
|
|
169
168
|
plan_id = exp.unresolved_attribute.plan_id
|
|
@@ -172,27 +171,13 @@ def map_unresolved_attribute(
|
|
|
172
171
|
assert (
|
|
173
172
|
target_df is not None
|
|
174
173
|
), f"resolving an attribute of a unresolved dataframe {plan_id}"
|
|
175
|
-
|
|
176
|
-
# Get the qualifiers for this column from the target DataFrame
|
|
177
|
-
source_qualifiers = (
|
|
178
|
-
target_df_container.column_map.get_qualifier_for_spark_column(
|
|
179
|
-
name_parts[-1]
|
|
180
|
-
)
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
if hasattr(column_mapping, "hidden_columns"):
|
|
184
|
-
hidden = column_mapping.hidden_columns
|
|
185
|
-
else:
|
|
186
|
-
hidden = None
|
|
187
|
-
|
|
188
174
|
column_mapping = target_df_container.column_map
|
|
189
|
-
column_mapping.hidden_columns = hidden
|
|
190
175
|
typer = ExpressionTyper(target_df)
|
|
191
176
|
|
|
192
|
-
def get_col(snowpark_name
|
|
177
|
+
def get_col(snowpark_name):
|
|
193
178
|
return (
|
|
194
179
|
snowpark_fn.col(snowpark_name)
|
|
195
|
-
if not has_plan_id
|
|
180
|
+
if not has_plan_id
|
|
196
181
|
else target_df.col(snowpark_name)
|
|
197
182
|
)
|
|
198
183
|
|
|
@@ -277,17 +262,10 @@ def map_unresolved_attribute(
|
|
|
277
262
|
quoted_attr_name = name_parts[0]
|
|
278
263
|
|
|
279
264
|
snowpark_name = column_mapping.get_snowpark_column_name_from_spark_column_name(
|
|
280
|
-
quoted_attr_name,
|
|
281
|
-
allow_non_exists=True,
|
|
282
|
-
is_qualified=has_plan_id,
|
|
283
|
-
source_qualifiers=source_qualifiers if has_plan_id else None,
|
|
265
|
+
quoted_attr_name, allow_non_exists=True
|
|
284
266
|
)
|
|
285
|
-
|
|
286
267
|
if snowpark_name is not None:
|
|
287
|
-
|
|
288
|
-
quoted_attr_name, source_qualifiers
|
|
289
|
-
)
|
|
290
|
-
col = get_col(snowpark_name, is_hidden)
|
|
268
|
+
col = get_col(snowpark_name)
|
|
291
269
|
qualifiers = column_mapping.get_qualifier_for_spark_column(quoted_attr_name)
|
|
292
270
|
else:
|
|
293
271
|
# this means it has to be a struct column with a field name
|
|
@@ -2619,9 +2619,18 @@ def map_unresolved_function(
|
|
|
2619
2619
|
result_type = input_type.element_type
|
|
2620
2620
|
result_exp = fn(snowpark_args[0])
|
|
2621
2621
|
case _:
|
|
2622
|
-
|
|
2623
|
-
|
|
2624
|
-
|
|
2622
|
+
# Check if the type has map-like attributes before accessing them
|
|
2623
|
+
if hasattr(input_type, "key_type") and hasattr(
|
|
2624
|
+
input_type, "value_type"
|
|
2625
|
+
):
|
|
2626
|
+
spark_col_names = ["key", "value"]
|
|
2627
|
+
result_exp = fn(snowpark_args[0])
|
|
2628
|
+
result_type = [input_type.key_type, input_type.value_type]
|
|
2629
|
+
else:
|
|
2630
|
+
# Throw proper error for types without key_type/value_type attributes
|
|
2631
|
+
raise AnalysisException(
|
|
2632
|
+
f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{function_name}({snowpark_arg_names[0]})" due to data type mismatch: Parameter 1 requires the ("ARRAY" or "MAP") type, however "{snowpark_arg_names[0]}" has the type "{str(input_type)}".'
|
|
2633
|
+
)
|
|
2625
2634
|
case "expm1":
|
|
2626
2635
|
spark_function_name = f"EXPM1({snowpark_arg_names[0]})"
|
|
2627
2636
|
result_exp = snowpark_fn.exp(*snowpark_args) - 1
|
|
@@ -34,7 +34,6 @@ def map_unresolved_star(
|
|
|
34
34
|
column_mapping: ColumnNameMap,
|
|
35
35
|
typer: ExpressionTyper,
|
|
36
36
|
) -> tuple[list[str], TypedColumn]:
|
|
37
|
-
|
|
38
37
|
if exp.unresolved_star.HasField("unparsed_target"):
|
|
39
38
|
unparsed_target = exp.unresolved_star.unparsed_target
|
|
40
39
|
name_parts = split_fully_qualified_spark_name(unparsed_target)
|
|
@@ -103,7 +102,7 @@ def map_unresolved_star(
|
|
|
103
102
|
prefix_candidate_str = f"{prefix_candidate_str}.{name_parts[i]}"
|
|
104
103
|
prefix_candidate = (
|
|
105
104
|
column_mapping.get_snowpark_column_name_from_spark_column_name(
|
|
106
|
-
prefix_candidate_str, allow_non_exists=True
|
|
105
|
+
prefix_candidate_str, allow_non_exists=True
|
|
107
106
|
)
|
|
108
107
|
)
|
|
109
108
|
if prefix_candidate is None:
|
|
@@ -181,7 +180,7 @@ def map_unresolved_star_struct(
|
|
|
181
180
|
prefix_candidate_str = f"{prefix_candidate_str}.{name_parts[i]}"
|
|
182
181
|
prefix_candidate = (
|
|
183
182
|
column_mapping.get_snowpark_column_name_from_spark_column_name(
|
|
184
|
-
prefix_candidate_str, allow_non_exists=True
|
|
183
|
+
prefix_candidate_str, allow_non_exists=True
|
|
185
184
|
)
|
|
186
185
|
)
|
|
187
186
|
if prefix_candidate is None:
|
|
Binary file
|
|
@@ -13,7 +13,6 @@ from snowflake.core.exceptions import APIError, NotFoundError
|
|
|
13
13
|
from snowflake.core.schema import Schema
|
|
14
14
|
from snowflake.core.table import Table, TableColumn
|
|
15
15
|
|
|
16
|
-
from snowflake.snowpark import functions
|
|
17
16
|
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
18
17
|
quote_name_without_upper_casing,
|
|
19
18
|
unquote_if_quoted,
|
|
@@ -34,12 +33,19 @@ from snowflake.snowpark_connect.relation.catalogs.abstract_spark_catalog import
|
|
|
34
33
|
)
|
|
35
34
|
from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
|
|
36
35
|
from snowflake.snowpark_connect.utils.identifiers import (
|
|
36
|
+
FQN,
|
|
37
|
+
spark_to_sf_single_id_with_unquoting,
|
|
37
38
|
split_fully_qualified_spark_name,
|
|
38
39
|
)
|
|
39
40
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
40
41
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
41
42
|
SnowparkConnectNotImplementedError,
|
|
42
43
|
)
|
|
44
|
+
from snowflake.snowpark_connect.utils.temporary_view_cache import (
|
|
45
|
+
get_temp_view,
|
|
46
|
+
get_temp_view_normalized_names,
|
|
47
|
+
unregister_temp_view,
|
|
48
|
+
)
|
|
43
49
|
from snowflake.snowpark_connect.utils.udf_cache import cached_udf
|
|
44
50
|
|
|
45
51
|
|
|
@@ -203,6 +209,93 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
203
209
|
exists = False
|
|
204
210
|
return pandas.DataFrame({"exists": [exists]})
|
|
205
211
|
|
|
212
|
+
def _get_temp_view_prefixes(self, spark_dbName: str | None) -> list[str]:
|
|
213
|
+
if spark_dbName is None:
|
|
214
|
+
return []
|
|
215
|
+
return [
|
|
216
|
+
quote_name_without_upper_casing(part)
|
|
217
|
+
for part in split_fully_qualified_spark_name(spark_dbName)
|
|
218
|
+
]
|
|
219
|
+
|
|
220
|
+
def _list_temp_views(
|
|
221
|
+
self,
|
|
222
|
+
spark_dbName: str | None = None,
|
|
223
|
+
pattern: str | None = None,
|
|
224
|
+
) -> typing.Tuple[
|
|
225
|
+
list[str | None],
|
|
226
|
+
list[list[str | None]],
|
|
227
|
+
list[str],
|
|
228
|
+
list[str | None],
|
|
229
|
+
list[str | None],
|
|
230
|
+
list[bool],
|
|
231
|
+
]:
|
|
232
|
+
catalogs: list[str | None] = list()
|
|
233
|
+
namespaces: list[list[str | None]] = list()
|
|
234
|
+
names: list[str] = list()
|
|
235
|
+
descriptions: list[str | None] = list()
|
|
236
|
+
table_types: list[str | None] = list()
|
|
237
|
+
is_temporaries: list[bool] = list()
|
|
238
|
+
|
|
239
|
+
temp_views_prefix = ".".join(self._get_temp_view_prefixes(spark_dbName))
|
|
240
|
+
normalized_spark_dbName = (
|
|
241
|
+
temp_views_prefix.lower()
|
|
242
|
+
if global_config.spark_sql_caseSensitive
|
|
243
|
+
else temp_views_prefix
|
|
244
|
+
)
|
|
245
|
+
normalized_global_temp_database_name = (
|
|
246
|
+
quote_name_without_upper_casing(
|
|
247
|
+
global_config.spark_sql_globalTempDatabase.lower()
|
|
248
|
+
)
|
|
249
|
+
if global_config.spark_sql_caseSensitive
|
|
250
|
+
else quote_name_without_upper_casing(
|
|
251
|
+
global_config.spark_sql_globalTempDatabase
|
|
252
|
+
)
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
temp_views = get_temp_view_normalized_names()
|
|
256
|
+
null_safe_pattern = pattern if pattern is not None else ""
|
|
257
|
+
|
|
258
|
+
for temp_view in temp_views:
|
|
259
|
+
normalized_temp_view = (
|
|
260
|
+
temp_view.lower()
|
|
261
|
+
if global_config.spark_sql_caseSensitive
|
|
262
|
+
else temp_view
|
|
263
|
+
)
|
|
264
|
+
fqn = FQN.from_string(temp_view)
|
|
265
|
+
normalized_schema = (
|
|
266
|
+
fqn.schema.lower()
|
|
267
|
+
if fqn.schema is not None and global_config.spark_sql_caseSensitive
|
|
268
|
+
else fqn.schema
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
is_global_view = normalized_global_temp_database_name == normalized_schema
|
|
272
|
+
is_local_temp_view = fqn.schema is None
|
|
273
|
+
# Temporary views are always shown if they match the pattern
|
|
274
|
+
matches_prefix = (
|
|
275
|
+
normalized_spark_dbName == normalized_schema or is_local_temp_view
|
|
276
|
+
)
|
|
277
|
+
if matches_prefix and bool(
|
|
278
|
+
re.match(null_safe_pattern, normalized_temp_view)
|
|
279
|
+
):
|
|
280
|
+
names.append(unquote_if_quoted(fqn.name))
|
|
281
|
+
catalogs.append(None)
|
|
282
|
+
namespaces.append(
|
|
283
|
+
[global_config.spark_sql_globalTempDatabase]
|
|
284
|
+
if is_global_view
|
|
285
|
+
else []
|
|
286
|
+
)
|
|
287
|
+
descriptions.append(None)
|
|
288
|
+
table_types.append("TEMPORARY")
|
|
289
|
+
is_temporaries.append(True)
|
|
290
|
+
return (
|
|
291
|
+
catalogs,
|
|
292
|
+
namespaces,
|
|
293
|
+
names,
|
|
294
|
+
descriptions,
|
|
295
|
+
table_types,
|
|
296
|
+
is_temporaries,
|
|
297
|
+
)
|
|
298
|
+
|
|
206
299
|
def listTables(
|
|
207
300
|
self,
|
|
208
301
|
spark_dbName: str | None = None,
|
|
@@ -232,8 +325,7 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
232
325
|
schema=sf_quote(sf_schema),
|
|
233
326
|
pattern=_normalize_identifier(pattern),
|
|
234
327
|
)
|
|
235
|
-
|
|
236
|
-
catalogs: list[str] = list()
|
|
328
|
+
catalogs: list[str | None] = list()
|
|
237
329
|
namespaces: list[list[str | None]] = list()
|
|
238
330
|
names: list[str] = list()
|
|
239
331
|
descriptions: list[str | None] = list()
|
|
@@ -253,6 +345,22 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
253
345
|
descriptions.append(o[6] if o[6] else None)
|
|
254
346
|
table_types.append("PERMANENT")
|
|
255
347
|
is_temporaries.append(False)
|
|
348
|
+
|
|
349
|
+
(
|
|
350
|
+
non_materialized_catalogs,
|
|
351
|
+
non_materialized_namespaces,
|
|
352
|
+
non_materialized_names,
|
|
353
|
+
non_materialized_descriptions,
|
|
354
|
+
non_materialized_table_types,
|
|
355
|
+
non_materialized_is_temporaries,
|
|
356
|
+
) = self._list_temp_views(spark_dbName, pattern)
|
|
357
|
+
catalogs.extend(non_materialized_catalogs)
|
|
358
|
+
namespaces.extend(non_materialized_namespaces)
|
|
359
|
+
names.extend(non_materialized_names)
|
|
360
|
+
descriptions.extend(non_materialized_descriptions)
|
|
361
|
+
table_types.extend(non_materialized_table_types)
|
|
362
|
+
is_temporaries.extend(non_materialized_is_temporaries)
|
|
363
|
+
|
|
256
364
|
return pandas.DataFrame(
|
|
257
365
|
{
|
|
258
366
|
"name": names,
|
|
@@ -297,6 +405,36 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
297
405
|
spark_tableName: str,
|
|
298
406
|
) -> pandas.DataFrame:
|
|
299
407
|
"""Listing a single table/view with provided name that's accessible in Snowflake."""
|
|
408
|
+
|
|
409
|
+
def _get_temp_view():
|
|
410
|
+
spark_table_name_parts = [
|
|
411
|
+
quote_name_without_upper_casing(part)
|
|
412
|
+
for part in split_fully_qualified_spark_name(spark_tableName)
|
|
413
|
+
]
|
|
414
|
+
spark_view_name = ".".join(spark_table_name_parts)
|
|
415
|
+
temp_view = get_temp_view(spark_view_name)
|
|
416
|
+
if temp_view:
|
|
417
|
+
return pandas.DataFrame(
|
|
418
|
+
{
|
|
419
|
+
"name": [unquote_if_quoted(spark_table_name_parts[-1])],
|
|
420
|
+
"catalog": [None],
|
|
421
|
+
"namespace": [
|
|
422
|
+
[unquote_if_quoted(spark_table_name_parts[-2])]
|
|
423
|
+
if len(spark_table_name_parts) > 1
|
|
424
|
+
else []
|
|
425
|
+
],
|
|
426
|
+
"description": [None],
|
|
427
|
+
"tableType": ["TEMPORARY"],
|
|
428
|
+
"isTemporary": [True],
|
|
429
|
+
}
|
|
430
|
+
)
|
|
431
|
+
return None
|
|
432
|
+
|
|
433
|
+
# Attempt to get the view from the non materialized views first
|
|
434
|
+
temp_view = _get_temp_view()
|
|
435
|
+
if temp_view is not None:
|
|
436
|
+
return temp_view
|
|
437
|
+
|
|
300
438
|
sp_catalog = get_or_create_snowpark_session().catalog
|
|
301
439
|
catalog, sf_database, sf_schema, table_name = _process_multi_layer_identifier(
|
|
302
440
|
spark_tableName
|
|
@@ -360,12 +498,64 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
360
498
|
exists = False
|
|
361
499
|
return pandas.DataFrame({"exists": [exists]})
|
|
362
500
|
|
|
501
|
+
def _list_temp_view_columns(
|
|
502
|
+
self,
|
|
503
|
+
spark_tableName: str,
|
|
504
|
+
spark_dbName: typing.Optional[str] = None,
|
|
505
|
+
):
|
|
506
|
+
spark_view_name_parts = [
|
|
507
|
+
quote_name_without_upper_casing(part)
|
|
508
|
+
for part in split_fully_qualified_spark_name(spark_tableName)
|
|
509
|
+
]
|
|
510
|
+
spark_view_name_parts = (
|
|
511
|
+
self._get_temp_view_prefixes(spark_dbName) + spark_view_name_parts
|
|
512
|
+
)
|
|
513
|
+
spark_view_name = ".".join(spark_view_name_parts)
|
|
514
|
+
temp_view = get_temp_view(spark_view_name)
|
|
515
|
+
|
|
516
|
+
if not temp_view:
|
|
517
|
+
return None
|
|
518
|
+
|
|
519
|
+
names: list[str] = list()
|
|
520
|
+
descriptions: list[str | None] = list()
|
|
521
|
+
data_types: list[str] = list()
|
|
522
|
+
nullables: list[bool] = list()
|
|
523
|
+
is_partitions: list[bool] = list()
|
|
524
|
+
is_buckets: list[bool] = list()
|
|
525
|
+
|
|
526
|
+
for field, spark_column in zip(
|
|
527
|
+
temp_view.dataframe.schema.fields,
|
|
528
|
+
temp_view.column_map.get_spark_columns(),
|
|
529
|
+
):
|
|
530
|
+
names.append(spark_column)
|
|
531
|
+
descriptions.append(None)
|
|
532
|
+
data_types.append(field.datatype.simpleString())
|
|
533
|
+
nullables.append(field.nullable)
|
|
534
|
+
is_partitions.append(False)
|
|
535
|
+
is_buckets.append(False)
|
|
536
|
+
|
|
537
|
+
return pandas.DataFrame(
|
|
538
|
+
{
|
|
539
|
+
"name": names,
|
|
540
|
+
"description": descriptions,
|
|
541
|
+
"dataType": data_types,
|
|
542
|
+
"nullable": nullables,
|
|
543
|
+
"isPartition": is_partitions,
|
|
544
|
+
"isBucket": is_buckets,
|
|
545
|
+
}
|
|
546
|
+
)
|
|
547
|
+
|
|
363
548
|
def listColumns(
|
|
364
549
|
self,
|
|
365
550
|
spark_tableName: str,
|
|
366
551
|
spark_dbName: typing.Optional[str] = None,
|
|
367
552
|
) -> pandas.DataFrame:
|
|
368
553
|
"""List all columns in a table/view, optionally database name filter can be provided."""
|
|
554
|
+
|
|
555
|
+
temp_view_columns = self._list_temp_view_columns(spark_tableName, spark_dbName)
|
|
556
|
+
if temp_view_columns is not None:
|
|
557
|
+
return temp_view_columns
|
|
558
|
+
|
|
369
559
|
sp_catalog = get_or_create_snowpark_session().catalog
|
|
370
560
|
columns: list[TableColumn] | None = None
|
|
371
561
|
if spark_dbName is None:
|
|
@@ -455,17 +645,15 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
455
645
|
spark_view_name: str,
|
|
456
646
|
) -> DataFrameContainer:
|
|
457
647
|
session = get_or_create_snowpark_session()
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
)
|
|
463
|
-
result_df = result_df.select(
|
|
464
|
-
functions.contains('"status"', functions.lit("successfully dropped")).alias(
|
|
465
|
-
"value"
|
|
648
|
+
if not spark_view_name == "":
|
|
649
|
+
schema = global_config.spark_sql_globalTempDatabase
|
|
650
|
+
result = unregister_temp_view(
|
|
651
|
+
f"{spark_to_sf_single_id_with_unquoting(schema)}.{spark_to_sf_single_id_with_unquoting(spark_view_name)}"
|
|
466
652
|
)
|
|
467
|
-
|
|
653
|
+
else:
|
|
654
|
+
result = False
|
|
468
655
|
columns = ["value"]
|
|
656
|
+
result_df = session.createDataFrame([result], schema=columns)
|
|
469
657
|
return DataFrameContainer.create_with_column_mapping(
|
|
470
658
|
dataframe=result_df,
|
|
471
659
|
spark_column_names=columns,
|
|
@@ -479,15 +667,14 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
479
667
|
) -> DataFrameContainer:
|
|
480
668
|
"""Drop the current temporary view."""
|
|
481
669
|
session = get_or_create_snowpark_session()
|
|
482
|
-
result = session.sql(
|
|
483
|
-
"drop view if exists identifier(?)",
|
|
484
|
-
params=[sf_quote(spark_view_name)],
|
|
485
|
-
).collect()
|
|
486
|
-
view_was_dropped = (
|
|
487
|
-
len(result) == 1 and "successfully dropped" in result[0]["status"]
|
|
488
|
-
)
|
|
489
|
-
result_df = session.createDataFrame([(view_was_dropped,)], schema=["value"])
|
|
490
670
|
columns = ["value"]
|
|
671
|
+
if spark_view_name:
|
|
672
|
+
result = unregister_temp_view(
|
|
673
|
+
spark_to_sf_single_id_with_unquoting(spark_view_name)
|
|
674
|
+
)
|
|
675
|
+
else:
|
|
676
|
+
result = False
|
|
677
|
+
result_df = session.createDataFrame([result], schema=columns)
|
|
491
678
|
return DataFrameContainer.create_with_column_mapping(
|
|
492
679
|
dataframe=result_df,
|
|
493
680
|
spark_column_names=columns,
|
|
@@ -23,7 +23,6 @@ from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
|
23
23
|
from snowflake.snowpark_connect.typed_column import TypedColumn
|
|
24
24
|
from snowflake.snowpark_connect.utils.context import (
|
|
25
25
|
get_sql_aggregate_function_count,
|
|
26
|
-
not_resolving_fun_args,
|
|
27
26
|
push_outer_dataframe,
|
|
28
27
|
set_current_grouping_columns,
|
|
29
28
|
)
|
|
@@ -336,15 +335,14 @@ def map_aggregate(
|
|
|
336
335
|
typer = ExpressionTyper(input_df)
|
|
337
336
|
|
|
338
337
|
def _map_column(exp: expression_proto.Expression) -> tuple[str, TypedColumn]:
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
338
|
+
new_names, snowpark_column = map_expression(
|
|
339
|
+
exp, input_container.column_map, typer
|
|
340
|
+
)
|
|
341
|
+
if len(new_names) != 1:
|
|
342
|
+
raise SnowparkConnectNotImplementedError(
|
|
343
|
+
"Multi-column aggregate expressions are not supported"
|
|
342
344
|
)
|
|
343
|
-
|
|
344
|
-
raise SnowparkConnectNotImplementedError(
|
|
345
|
-
"Multi-column aggregate expressions are not supported"
|
|
346
|
-
)
|
|
347
|
-
return new_names[0], snowpark_column
|
|
345
|
+
return new_names[0], snowpark_column
|
|
348
346
|
|
|
349
347
|
raw_groupings: list[tuple[str, TypedColumn]] = []
|
|
350
348
|
raw_aggregations: list[tuple[str, TypedColumn]] = []
|
|
@@ -431,12 +429,18 @@ def map_aggregate(
|
|
|
431
429
|
if groupings:
|
|
432
430
|
# Normal GROUP BY with explicit grouping columns
|
|
433
431
|
result = input_df.group_by(groupings)
|
|
434
|
-
|
|
432
|
+
elif not is_group_by_all:
|
|
435
433
|
# No explicit GROUP BY - this is an aggregate over the entire table
|
|
436
434
|
# Use a dummy constant that will be excluded from the final result
|
|
437
435
|
result = input_df.with_column(
|
|
438
436
|
"__dummy_group__", snowpark_fn.lit(1)
|
|
439
437
|
).group_by("__dummy_group__")
|
|
438
|
+
else:
|
|
439
|
+
# GROUP BY ALL with only one aggregate column
|
|
440
|
+
# Snowpark doesn't support GROUP BY ALL
|
|
441
|
+
# TODO: Change in future with Snowpark Supported arguments or API for GROUP BY ALL
|
|
442
|
+
result = input_df.group_by()
|
|
443
|
+
|
|
440
444
|
case snowflake_proto.Aggregate.GROUP_TYPE_ROLLUP:
|
|
441
445
|
result = input_df.rollup(groupings)
|
|
442
446
|
case snowflake_proto.Aggregate.GROUP_TYPE_CUBE:
|