snowpark-connect 0.28.0__py3-none-any.whl → 0.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +12 -3
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +31 -68
- snowflake/snowpark_connect/expression/map_unresolved_function.py +172 -210
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +207 -20
- snowflake/snowpark_connect/relation/io_utils.py +21 -1
- snowflake/snowpark_connect/relation/map_extension.py +21 -4
- snowflake/snowpark_connect/relation/map_map_partitions.py +7 -8
- snowflake/snowpark_connect/relation/map_relation.py +1 -3
- snowflake/snowpark_connect/relation/map_sql.py +112 -53
- snowflake/snowpark_connect/relation/read/map_read.py +22 -3
- snowflake/snowpark_connect/relation/read/map_read_csv.py +105 -26
- snowflake/snowpark_connect/relation/read/map_read_json.py +45 -34
- snowflake/snowpark_connect/relation/read/map_read_table.py +58 -0
- snowflake/snowpark_connect/relation/read/map_read_text.py +6 -1
- snowflake/snowpark_connect/relation/stage_locator.py +85 -53
- snowflake/snowpark_connect/relation/write/map_write.py +95 -14
- snowflake/snowpark_connect/server.py +18 -13
- snowflake/snowpark_connect/utils/context.py +21 -14
- snowflake/snowpark_connect/utils/identifiers.py +8 -2
- snowflake/snowpark_connect/utils/io_utils.py +36 -0
- snowflake/snowpark_connect/utils/session.py +3 -0
- snowflake/snowpark_connect/utils/temporary_view_cache.py +61 -0
- snowflake/snowpark_connect/utils/udf_cache.py +37 -7
- snowflake/snowpark_connect/utils/udf_utils.py +9 -8
- snowflake/snowpark_connect/utils/udtf_utils.py +3 -2
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/METADATA +3 -2
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/RECORD +36 -35
- {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/top_level.txt +0 -0
|
@@ -13,7 +13,6 @@ from snowflake.core.exceptions import APIError, NotFoundError
|
|
|
13
13
|
from snowflake.core.schema import Schema
|
|
14
14
|
from snowflake.core.table import Table, TableColumn
|
|
15
15
|
|
|
16
|
-
from snowflake.snowpark import functions
|
|
17
16
|
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
|
|
18
17
|
quote_name_without_upper_casing,
|
|
19
18
|
unquote_if_quoted,
|
|
@@ -34,12 +33,19 @@ from snowflake.snowpark_connect.relation.catalogs.abstract_spark_catalog import
|
|
|
34
33
|
)
|
|
35
34
|
from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
|
|
36
35
|
from snowflake.snowpark_connect.utils.identifiers import (
|
|
36
|
+
FQN,
|
|
37
|
+
spark_to_sf_single_id_with_unquoting,
|
|
37
38
|
split_fully_qualified_spark_name,
|
|
38
39
|
)
|
|
39
40
|
from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
|
|
40
41
|
from snowflake.snowpark_connect.utils.telemetry import (
|
|
41
42
|
SnowparkConnectNotImplementedError,
|
|
42
43
|
)
|
|
44
|
+
from snowflake.snowpark_connect.utils.temporary_view_cache import (
|
|
45
|
+
get_temp_view,
|
|
46
|
+
get_temp_view_normalized_names,
|
|
47
|
+
unregister_temp_view,
|
|
48
|
+
)
|
|
43
49
|
from snowflake.snowpark_connect.utils.udf_cache import cached_udf
|
|
44
50
|
|
|
45
51
|
|
|
@@ -203,6 +209,93 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
203
209
|
exists = False
|
|
204
210
|
return pandas.DataFrame({"exists": [exists]})
|
|
205
211
|
|
|
212
|
+
def _get_temp_view_prefixes(self, spark_dbName: str | None) -> list[str]:
|
|
213
|
+
if spark_dbName is None:
|
|
214
|
+
return []
|
|
215
|
+
return [
|
|
216
|
+
quote_name_without_upper_casing(part)
|
|
217
|
+
for part in split_fully_qualified_spark_name(spark_dbName)
|
|
218
|
+
]
|
|
219
|
+
|
|
220
|
+
def _list_temp_views(
|
|
221
|
+
self,
|
|
222
|
+
spark_dbName: str | None = None,
|
|
223
|
+
pattern: str | None = None,
|
|
224
|
+
) -> typing.Tuple[
|
|
225
|
+
list[str | None],
|
|
226
|
+
list[list[str | None]],
|
|
227
|
+
list[str],
|
|
228
|
+
list[str | None],
|
|
229
|
+
list[str | None],
|
|
230
|
+
list[bool],
|
|
231
|
+
]:
|
|
232
|
+
catalogs: list[str | None] = list()
|
|
233
|
+
namespaces: list[list[str | None]] = list()
|
|
234
|
+
names: list[str] = list()
|
|
235
|
+
descriptions: list[str | None] = list()
|
|
236
|
+
table_types: list[str | None] = list()
|
|
237
|
+
is_temporaries: list[bool] = list()
|
|
238
|
+
|
|
239
|
+
temp_views_prefix = ".".join(self._get_temp_view_prefixes(spark_dbName))
|
|
240
|
+
normalized_spark_dbName = (
|
|
241
|
+
temp_views_prefix.lower()
|
|
242
|
+
if global_config.spark_sql_caseSensitive
|
|
243
|
+
else temp_views_prefix
|
|
244
|
+
)
|
|
245
|
+
normalized_global_temp_database_name = (
|
|
246
|
+
quote_name_without_upper_casing(
|
|
247
|
+
global_config.spark_sql_globalTempDatabase.lower()
|
|
248
|
+
)
|
|
249
|
+
if global_config.spark_sql_caseSensitive
|
|
250
|
+
else quote_name_without_upper_casing(
|
|
251
|
+
global_config.spark_sql_globalTempDatabase
|
|
252
|
+
)
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
temp_views = get_temp_view_normalized_names()
|
|
256
|
+
null_safe_pattern = pattern if pattern is not None else ""
|
|
257
|
+
|
|
258
|
+
for temp_view in temp_views:
|
|
259
|
+
normalized_temp_view = (
|
|
260
|
+
temp_view.lower()
|
|
261
|
+
if global_config.spark_sql_caseSensitive
|
|
262
|
+
else temp_view
|
|
263
|
+
)
|
|
264
|
+
fqn = FQN.from_string(temp_view)
|
|
265
|
+
normalized_schema = (
|
|
266
|
+
fqn.schema.lower()
|
|
267
|
+
if fqn.schema is not None and global_config.spark_sql_caseSensitive
|
|
268
|
+
else fqn.schema
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
is_global_view = normalized_global_temp_database_name == normalized_schema
|
|
272
|
+
is_local_temp_view = fqn.schema is None
|
|
273
|
+
# Temporary views are always shown if they match the pattern
|
|
274
|
+
matches_prefix = (
|
|
275
|
+
normalized_spark_dbName == normalized_schema or is_local_temp_view
|
|
276
|
+
)
|
|
277
|
+
if matches_prefix and bool(
|
|
278
|
+
re.match(null_safe_pattern, normalized_temp_view)
|
|
279
|
+
):
|
|
280
|
+
names.append(unquote_if_quoted(fqn.name))
|
|
281
|
+
catalogs.append(None)
|
|
282
|
+
namespaces.append(
|
|
283
|
+
[global_config.spark_sql_globalTempDatabase]
|
|
284
|
+
if is_global_view
|
|
285
|
+
else []
|
|
286
|
+
)
|
|
287
|
+
descriptions.append(None)
|
|
288
|
+
table_types.append("TEMPORARY")
|
|
289
|
+
is_temporaries.append(True)
|
|
290
|
+
return (
|
|
291
|
+
catalogs,
|
|
292
|
+
namespaces,
|
|
293
|
+
names,
|
|
294
|
+
descriptions,
|
|
295
|
+
table_types,
|
|
296
|
+
is_temporaries,
|
|
297
|
+
)
|
|
298
|
+
|
|
206
299
|
def listTables(
|
|
207
300
|
self,
|
|
208
301
|
spark_dbName: str | None = None,
|
|
@@ -232,8 +325,7 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
232
325
|
schema=sf_quote(sf_schema),
|
|
233
326
|
pattern=_normalize_identifier(pattern),
|
|
234
327
|
)
|
|
235
|
-
|
|
236
|
-
catalogs: list[str] = list()
|
|
328
|
+
catalogs: list[str | None] = list()
|
|
237
329
|
namespaces: list[list[str | None]] = list()
|
|
238
330
|
names: list[str] = list()
|
|
239
331
|
descriptions: list[str | None] = list()
|
|
@@ -253,6 +345,22 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
253
345
|
descriptions.append(o[6] if o[6] else None)
|
|
254
346
|
table_types.append("PERMANENT")
|
|
255
347
|
is_temporaries.append(False)
|
|
348
|
+
|
|
349
|
+
(
|
|
350
|
+
non_materialized_catalogs,
|
|
351
|
+
non_materialized_namespaces,
|
|
352
|
+
non_materialized_names,
|
|
353
|
+
non_materialized_descriptions,
|
|
354
|
+
non_materialized_table_types,
|
|
355
|
+
non_materialized_is_temporaries,
|
|
356
|
+
) = self._list_temp_views(spark_dbName, pattern)
|
|
357
|
+
catalogs.extend(non_materialized_catalogs)
|
|
358
|
+
namespaces.extend(non_materialized_namespaces)
|
|
359
|
+
names.extend(non_materialized_names)
|
|
360
|
+
descriptions.extend(non_materialized_descriptions)
|
|
361
|
+
table_types.extend(non_materialized_table_types)
|
|
362
|
+
is_temporaries.extend(non_materialized_is_temporaries)
|
|
363
|
+
|
|
256
364
|
return pandas.DataFrame(
|
|
257
365
|
{
|
|
258
366
|
"name": names,
|
|
@@ -297,6 +405,36 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
297
405
|
spark_tableName: str,
|
|
298
406
|
) -> pandas.DataFrame:
|
|
299
407
|
"""Listing a single table/view with provided name that's accessible in Snowflake."""
|
|
408
|
+
|
|
409
|
+
def _get_temp_view():
|
|
410
|
+
spark_table_name_parts = [
|
|
411
|
+
quote_name_without_upper_casing(part)
|
|
412
|
+
for part in split_fully_qualified_spark_name(spark_tableName)
|
|
413
|
+
]
|
|
414
|
+
spark_view_name = ".".join(spark_table_name_parts)
|
|
415
|
+
temp_view = get_temp_view(spark_view_name)
|
|
416
|
+
if temp_view:
|
|
417
|
+
return pandas.DataFrame(
|
|
418
|
+
{
|
|
419
|
+
"name": [unquote_if_quoted(spark_table_name_parts[-1])],
|
|
420
|
+
"catalog": [None],
|
|
421
|
+
"namespace": [
|
|
422
|
+
[unquote_if_quoted(spark_table_name_parts[-2])]
|
|
423
|
+
if len(spark_table_name_parts) > 1
|
|
424
|
+
else []
|
|
425
|
+
],
|
|
426
|
+
"description": [None],
|
|
427
|
+
"tableType": ["TEMPORARY"],
|
|
428
|
+
"isTemporary": [True],
|
|
429
|
+
}
|
|
430
|
+
)
|
|
431
|
+
return None
|
|
432
|
+
|
|
433
|
+
# Attempt to get the view from the non materialized views first
|
|
434
|
+
temp_view = _get_temp_view()
|
|
435
|
+
if temp_view is not None:
|
|
436
|
+
return temp_view
|
|
437
|
+
|
|
300
438
|
sp_catalog = get_or_create_snowpark_session().catalog
|
|
301
439
|
catalog, sf_database, sf_schema, table_name = _process_multi_layer_identifier(
|
|
302
440
|
spark_tableName
|
|
@@ -360,12 +498,64 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
360
498
|
exists = False
|
|
361
499
|
return pandas.DataFrame({"exists": [exists]})
|
|
362
500
|
|
|
501
|
+
def _list_temp_view_columns(
|
|
502
|
+
self,
|
|
503
|
+
spark_tableName: str,
|
|
504
|
+
spark_dbName: typing.Optional[str] = None,
|
|
505
|
+
):
|
|
506
|
+
spark_view_name_parts = [
|
|
507
|
+
quote_name_without_upper_casing(part)
|
|
508
|
+
for part in split_fully_qualified_spark_name(spark_tableName)
|
|
509
|
+
]
|
|
510
|
+
spark_view_name_parts = (
|
|
511
|
+
self._get_temp_view_prefixes(spark_dbName) + spark_view_name_parts
|
|
512
|
+
)
|
|
513
|
+
spark_view_name = ".".join(spark_view_name_parts)
|
|
514
|
+
temp_view = get_temp_view(spark_view_name)
|
|
515
|
+
|
|
516
|
+
if not temp_view:
|
|
517
|
+
return None
|
|
518
|
+
|
|
519
|
+
names: list[str] = list()
|
|
520
|
+
descriptions: list[str | None] = list()
|
|
521
|
+
data_types: list[str] = list()
|
|
522
|
+
nullables: list[bool] = list()
|
|
523
|
+
is_partitions: list[bool] = list()
|
|
524
|
+
is_buckets: list[bool] = list()
|
|
525
|
+
|
|
526
|
+
for field, spark_column in zip(
|
|
527
|
+
temp_view.dataframe.schema.fields,
|
|
528
|
+
temp_view.column_map.get_spark_columns(),
|
|
529
|
+
):
|
|
530
|
+
names.append(spark_column)
|
|
531
|
+
descriptions.append(None)
|
|
532
|
+
data_types.append(field.datatype.simpleString())
|
|
533
|
+
nullables.append(field.nullable)
|
|
534
|
+
is_partitions.append(False)
|
|
535
|
+
is_buckets.append(False)
|
|
536
|
+
|
|
537
|
+
return pandas.DataFrame(
|
|
538
|
+
{
|
|
539
|
+
"name": names,
|
|
540
|
+
"description": descriptions,
|
|
541
|
+
"dataType": data_types,
|
|
542
|
+
"nullable": nullables,
|
|
543
|
+
"isPartition": is_partitions,
|
|
544
|
+
"isBucket": is_buckets,
|
|
545
|
+
}
|
|
546
|
+
)
|
|
547
|
+
|
|
363
548
|
def listColumns(
|
|
364
549
|
self,
|
|
365
550
|
spark_tableName: str,
|
|
366
551
|
spark_dbName: typing.Optional[str] = None,
|
|
367
552
|
) -> pandas.DataFrame:
|
|
368
553
|
"""List all columns in a table/view, optionally database name filter can be provided."""
|
|
554
|
+
|
|
555
|
+
temp_view_columns = self._list_temp_view_columns(spark_tableName, spark_dbName)
|
|
556
|
+
if temp_view_columns is not None:
|
|
557
|
+
return temp_view_columns
|
|
558
|
+
|
|
369
559
|
sp_catalog = get_or_create_snowpark_session().catalog
|
|
370
560
|
columns: list[TableColumn] | None = None
|
|
371
561
|
if spark_dbName is None:
|
|
@@ -455,17 +645,15 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
455
645
|
spark_view_name: str,
|
|
456
646
|
) -> DataFrameContainer:
|
|
457
647
|
session = get_or_create_snowpark_session()
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
)
|
|
463
|
-
result_df = result_df.select(
|
|
464
|
-
functions.contains('"status"', functions.lit("successfully dropped")).alias(
|
|
465
|
-
"value"
|
|
648
|
+
if not spark_view_name == "":
|
|
649
|
+
schema = global_config.spark_sql_globalTempDatabase
|
|
650
|
+
result = unregister_temp_view(
|
|
651
|
+
f"{spark_to_sf_single_id_with_unquoting(schema)}.{spark_to_sf_single_id_with_unquoting(spark_view_name)}"
|
|
466
652
|
)
|
|
467
|
-
|
|
653
|
+
else:
|
|
654
|
+
result = False
|
|
468
655
|
columns = ["value"]
|
|
656
|
+
result_df = session.createDataFrame([result], schema=columns)
|
|
469
657
|
return DataFrameContainer.create_with_column_mapping(
|
|
470
658
|
dataframe=result_df,
|
|
471
659
|
spark_column_names=columns,
|
|
@@ -479,15 +667,14 @@ class SnowflakeCatalog(AbstractSparkCatalog):
|
|
|
479
667
|
) -> DataFrameContainer:
|
|
480
668
|
"""Drop the current temporary view."""
|
|
481
669
|
session = get_or_create_snowpark_session()
|
|
482
|
-
result = session.sql(
|
|
483
|
-
"drop view if exists identifier(?)",
|
|
484
|
-
params=[sf_quote(spark_view_name)],
|
|
485
|
-
).collect()
|
|
486
|
-
view_was_dropped = (
|
|
487
|
-
len(result) == 1 and "successfully dropped" in result[0]["status"]
|
|
488
|
-
)
|
|
489
|
-
result_df = session.createDataFrame([(view_was_dropped,)], schema=["value"])
|
|
490
670
|
columns = ["value"]
|
|
671
|
+
if spark_view_name:
|
|
672
|
+
result = unregister_temp_view(
|
|
673
|
+
spark_to_sf_single_id_with_unquoting(spark_view_name)
|
|
674
|
+
)
|
|
675
|
+
else:
|
|
676
|
+
result = False
|
|
677
|
+
result_df = session.createDataFrame([result], schema=columns)
|
|
491
678
|
return DataFrameContainer.create_with_column_mapping(
|
|
492
679
|
dataframe=result_df,
|
|
493
680
|
spark_column_names=columns,
|
|
@@ -7,8 +7,27 @@ from urllib.parse import urlparse
|
|
|
7
7
|
CLOUD_PREFIX_TO_CLOUD = {
|
|
8
8
|
"abfss": "azure",
|
|
9
9
|
"wasbs": "azure",
|
|
10
|
+
"gcs": "gcp",
|
|
11
|
+
"gs": "gcp",
|
|
10
12
|
}
|
|
11
13
|
|
|
14
|
+
SUPPORTED_COMPRESSION_PER_FORMAT = {
|
|
15
|
+
"csv": {"AUTO", "GZIP", "BZ2", "BROTLI", "ZSTD", "DEFLATE", "RAW_DEFLATE", "NONE"},
|
|
16
|
+
"json": {"AUTO", "GZIP", "BZ2", "BROTLI", "ZSTD", "DEFLATE", "RAW_DEFLATE", "NONE"},
|
|
17
|
+
"parquet": {"AUTO", "LZO", "SNAPPY", "NONE"},
|
|
18
|
+
"text": {"NONE"},
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def supported_compressions_for_format(format: str) -> set[str]:
|
|
23
|
+
return SUPPORTED_COMPRESSION_PER_FORMAT.get(format, set())
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def is_supported_compression(format: str, compression: str | None) -> bool:
|
|
27
|
+
if compression is None:
|
|
28
|
+
return True
|
|
29
|
+
return compression in supported_compressions_for_format(format)
|
|
30
|
+
|
|
12
31
|
|
|
13
32
|
def get_cloud_from_url(
|
|
14
33
|
url: str,
|
|
@@ -66,7 +85,8 @@ def is_cloud_path(path: str) -> bool:
|
|
|
66
85
|
or path.startswith("azure://")
|
|
67
86
|
or path.startswith("abfss://")
|
|
68
87
|
or path.startswith("wasbs://") # Azure
|
|
69
|
-
or path.startswith("gcs://")
|
|
88
|
+
or path.startswith("gcs://")
|
|
89
|
+
or path.startswith("gs://") # GCP
|
|
70
90
|
)
|
|
71
91
|
|
|
72
92
|
|
|
@@ -345,7 +345,7 @@ def map_aggregate(
|
|
|
345
345
|
return new_names[0], snowpark_column
|
|
346
346
|
|
|
347
347
|
raw_groupings: list[tuple[str, TypedColumn]] = []
|
|
348
|
-
raw_aggregations: list[tuple[str, TypedColumn]] = []
|
|
348
|
+
raw_aggregations: list[tuple[str, TypedColumn, list[str]]] = []
|
|
349
349
|
|
|
350
350
|
if not is_group_by_all:
|
|
351
351
|
raw_groupings = [_map_column(exp) for exp in aggregate.grouping_expressions]
|
|
@@ -375,10 +375,21 @@ def map_aggregate(
|
|
|
375
375
|
# Note: We don't clear the map here to preserve any parent context aliases
|
|
376
376
|
from snowflake.snowpark_connect.utils.context import register_lca_alias
|
|
377
377
|
|
|
378
|
+
# If it's an unresolved attribute when its in aggregate.aggregate_expressions, we know it came from the parent map straight away
|
|
379
|
+
# in this case, we should see if the parent map has a qualifier for it and propagate that here, in case the order by references it in
|
|
380
|
+
# a qualified way later.
|
|
378
381
|
agg_count = get_sql_aggregate_function_count()
|
|
379
382
|
for exp in aggregate.aggregate_expressions:
|
|
380
383
|
col = _map_column(exp)
|
|
381
|
-
|
|
384
|
+
if exp.WhichOneof("expr_type") == "unresolved_attribute":
|
|
385
|
+
spark_name = col[0]
|
|
386
|
+
qualifiers = input_container.column_map.get_qualifier_for_spark_column(
|
|
387
|
+
spark_name
|
|
388
|
+
)
|
|
389
|
+
else:
|
|
390
|
+
qualifiers = []
|
|
391
|
+
|
|
392
|
+
raw_aggregations.append((col[0], col[1], qualifiers))
|
|
382
393
|
|
|
383
394
|
# If this is an alias, register it in the LCA map for subsequent expressions
|
|
384
395
|
if (
|
|
@@ -409,18 +420,20 @@ def map_aggregate(
|
|
|
409
420
|
spark_columns: list[str] = []
|
|
410
421
|
snowpark_columns: list[str] = []
|
|
411
422
|
snowpark_column_types: list[snowpark_types.DataType] = []
|
|
423
|
+
all_qualifiers: list[list[str]] = []
|
|
412
424
|
|
|
413
425
|
# Use grouping columns directly without aliases
|
|
414
426
|
groupings = [col.col for _, col in raw_groupings]
|
|
415
427
|
|
|
416
428
|
# Create aliases only for aggregation columns
|
|
417
429
|
aggregations = []
|
|
418
|
-
for i, (spark_name, snowpark_column) in enumerate(raw_aggregations):
|
|
430
|
+
for i, (spark_name, snowpark_column, qualifiers) in enumerate(raw_aggregations):
|
|
419
431
|
alias = make_column_names_snowpark_compatible([spark_name], plan_id, i)[0]
|
|
420
432
|
|
|
421
433
|
spark_columns.append(spark_name)
|
|
422
434
|
snowpark_columns.append(alias)
|
|
423
435
|
snowpark_column_types.append(snowpark_column.typ)
|
|
436
|
+
all_qualifiers.append(qualifiers)
|
|
424
437
|
|
|
425
438
|
aggregations.append(snowpark_column.col.alias(alias))
|
|
426
439
|
|
|
@@ -483,6 +496,7 @@ def map_aggregate(
|
|
|
483
496
|
spark_column_names=spark_columns,
|
|
484
497
|
snowpark_column_names=snowpark_columns,
|
|
485
498
|
snowpark_column_types=snowpark_column_types,
|
|
499
|
+
column_qualifiers=all_qualifiers,
|
|
486
500
|
).column_map
|
|
487
501
|
|
|
488
502
|
# Create hybrid column map that can resolve both input and aggregate contexts
|
|
@@ -494,7 +508,9 @@ def map_aggregate(
|
|
|
494
508
|
aggregate_expressions=list(aggregate.aggregate_expressions),
|
|
495
509
|
grouping_expressions=list(aggregate.grouping_expressions),
|
|
496
510
|
spark_columns=spark_columns,
|
|
497
|
-
raw_aggregations=
|
|
511
|
+
raw_aggregations=[
|
|
512
|
+
(spark_name, col) for spark_name, col, _ in raw_aggregations
|
|
513
|
+
],
|
|
498
514
|
)
|
|
499
515
|
|
|
500
516
|
# Map the HAVING condition using hybrid resolution
|
|
@@ -515,4 +531,5 @@ def map_aggregate(
|
|
|
515
531
|
snowpark_column_names=snowpark_columns,
|
|
516
532
|
snowpark_column_types=snowpark_column_types,
|
|
517
533
|
parent_column_name_map=input_df._column_map,
|
|
534
|
+
column_qualifiers=all_qualifiers,
|
|
518
535
|
)
|
|
@@ -12,7 +12,6 @@ from snowflake.snowpark_connect.constants import MAP_IN_ARROW_EVAL_TYPE
|
|
|
12
12
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
13
13
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
14
14
|
from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
|
|
15
|
-
from snowflake.snowpark_connect.utils.context import map_partitions_depth
|
|
16
15
|
from snowflake.snowpark_connect.utils.pandas_udtf_utils import (
|
|
17
16
|
create_pandas_udtf,
|
|
18
17
|
create_pandas_udtf_with_arrow,
|
|
@@ -53,18 +52,18 @@ def _call_udtf(
|
|
|
53
52
|
).cast("int"),
|
|
54
53
|
)
|
|
55
54
|
|
|
56
|
-
udtf_columns = input_df.columns + [
|
|
55
|
+
udtf_columns = [f"snowflake_jtf_{column}" for column in input_df.columns] + [
|
|
56
|
+
"_DUMMY_PARTITION_KEY"
|
|
57
|
+
]
|
|
57
58
|
|
|
58
59
|
tfc = snowpark_fn.call_table_function(udtf_name, *udtf_columns).over(
|
|
59
60
|
partition_by=[snowpark_fn.col("_DUMMY_PARTITION_KEY")]
|
|
60
61
|
)
|
|
61
62
|
|
|
62
|
-
#
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
else:
|
|
67
|
-
result_df_with_dummy = input_df_with_dummy.select(tfc)
|
|
63
|
+
# Overwrite the input_df columns to prevent name conflicts with UDTF output columns
|
|
64
|
+
result_df_with_dummy = input_df_with_dummy.to_df(udtf_columns).join_table_function(
|
|
65
|
+
tfc
|
|
66
|
+
)
|
|
68
67
|
|
|
69
68
|
output_cols = [field.name for field in return_type.fields]
|
|
70
69
|
|
|
@@ -16,7 +16,6 @@ from snowflake.snowpark_connect.utils.context import (
|
|
|
16
16
|
get_plan_id_map,
|
|
17
17
|
get_session_id,
|
|
18
18
|
not_resolving_fun_args,
|
|
19
|
-
push_map_partitions,
|
|
20
19
|
push_operation_scope,
|
|
21
20
|
set_is_aggregate_function,
|
|
22
21
|
set_plan_id_map,
|
|
@@ -185,8 +184,7 @@ def map_relation(
|
|
|
185
184
|
)
|
|
186
185
|
return cached_df
|
|
187
186
|
case "map_partitions":
|
|
188
|
-
|
|
189
|
-
result = map_map_partitions.map_map_partitions(rel)
|
|
187
|
+
result = map_map_partitions.map_map_partitions(rel)
|
|
190
188
|
case "offset":
|
|
191
189
|
result = map_row_ops.map_offset(rel)
|
|
192
190
|
case "project":
|