snowpark-connect 0.28.0__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (36) hide show
  1. snowflake/snowpark_connect/config.py +12 -3
  2. snowflake/snowpark_connect/execute_plan/map_execution_command.py +31 -68
  3. snowflake/snowpark_connect/expression/map_unresolved_function.py +172 -210
  4. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +207 -20
  5. snowflake/snowpark_connect/relation/io_utils.py +21 -1
  6. snowflake/snowpark_connect/relation/map_extension.py +21 -4
  7. snowflake/snowpark_connect/relation/map_map_partitions.py +7 -8
  8. snowflake/snowpark_connect/relation/map_relation.py +1 -3
  9. snowflake/snowpark_connect/relation/map_sql.py +112 -53
  10. snowflake/snowpark_connect/relation/read/map_read.py +22 -3
  11. snowflake/snowpark_connect/relation/read/map_read_csv.py +105 -26
  12. snowflake/snowpark_connect/relation/read/map_read_json.py +45 -34
  13. snowflake/snowpark_connect/relation/read/map_read_table.py +58 -0
  14. snowflake/snowpark_connect/relation/read/map_read_text.py +6 -1
  15. snowflake/snowpark_connect/relation/stage_locator.py +85 -53
  16. snowflake/snowpark_connect/relation/write/map_write.py +95 -14
  17. snowflake/snowpark_connect/server.py +18 -13
  18. snowflake/snowpark_connect/utils/context.py +21 -14
  19. snowflake/snowpark_connect/utils/identifiers.py +8 -2
  20. snowflake/snowpark_connect/utils/io_utils.py +36 -0
  21. snowflake/snowpark_connect/utils/session.py +3 -0
  22. snowflake/snowpark_connect/utils/temporary_view_cache.py +61 -0
  23. snowflake/snowpark_connect/utils/udf_cache.py +37 -7
  24. snowflake/snowpark_connect/utils/udf_utils.py +9 -8
  25. snowflake/snowpark_connect/utils/udtf_utils.py +3 -2
  26. snowflake/snowpark_connect/version.py +1 -1
  27. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/METADATA +3 -2
  28. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/RECORD +36 -35
  29. {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-connect +0 -0
  30. {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-session +0 -0
  31. {snowpark_connect-0.28.0.data → snowpark_connect-0.29.0.data}/scripts/snowpark-submit +0 -0
  32. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/WHEEL +0 -0
  33. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE-binary +0 -0
  34. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/LICENSE.txt +0 -0
  35. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/licenses/NOTICE-binary +0 -0
  36. {snowpark_connect-0.28.0.dist-info → snowpark_connect-0.29.0.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@ from snowflake.core.exceptions import APIError, NotFoundError
13
13
  from snowflake.core.schema import Schema
14
14
  from snowflake.core.table import Table, TableColumn
15
15
 
16
- from snowflake.snowpark import functions
17
16
  from snowflake.snowpark._internal.analyzer.analyzer_utils import (
18
17
  quote_name_without_upper_casing,
19
18
  unquote_if_quoted,
@@ -34,12 +33,19 @@ from snowflake.snowpark_connect.relation.catalogs.abstract_spark_catalog import
34
33
  )
35
34
  from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
36
35
  from snowflake.snowpark_connect.utils.identifiers import (
36
+ FQN,
37
+ spark_to_sf_single_id_with_unquoting,
37
38
  split_fully_qualified_spark_name,
38
39
  )
39
40
  from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
40
41
  from snowflake.snowpark_connect.utils.telemetry import (
41
42
  SnowparkConnectNotImplementedError,
42
43
  )
44
+ from snowflake.snowpark_connect.utils.temporary_view_cache import (
45
+ get_temp_view,
46
+ get_temp_view_normalized_names,
47
+ unregister_temp_view,
48
+ )
43
49
  from snowflake.snowpark_connect.utils.udf_cache import cached_udf
44
50
 
45
51
 
@@ -203,6 +209,93 @@ class SnowflakeCatalog(AbstractSparkCatalog):
203
209
  exists = False
204
210
  return pandas.DataFrame({"exists": [exists]})
205
211
 
212
+ def _get_temp_view_prefixes(self, spark_dbName: str | None) -> list[str]:
213
+ if spark_dbName is None:
214
+ return []
215
+ return [
216
+ quote_name_without_upper_casing(part)
217
+ for part in split_fully_qualified_spark_name(spark_dbName)
218
+ ]
219
+
220
+ def _list_temp_views(
221
+ self,
222
+ spark_dbName: str | None = None,
223
+ pattern: str | None = None,
224
+ ) -> typing.Tuple[
225
+ list[str | None],
226
+ list[list[str | None]],
227
+ list[str],
228
+ list[str | None],
229
+ list[str | None],
230
+ list[bool],
231
+ ]:
232
+ catalogs: list[str | None] = list()
233
+ namespaces: list[list[str | None]] = list()
234
+ names: list[str] = list()
235
+ descriptions: list[str | None] = list()
236
+ table_types: list[str | None] = list()
237
+ is_temporaries: list[bool] = list()
238
+
239
+ temp_views_prefix = ".".join(self._get_temp_view_prefixes(spark_dbName))
240
+ normalized_spark_dbName = (
241
+ temp_views_prefix.lower()
242
+ if global_config.spark_sql_caseSensitive
243
+ else temp_views_prefix
244
+ )
245
+ normalized_global_temp_database_name = (
246
+ quote_name_without_upper_casing(
247
+ global_config.spark_sql_globalTempDatabase.lower()
248
+ )
249
+ if global_config.spark_sql_caseSensitive
250
+ else quote_name_without_upper_casing(
251
+ global_config.spark_sql_globalTempDatabase
252
+ )
253
+ )
254
+
255
+ temp_views = get_temp_view_normalized_names()
256
+ null_safe_pattern = pattern if pattern is not None else ""
257
+
258
+ for temp_view in temp_views:
259
+ normalized_temp_view = (
260
+ temp_view.lower()
261
+ if global_config.spark_sql_caseSensitive
262
+ else temp_view
263
+ )
264
+ fqn = FQN.from_string(temp_view)
265
+ normalized_schema = (
266
+ fqn.schema.lower()
267
+ if fqn.schema is not None and global_config.spark_sql_caseSensitive
268
+ else fqn.schema
269
+ )
270
+
271
+ is_global_view = normalized_global_temp_database_name == normalized_schema
272
+ is_local_temp_view = fqn.schema is None
273
+ # Temporary views are always shown if they match the pattern
274
+ matches_prefix = (
275
+ normalized_spark_dbName == normalized_schema or is_local_temp_view
276
+ )
277
+ if matches_prefix and bool(
278
+ re.match(null_safe_pattern, normalized_temp_view)
279
+ ):
280
+ names.append(unquote_if_quoted(fqn.name))
281
+ catalogs.append(None)
282
+ namespaces.append(
283
+ [global_config.spark_sql_globalTempDatabase]
284
+ if is_global_view
285
+ else []
286
+ )
287
+ descriptions.append(None)
288
+ table_types.append("TEMPORARY")
289
+ is_temporaries.append(True)
290
+ return (
291
+ catalogs,
292
+ namespaces,
293
+ names,
294
+ descriptions,
295
+ table_types,
296
+ is_temporaries,
297
+ )
298
+
206
299
  def listTables(
207
300
  self,
208
301
  spark_dbName: str | None = None,
@@ -232,8 +325,7 @@ class SnowflakeCatalog(AbstractSparkCatalog):
232
325
  schema=sf_quote(sf_schema),
233
326
  pattern=_normalize_identifier(pattern),
234
327
  )
235
- names: list[str] = list()
236
- catalogs: list[str] = list()
328
+ catalogs: list[str | None] = list()
237
329
  namespaces: list[list[str | None]] = list()
238
330
  names: list[str] = list()
239
331
  descriptions: list[str | None] = list()
@@ -253,6 +345,22 @@ class SnowflakeCatalog(AbstractSparkCatalog):
253
345
  descriptions.append(o[6] if o[6] else None)
254
346
  table_types.append("PERMANENT")
255
347
  is_temporaries.append(False)
348
+
349
+ (
350
+ non_materialized_catalogs,
351
+ non_materialized_namespaces,
352
+ non_materialized_names,
353
+ non_materialized_descriptions,
354
+ non_materialized_table_types,
355
+ non_materialized_is_temporaries,
356
+ ) = self._list_temp_views(spark_dbName, pattern)
357
+ catalogs.extend(non_materialized_catalogs)
358
+ namespaces.extend(non_materialized_namespaces)
359
+ names.extend(non_materialized_names)
360
+ descriptions.extend(non_materialized_descriptions)
361
+ table_types.extend(non_materialized_table_types)
362
+ is_temporaries.extend(non_materialized_is_temporaries)
363
+
256
364
  return pandas.DataFrame(
257
365
  {
258
366
  "name": names,
@@ -297,6 +405,36 @@ class SnowflakeCatalog(AbstractSparkCatalog):
297
405
  spark_tableName: str,
298
406
  ) -> pandas.DataFrame:
299
407
  """Listing a single table/view with provided name that's accessible in Snowflake."""
408
+
409
+ def _get_temp_view():
410
+ spark_table_name_parts = [
411
+ quote_name_without_upper_casing(part)
412
+ for part in split_fully_qualified_spark_name(spark_tableName)
413
+ ]
414
+ spark_view_name = ".".join(spark_table_name_parts)
415
+ temp_view = get_temp_view(spark_view_name)
416
+ if temp_view:
417
+ return pandas.DataFrame(
418
+ {
419
+ "name": [unquote_if_quoted(spark_table_name_parts[-1])],
420
+ "catalog": [None],
421
+ "namespace": [
422
+ [unquote_if_quoted(spark_table_name_parts[-2])]
423
+ if len(spark_table_name_parts) > 1
424
+ else []
425
+ ],
426
+ "description": [None],
427
+ "tableType": ["TEMPORARY"],
428
+ "isTemporary": [True],
429
+ }
430
+ )
431
+ return None
432
+
433
+ # Attempt to get the view from the non materialized views first
434
+ temp_view = _get_temp_view()
435
+ if temp_view is not None:
436
+ return temp_view
437
+
300
438
  sp_catalog = get_or_create_snowpark_session().catalog
301
439
  catalog, sf_database, sf_schema, table_name = _process_multi_layer_identifier(
302
440
  spark_tableName
@@ -360,12 +498,64 @@ class SnowflakeCatalog(AbstractSparkCatalog):
360
498
  exists = False
361
499
  return pandas.DataFrame({"exists": [exists]})
362
500
 
501
+ def _list_temp_view_columns(
502
+ self,
503
+ spark_tableName: str,
504
+ spark_dbName: typing.Optional[str] = None,
505
+ ):
506
+ spark_view_name_parts = [
507
+ quote_name_without_upper_casing(part)
508
+ for part in split_fully_qualified_spark_name(spark_tableName)
509
+ ]
510
+ spark_view_name_parts = (
511
+ self._get_temp_view_prefixes(spark_dbName) + spark_view_name_parts
512
+ )
513
+ spark_view_name = ".".join(spark_view_name_parts)
514
+ temp_view = get_temp_view(spark_view_name)
515
+
516
+ if not temp_view:
517
+ return None
518
+
519
+ names: list[str] = list()
520
+ descriptions: list[str | None] = list()
521
+ data_types: list[str] = list()
522
+ nullables: list[bool] = list()
523
+ is_partitions: list[bool] = list()
524
+ is_buckets: list[bool] = list()
525
+
526
+ for field, spark_column in zip(
527
+ temp_view.dataframe.schema.fields,
528
+ temp_view.column_map.get_spark_columns(),
529
+ ):
530
+ names.append(spark_column)
531
+ descriptions.append(None)
532
+ data_types.append(field.datatype.simpleString())
533
+ nullables.append(field.nullable)
534
+ is_partitions.append(False)
535
+ is_buckets.append(False)
536
+
537
+ return pandas.DataFrame(
538
+ {
539
+ "name": names,
540
+ "description": descriptions,
541
+ "dataType": data_types,
542
+ "nullable": nullables,
543
+ "isPartition": is_partitions,
544
+ "isBucket": is_buckets,
545
+ }
546
+ )
547
+
363
548
  def listColumns(
364
549
  self,
365
550
  spark_tableName: str,
366
551
  spark_dbName: typing.Optional[str] = None,
367
552
  ) -> pandas.DataFrame:
368
553
  """List all columns in a table/view, optionally database name filter can be provided."""
554
+
555
+ temp_view_columns = self._list_temp_view_columns(spark_tableName, spark_dbName)
556
+ if temp_view_columns is not None:
557
+ return temp_view_columns
558
+
369
559
  sp_catalog = get_or_create_snowpark_session().catalog
370
560
  columns: list[TableColumn] | None = None
371
561
  if spark_dbName is None:
@@ -455,17 +645,15 @@ class SnowflakeCatalog(AbstractSparkCatalog):
455
645
  spark_view_name: str,
456
646
  ) -> DataFrameContainer:
457
647
  session = get_or_create_snowpark_session()
458
- schema = global_config.spark_sql_globalTempDatabase
459
- result_df = session.sql(
460
- "drop view if exists identifier(?)",
461
- params=[f"{sf_quote(schema)}.{sf_quote(spark_view_name)}"],
462
- )
463
- result_df = result_df.select(
464
- functions.contains('"status"', functions.lit("successfully dropped")).alias(
465
- "value"
648
+ if not spark_view_name == "":
649
+ schema = global_config.spark_sql_globalTempDatabase
650
+ result = unregister_temp_view(
651
+ f"{spark_to_sf_single_id_with_unquoting(schema)}.{spark_to_sf_single_id_with_unquoting(spark_view_name)}"
466
652
  )
467
- )
653
+ else:
654
+ result = False
468
655
  columns = ["value"]
656
+ result_df = session.createDataFrame([result], schema=columns)
469
657
  return DataFrameContainer.create_with_column_mapping(
470
658
  dataframe=result_df,
471
659
  spark_column_names=columns,
@@ -479,15 +667,14 @@ class SnowflakeCatalog(AbstractSparkCatalog):
479
667
  ) -> DataFrameContainer:
480
668
  """Drop the current temporary view."""
481
669
  session = get_or_create_snowpark_session()
482
- result = session.sql(
483
- "drop view if exists identifier(?)",
484
- params=[sf_quote(spark_view_name)],
485
- ).collect()
486
- view_was_dropped = (
487
- len(result) == 1 and "successfully dropped" in result[0]["status"]
488
- )
489
- result_df = session.createDataFrame([(view_was_dropped,)], schema=["value"])
490
670
  columns = ["value"]
671
+ if spark_view_name:
672
+ result = unregister_temp_view(
673
+ spark_to_sf_single_id_with_unquoting(spark_view_name)
674
+ )
675
+ else:
676
+ result = False
677
+ result_df = session.createDataFrame([result], schema=columns)
491
678
  return DataFrameContainer.create_with_column_mapping(
492
679
  dataframe=result_df,
493
680
  spark_column_names=columns,
@@ -7,8 +7,27 @@ from urllib.parse import urlparse
7
7
  CLOUD_PREFIX_TO_CLOUD = {
8
8
  "abfss": "azure",
9
9
  "wasbs": "azure",
10
+ "gcs": "gcp",
11
+ "gs": "gcp",
10
12
  }
11
13
 
14
+ SUPPORTED_COMPRESSION_PER_FORMAT = {
15
+ "csv": {"AUTO", "GZIP", "BZ2", "BROTLI", "ZSTD", "DEFLATE", "RAW_DEFLATE", "NONE"},
16
+ "json": {"AUTO", "GZIP", "BZ2", "BROTLI", "ZSTD", "DEFLATE", "RAW_DEFLATE", "NONE"},
17
+ "parquet": {"AUTO", "LZO", "SNAPPY", "NONE"},
18
+ "text": {"NONE"},
19
+ }
20
+
21
+
22
+ def supported_compressions_for_format(format: str) -> set[str]:
23
+ return SUPPORTED_COMPRESSION_PER_FORMAT.get(format, set())
24
+
25
+
26
+ def is_supported_compression(format: str, compression: str | None) -> bool:
27
+ if compression is None:
28
+ return True
29
+ return compression in supported_compressions_for_format(format)
30
+
12
31
 
13
32
  def get_cloud_from_url(
14
33
  url: str,
@@ -66,7 +85,8 @@ def is_cloud_path(path: str) -> bool:
66
85
  or path.startswith("azure://")
67
86
  or path.startswith("abfss://")
68
87
  or path.startswith("wasbs://") # Azure
69
- or path.startswith("gcs://") # GCP
88
+ or path.startswith("gcs://")
89
+ or path.startswith("gs://") # GCP
70
90
  )
71
91
 
72
92
 
@@ -345,7 +345,7 @@ def map_aggregate(
345
345
  return new_names[0], snowpark_column
346
346
 
347
347
  raw_groupings: list[tuple[str, TypedColumn]] = []
348
- raw_aggregations: list[tuple[str, TypedColumn]] = []
348
+ raw_aggregations: list[tuple[str, TypedColumn, list[str]]] = []
349
349
 
350
350
  if not is_group_by_all:
351
351
  raw_groupings = [_map_column(exp) for exp in aggregate.grouping_expressions]
@@ -375,10 +375,21 @@ def map_aggregate(
375
375
  # Note: We don't clear the map here to preserve any parent context aliases
376
376
  from snowflake.snowpark_connect.utils.context import register_lca_alias
377
377
 
378
+ # If it's an unresolved attribute when its in aggregate.aggregate_expressions, we know it came from the parent map straight away
379
+ # in this case, we should see if the parent map has a qualifier for it and propagate that here, in case the order by references it in
380
+ # a qualified way later.
378
381
  agg_count = get_sql_aggregate_function_count()
379
382
  for exp in aggregate.aggregate_expressions:
380
383
  col = _map_column(exp)
381
- raw_aggregations.append(col)
384
+ if exp.WhichOneof("expr_type") == "unresolved_attribute":
385
+ spark_name = col[0]
386
+ qualifiers = input_container.column_map.get_qualifier_for_spark_column(
387
+ spark_name
388
+ )
389
+ else:
390
+ qualifiers = []
391
+
392
+ raw_aggregations.append((col[0], col[1], qualifiers))
382
393
 
383
394
  # If this is an alias, register it in the LCA map for subsequent expressions
384
395
  if (
@@ -409,18 +420,20 @@ def map_aggregate(
409
420
  spark_columns: list[str] = []
410
421
  snowpark_columns: list[str] = []
411
422
  snowpark_column_types: list[snowpark_types.DataType] = []
423
+ all_qualifiers: list[list[str]] = []
412
424
 
413
425
  # Use grouping columns directly without aliases
414
426
  groupings = [col.col for _, col in raw_groupings]
415
427
 
416
428
  # Create aliases only for aggregation columns
417
429
  aggregations = []
418
- for i, (spark_name, snowpark_column) in enumerate(raw_aggregations):
430
+ for i, (spark_name, snowpark_column, qualifiers) in enumerate(raw_aggregations):
419
431
  alias = make_column_names_snowpark_compatible([spark_name], plan_id, i)[0]
420
432
 
421
433
  spark_columns.append(spark_name)
422
434
  snowpark_columns.append(alias)
423
435
  snowpark_column_types.append(snowpark_column.typ)
436
+ all_qualifiers.append(qualifiers)
424
437
 
425
438
  aggregations.append(snowpark_column.col.alias(alias))
426
439
 
@@ -483,6 +496,7 @@ def map_aggregate(
483
496
  spark_column_names=spark_columns,
484
497
  snowpark_column_names=snowpark_columns,
485
498
  snowpark_column_types=snowpark_column_types,
499
+ column_qualifiers=all_qualifiers,
486
500
  ).column_map
487
501
 
488
502
  # Create hybrid column map that can resolve both input and aggregate contexts
@@ -494,7 +508,9 @@ def map_aggregate(
494
508
  aggregate_expressions=list(aggregate.aggregate_expressions),
495
509
  grouping_expressions=list(aggregate.grouping_expressions),
496
510
  spark_columns=spark_columns,
497
- raw_aggregations=raw_aggregations,
511
+ raw_aggregations=[
512
+ (spark_name, col) for spark_name, col, _ in raw_aggregations
513
+ ],
498
514
  )
499
515
 
500
516
  # Map the HAVING condition using hybrid resolution
@@ -515,4 +531,5 @@ def map_aggregate(
515
531
  snowpark_column_names=snowpark_columns,
516
532
  snowpark_column_types=snowpark_column_types,
517
533
  parent_column_name_map=input_df._column_map,
534
+ column_qualifiers=all_qualifiers,
518
535
  )
@@ -12,7 +12,6 @@ from snowflake.snowpark_connect.constants import MAP_IN_ARROW_EVAL_TYPE
12
12
  from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
13
13
  from snowflake.snowpark_connect.relation.map_relation import map_relation
14
14
  from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
15
- from snowflake.snowpark_connect.utils.context import map_partitions_depth
16
15
  from snowflake.snowpark_connect.utils.pandas_udtf_utils import (
17
16
  create_pandas_udtf,
18
17
  create_pandas_udtf_with_arrow,
@@ -53,18 +52,18 @@ def _call_udtf(
53
52
  ).cast("int"),
54
53
  )
55
54
 
56
- udtf_columns = input_df.columns + [snowpark_fn.col("_DUMMY_PARTITION_KEY")]
55
+ udtf_columns = [f"snowflake_jtf_{column}" for column in input_df.columns] + [
56
+ "_DUMMY_PARTITION_KEY"
57
+ ]
57
58
 
58
59
  tfc = snowpark_fn.call_table_function(udtf_name, *udtf_columns).over(
59
60
  partition_by=[snowpark_fn.col("_DUMMY_PARTITION_KEY")]
60
61
  )
61
62
 
62
- # Use map_partitions_depth only when mapping non nested map_partitions
63
- # When mapping chained functions additional column casting is necessary
64
- if map_partitions_depth() == 1:
65
- result_df_with_dummy = input_df_with_dummy.join_table_function(tfc)
66
- else:
67
- result_df_with_dummy = input_df_with_dummy.select(tfc)
63
+ # Overwrite the input_df columns to prevent name conflicts with UDTF output columns
64
+ result_df_with_dummy = input_df_with_dummy.to_df(udtf_columns).join_table_function(
65
+ tfc
66
+ )
68
67
 
69
68
  output_cols = [field.name for field in return_type.fields]
70
69
 
@@ -16,7 +16,6 @@ from snowflake.snowpark_connect.utils.context import (
16
16
  get_plan_id_map,
17
17
  get_session_id,
18
18
  not_resolving_fun_args,
19
- push_map_partitions,
20
19
  push_operation_scope,
21
20
  set_is_aggregate_function,
22
21
  set_plan_id_map,
@@ -185,8 +184,7 @@ def map_relation(
185
184
  )
186
185
  return cached_df
187
186
  case "map_partitions":
188
- with push_map_partitions():
189
- result = map_map_partitions.map_map_partitions(rel)
187
+ result = map_map_partitions.map_map_partitions(rel)
190
188
  case "offset":
191
189
  result = map_row_ops.map_offset(rel)
192
190
  case "project":