snowpark-connect 0.27.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_connect/__init__.py +1 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +8 -4
- snowflake/snowpark_connect/client/__init__.py +15 -0
- snowflake/snowpark_connect/client/error_utils.py +30 -0
- snowflake/snowpark_connect/client/exceptions.py +36 -0
- snowflake/snowpark_connect/client/query_results.py +90 -0
- snowflake/snowpark_connect/client/server.py +680 -0
- snowflake/snowpark_connect/client/utils/__init__.py +10 -0
- snowflake/snowpark_connect/client/utils/session.py +85 -0
- snowflake/snowpark_connect/column_name_handler.py +404 -243
- snowflake/snowpark_connect/column_qualifier.py +43 -0
- snowflake/snowpark_connect/config.py +237 -23
- snowflake/snowpark_connect/constants.py +2 -0
- snowflake/snowpark_connect/dataframe_container.py +102 -8
- snowflake/snowpark_connect/date_time_format_mapping.py +71 -13
- snowflake/snowpark_connect/error/error_codes.py +50 -0
- snowflake/snowpark_connect/error/error_utils.py +172 -23
- snowflake/snowpark_connect/error/exceptions.py +13 -4
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +15 -160
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +26 -20
- snowflake/snowpark_connect/execute_plan/utils.py +5 -1
- snowflake/snowpark_connect/expression/function_defaults.py +9 -2
- snowflake/snowpark_connect/expression/hybrid_column_map.py +53 -5
- snowflake/snowpark_connect/expression/literal.py +37 -13
- snowflake/snowpark_connect/expression/map_cast.py +123 -5
- snowflake/snowpark_connect/expression/map_expression.py +80 -27
- snowflake/snowpark_connect/expression/map_extension.py +322 -12
- snowflake/snowpark_connect/expression/map_sql_expression.py +316 -81
- snowflake/snowpark_connect/expression/map_udf.py +85 -20
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +451 -173
- snowflake/snowpark_connect/expression/map_unresolved_function.py +2748 -746
- snowflake/snowpark_connect/expression/map_unresolved_star.py +87 -23
- snowflake/snowpark_connect/expression/map_update_fields.py +70 -18
- snowflake/snowpark_connect/expression/map_window_function.py +18 -3
- snowflake/snowpark_connect/includes/jars/{scala-library-2.12.18.jar → sas-scala-udf_2.12-0.2.0.jar} +0 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +1 -1
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +1 -1
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +12 -10
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +14 -2
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +10 -8
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +13 -6
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +65 -17
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +297 -49
- snowflake/snowpark_connect/relation/catalogs/utils.py +12 -4
- snowflake/snowpark_connect/relation/io_utils.py +110 -10
- snowflake/snowpark_connect/relation/map_aggregate.py +196 -255
- snowflake/snowpark_connect/relation/map_catalog.py +5 -1
- snowflake/snowpark_connect/relation/map_column_ops.py +264 -96
- snowflake/snowpark_connect/relation/map_extension.py +263 -29
- snowflake/snowpark_connect/relation/map_join.py +683 -442
- snowflake/snowpark_connect/relation/map_local_relation.py +28 -1
- snowflake/snowpark_connect/relation/map_map_partitions.py +83 -8
- snowflake/snowpark_connect/relation/map_relation.py +48 -19
- snowflake/snowpark_connect/relation/map_row_ops.py +310 -91
- snowflake/snowpark_connect/relation/map_show_string.py +13 -6
- snowflake/snowpark_connect/relation/map_sql.py +1233 -222
- snowflake/snowpark_connect/relation/map_stats.py +48 -9
- snowflake/snowpark_connect/relation/map_subquery_alias.py +11 -2
- snowflake/snowpark_connect/relation/map_udtf.py +14 -4
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +53 -14
- snowflake/snowpark_connect/relation/read/map_read.py +134 -43
- snowflake/snowpark_connect/relation/read/map_read_csv.py +255 -45
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +17 -5
- snowflake/snowpark_connect/relation/read/map_read_json.py +320 -85
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +142 -27
- snowflake/snowpark_connect/relation/read/map_read_partitioned_parquet.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +11 -3
- snowflake/snowpark_connect/relation/read/map_read_table.py +82 -5
- snowflake/snowpark_connect/relation/read/map_read_text.py +18 -3
- snowflake/snowpark_connect/relation/read/metadata_utils.py +170 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +36 -3
- snowflake/snowpark_connect/relation/read/utils.py +50 -5
- snowflake/snowpark_connect/relation/stage_locator.py +91 -55
- snowflake/snowpark_connect/relation/utils.py +128 -5
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +19 -3
- snowflake/snowpark_connect/relation/write/map_write.py +929 -319
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +8 -2
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +110 -48
- snowflake/snowpark_connect/server.py +546 -456
- snowflake/snowpark_connect/server_common/__init__.py +500 -0
- snowflake/snowpark_connect/snowflake_session.py +65 -0
- snowflake/snowpark_connect/start_server.py +53 -5
- snowflake/snowpark_connect/type_mapping.py +349 -27
- snowflake/snowpark_connect/typed_column.py +9 -7
- snowflake/snowpark_connect/utils/artifacts.py +9 -8
- snowflake/snowpark_connect/utils/cache.py +49 -27
- snowflake/snowpark_connect/utils/concurrent.py +36 -1
- snowflake/snowpark_connect/utils/context.py +187 -37
- snowflake/snowpark_connect/utils/describe_query_cache.py +68 -53
- snowflake/snowpark_connect/utils/env_utils.py +5 -1
- snowflake/snowpark_connect/utils/expression_transformer.py +172 -0
- snowflake/snowpark_connect/utils/identifiers.py +137 -3
- snowflake/snowpark_connect/utils/io_utils.py +57 -1
- snowflake/snowpark_connect/utils/java_stored_procedure.py +125 -0
- snowflake/snowpark_connect/utils/java_udaf_utils.py +303 -0
- snowflake/snowpark_connect/utils/java_udtf_utils.py +239 -0
- snowflake/snowpark_connect/utils/jvm_udf_utils.py +248 -0
- snowflake/snowpark_connect/utils/open_telemetry.py +516 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +8 -4
- snowflake/snowpark_connect/utils/patch_spark_line_number.py +181 -0
- snowflake/snowpark_connect/utils/profiling.py +25 -8
- snowflake/snowpark_connect/utils/scala_udf_utils.py +101 -332
- snowflake/snowpark_connect/utils/sequence.py +21 -0
- snowflake/snowpark_connect/utils/session.py +64 -28
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +51 -9
- snowflake/snowpark_connect/utils/spcs_logger.py +290 -0
- snowflake/snowpark_connect/utils/telemetry.py +163 -22
- snowflake/snowpark_connect/utils/temporary_view_cache.py +67 -0
- snowflake/snowpark_connect/utils/temporary_view_helper.py +334 -0
- snowflake/snowpark_connect/utils/udf_cache.py +117 -41
- snowflake/snowpark_connect/utils/udf_helper.py +39 -37
- snowflake/snowpark_connect/utils/udf_utils.py +133 -14
- snowflake/snowpark_connect/utils/udtf_helper.py +8 -1
- snowflake/snowpark_connect/utils/udtf_utils.py +46 -31
- snowflake/snowpark_connect/utils/upload_java_jar.py +57 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/dp_session.py +6 -2
- snowflake/snowpark_decoder/spark_decoder.py +12 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-submit +2 -2
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/METADATA +14 -7
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/RECORD +129 -167
- snowflake/snowpark_connect/hidden_column.py +0 -39
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-trimmed-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-connect-client-jvm_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.27.0.data → snowpark_connect-1.6.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.27.0.dist-info → snowpark_connect-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -11,6 +11,7 @@ import pyarrow as pa
|
|
|
11
11
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
12
12
|
|
|
13
13
|
from snowflake import snowpark
|
|
14
|
+
from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD
|
|
14
15
|
from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
|
|
15
16
|
from snowflake.snowpark._internal.utils import is_in_stored_procedure
|
|
16
17
|
from snowflake.snowpark.types import LongType, StructField, StructType
|
|
@@ -18,7 +19,10 @@ from snowflake.snowpark_connect import tcm
|
|
|
18
19
|
from snowflake.snowpark_connect.column_name_handler import (
|
|
19
20
|
make_column_names_snowpark_compatible,
|
|
20
21
|
)
|
|
22
|
+
from snowflake.snowpark_connect.config import global_config
|
|
21
23
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
24
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
25
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
22
26
|
from snowflake.snowpark_connect.type_mapping import (
|
|
23
27
|
get_python_sql_utils_class,
|
|
24
28
|
map_json_schema_to_snowpark,
|
|
@@ -242,6 +246,16 @@ def map_local_relation(
|
|
|
242
246
|
# _create_temp_stage() changes were not ported to the internal connector, leading to this
|
|
243
247
|
# error on TCM and in notebooks (sproc):
|
|
244
248
|
# TypeError: _create_temp_stage() takes 7 positional arguments but 8 were given
|
|
249
|
+
#
|
|
250
|
+
# For large local relations (rows * cols >= ARRAY_BIND_THRESHOLD), use PyArrow path for better performance.
|
|
251
|
+
# PyArrow uses stage operations (5-6 queries) which is more efficient for large data than batch inserts.
|
|
252
|
+
|
|
253
|
+
enable_optimization = global_config._get_config_setting(
|
|
254
|
+
"snowpark.connect.localRelation.optimizeSmallData"
|
|
255
|
+
)
|
|
256
|
+
use_vectorized_scanner = global_config._get_config_setting(
|
|
257
|
+
"snowpark.connect.parquet.useVectorizedScanner"
|
|
258
|
+
)
|
|
245
259
|
use_pyarrow = (
|
|
246
260
|
not is_in_stored_procedure()
|
|
247
261
|
# TODO: SNOW-2220726 investigate why use_pyarrow failed in TCM:
|
|
@@ -253,12 +267,19 @@ def map_local_relation(
|
|
|
253
267
|
current_schema.strip('"') if current_schema is not None else "",
|
|
254
268
|
)
|
|
255
269
|
is not None
|
|
270
|
+
and (
|
|
271
|
+
# When optimization is disabled, always use PyArrow (preserves row ordering that some tests rely on)
|
|
272
|
+
not enable_optimization
|
|
273
|
+
# When optimization is enabled, use PyArrow only for large data for better performance.
|
|
274
|
+
or (table.num_rows * table.num_columns >= ARRAY_BIND_THRESHOLD)
|
|
275
|
+
)
|
|
256
276
|
)
|
|
257
277
|
|
|
258
278
|
if use_pyarrow:
|
|
259
279
|
snowpark_df: snowpark.DataFrame = session.create_dataframe(
|
|
260
280
|
# Rename the columns to match the Snowpark schema before creating.
|
|
261
281
|
data=table.rename_columns([unquote_if_quoted(c) for c in new_columns]),
|
|
282
|
+
use_vectorized_scanner=use_vectorized_scanner,
|
|
262
283
|
)
|
|
263
284
|
|
|
264
285
|
# Cast the columns to the correct types based on the schema as create_dataframe will
|
|
@@ -273,6 +294,9 @@ def map_local_relation(
|
|
|
273
294
|
snowpark_df = snowpark_df.select(*casted_columns)
|
|
274
295
|
|
|
275
296
|
else:
|
|
297
|
+
# For small datasets (< ARRAY_BIND_THRESHOLD), use List[Row] path.
|
|
298
|
+
# Snowpark's SnowflakeValues will use inline VALUES clause (lazy, no queries) for small data,
|
|
299
|
+
# or temp table with batch insert (lazy, 3 queries on action) if it grows larger.
|
|
276
300
|
pylist_df = [
|
|
277
301
|
list(row)
|
|
278
302
|
for row in zip(*(col.to_pylist() for col in table.itercolumns()))
|
|
@@ -325,11 +349,14 @@ def map_local_relation(
|
|
|
325
349
|
spark_column_names=spark_column_names,
|
|
326
350
|
snowpark_column_names=new_columns,
|
|
327
351
|
column_metadata=column_metadata,
|
|
352
|
+
snowpark_column_types=[f.datatype for f in snowpark_schema.fields],
|
|
328
353
|
)
|
|
329
354
|
else:
|
|
330
|
-
|
|
355
|
+
exception = SnowparkConnectNotImplementedError(
|
|
331
356
|
"LocalRelation without data & schema is not supported"
|
|
332
357
|
)
|
|
358
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
359
|
+
raise exception
|
|
333
360
|
|
|
334
361
|
|
|
335
362
|
def map_range(
|
|
@@ -8,11 +8,19 @@ from pyspark.sql.connect.proto.expressions_pb2 import CommonInlineUserDefinedFun
|
|
|
8
8
|
import snowflake.snowpark.functions as snowpark_fn
|
|
9
9
|
from snowflake import snowpark
|
|
10
10
|
from snowflake.snowpark.types import StructType
|
|
11
|
+
from snowflake.snowpark_connect.column_name_handler import make_unique_snowpark_name
|
|
11
12
|
from snowflake.snowpark_connect.constants import MAP_IN_ARROW_EVAL_TYPE
|
|
12
13
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
14
|
+
from snowflake.snowpark_connect.expression.map_unresolved_star import (
|
|
15
|
+
map_unresolved_star_as_single_column,
|
|
16
|
+
)
|
|
17
|
+
from snowflake.snowpark_connect.expression.typer import ExpressionTyper
|
|
13
18
|
from snowflake.snowpark_connect.relation.map_relation import map_relation
|
|
14
19
|
from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
|
|
15
|
-
from snowflake.snowpark_connect.utils.
|
|
20
|
+
from snowflake.snowpark_connect.utils.java_udtf_utils import (
|
|
21
|
+
JAVA_UDTF_PREFIX,
|
|
22
|
+
create_java_udtf_for_scala_flatmap_handling,
|
|
23
|
+
)
|
|
16
24
|
from snowflake.snowpark_connect.utils.pandas_udtf_utils import (
|
|
17
25
|
create_pandas_udtf,
|
|
18
26
|
create_pandas_udtf_with_arrow,
|
|
@@ -53,18 +61,18 @@ def _call_udtf(
|
|
|
53
61
|
).cast("int"),
|
|
54
62
|
)
|
|
55
63
|
|
|
56
|
-
udtf_columns = input_df.columns + [
|
|
64
|
+
udtf_columns = [f"snowflake_jtf_{column}" for column in input_df.columns] + [
|
|
65
|
+
"_DUMMY_PARTITION_KEY"
|
|
66
|
+
]
|
|
57
67
|
|
|
58
68
|
tfc = snowpark_fn.call_table_function(udtf_name, *udtf_columns).over(
|
|
59
69
|
partition_by=[snowpark_fn.col("_DUMMY_PARTITION_KEY")]
|
|
60
70
|
)
|
|
61
71
|
|
|
62
|
-
#
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
else:
|
|
67
|
-
result_df_with_dummy = input_df_with_dummy.select(tfc)
|
|
72
|
+
# Overwrite the input_df columns to prevent name conflicts with UDTF output columns
|
|
73
|
+
result_df_with_dummy = input_df_with_dummy.to_df(udtf_columns).join_table_function(
|
|
74
|
+
tfc
|
|
75
|
+
)
|
|
68
76
|
|
|
69
77
|
output_cols = [field.name for field in return_type.fields]
|
|
70
78
|
|
|
@@ -95,6 +103,73 @@ def _map_with_pandas_udtf(
|
|
|
95
103
|
else udf_proto.scalar_scala_udf.outputType
|
|
96
104
|
)
|
|
97
105
|
|
|
106
|
+
if udf_proto.WhichOneof("function") == "scalar_scala_udf":
|
|
107
|
+
assert (
|
|
108
|
+
len(udf_proto.scalar_scala_udf.inputTypes) == 1
|
|
109
|
+
), "len(inputTypes) should be 1 for map and flatMap operations"
|
|
110
|
+
|
|
111
|
+
udtf_name = create_java_udtf_for_scala_flatmap_handling(udf_proto)
|
|
112
|
+
|
|
113
|
+
if udf_proto.scalar_scala_udf.inputTypes[0].WhichOneof("kind") == "struct":
|
|
114
|
+
spark_col_name, typed_col = map_unresolved_star_as_single_column(
|
|
115
|
+
udf_proto.arguments[0],
|
|
116
|
+
input_df_container.column_map,
|
|
117
|
+
ExpressionTyper(input_df),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
udtf_arg_column = typed_col.col
|
|
121
|
+
else:
|
|
122
|
+
udtf_arg_column = snowpark_fn.col(
|
|
123
|
+
input_df_container.column_map.get_snowpark_columns()[0]
|
|
124
|
+
)
|
|
125
|
+
spark_col_name = input_df_container.column_map.get_spark_columns()[0]
|
|
126
|
+
|
|
127
|
+
if udf_proto.scalar_scala_udf.inputTypes[0].WhichOneof("kind") in (
|
|
128
|
+
"map",
|
|
129
|
+
"array",
|
|
130
|
+
):
|
|
131
|
+
udtf_arg_column = snowpark_fn.to_variant(udtf_arg_column)
|
|
132
|
+
|
|
133
|
+
new_snowpark_col_name = make_unique_snowpark_name(spark_col_name)
|
|
134
|
+
|
|
135
|
+
df = input_df.join_table_function(
|
|
136
|
+
snowpark_fn.call_table_function(udtf_name, udtf_arg_column)
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
df = df.select(
|
|
140
|
+
snowpark_fn.cast(
|
|
141
|
+
snowpark_fn.col(JAVA_UDTF_PREFIX + "C1"), return_type
|
|
142
|
+
).alias(new_snowpark_col_name)
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
if udf_proto.scalar_scala_udf.outputType.WhichOneof("kind") == "struct":
|
|
146
|
+
spark_names = [field.name for field in return_type.fields]
|
|
147
|
+
output_snowpark_names = [
|
|
148
|
+
make_unique_snowpark_name(name) for name in spark_names
|
|
149
|
+
]
|
|
150
|
+
output_types = [field.datatype for field in return_type.fields]
|
|
151
|
+
|
|
152
|
+
cols = [
|
|
153
|
+
snowpark_fn.get(
|
|
154
|
+
snowpark_fn.col(new_snowpark_col_name), snowpark_fn.lit(spark_name)
|
|
155
|
+
).alias(snowpark_name)
|
|
156
|
+
for spark_name, snowpark_name in zip(spark_names, output_snowpark_names)
|
|
157
|
+
]
|
|
158
|
+
|
|
159
|
+
if cols:
|
|
160
|
+
df = df.select(*cols)
|
|
161
|
+
else:
|
|
162
|
+
output_types = [return_type]
|
|
163
|
+
output_snowpark_names = [new_snowpark_col_name]
|
|
164
|
+
spark_names = [spark_col_name]
|
|
165
|
+
|
|
166
|
+
return DataFrameContainer.create_with_column_mapping(
|
|
167
|
+
dataframe=df,
|
|
168
|
+
spark_column_names=spark_names,
|
|
169
|
+
snowpark_column_names=output_snowpark_names,
|
|
170
|
+
snowpark_column_types=output_types,
|
|
171
|
+
)
|
|
172
|
+
|
|
98
173
|
# Check if this is mapInArrow (eval_type == 207)
|
|
99
174
|
map_in_arrow = (
|
|
100
175
|
udf_proto.WhichOneof("function") == "python_udf"
|
|
@@ -8,14 +8,16 @@ import pandas
|
|
|
8
8
|
import pyspark.sql.connect.proto.relations_pb2 as relation_proto
|
|
9
9
|
|
|
10
10
|
from snowflake.snowpark_connect.dataframe_container import DataFrameContainer
|
|
11
|
+
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
|
|
12
|
+
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
|
|
11
13
|
from snowflake.snowpark_connect.utils.cache import (
|
|
12
14
|
df_cache_map_get,
|
|
13
15
|
df_cache_map_put_if_absent,
|
|
14
16
|
)
|
|
15
17
|
from snowflake.snowpark_connect.utils.context import (
|
|
16
18
|
get_plan_id_map,
|
|
17
|
-
|
|
18
|
-
|
|
19
|
+
get_spark_session_id,
|
|
20
|
+
not_resolving_fun_args,
|
|
19
21
|
push_operation_scope,
|
|
20
22
|
set_is_aggregate_function,
|
|
21
23
|
set_plan_id_map,
|
|
@@ -73,7 +75,7 @@ def map_relation(
|
|
|
73
75
|
if reuse_parsed_plan and rel.HasField("common") and rel.common.HasField("plan_id"):
|
|
74
76
|
# TODO: remove get_session_id() when we host SAS in Snowflake server
|
|
75
77
|
# Check for cached relation
|
|
76
|
-
cache_entry = df_cache_map_get((
|
|
78
|
+
cache_entry = df_cache_map_get((get_spark_session_id(), rel.common.plan_id))
|
|
77
79
|
if cache_entry is not None:
|
|
78
80
|
if isinstance(cache_entry, DataFrameContainer):
|
|
79
81
|
set_plan_id_map(rel.common.plan_id, cache_entry)
|
|
@@ -103,7 +105,9 @@ def map_relation(
|
|
|
103
105
|
else:
|
|
104
106
|
# This happens when the relation is empty, usually because the incoming message
|
|
105
107
|
# type was incorrectly routed here.
|
|
106
|
-
|
|
108
|
+
exception = SnowparkConnectNotImplementedError("No Relation Type")
|
|
109
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
110
|
+
raise exception
|
|
107
111
|
|
|
108
112
|
result: DataFrameContainer | pandas.DataFrame
|
|
109
113
|
operation = rel.WhichOneof("rel_type")
|
|
@@ -121,11 +125,19 @@ def map_relation(
|
|
|
121
125
|
case relation_proto.Aggregate.GroupType.GROUP_TYPE_PIVOT:
|
|
122
126
|
result = map_aggregate.map_pivot_aggregate(rel)
|
|
123
127
|
case other:
|
|
124
|
-
|
|
128
|
+
exception = SnowparkConnectNotImplementedError(
|
|
129
|
+
f"AGGREGATE {other}"
|
|
130
|
+
)
|
|
131
|
+
attach_custom_error_code(
|
|
132
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
133
|
+
)
|
|
134
|
+
raise exception
|
|
125
135
|
case "approx_quantile":
|
|
126
136
|
result = map_stats.map_approx_quantile(rel)
|
|
127
137
|
case "as_of_join":
|
|
128
|
-
|
|
138
|
+
exception = SnowparkConnectNotImplementedError("AS_OF_JOIN")
|
|
139
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
140
|
+
raise exception
|
|
129
141
|
case "catalog": # TODO: order these alphabetically
|
|
130
142
|
result = map_catalog.map_catalog(rel.catalog)
|
|
131
143
|
case "collect_metrics":
|
|
@@ -150,7 +162,10 @@ def map_relation(
|
|
|
150
162
|
case "drop_na":
|
|
151
163
|
result = map_row_ops.map_dropna(rel)
|
|
152
164
|
case "extension":
|
|
153
|
-
|
|
165
|
+
# Extensions can be passed as function args, and we need to reset the context here.
|
|
166
|
+
# Matters only for resolving alias expressions in the extensions rel.
|
|
167
|
+
with not_resolving_fun_args():
|
|
168
|
+
result = map_extension.map_extension(rel)
|
|
154
169
|
case "fill_na":
|
|
155
170
|
result = map_row_ops.map_fillna(rel)
|
|
156
171
|
case "filter":
|
|
@@ -167,22 +182,25 @@ def map_relation(
|
|
|
167
182
|
case "limit":
|
|
168
183
|
result = map_row_ops.map_limit(rel)
|
|
169
184
|
case "local_relation":
|
|
170
|
-
result = map_local_relation.map_local_relation(
|
|
185
|
+
result = map_local_relation.map_local_relation(
|
|
186
|
+
rel
|
|
187
|
+
).without_materialization()
|
|
171
188
|
df_cache_map_put_if_absent(
|
|
172
|
-
(
|
|
189
|
+
(get_spark_session_id(), rel.common.plan_id), lambda: result
|
|
173
190
|
)
|
|
174
191
|
case "cached_local_relation":
|
|
175
192
|
cached_df = df_cache_map_get(
|
|
176
|
-
(
|
|
193
|
+
(get_spark_session_id(), rel.cached_local_relation.hash)
|
|
177
194
|
)
|
|
178
195
|
if cached_df is None:
|
|
179
|
-
|
|
196
|
+
exception = ValueError(
|
|
180
197
|
f"Local relation with hash {rel.cached_local_relation.hash} not found in cache."
|
|
181
198
|
)
|
|
199
|
+
attach_custom_error_code(exception, ErrorCodes.INTERNAL_ERROR)
|
|
200
|
+
raise exception
|
|
182
201
|
return cached_df
|
|
183
202
|
case "map_partitions":
|
|
184
|
-
|
|
185
|
-
result = map_map_partitions.map_map_partitions(rel)
|
|
203
|
+
result = map_map_partitions.map_map_partitions(rel)
|
|
186
204
|
case "offset":
|
|
187
205
|
result = map_row_ops.map_offset(rel)
|
|
188
206
|
case "project":
|
|
@@ -214,14 +232,13 @@ def map_relation(
|
|
|
214
232
|
case "sample":
|
|
215
233
|
sampled_df_not_evaluated = map_row_ops.map_sample(rel)
|
|
216
234
|
df_cache_map_put_if_absent(
|
|
217
|
-
(
|
|
235
|
+
(get_spark_session_id(), rel.common.plan_id),
|
|
218
236
|
lambda: sampled_df_not_evaluated,
|
|
219
|
-
True,
|
|
220
237
|
)
|
|
221
238
|
|
|
222
239
|
# We will retrieve from cache and return that, because insertion to cache
|
|
223
240
|
# triggers evaluation.
|
|
224
|
-
result = df_cache_map_get((
|
|
241
|
+
result = df_cache_map_get((get_spark_session_id(), rel.common.plan_id))
|
|
225
242
|
case "sample_by":
|
|
226
243
|
result = map_sample_by.map_sample_by(rel)
|
|
227
244
|
case "set_op":
|
|
@@ -233,7 +250,13 @@ def map_relation(
|
|
|
233
250
|
case relation_proto.SetOperation.SetOpType.SET_OP_TYPE_EXCEPT:
|
|
234
251
|
result = map_row_ops.map_except(rel)
|
|
235
252
|
case other:
|
|
236
|
-
|
|
253
|
+
exception = SnowparkConnectNotImplementedError(
|
|
254
|
+
f"SET_OP {other}"
|
|
255
|
+
)
|
|
256
|
+
attach_custom_error_code(
|
|
257
|
+
exception, ErrorCodes.UNSUPPORTED_OPERATION
|
|
258
|
+
)
|
|
259
|
+
raise exception
|
|
237
260
|
case "show_string":
|
|
238
261
|
result = map_show_string.map_show_string(rel)
|
|
239
262
|
case "sort":
|
|
@@ -259,11 +282,17 @@ def map_relation(
|
|
|
259
282
|
case "with_columns_renamed":
|
|
260
283
|
result = map_column_ops.map_with_columns_renamed(rel)
|
|
261
284
|
case "with_relations":
|
|
262
|
-
|
|
285
|
+
exception = SnowparkConnectNotImplementedError("WITH_RELATIONS")
|
|
286
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
287
|
+
raise exception
|
|
263
288
|
case "group_map":
|
|
264
289
|
result = map_column_ops.map_group_map(rel)
|
|
265
290
|
case other:
|
|
266
|
-
|
|
291
|
+
exception = SnowparkConnectNotImplementedError(
|
|
292
|
+
f"Other Relation {other}"
|
|
293
|
+
)
|
|
294
|
+
attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_OPERATION)
|
|
295
|
+
raise exception
|
|
267
296
|
|
|
268
297
|
# Store container in plan cache
|
|
269
298
|
if isinstance(result, DataFrameContainer):
|