snowpark-connect 0.27.0__py3-none-any.whl → 0.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (42) hide show
  1. snowflake/snowpark_connect/column_name_handler.py +3 -93
  2. snowflake/snowpark_connect/config.py +99 -1
  3. snowflake/snowpark_connect/dataframe_container.py +0 -6
  4. snowflake/snowpark_connect/execute_plan/map_execution_command.py +31 -68
  5. snowflake/snowpark_connect/expression/map_expression.py +22 -7
  6. snowflake/snowpark_connect/expression/map_sql_expression.py +22 -18
  7. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +4 -26
  8. snowflake/snowpark_connect/expression/map_unresolved_function.py +12 -3
  9. snowflake/snowpark_connect/expression/map_unresolved_star.py +2 -3
  10. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  11. snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +207 -20
  12. snowflake/snowpark_connect/relation/map_extension.py +14 -10
  13. snowflake/snowpark_connect/relation/map_join.py +62 -258
  14. snowflake/snowpark_connect/relation/map_relation.py +5 -1
  15. snowflake/snowpark_connect/relation/map_sql.py +464 -68
  16. snowflake/snowpark_connect/relation/read/map_read_table.py +58 -0
  17. snowflake/snowpark_connect/relation/write/map_write.py +228 -120
  18. snowflake/snowpark_connect/resources_initializer.py +20 -5
  19. snowflake/snowpark_connect/server.py +16 -17
  20. snowflake/snowpark_connect/utils/concurrent.py +4 -0
  21. snowflake/snowpark_connect/utils/context.py +21 -0
  22. snowflake/snowpark_connect/utils/describe_query_cache.py +57 -51
  23. snowflake/snowpark_connect/utils/identifiers.py +128 -2
  24. snowflake/snowpark_connect/utils/io_utils.py +21 -1
  25. snowflake/snowpark_connect/utils/scala_udf_utils.py +34 -43
  26. snowflake/snowpark_connect/utils/session.py +16 -26
  27. snowflake/snowpark_connect/utils/telemetry.py +53 -0
  28. snowflake/snowpark_connect/utils/temporary_view_cache.py +61 -0
  29. snowflake/snowpark_connect/utils/udf_utils.py +9 -8
  30. snowflake/snowpark_connect/utils/udtf_utils.py +3 -2
  31. snowflake/snowpark_connect/version.py +1 -1
  32. {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/METADATA +2 -2
  33. {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/RECORD +41 -41
  34. snowflake/snowpark_connect/hidden_column.py +0 -39
  35. {snowpark_connect-0.27.0.data → snowpark_connect-0.28.1.data}/scripts/snowpark-connect +0 -0
  36. {snowpark_connect-0.27.0.data → snowpark_connect-0.28.1.data}/scripts/snowpark-session +0 -0
  37. {snowpark_connect-0.27.0.data → snowpark_connect-0.28.1.data}/scripts/snowpark-submit +0 -0
  38. {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/WHEEL +0 -0
  39. {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/licenses/LICENSE-binary +0 -0
  40. {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/licenses/LICENSE.txt +0 -0
  41. {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/licenses/NOTICE-binary +0 -0
  42. {snowpark_connect-0.27.0.dist-info → snowpark_connect-0.28.1.dist-info}/top_level.txt +0 -0
@@ -163,7 +163,6 @@ def map_unresolved_attribute(
163
163
  attr_name = ".".join(name_parts)
164
164
 
165
165
  has_plan_id = exp.unresolved_attribute.HasField("plan_id")
166
- source_qualifiers = None
167
166
 
168
167
  if has_plan_id:
169
168
  plan_id = exp.unresolved_attribute.plan_id
@@ -172,27 +171,13 @@ def map_unresolved_attribute(
172
171
  assert (
173
172
  target_df is not None
174
173
  ), f"resolving an attribute of a unresolved dataframe {plan_id}"
175
-
176
- # Get the qualifiers for this column from the target DataFrame
177
- source_qualifiers = (
178
- target_df_container.column_map.get_qualifier_for_spark_column(
179
- name_parts[-1]
180
- )
181
- )
182
-
183
- if hasattr(column_mapping, "hidden_columns"):
184
- hidden = column_mapping.hidden_columns
185
- else:
186
- hidden = None
187
-
188
174
  column_mapping = target_df_container.column_map
189
- column_mapping.hidden_columns = hidden
190
175
  typer = ExpressionTyper(target_df)
191
176
 
192
- def get_col(snowpark_name, has_hidden=False):
177
+ def get_col(snowpark_name):
193
178
  return (
194
179
  snowpark_fn.col(snowpark_name)
195
- if not has_plan_id or has_hidden
180
+ if not has_plan_id
196
181
  else target_df.col(snowpark_name)
197
182
  )
198
183
 
@@ -277,17 +262,10 @@ def map_unresolved_attribute(
277
262
  quoted_attr_name = name_parts[0]
278
263
 
279
264
  snowpark_name = column_mapping.get_snowpark_column_name_from_spark_column_name(
280
- quoted_attr_name,
281
- allow_non_exists=True,
282
- is_qualified=has_plan_id,
283
- source_qualifiers=source_qualifiers if has_plan_id else None,
265
+ quoted_attr_name, allow_non_exists=True
284
266
  )
285
-
286
267
  if snowpark_name is not None:
287
- is_hidden = column_mapping.is_hidden_column_reference(
288
- quoted_attr_name, source_qualifiers
289
- )
290
- col = get_col(snowpark_name, is_hidden)
268
+ col = get_col(snowpark_name)
291
269
  qualifiers = column_mapping.get_qualifier_for_spark_column(quoted_attr_name)
292
270
  else:
293
271
  # this means it has to be a struct column with a field name
@@ -2619,9 +2619,18 @@ def map_unresolved_function(
2619
2619
  result_type = input_type.element_type
2620
2620
  result_exp = fn(snowpark_args[0])
2621
2621
  case _:
2622
- spark_col_names = ["key", "value"]
2623
- result_exp = fn(snowpark_args[0])
2624
- result_type = [input_type.key_type, input_type.value_type]
2622
+ # Check if the type has map-like attributes before accessing them
2623
+ if hasattr(input_type, "key_type") and hasattr(
2624
+ input_type, "value_type"
2625
+ ):
2626
+ spark_col_names = ["key", "value"]
2627
+ result_exp = fn(snowpark_args[0])
2628
+ result_type = [input_type.key_type, input_type.value_type]
2629
+ else:
2630
+ # Throw proper error for types without key_type/value_type attributes
2631
+ raise AnalysisException(
2632
+ f'[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "{function_name}({snowpark_arg_names[0]})" due to data type mismatch: Parameter 1 requires the ("ARRAY" or "MAP") type, however "{snowpark_arg_names[0]}" has the type "{str(input_type)}".'
2633
+ )
2625
2634
  case "expm1":
2626
2635
  spark_function_name = f"EXPM1({snowpark_arg_names[0]})"
2627
2636
  result_exp = snowpark_fn.exp(*snowpark_args) - 1
@@ -34,7 +34,6 @@ def map_unresolved_star(
34
34
  column_mapping: ColumnNameMap,
35
35
  typer: ExpressionTyper,
36
36
  ) -> tuple[list[str], TypedColumn]:
37
-
38
37
  if exp.unresolved_star.HasField("unparsed_target"):
39
38
  unparsed_target = exp.unresolved_star.unparsed_target
40
39
  name_parts = split_fully_qualified_spark_name(unparsed_target)
@@ -103,7 +102,7 @@ def map_unresolved_star(
103
102
  prefix_candidate_str = f"{prefix_candidate_str}.{name_parts[i]}"
104
103
  prefix_candidate = (
105
104
  column_mapping.get_snowpark_column_name_from_spark_column_name(
106
- prefix_candidate_str, allow_non_exists=True, is_qualified=(i > 0)
105
+ prefix_candidate_str, allow_non_exists=True
107
106
  )
108
107
  )
109
108
  if prefix_candidate is None:
@@ -181,7 +180,7 @@ def map_unresolved_star_struct(
181
180
  prefix_candidate_str = f"{prefix_candidate_str}.{name_parts[i]}"
182
181
  prefix_candidate = (
183
182
  column_mapping.get_snowpark_column_name_from_spark_column_name(
184
- prefix_candidate_str, allow_non_exists=True, is_qualified=(i > 0)
183
+ prefix_candidate_str, allow_non_exists=True
185
184
  )
186
185
  )
187
186
  if prefix_candidate is None:
@@ -13,7 +13,6 @@ from snowflake.core.exceptions import APIError, NotFoundError
13
13
  from snowflake.core.schema import Schema
14
14
  from snowflake.core.table import Table, TableColumn
15
15
 
16
- from snowflake.snowpark import functions
17
16
  from snowflake.snowpark._internal.analyzer.analyzer_utils import (
18
17
  quote_name_without_upper_casing,
19
18
  unquote_if_quoted,
@@ -34,12 +33,19 @@ from snowflake.snowpark_connect.relation.catalogs.abstract_spark_catalog import
34
33
  )
35
34
  from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
36
35
  from snowflake.snowpark_connect.utils.identifiers import (
36
+ FQN,
37
+ spark_to_sf_single_id_with_unquoting,
37
38
  split_fully_qualified_spark_name,
38
39
  )
39
40
  from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
40
41
  from snowflake.snowpark_connect.utils.telemetry import (
41
42
  SnowparkConnectNotImplementedError,
42
43
  )
44
+ from snowflake.snowpark_connect.utils.temporary_view_cache import (
45
+ get_temp_view,
46
+ get_temp_view_normalized_names,
47
+ unregister_temp_view,
48
+ )
43
49
  from snowflake.snowpark_connect.utils.udf_cache import cached_udf
44
50
 
45
51
 
@@ -203,6 +209,93 @@ class SnowflakeCatalog(AbstractSparkCatalog):
203
209
  exists = False
204
210
  return pandas.DataFrame({"exists": [exists]})
205
211
 
212
+ def _get_temp_view_prefixes(self, spark_dbName: str | None) -> list[str]:
213
+ if spark_dbName is None:
214
+ return []
215
+ return [
216
+ quote_name_without_upper_casing(part)
217
+ for part in split_fully_qualified_spark_name(spark_dbName)
218
+ ]
219
+
220
+ def _list_temp_views(
221
+ self,
222
+ spark_dbName: str | None = None,
223
+ pattern: str | None = None,
224
+ ) -> typing.Tuple[
225
+ list[str | None],
226
+ list[list[str | None]],
227
+ list[str],
228
+ list[str | None],
229
+ list[str | None],
230
+ list[bool],
231
+ ]:
232
+ catalogs: list[str | None] = list()
233
+ namespaces: list[list[str | None]] = list()
234
+ names: list[str] = list()
235
+ descriptions: list[str | None] = list()
236
+ table_types: list[str | None] = list()
237
+ is_temporaries: list[bool] = list()
238
+
239
+ temp_views_prefix = ".".join(self._get_temp_view_prefixes(spark_dbName))
240
+ normalized_spark_dbName = (
241
+ temp_views_prefix.lower()
242
+ if global_config.spark_sql_caseSensitive
243
+ else temp_views_prefix
244
+ )
245
+ normalized_global_temp_database_name = (
246
+ quote_name_without_upper_casing(
247
+ global_config.spark_sql_globalTempDatabase.lower()
248
+ )
249
+ if global_config.spark_sql_caseSensitive
250
+ else quote_name_without_upper_casing(
251
+ global_config.spark_sql_globalTempDatabase
252
+ )
253
+ )
254
+
255
+ temp_views = get_temp_view_normalized_names()
256
+ null_safe_pattern = pattern if pattern is not None else ""
257
+
258
+ for temp_view in temp_views:
259
+ normalized_temp_view = (
260
+ temp_view.lower()
261
+ if global_config.spark_sql_caseSensitive
262
+ else temp_view
263
+ )
264
+ fqn = FQN.from_string(temp_view)
265
+ normalized_schema = (
266
+ fqn.schema.lower()
267
+ if fqn.schema is not None and global_config.spark_sql_caseSensitive
268
+ else fqn.schema
269
+ )
270
+
271
+ is_global_view = normalized_global_temp_database_name == normalized_schema
272
+ is_local_temp_view = fqn.schema is None
273
+ # Temporary views are always shown if they match the pattern
274
+ matches_prefix = (
275
+ normalized_spark_dbName == normalized_schema or is_local_temp_view
276
+ )
277
+ if matches_prefix and bool(
278
+ re.match(null_safe_pattern, normalized_temp_view)
279
+ ):
280
+ names.append(unquote_if_quoted(fqn.name))
281
+ catalogs.append(None)
282
+ namespaces.append(
283
+ [global_config.spark_sql_globalTempDatabase]
284
+ if is_global_view
285
+ else []
286
+ )
287
+ descriptions.append(None)
288
+ table_types.append("TEMPORARY")
289
+ is_temporaries.append(True)
290
+ return (
291
+ catalogs,
292
+ namespaces,
293
+ names,
294
+ descriptions,
295
+ table_types,
296
+ is_temporaries,
297
+ )
298
+
206
299
  def listTables(
207
300
  self,
208
301
  spark_dbName: str | None = None,
@@ -232,8 +325,7 @@ class SnowflakeCatalog(AbstractSparkCatalog):
232
325
  schema=sf_quote(sf_schema),
233
326
  pattern=_normalize_identifier(pattern),
234
327
  )
235
- names: list[str] = list()
236
- catalogs: list[str] = list()
328
+ catalogs: list[str | None] = list()
237
329
  namespaces: list[list[str | None]] = list()
238
330
  names: list[str] = list()
239
331
  descriptions: list[str | None] = list()
@@ -253,6 +345,22 @@ class SnowflakeCatalog(AbstractSparkCatalog):
253
345
  descriptions.append(o[6] if o[6] else None)
254
346
  table_types.append("PERMANENT")
255
347
  is_temporaries.append(False)
348
+
349
+ (
350
+ non_materialized_catalogs,
351
+ non_materialized_namespaces,
352
+ non_materialized_names,
353
+ non_materialized_descriptions,
354
+ non_materialized_table_types,
355
+ non_materialized_is_temporaries,
356
+ ) = self._list_temp_views(spark_dbName, pattern)
357
+ catalogs.extend(non_materialized_catalogs)
358
+ namespaces.extend(non_materialized_namespaces)
359
+ names.extend(non_materialized_names)
360
+ descriptions.extend(non_materialized_descriptions)
361
+ table_types.extend(non_materialized_table_types)
362
+ is_temporaries.extend(non_materialized_is_temporaries)
363
+
256
364
  return pandas.DataFrame(
257
365
  {
258
366
  "name": names,
@@ -297,6 +405,36 @@ class SnowflakeCatalog(AbstractSparkCatalog):
297
405
  spark_tableName: str,
298
406
  ) -> pandas.DataFrame:
299
407
  """Listing a single table/view with provided name that's accessible in Snowflake."""
408
+
409
+ def _get_temp_view():
410
+ spark_table_name_parts = [
411
+ quote_name_without_upper_casing(part)
412
+ for part in split_fully_qualified_spark_name(spark_tableName)
413
+ ]
414
+ spark_view_name = ".".join(spark_table_name_parts)
415
+ temp_view = get_temp_view(spark_view_name)
416
+ if temp_view:
417
+ return pandas.DataFrame(
418
+ {
419
+ "name": [unquote_if_quoted(spark_table_name_parts[-1])],
420
+ "catalog": [None],
421
+ "namespace": [
422
+ [unquote_if_quoted(spark_table_name_parts[-2])]
423
+ if len(spark_table_name_parts) > 1
424
+ else []
425
+ ],
426
+ "description": [None],
427
+ "tableType": ["TEMPORARY"],
428
+ "isTemporary": [True],
429
+ }
430
+ )
431
+ return None
432
+
433
+ # Attempt to get the view from the non materialized views first
434
+ temp_view = _get_temp_view()
435
+ if temp_view is not None:
436
+ return temp_view
437
+
300
438
  sp_catalog = get_or_create_snowpark_session().catalog
301
439
  catalog, sf_database, sf_schema, table_name = _process_multi_layer_identifier(
302
440
  spark_tableName
@@ -360,12 +498,64 @@ class SnowflakeCatalog(AbstractSparkCatalog):
360
498
  exists = False
361
499
  return pandas.DataFrame({"exists": [exists]})
362
500
 
501
+ def _list_temp_view_columns(
502
+ self,
503
+ spark_tableName: str,
504
+ spark_dbName: typing.Optional[str] = None,
505
+ ):
506
+ spark_view_name_parts = [
507
+ quote_name_without_upper_casing(part)
508
+ for part in split_fully_qualified_spark_name(spark_tableName)
509
+ ]
510
+ spark_view_name_parts = (
511
+ self._get_temp_view_prefixes(spark_dbName) + spark_view_name_parts
512
+ )
513
+ spark_view_name = ".".join(spark_view_name_parts)
514
+ temp_view = get_temp_view(spark_view_name)
515
+
516
+ if not temp_view:
517
+ return None
518
+
519
+ names: list[str] = list()
520
+ descriptions: list[str | None] = list()
521
+ data_types: list[str] = list()
522
+ nullables: list[bool] = list()
523
+ is_partitions: list[bool] = list()
524
+ is_buckets: list[bool] = list()
525
+
526
+ for field, spark_column in zip(
527
+ temp_view.dataframe.schema.fields,
528
+ temp_view.column_map.get_spark_columns(),
529
+ ):
530
+ names.append(spark_column)
531
+ descriptions.append(None)
532
+ data_types.append(field.datatype.simpleString())
533
+ nullables.append(field.nullable)
534
+ is_partitions.append(False)
535
+ is_buckets.append(False)
536
+
537
+ return pandas.DataFrame(
538
+ {
539
+ "name": names,
540
+ "description": descriptions,
541
+ "dataType": data_types,
542
+ "nullable": nullables,
543
+ "isPartition": is_partitions,
544
+ "isBucket": is_buckets,
545
+ }
546
+ )
547
+
363
548
  def listColumns(
364
549
  self,
365
550
  spark_tableName: str,
366
551
  spark_dbName: typing.Optional[str] = None,
367
552
  ) -> pandas.DataFrame:
368
553
  """List all columns in a table/view, optionally database name filter can be provided."""
554
+
555
+ temp_view_columns = self._list_temp_view_columns(spark_tableName, spark_dbName)
556
+ if temp_view_columns is not None:
557
+ return temp_view_columns
558
+
369
559
  sp_catalog = get_or_create_snowpark_session().catalog
370
560
  columns: list[TableColumn] | None = None
371
561
  if spark_dbName is None:
@@ -455,17 +645,15 @@ class SnowflakeCatalog(AbstractSparkCatalog):
455
645
  spark_view_name: str,
456
646
  ) -> DataFrameContainer:
457
647
  session = get_or_create_snowpark_session()
458
- schema = global_config.spark_sql_globalTempDatabase
459
- result_df = session.sql(
460
- "drop view if exists identifier(?)",
461
- params=[f"{sf_quote(schema)}.{sf_quote(spark_view_name)}"],
462
- )
463
- result_df = result_df.select(
464
- functions.contains('"status"', functions.lit("successfully dropped")).alias(
465
- "value"
648
+ if not spark_view_name == "":
649
+ schema = global_config.spark_sql_globalTempDatabase
650
+ result = unregister_temp_view(
651
+ f"{spark_to_sf_single_id_with_unquoting(schema)}.{spark_to_sf_single_id_with_unquoting(spark_view_name)}"
466
652
  )
467
- )
653
+ else:
654
+ result = False
468
655
  columns = ["value"]
656
+ result_df = session.createDataFrame([result], schema=columns)
469
657
  return DataFrameContainer.create_with_column_mapping(
470
658
  dataframe=result_df,
471
659
  spark_column_names=columns,
@@ -479,15 +667,14 @@ class SnowflakeCatalog(AbstractSparkCatalog):
479
667
  ) -> DataFrameContainer:
480
668
  """Drop the current temporary view."""
481
669
  session = get_or_create_snowpark_session()
482
- result = session.sql(
483
- "drop view if exists identifier(?)",
484
- params=[sf_quote(spark_view_name)],
485
- ).collect()
486
- view_was_dropped = (
487
- len(result) == 1 and "successfully dropped" in result[0]["status"]
488
- )
489
- result_df = session.createDataFrame([(view_was_dropped,)], schema=["value"])
490
670
  columns = ["value"]
671
+ if spark_view_name:
672
+ result = unregister_temp_view(
673
+ spark_to_sf_single_id_with_unquoting(spark_view_name)
674
+ )
675
+ else:
676
+ result = False
677
+ result_df = session.createDataFrame([result], schema=columns)
491
678
  return DataFrameContainer.create_with_column_mapping(
492
679
  dataframe=result_df,
493
680
  spark_column_names=columns,
@@ -23,7 +23,6 @@ from snowflake.snowpark_connect.relation.map_relation import map_relation
23
23
  from snowflake.snowpark_connect.typed_column import TypedColumn
24
24
  from snowflake.snowpark_connect.utils.context import (
25
25
  get_sql_aggregate_function_count,
26
- not_resolving_fun_args,
27
26
  push_outer_dataframe,
28
27
  set_current_grouping_columns,
29
28
  )
@@ -336,15 +335,14 @@ def map_aggregate(
336
335
  typer = ExpressionTyper(input_df)
337
336
 
338
337
  def _map_column(exp: expression_proto.Expression) -> tuple[str, TypedColumn]:
339
- with not_resolving_fun_args():
340
- new_names, snowpark_column = map_expression(
341
- exp, input_container.column_map, typer
338
+ new_names, snowpark_column = map_expression(
339
+ exp, input_container.column_map, typer
340
+ )
341
+ if len(new_names) != 1:
342
+ raise SnowparkConnectNotImplementedError(
343
+ "Multi-column aggregate expressions are not supported"
342
344
  )
343
- if len(new_names) != 1:
344
- raise SnowparkConnectNotImplementedError(
345
- "Multi-column aggregate expressions are not supported"
346
- )
347
- return new_names[0], snowpark_column
345
+ return new_names[0], snowpark_column
348
346
 
349
347
  raw_groupings: list[tuple[str, TypedColumn]] = []
350
348
  raw_aggregations: list[tuple[str, TypedColumn]] = []
@@ -431,12 +429,18 @@ def map_aggregate(
431
429
  if groupings:
432
430
  # Normal GROUP BY with explicit grouping columns
433
431
  result = input_df.group_by(groupings)
434
- else:
432
+ elif not is_group_by_all:
435
433
  # No explicit GROUP BY - this is an aggregate over the entire table
436
434
  # Use a dummy constant that will be excluded from the final result
437
435
  result = input_df.with_column(
438
436
  "__dummy_group__", snowpark_fn.lit(1)
439
437
  ).group_by("__dummy_group__")
438
+ else:
439
+ # GROUP BY ALL with only one aggregate column
440
+ # Snowpark doesn't support GROUP BY ALL
441
+ # TODO: Change in future with Snowpark Supported arguments or API for GROUP BY ALL
442
+ result = input_df.group_by()
443
+
440
444
  case snowflake_proto.Aggregate.GROUP_TYPE_ROLLUP:
441
445
  result = input_df.rollup(groupings)
442
446
  case snowflake_proto.Aggregate.GROUP_TYPE_CUBE: