snowpark-connect 1.6.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. snowflake/snowpark_connect/client/server.py +37 -0
  2. snowflake/snowpark_connect/config.py +72 -3
  3. snowflake/snowpark_connect/expression/error_utils.py +28 -0
  4. snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
  5. snowflake/snowpark_connect/expression/map_cast.py +108 -17
  6. snowflake/snowpark_connect/expression/map_udf.py +1 -0
  7. snowflake/snowpark_connect/expression/map_unresolved_function.py +229 -96
  8. snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
  9. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.2.0.jar +0 -0
  10. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
  11. snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
  12. snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
  13. snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.13-3.5.6.jar +0 -0
  14. snowflake/snowpark_connect/includes/jars/spark-sql_2.13-3.5.6.jar +0 -0
  15. snowflake/snowpark_connect/relation/map_aggregate.py +43 -1
  16. snowflake/snowpark_connect/relation/read/map_read_csv.py +73 -4
  17. snowflake/snowpark_connect/relation/read/map_read_jdbc.py +4 -1
  18. snowflake/snowpark_connect/relation/read/map_read_json.py +4 -1
  19. snowflake/snowpark_connect/relation/read/map_read_parquet.py +4 -1
  20. snowflake/snowpark_connect/relation/read/map_read_socket.py +4 -0
  21. snowflake/snowpark_connect/relation/read/map_read_table.py +4 -1
  22. snowflake/snowpark_connect/relation/read/map_read_text.py +4 -1
  23. snowflake/snowpark_connect/relation/read/reader_config.py +6 -0
  24. snowflake/snowpark_connect/resources_initializer.py +90 -29
  25. snowflake/snowpark_connect/server.py +6 -41
  26. snowflake/snowpark_connect/server_common/__init__.py +4 -1
  27. snowflake/snowpark_connect/type_support.py +130 -0
  28. snowflake/snowpark_connect/utils/context.py +8 -0
  29. snowflake/snowpark_connect/utils/java_stored_procedure.py +53 -27
  30. snowflake/snowpark_connect/utils/java_udaf_utils.py +46 -28
  31. snowflake/snowpark_connect/utils/java_udtf_utils.py +1 -1
  32. snowflake/snowpark_connect/utils/jvm_udf_utils.py +48 -15
  33. snowflake/snowpark_connect/utils/scala_udf_utils.py +98 -22
  34. snowflake/snowpark_connect/utils/telemetry.py +33 -22
  35. snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
  36. snowflake/snowpark_connect/version.py +1 -1
  37. {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +12 -2
  38. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +4 -2
  39. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +46 -37
  40. {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
  41. {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
  42. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
  43. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
  44. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  45. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
  46. {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
6
6
  import pyspark.sql.connect.proto.types_pb2 as types_proto
7
7
  from pyspark.errors.exceptions.base import (
8
8
  AnalysisException,
9
+ ArithmeticException,
9
10
  IllegalArgumentException,
10
11
  NumberFormatException,
11
12
  SparkRuntimeException,
@@ -18,7 +19,9 @@ from snowflake.snowpark.types import (
18
19
  BooleanType,
19
20
  DataType,
20
21
  DateType,
22
+ DecimalType,
21
23
  DoubleType,
24
+ FloatType,
22
25
  IntegerType,
23
26
  LongType,
24
27
  MapType,
@@ -36,6 +39,13 @@ from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
36
39
  from snowflake.snowpark_connect.config import global_config
37
40
  from snowflake.snowpark_connect.error.error_codes import ErrorCodes
38
41
  from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
42
+ from snowflake.snowpark_connect.expression.error_utils import raise_error_helper
43
+ from snowflake.snowpark_connect.expression.integral_types_support import (
44
+ apply_fractional_to_integral_cast,
45
+ apply_fractional_to_integral_cast_with_ansi_check,
46
+ apply_integral_overflow_with_ansi_check,
47
+ get_integral_type_bounds,
48
+ )
39
49
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
40
50
  from snowflake.snowpark_connect.type_mapping import (
41
51
  map_type_string_to_snowpark_type,
@@ -54,7 +64,7 @@ SYMBOL_FUNCTIONS = {"<", ">", "<=", ">=", "!=", "+", "-", "*", "/", "%", "div"}
54
64
  CAST_FUNCTIONS = {
55
65
  "boolean": types_proto.DataType(boolean=types_proto.DataType.Boolean()),
56
66
  "int": types_proto.DataType(integer=types_proto.DataType.Integer()),
57
- "smallint": types_proto.DataType(integer=types_proto.DataType.Integer()),
67
+ "smallint": types_proto.DataType(short=types_proto.DataType.Short()),
58
68
  "bigint": types_proto.DataType(long=types_proto.DataType.Long()),
59
69
  "tinyint": types_proto.DataType(byte=types_proto.DataType.Byte()),
60
70
  "float": types_proto.DataType(float=types_proto.DataType.Float()),
@@ -237,6 +247,11 @@ def map_cast(
237
247
  case (_, BooleanType()) if isinstance(from_type, _NumericType):
238
248
  result_exp = col.cast(LongType()).cast(to_type)
239
249
 
250
+ case (_IntegralType(), _IntegralType()):
251
+ result_exp = apply_integral_overflow_with_ansi_check(
252
+ col, to_type, spark_sql_ansi_enabled
253
+ )
254
+
240
255
  # binary
241
256
  case (StringType(), BinaryType()):
242
257
  result_exp = snowpark_fn.to_binary(col, "UTF-8")
@@ -272,6 +287,44 @@ def map_cast(
272
287
  result_exp = snowpark_fn.to_varchar(col, "UTF-8")
273
288
 
274
289
  # numeric
290
+ case (_, _) if isinstance(from_type, (FloatType, DoubleType)) and isinstance(
291
+ to_type, _IntegralType
292
+ ):
293
+ truncated = (
294
+ snowpark_fn.when(
295
+ col == snowpark_fn.lit(float("nan")), snowpark_fn.lit(0)
296
+ )
297
+ .when(col < 0, snowpark_fn.ceil(col))
298
+ .otherwise(snowpark_fn.floor(col))
299
+ )
300
+
301
+ if spark_sql_ansi_enabled:
302
+ result_exp = apply_fractional_to_integral_cast_with_ansi_check(
303
+ truncated, to_type, True
304
+ )
305
+ else:
306
+ target_min, target_max = get_integral_type_bounds(to_type)
307
+ result_exp = (
308
+ snowpark_fn.when(
309
+ truncated > snowpark_fn.lit(target_max),
310
+ snowpark_fn.lit(target_max),
311
+ )
312
+ .when(
313
+ truncated < snowpark_fn.lit(target_min),
314
+ snowpark_fn.lit(target_min),
315
+ )
316
+ .otherwise(truncated.cast(to_type))
317
+ )
318
+ case (_, _) if isinstance(from_type, DecimalType) and isinstance(
319
+ to_type, _IntegralType
320
+ ):
321
+ result_exp = snowpark_fn.when(col < 0, snowpark_fn.ceil(col)).otherwise(
322
+ snowpark_fn.floor(col)
323
+ )
324
+ result_exp = result_exp.cast(to_type)
325
+ result_exp = apply_integral_overflow_with_ansi_check(
326
+ result_exp, to_type, spark_sql_ansi_enabled
327
+ )
275
328
  case (_, _) if isinstance(from_type, _FractionalType) and isinstance(
276
329
  to_type, _IntegralType
277
330
  ):
@@ -282,16 +335,49 @@ def map_cast(
282
335
  .when(col < 0, snowpark_fn.ceil(col))
283
336
  .otherwise(snowpark_fn.floor(col))
284
337
  )
285
- result_exp = result_exp.cast(to_type)
338
+ result_exp = apply_fractional_to_integral_cast(result_exp, to_type)
286
339
  case (StringType(), _) if (isinstance(to_type, _IntegralType)):
287
340
  if spark_sql_ansi_enabled:
288
- result_exp = snowpark_fn.cast(col, DoubleType())
341
+ double_val = snowpark_fn.cast(col, DoubleType())
342
+
343
+ target_min, target_max = get_integral_type_bounds(to_type)
344
+ raise_error = raise_error_helper(to_type, NumberFormatException)
345
+ to_type_name = to_type.__class__.__name__.upper().replace("TYPE", "")
346
+
347
+ truncated = snowpark_fn.when(
348
+ double_val < 0, snowpark_fn.ceil(double_val)
349
+ ).otherwise(snowpark_fn.floor(double_val))
350
+
351
+ result_exp = snowpark_fn.when(
352
+ (truncated < snowpark_fn.lit(target_min))
353
+ | (truncated > snowpark_fn.lit(target_max)),
354
+ raise_error(
355
+ snowpark_fn.lit("[CAST_INVALID_INPUT] The value '"),
356
+ col,
357
+ snowpark_fn.lit(
358
+ f'\' of the type "STRING" cannot be cast to "{to_type_name}" because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.'
359
+ ),
360
+ ),
361
+ ).otherwise(truncated.cast(to_type))
289
362
  else:
290
- result_exp = snowpark_fn.try_cast(col, DoubleType())
291
- result_exp = snowpark_fn.when(
292
- result_exp < 0, snowpark_fn.ceil(result_exp)
293
- ).otherwise(snowpark_fn.floor(result_exp))
294
- result_exp = result_exp.cast(to_type)
363
+ double_val = snowpark_fn.try_cast(col, DoubleType())
364
+
365
+ truncated = snowpark_fn.when(
366
+ double_val < 0, snowpark_fn.ceil(double_val)
367
+ ).otherwise(snowpark_fn.floor(double_val))
368
+
369
+ target_min, target_max = get_integral_type_bounds(to_type)
370
+ result_exp = (
371
+ snowpark_fn.when(
372
+ double_val.isNull(), snowpark_fn.lit(None).cast(to_type)
373
+ )
374
+ .when(
375
+ (truncated < snowpark_fn.lit(target_min))
376
+ | (truncated > snowpark_fn.lit(target_max)),
377
+ snowpark_fn.lit(None).cast(to_type),
378
+ )
379
+ .otherwise(truncated.cast(to_type))
380
+ )
295
381
  # https://docs.snowflake.com/en/sql-reference/functions/try_cast Only works on certain types (mostly non-structured ones)
296
382
  case (StringType(), _) if isinstance(to_type, _NumericType) or isinstance(
297
383
  to_type, StringType
@@ -368,10 +454,19 @@ def sanity_check(
368
454
  except Exception:
369
455
  raise_cast_failure_exception = True
370
456
  if raise_cast_failure_exception:
371
- exception = NumberFormatException(
372
- """[CAST_INVALID_INPUT] Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary setting "spark.sql.ansi.enabled" to "false" may bypass this error."""
373
- )
374
- attach_custom_error_code(exception, ErrorCodes.INVALID_CAST)
457
+ if not isinstance(from_type, StringType) and isinstance(to_type, _IntegralType):
458
+ from_type_name = from_type.__class__.__name__.upper().replace("TYPE", "")
459
+ to_type_name = to_type.__class__.__name__.upper().replace("TYPE", "")
460
+ value_suffix = "L" if isinstance(from_type, LongType) else ""
461
+ exception = ArithmeticException(
462
+ f"""[CAST_OVERFLOW] The value {value}{value_suffix} of the type "{from_type_name}" cannot be cast to "{to_type_name}" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error."""
463
+ )
464
+ attach_custom_error_code(exception, ErrorCodes.INVALID_CAST)
465
+ else:
466
+ exception = NumberFormatException(
467
+ """[CAST_INVALID_INPUT] Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary setting "spark.sql.ansi.enabled" to "false" may bypass this error."""
468
+ )
469
+ attach_custom_error_code(exception, ErrorCodes.INVALID_CAST)
375
470
  raise exception
376
471
 
377
472
 
@@ -386,15 +481,11 @@ def _cast_string_to_year_month_interval(col: Column, to_type: YearMonthIntervalT
386
481
  5. 'INTERVAL [+|-]'[+|-]y' YEAR' format - extract the y part
387
482
  6. 'INTERVAL [+|-]'[+|-]m' MONTH' format - extract the m part
388
483
  """
389
- from snowflake.snowpark_connect.expression.map_unresolved_function import (
390
- _raise_error_helper,
391
- )
392
-
393
484
  # Extract values from different formats
394
485
  value = snowpark_fn.regexp_extract(col, "'([^']+)'", 1)
395
486
  years = snowpark_fn.regexp_extract(col, "^[+-]?\\d+", 0)
396
487
  months = snowpark_fn.regexp_extract(col, "-(\\d+)$", 1)
397
- raise_error = _raise_error_helper(to_type, IllegalArgumentException)
488
+ raise_error = raise_error_helper(to_type, IllegalArgumentException)
398
489
 
399
490
  # For MONTH-only intervals, treat the input as months
400
491
  if (
@@ -244,6 +244,7 @@ def map_common_inline_user_defined_udf(
244
244
  # All Scala UDFs return Variant, so we always need to cast back to the original type
245
245
  result_expr = snowpark_fn.cast(udf_call_expr, original_return_type)
246
246
  result_type = original_return_type
247
+
247
248
  elif isinstance(original_return_type, (MapType, StructType)) and isinstance(
248
249
  processed_return_type, VariantType
249
250
  ):