snowpark-connect 1.6.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/client/server.py +37 -0
- snowflake/snowpark_connect/config.py +72 -3
- snowflake/snowpark_connect/expression/error_utils.py +28 -0
- snowflake/snowpark_connect/expression/integral_types_support.py +219 -0
- snowflake/snowpark_connect/expression/map_cast.py +108 -17
- snowflake/snowpark_connect/expression/map_udf.py +1 -0
- snowflake/snowpark_connect/expression/map_unresolved_function.py +229 -96
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.13-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.13-0.2.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.13.16.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.13-3.5.6.jar +0 -0
- snowflake/snowpark_connect/relation/map_aggregate.py +43 -1
- snowflake/snowpark_connect/relation/read/map_read_csv.py +73 -4
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_json.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_socket.py +4 -0
- snowflake/snowpark_connect/relation/read/map_read_table.py +4 -1
- snowflake/snowpark_connect/relation/read/map_read_text.py +4 -1
- snowflake/snowpark_connect/relation/read/reader_config.py +6 -0
- snowflake/snowpark_connect/resources_initializer.py +90 -29
- snowflake/snowpark_connect/server.py +6 -41
- snowflake/snowpark_connect/server_common/__init__.py +4 -1
- snowflake/snowpark_connect/type_support.py +130 -0
- snowflake/snowpark_connect/utils/context.py +8 -0
- snowflake/snowpark_connect/utils/java_stored_procedure.py +53 -27
- snowflake/snowpark_connect/utils/java_udaf_utils.py +46 -28
- snowflake/snowpark_connect/utils/java_udtf_utils.py +1 -1
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +48 -15
- snowflake/snowpark_connect/utils/scala_udf_utils.py +98 -22
- snowflake/snowpark_connect/utils/telemetry.py +33 -22
- snowflake/snowpark_connect/utils/udxf_import_utils.py +9 -2
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-submit +12 -2
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/METADATA +4 -2
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/RECORD +46 -37
- {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-1.6.0.data → snowpark_connect-1.7.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-1.6.0.dist-info → snowpark_connect-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -6,6 +6,7 @@ import pyspark.sql.connect.proto.expressions_pb2 as expressions_proto
|
|
|
6
6
|
import pyspark.sql.connect.proto.types_pb2 as types_proto
|
|
7
7
|
from pyspark.errors.exceptions.base import (
|
|
8
8
|
AnalysisException,
|
|
9
|
+
ArithmeticException,
|
|
9
10
|
IllegalArgumentException,
|
|
10
11
|
NumberFormatException,
|
|
11
12
|
SparkRuntimeException,
|
|
@@ -18,7 +19,9 @@ from snowflake.snowpark.types import (
|
|
|
18
19
|
BooleanType,
|
|
19
20
|
DataType,
|
|
20
21
|
DateType,
|
|
22
|
+
DecimalType,
|
|
21
23
|
DoubleType,
|
|
24
|
+
FloatType,
|
|
22
25
|
IntegerType,
|
|
23
26
|
LongType,
|
|
24
27
|
MapType,
|
|
@@ -36,6 +39,13 @@ from snowflake.snowpark_connect.column_name_handler import ColumnNameMap
|
|
|
36
39
|
from snowflake.snowpark_connect.config import global_config
|
|
37
40
|
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
38
41
|
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
42
|
+
from snowflake.snowpark_connect.expression.error_utils import raise_error_helper
|
|
43
|
+
from snowflake.snowpark_connect.expression.integral_types_support import (
|
|
44
|
+
apply_fractional_to_integral_cast,
|
|
45
|
+
apply_fractional_to_integral_cast_with_ansi_check,
|
|
46
|
+
apply_integral_overflow_with_ansi_check,
|
|
47
|
+
get_integral_type_bounds,
|
|
48
|
+
)
|
|
39
49
|
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
40
50
|
from snowflake.snowpark_connect.type_mapping import (
|
|
41
51
|
map_type_string_to_snowpark_type,
|
|
@@ -54,7 +64,7 @@ SYMBOL_FUNCTIONS = {"<", ">", "<=", ">=", "!=", "+", "-", "*", "/", "%", "div"}
|
|
|
54
64
|
CAST_FUNCTIONS = {
|
|
55
65
|
"boolean": types_proto.DataType(boolean=types_proto.DataType.Boolean()),
|
|
56
66
|
"int": types_proto.DataType(integer=types_proto.DataType.Integer()),
|
|
57
|
-
"smallint": types_proto.DataType(
|
|
67
|
+
"smallint": types_proto.DataType(short=types_proto.DataType.Short()),
|
|
58
68
|
"bigint": types_proto.DataType(long=types_proto.DataType.Long()),
|
|
59
69
|
"tinyint": types_proto.DataType(byte=types_proto.DataType.Byte()),
|
|
60
70
|
"float": types_proto.DataType(float=types_proto.DataType.Float()),
|
|
@@ -237,6 +247,11 @@ def map_cast(
|
|
|
237
247
|
case (_, BooleanType()) if isinstance(from_type, _NumericType):
|
|
238
248
|
result_exp = col.cast(LongType()).cast(to_type)
|
|
239
249
|
|
|
250
|
+
case (_IntegralType(), _IntegralType()):
|
|
251
|
+
result_exp = apply_integral_overflow_with_ansi_check(
|
|
252
|
+
col, to_type, spark_sql_ansi_enabled
|
|
253
|
+
)
|
|
254
|
+
|
|
240
255
|
# binary
|
|
241
256
|
case (StringType(), BinaryType()):
|
|
242
257
|
result_exp = snowpark_fn.to_binary(col, "UTF-8")
|
|
@@ -272,6 +287,44 @@ def map_cast(
|
|
|
272
287
|
result_exp = snowpark_fn.to_varchar(col, "UTF-8")
|
|
273
288
|
|
|
274
289
|
# numeric
|
|
290
|
+
case (_, _) if isinstance(from_type, (FloatType, DoubleType)) and isinstance(
|
|
291
|
+
to_type, _IntegralType
|
|
292
|
+
):
|
|
293
|
+
truncated = (
|
|
294
|
+
snowpark_fn.when(
|
|
295
|
+
col == snowpark_fn.lit(float("nan")), snowpark_fn.lit(0)
|
|
296
|
+
)
|
|
297
|
+
.when(col < 0, snowpark_fn.ceil(col))
|
|
298
|
+
.otherwise(snowpark_fn.floor(col))
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if spark_sql_ansi_enabled:
|
|
302
|
+
result_exp = apply_fractional_to_integral_cast_with_ansi_check(
|
|
303
|
+
truncated, to_type, True
|
|
304
|
+
)
|
|
305
|
+
else:
|
|
306
|
+
target_min, target_max = get_integral_type_bounds(to_type)
|
|
307
|
+
result_exp = (
|
|
308
|
+
snowpark_fn.when(
|
|
309
|
+
truncated > snowpark_fn.lit(target_max),
|
|
310
|
+
snowpark_fn.lit(target_max),
|
|
311
|
+
)
|
|
312
|
+
.when(
|
|
313
|
+
truncated < snowpark_fn.lit(target_min),
|
|
314
|
+
snowpark_fn.lit(target_min),
|
|
315
|
+
)
|
|
316
|
+
.otherwise(truncated.cast(to_type))
|
|
317
|
+
)
|
|
318
|
+
case (_, _) if isinstance(from_type, DecimalType) and isinstance(
|
|
319
|
+
to_type, _IntegralType
|
|
320
|
+
):
|
|
321
|
+
result_exp = snowpark_fn.when(col < 0, snowpark_fn.ceil(col)).otherwise(
|
|
322
|
+
snowpark_fn.floor(col)
|
|
323
|
+
)
|
|
324
|
+
result_exp = result_exp.cast(to_type)
|
|
325
|
+
result_exp = apply_integral_overflow_with_ansi_check(
|
|
326
|
+
result_exp, to_type, spark_sql_ansi_enabled
|
|
327
|
+
)
|
|
275
328
|
case (_, _) if isinstance(from_type, _FractionalType) and isinstance(
|
|
276
329
|
to_type, _IntegralType
|
|
277
330
|
):
|
|
@@ -282,16 +335,49 @@ def map_cast(
|
|
|
282
335
|
.when(col < 0, snowpark_fn.ceil(col))
|
|
283
336
|
.otherwise(snowpark_fn.floor(col))
|
|
284
337
|
)
|
|
285
|
-
result_exp = result_exp
|
|
338
|
+
result_exp = apply_fractional_to_integral_cast(result_exp, to_type)
|
|
286
339
|
case (StringType(), _) if (isinstance(to_type, _IntegralType)):
|
|
287
340
|
if spark_sql_ansi_enabled:
|
|
288
|
-
|
|
341
|
+
double_val = snowpark_fn.cast(col, DoubleType())
|
|
342
|
+
|
|
343
|
+
target_min, target_max = get_integral_type_bounds(to_type)
|
|
344
|
+
raise_error = raise_error_helper(to_type, NumberFormatException)
|
|
345
|
+
to_type_name = to_type.__class__.__name__.upper().replace("TYPE", "")
|
|
346
|
+
|
|
347
|
+
truncated = snowpark_fn.when(
|
|
348
|
+
double_val < 0, snowpark_fn.ceil(double_val)
|
|
349
|
+
).otherwise(snowpark_fn.floor(double_val))
|
|
350
|
+
|
|
351
|
+
result_exp = snowpark_fn.when(
|
|
352
|
+
(truncated < snowpark_fn.lit(target_min))
|
|
353
|
+
| (truncated > snowpark_fn.lit(target_max)),
|
|
354
|
+
raise_error(
|
|
355
|
+
snowpark_fn.lit("[CAST_INVALID_INPUT] The value '"),
|
|
356
|
+
col,
|
|
357
|
+
snowpark_fn.lit(
|
|
358
|
+
f'\' of the type "STRING" cannot be cast to "{to_type_name}" because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.'
|
|
359
|
+
),
|
|
360
|
+
),
|
|
361
|
+
).otherwise(truncated.cast(to_type))
|
|
289
362
|
else:
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
363
|
+
double_val = snowpark_fn.try_cast(col, DoubleType())
|
|
364
|
+
|
|
365
|
+
truncated = snowpark_fn.when(
|
|
366
|
+
double_val < 0, snowpark_fn.ceil(double_val)
|
|
367
|
+
).otherwise(snowpark_fn.floor(double_val))
|
|
368
|
+
|
|
369
|
+
target_min, target_max = get_integral_type_bounds(to_type)
|
|
370
|
+
result_exp = (
|
|
371
|
+
snowpark_fn.when(
|
|
372
|
+
double_val.isNull(), snowpark_fn.lit(None).cast(to_type)
|
|
373
|
+
)
|
|
374
|
+
.when(
|
|
375
|
+
(truncated < snowpark_fn.lit(target_min))
|
|
376
|
+
| (truncated > snowpark_fn.lit(target_max)),
|
|
377
|
+
snowpark_fn.lit(None).cast(to_type),
|
|
378
|
+
)
|
|
379
|
+
.otherwise(truncated.cast(to_type))
|
|
380
|
+
)
|
|
295
381
|
# https://docs.snowflake.com/en/sql-reference/functions/try_cast Only works on certain types (mostly non-structured ones)
|
|
296
382
|
case (StringType(), _) if isinstance(to_type, _NumericType) or isinstance(
|
|
297
383
|
to_type, StringType
|
|
@@ -368,10 +454,19 @@ def sanity_check(
|
|
|
368
454
|
except Exception:
|
|
369
455
|
raise_cast_failure_exception = True
|
|
370
456
|
if raise_cast_failure_exception:
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
457
|
+
if not isinstance(from_type, StringType) and isinstance(to_type, _IntegralType):
|
|
458
|
+
from_type_name = from_type.__class__.__name__.upper().replace("TYPE", "")
|
|
459
|
+
to_type_name = to_type.__class__.__name__.upper().replace("TYPE", "")
|
|
460
|
+
value_suffix = "L" if isinstance(from_type, LongType) else ""
|
|
461
|
+
exception = ArithmeticException(
|
|
462
|
+
f"""[CAST_OVERFLOW] The value {value}{value_suffix} of the type "{from_type_name}" cannot be cast to "{to_type_name}" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error."""
|
|
463
|
+
)
|
|
464
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_CAST)
|
|
465
|
+
else:
|
|
466
|
+
exception = NumberFormatException(
|
|
467
|
+
"""[CAST_INVALID_INPUT] Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary setting "spark.sql.ansi.enabled" to "false" may bypass this error."""
|
|
468
|
+
)
|
|
469
|
+
attach_custom_error_code(exception, ErrorCodes.INVALID_CAST)
|
|
375
470
|
raise exception
|
|
376
471
|
|
|
377
472
|
|
|
@@ -386,15 +481,11 @@ def _cast_string_to_year_month_interval(col: Column, to_type: YearMonthIntervalT
|
|
|
386
481
|
5. 'INTERVAL [+|-]'[+|-]y' YEAR' format - extract the y part
|
|
387
482
|
6. 'INTERVAL [+|-]'[+|-]m' MONTH' format - extract the m part
|
|
388
483
|
"""
|
|
389
|
-
from snowflake.snowpark_connect.expression.map_unresolved_function import (
|
|
390
|
-
_raise_error_helper,
|
|
391
|
-
)
|
|
392
|
-
|
|
393
484
|
# Extract values from different formats
|
|
394
485
|
value = snowpark_fn.regexp_extract(col, "'([^']+)'", 1)
|
|
395
486
|
years = snowpark_fn.regexp_extract(col, "^[+-]?\\d+", 0)
|
|
396
487
|
months = snowpark_fn.regexp_extract(col, "-(\\d+)$", 1)
|
|
397
|
-
raise_error =
|
|
488
|
+
raise_error = raise_error_helper(to_type, IllegalArgumentException)
|
|
398
489
|
|
|
399
490
|
# For MONTH-only intervals, treat the input as months
|
|
400
491
|
if (
|
|
@@ -244,6 +244,7 @@ def map_common_inline_user_defined_udf(
|
|
|
244
244
|
# All Scala UDFs return Variant, so we always need to cast back to the original type
|
|
245
245
|
result_expr = snowpark_fn.cast(udf_call_expr, original_return_type)
|
|
246
246
|
result_type = original_return_type
|
|
247
|
+
|
|
247
248
|
elif isinstance(original_return_type, (MapType, StructType)) and isinstance(
|
|
248
249
|
processed_return_type, VariantType
|
|
249
250
|
):
|