snowpark-connect 0.20.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/__init__.py +23 -0
- snowflake/snowpark_connect/analyze_plan/__init__.py +3 -0
- snowflake/snowpark_connect/analyze_plan/map_tree_string.py +38 -0
- snowflake/snowpark_connect/column_name_handler.py +735 -0
- snowflake/snowpark_connect/config.py +576 -0
- snowflake/snowpark_connect/constants.py +47 -0
- snowflake/snowpark_connect/control_server.py +52 -0
- snowflake/snowpark_connect/dataframe_name_handler.py +54 -0
- snowflake/snowpark_connect/date_time_format_mapping.py +399 -0
- snowflake/snowpark_connect/empty_dataframe.py +18 -0
- snowflake/snowpark_connect/error/__init__.py +11 -0
- snowflake/snowpark_connect/error/error_mapping.py +6174 -0
- snowflake/snowpark_connect/error/error_utils.py +321 -0
- snowflake/snowpark_connect/error/exceptions.py +24 -0
- snowflake/snowpark_connect/execute_plan/__init__.py +3 -0
- snowflake/snowpark_connect/execute_plan/map_execution_command.py +204 -0
- snowflake/snowpark_connect/execute_plan/map_execution_root.py +173 -0
- snowflake/snowpark_connect/execute_plan/utils.py +183 -0
- snowflake/snowpark_connect/expression/__init__.py +3 -0
- snowflake/snowpark_connect/expression/literal.py +90 -0
- snowflake/snowpark_connect/expression/map_cast.py +343 -0
- snowflake/snowpark_connect/expression/map_expression.py +293 -0
- snowflake/snowpark_connect/expression/map_extension.py +104 -0
- snowflake/snowpark_connect/expression/map_sql_expression.py +633 -0
- snowflake/snowpark_connect/expression/map_udf.py +142 -0
- snowflake/snowpark_connect/expression/map_unresolved_attribute.py +241 -0
- snowflake/snowpark_connect/expression/map_unresolved_extract_value.py +85 -0
- snowflake/snowpark_connect/expression/map_unresolved_function.py +9450 -0
- snowflake/snowpark_connect/expression/map_unresolved_star.py +218 -0
- snowflake/snowpark_connect/expression/map_update_fields.py +164 -0
- snowflake/snowpark_connect/expression/map_window_function.py +258 -0
- snowflake/snowpark_connect/expression/typer.py +125 -0
- snowflake/snowpark_connect/includes/__init__.py +0 -0
- snowflake/snowpark_connect/includes/jars/antlr4-runtime-4.9.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-cli-1.5.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-codec-1.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections-3.2.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-collections4-4.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compiler-3.1.9.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-compress-1.26.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-crypto-1.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-dbcp-1.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-io-2.16.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang-2.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-lang3-3.12.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-logging-1.1.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-math3-3.6.1.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-pool-1.5.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/commons-text-1.10.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/hadoop-client-api-3.3.4.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-annotations-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-core-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-databind-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-dataformat-yaml-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-datatype-jsr310-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
- snowflake/snowpark_connect/includes/jars/jackson-module-scala_2.12-2.15.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-ast_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-core_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-jackson_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/json4s-scalap_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/kryo-shaded-4.0.2.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-1.2-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-api-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-core-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/log4j-slf4j2-impl-2.20.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-collection-compat_2.12-2.7.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-compiler-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-library-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-parser-combinators_2.12-2.3.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-reflect-2.12.18.jar +0 -0
- snowflake/snowpark_connect/includes/jars/scala-xml_2.12-2.1.0.jar +0 -0
- snowflake/snowpark_connect/includes/jars/slf4j-api-2.0.7.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-catalyst_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-common-utils_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-core_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-graphx_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive-thriftserver_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-hive_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kubernetes_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-kvstore_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-launcher_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mesos_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib-local_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-mllib_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-common_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-network-shuffle_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-repl_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sketch_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql-api_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-sql_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-streaming_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-tags_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-unsafe_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/jars/spark-yarn_2.12-3.5.6.jar +0 -0
- snowflake/snowpark_connect/includes/python/__init__.py +21 -0
- snowflake/snowpark_connect/includes/python/pyspark/__init__.py +173 -0
- snowflake/snowpark_connect/includes/python/pyspark/_globals.py +71 -0
- snowflake/snowpark_connect/includes/python/pyspark/_typing.pyi +43 -0
- snowflake/snowpark_connect/includes/python/pyspark/accumulators.py +341 -0
- snowflake/snowpark_connect/includes/python/pyspark/broadcast.py +383 -0
- snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/__init__.py +8 -0
- snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle.py +948 -0
- snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/cloudpickle_fast.py +844 -0
- snowflake/snowpark_connect/includes/python/pyspark/cloudpickle/compat.py +18 -0
- snowflake/snowpark_connect/includes/python/pyspark/conf.py +276 -0
- snowflake/snowpark_connect/includes/python/pyspark/context.py +2601 -0
- snowflake/snowpark_connect/includes/python/pyspark/daemon.py +218 -0
- snowflake/snowpark_connect/includes/python/pyspark/errors/__init__.py +70 -0
- snowflake/snowpark_connect/includes/python/pyspark/errors/error_classes.py +889 -0
- snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/base.py +228 -0
- snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/captured.py +307 -0
- snowflake/snowpark_connect/includes/python/pyspark/errors/exceptions/connect.py +190 -0
- snowflake/snowpark_connect/includes/python/pyspark/errors/tests/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/errors/tests/test_errors.py +60 -0
- snowflake/snowpark_connect/includes/python/pyspark/errors/utils.py +116 -0
- snowflake/snowpark_connect/includes/python/pyspark/files.py +165 -0
- snowflake/snowpark_connect/includes/python/pyspark/find_spark_home.py +95 -0
- snowflake/snowpark_connect/includes/python/pyspark/install.py +203 -0
- snowflake/snowpark_connect/includes/python/pyspark/instrumentation_utils.py +190 -0
- snowflake/snowpark_connect/includes/python/pyspark/java_gateway.py +248 -0
- snowflake/snowpark_connect/includes/python/pyspark/join.py +118 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/__init__.py +71 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/_typing.pyi +84 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/base.py +414 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/classification.py +4332 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/clustering.py +2188 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/common.py +146 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/connect/__init__.py +44 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/connect/base.py +346 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/connect/classification.py +382 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/connect/evaluation.py +291 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/connect/feature.py +258 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/connect/functions.py +77 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/connect/io_utils.py +335 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/connect/pipeline.py +262 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/connect/summarizer.py +120 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/connect/tuning.py +579 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/connect/util.py +173 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/deepspeed_distributor.py +165 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py +306 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/dl_util.py +150 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/evaluation.py +1166 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/feature.py +7474 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/fpm.py +543 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/functions.py +842 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/image.py +271 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/linalg/__init__.py +1382 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/model_cache.py +55 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/param/__init__.py +602 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/param/_shared_params_code_gen.py +368 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/param/shared.py +878 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/pipeline.py +451 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/recommendation.py +748 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/regression.py +3335 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/stat.py +523 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_classification.py +53 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_evaluation.py +50 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_feature.py +43 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_function.py +114 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_pipeline.py +47 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_summarizer.py +43 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_connect_tuning.py +46 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py +238 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py +194 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py +156 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py +184 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py +78 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py +292 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py +50 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py +152 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_algorithms.py +456 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_base.py +96 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_dl_util.py +186 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_evaluation.py +77 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_feature.py +401 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_functions.py +528 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_image.py +82 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_linalg.py +409 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_model_cache.py +55 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_param.py +441 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_persistence.py +546 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_pipeline.py +71 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_stat.py +52 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_training_summary.py +494 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_util.py +85 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/test_wrapper.py +138 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_basic.py +151 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_nested.py +97 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_cv_io_pipeline.py +143 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tuning.py +551 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_basic.py +137 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_nested.py +96 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tests/tuning/test_tvs_io_pipeline.py +142 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/torch/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/torch/data.py +100 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/torch/distributor.py +1133 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/torch/log_communication.py +198 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_data_loader.py +137 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_distributor.py +561 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/torch/tests/test_log_communication.py +172 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/torch/torch_run_process_wrapper.py +83 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tree.py +434 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/tuning.py +1741 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/util.py +749 -0
- snowflake/snowpark_connect/includes/python/pyspark/ml/wrapper.py +465 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/__init__.py +44 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/_typing.pyi +33 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/classification.py +989 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/clustering.py +1318 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/common.py +174 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/evaluation.py +691 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/feature.py +1085 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/fpm.py +233 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/__init__.py +1653 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/linalg/distributed.py +1662 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/random.py +698 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/recommendation.py +389 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/regression.py +1067 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/KernelDensity.py +59 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/__init__.py +34 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/_statistics.py +409 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/distribution.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/stat/test.py +86 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_algorithms.py +353 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_feature.py +192 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_linalg.py +680 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_stat.py +206 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_streaming_algorithms.py +471 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/tests/test_util.py +108 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/tree.py +888 -0
- snowflake/snowpark_connect/includes/python/pyspark/mllib/util.py +659 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/__init__.py +165 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/_typing.py +52 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/accessors.py +989 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/base.py +1804 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/categorical.py +822 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/config.py +539 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/correlation.py +262 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/base.py +519 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/binary_ops.py +98 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/boolean_ops.py +426 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/categorical_ops.py +141 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/complex_ops.py +145 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/date_ops.py +127 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/datetime_ops.py +171 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/null_ops.py +83 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/num_ops.py +588 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/string_ops.py +154 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/timedelta_ops.py +101 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/data_type_ops/udt_ops.py +29 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/datetimes.py +891 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/exceptions.py +150 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/extensions.py +388 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/frame.py +13738 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/generic.py +3560 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/groupby.py +4448 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/__init__.py +21 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/base.py +2783 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/category.py +773 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/datetimes.py +843 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/multi.py +1323 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/numeric.py +210 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/indexes/timedelta.py +197 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/indexing.py +1862 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/internal.py +1680 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/__init__.py +48 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/common.py +76 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/frame.py +63 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/general_functions.py +43 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/groupby.py +93 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/indexes.py +184 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/resample.py +101 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/scalars.py +29 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/series.py +69 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/missing/window.py +168 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/mlflow.py +238 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/namespace.py +3807 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/numpy_compat.py +260 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/__init__.py +17 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/core.py +1213 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/matplotlib.py +928 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/plot/plotly.py +261 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/resample.py +816 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/series.py +7440 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_formatter.py +308 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/sql_processor.py +394 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/strings.py +2371 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/supported_api_gen.py +378 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_any_all.py +177 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_apply_func.py +575 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_binary_ops.py +235 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_combine.py +653 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_compute.py +463 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_corrwith.py +86 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cov.py +151 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_cumulative.py +139 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_describe.py +458 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_eval.py +86 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_melt.py +202 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_missing_data.py +520 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/computation/test_pivot.py +361 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_any_all.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_apply_func.py +42 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_binary_ops.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_combine.py +37 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_compute.py +60 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_corrwith.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cov.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_cumulative.py +90 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_describe.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_eval.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_melt.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_missing_data.py +42 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py +37 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_base.py +36 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py +42 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py +47 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py +55 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_complex_ops.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py +47 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py +47 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py +42 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_arithmetic.py +43 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py +47 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_reverse.py +43 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py +47 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py +47 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/data_type_ops/testing_utils.py +226 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_align.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_basic_slow.py +55 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_cov_corrwith.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_frame.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_dot_series.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_index.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_series.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_frame.py +43 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/diff_frames_ops/test_parity_setitem_series.py +43 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_attrs.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_conversion.py +42 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reindexing.py +42 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py +37 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_spark.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_take.py +42 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_time_series.py +48 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/frame/test_parity_truncate.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_aggregate.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_apply_func.py +41 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_cumulative.py +67 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_describe.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_groupby.py +55 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_head_tail.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_index.py +38 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py +55 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_split_apply.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py +38 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_align.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py +50 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py +73 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_datetime.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_indexing.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reindex.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_rename.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_reset_index.py +48 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/indexes/test_parity_timedelta.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/io/test_parity_io.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot.py +45 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_matplotlib.py +45 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_frame_plot_plotly.py +49 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot.py +37 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_matplotlib.py +53 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/plot/test_parity_series_plot_plotly.py +45 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_all_any.py +38 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_arg_ops.py +37 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_of.py +37 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_as_type.py +38 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +37 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_conversion.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_cumulative.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_index.py +38 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_missing_data.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_series.py +37 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_sort.py +38 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/series/test_parity_stat.py +38 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_categorical.py +66 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_config.py +37 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_csv.py +37 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_conversion.py +42 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_dataframe_spark_io.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_default_index.py +49 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ewm.py +37 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_expanding.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_extension.py +49 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_frame_spark.py +53 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py +43 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexing.py +49 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_indexops_spark.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_internal.py +41 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_namespace.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_numpy_compat.py +60 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames.py +48 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_expanding.py +44 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_ops_on_diff_frames_groupby_rolling.py +84 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_repr.py +37 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_resample.py +45 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_reshape.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_rolling.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_scalars.py +37 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_conversion.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_datetime.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_series_string.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_spark_functions.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_sql.py +43 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_stats.py +37 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_typedef.py +36 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_utils.py +37 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/connect/test_parity_window.py +39 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_base.py +107 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py +224 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py +825 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py +562 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py +368 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py +257 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py +260 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py +178 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_arithmetic.py +184 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +497 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_num_reverse.py +140 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py +354 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py +219 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py +192 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/data_type_ops/testing_utils.py +228 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_align.py +118 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_basic_slow.py +198 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_cov_corrwith.py +181 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_frame.py +103 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_dot_series.py +141 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_index.py +109 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_series.py +136 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_frame.py +125 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/diff_frames_ops/test_setitem_series.py +217 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_attrs.py +384 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_constructor.py +598 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_conversion.py +73 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reindexing.py +869 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_reshaping.py +487 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_spark.py +309 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_take.py +156 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_time_series.py +149 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/frame/test_truncate.py +163 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_aggregate.py +311 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_apply_func.py +524 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_cumulative.py +419 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_describe.py +144 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_groupby.py +979 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_head_tail.py +234 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_index.py +206 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_missing_data.py +421 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_split_apply.py +187 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/groupby/test_stat.py +397 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_align.py +100 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_base.py +2743 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_category.py +484 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_datetime.py +276 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_indexing.py +432 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reindex.py +310 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_rename.py +257 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_reset_index.py +160 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/indexes/test_timedelta.py +128 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/io/test_io.py +137 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot.py +170 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py +547 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +285 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot.py +106 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py +409 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py +247 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_all_any.py +105 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_arg_ops.py +197 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_of.py +137 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_as_type.py +227 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_compute.py +634 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_conversion.py +88 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_cumulative.py +139 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_index.py +475 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_missing_data.py +265 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_series.py +818 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_sort.py +162 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/series/test_stat.py +780 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_categorical.py +741 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_config.py +160 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_csv.py +453 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_conversion.py +281 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_dataframe_spark_io.py +487 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_default_index.py +109 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ewm.py +434 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_expanding.py +253 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_extension.py +152 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_frame_spark.py +162 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_generic_functions.py +234 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexing.py +1339 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_indexops_spark.py +82 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_internal.py +124 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_namespace.py +638 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_numpy_compat.py +200 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +1355 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py +655 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py +113 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py +118 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_repr.py +192 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_resample.py +346 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_reshape.py +495 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_rolling.py +263 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_scalars.py +59 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_conversion.py +85 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_datetime.py +364 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_series_string.py +362 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_spark_functions.py +46 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_sql.py +123 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_stats.py +581 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_typedef.py +447 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_utils.py +301 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/tests/test_window.py +465 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/__init__.py +18 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/typedef/typehints.py +874 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/__init__.py +143 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/usage_logging/usage_logger.py +132 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/utils.py +1063 -0
- snowflake/snowpark_connect/includes/python/pyspark/pandas/window.py +2702 -0
- snowflake/snowpark_connect/includes/python/pyspark/profiler.py +489 -0
- snowflake/snowpark_connect/includes/python/pyspark/py.typed +1 -0
- snowflake/snowpark_connect/includes/python/pyspark/python/pyspark/shell.py +123 -0
- snowflake/snowpark_connect/includes/python/pyspark/rdd.py +5518 -0
- snowflake/snowpark_connect/includes/python/pyspark/rddsampler.py +115 -0
- snowflake/snowpark_connect/includes/python/pyspark/resource/__init__.py +38 -0
- snowflake/snowpark_connect/includes/python/pyspark/resource/information.py +69 -0
- snowflake/snowpark_connect/includes/python/pyspark/resource/profile.py +317 -0
- snowflake/snowpark_connect/includes/python/pyspark/resource/requests.py +539 -0
- snowflake/snowpark_connect/includes/python/pyspark/resource/tests/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/resource/tests/test_resources.py +83 -0
- snowflake/snowpark_connect/includes/python/pyspark/resultiterable.py +45 -0
- snowflake/snowpark_connect/includes/python/pyspark/serializers.py +681 -0
- snowflake/snowpark_connect/includes/python/pyspark/shell.py +123 -0
- snowflake/snowpark_connect/includes/python/pyspark/shuffle.py +854 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/__init__.py +75 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/_typing.pyi +80 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/avro/__init__.py +18 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/avro/functions.py +188 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/catalog.py +1270 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/column.py +1431 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/conf.py +99 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/__init__.py +18 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/_typing.py +90 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/__init__.py +18 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/avro/functions.py +107 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/catalog.py +356 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/__init__.py +22 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/artifact.py +412 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/core.py +1689 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/client/reattach.py +340 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/column.py +514 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conf.py +128 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/conversion.py +490 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/dataframe.py +2172 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/expressions.py +1056 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/functions.py +3937 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/group.py +418 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/plan.py +2289 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/__init__.py +25 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.py +203 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2.pyi +2718 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/base_pb2_grpc.py +423 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.py +109 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/catalog_pb2.pyi +1130 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.py +141 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/commands_pb2.pyi +1766 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.py +47 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/common_pb2.pyi +123 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.py +53 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/example_plugins_pb2.pyi +112 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.py +107 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/expressions_pb2.pyi +1507 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.py +195 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/relations_pb2.pyi +3613 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.py +95 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/proto/types_pb2.pyi +980 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/__init__.py +18 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/protobuf/functions.py +166 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/readwriter.py +861 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/session.py +952 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/__init__.py +22 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/query.py +295 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/readwriter.py +618 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/__init__.py +18 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +87 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/streaming/worker/listener_worker.py +100 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/types.py +301 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udf.py +296 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/udtf.py +200 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/utils.py +58 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/connect/window.py +266 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/context.py +818 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/dataframe.py +5973 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/functions.py +15889 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/group.py +547 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/observation.py +152 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/__init__.py +21 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/__init__.pyi +344 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/__init__.pyi +17 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/frame.pyi +20 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/_typing/protocols/series.pyi +20 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/conversion.py +671 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.py +480 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/functions.pyi +132 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/group_ops.py +523 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/map_ops.py +216 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/serializers.py +1019 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/typehints.py +172 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/types.py +972 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/pandas/utils.py +86 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/__init__.py +18 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/protobuf/functions.py +334 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/readwriter.py +2159 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/session.py +2088 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/sql_formatter.py +84 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/__init__.py +21 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/listener.py +1050 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/query.py +746 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/readwriter.py +1652 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/streaming/state.py +288 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_artifact.py +420 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/client/test_client.py +358 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach.py +36 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py +44 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +116 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py +35 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_basic.py +3612 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_column.py +1042 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_function.py +2381 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_connect_plan.py +1060 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow.py +163 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_map.py +38 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py +48 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_catalog.py +36 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_column.py +55 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_conf.py +36 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_dataframe.py +96 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_datasources.py +44 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_errors.py +36 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_functions.py +59 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_group.py +36 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py +59 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py +74 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py +62 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_map.py +58 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py +70 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py +50 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py +68 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py +40 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_readwriter.py +46 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_serde.py +44 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_types.py +100 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udf.py +100 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_parity_udtf.py +163 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_session.py +181 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/connect/test_utils.py +42 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +623 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +869 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py +342 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_map.py +436 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf.py +363 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +592 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +1503 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py +392 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py +375 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +411 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming.py +401 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach.py +295 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py +106 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/streaming/test_streaming_listener.py +558 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow.py +1346 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_map.py +182 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_arrow_python_udf.py +202 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_catalog.py +503 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_column.py +225 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_conf.py +83 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_context.py +201 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_dataframe.py +1931 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_datasources.py +256 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_errors.py +69 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_functions.py +1349 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_group.py +53 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_pandas_sqlmetrics.py +68 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_readwriter.py +283 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_serde.py +155 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_session.py +412 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_types.py +1581 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf.py +961 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udf_profiler.py +165 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_udtf.py +1456 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/tests/test_utils.py +1686 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/types.py +2558 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/udf.py +714 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/udtf.py +325 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/utils.py +339 -0
- snowflake/snowpark_connect/includes/python/pyspark/sql/window.py +492 -0
- snowflake/snowpark_connect/includes/python/pyspark/statcounter.py +165 -0
- snowflake/snowpark_connect/includes/python/pyspark/status.py +112 -0
- snowflake/snowpark_connect/includes/python/pyspark/storagelevel.py +97 -0
- snowflake/snowpark_connect/includes/python/pyspark/streaming/__init__.py +22 -0
- snowflake/snowpark_connect/includes/python/pyspark/streaming/context.py +471 -0
- snowflake/snowpark_connect/includes/python/pyspark/streaming/dstream.py +933 -0
- snowflake/snowpark_connect/includes/python/pyspark/streaming/kinesis.py +205 -0
- snowflake/snowpark_connect/includes/python/pyspark/streaming/listener.py +83 -0
- snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_context.py +184 -0
- snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_dstream.py +706 -0
- snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_kinesis.py +118 -0
- snowflake/snowpark_connect/includes/python/pyspark/streaming/tests/test_listener.py +160 -0
- snowflake/snowpark_connect/includes/python/pyspark/streaming/util.py +168 -0
- snowflake/snowpark_connect/includes/python/pyspark/taskcontext.py +502 -0
- snowflake/snowpark_connect/includes/python/pyspark/testing/__init__.py +21 -0
- snowflake/snowpark_connect/includes/python/pyspark/testing/connectutils.py +199 -0
- snowflake/snowpark_connect/includes/python/pyspark/testing/mllibutils.py +30 -0
- snowflake/snowpark_connect/includes/python/pyspark/testing/mlutils.py +275 -0
- snowflake/snowpark_connect/includes/python/pyspark/testing/objects.py +121 -0
- snowflake/snowpark_connect/includes/python/pyspark/testing/pandasutils.py +714 -0
- snowflake/snowpark_connect/includes/python/pyspark/testing/sqlutils.py +168 -0
- snowflake/snowpark_connect/includes/python/pyspark/testing/streamingutils.py +178 -0
- snowflake/snowpark_connect/includes/python/pyspark/testing/utils.py +636 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/__init__.py +16 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_appsubmit.py +306 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_broadcast.py +196 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_conf.py +44 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_context.py +346 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_daemon.py +89 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_install_spark.py +124 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_join.py +69 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_memory_profiler.py +167 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_pin_thread.py +194 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_profiler.py +168 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_rdd.py +939 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddbarrier.py +52 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_rddsampler.py +66 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_readwrite.py +368 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_serializers.py +257 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_shuffle.py +267 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_stage_sched.py +153 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_statcounter.py +130 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_taskcontext.py +350 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_util.py +97 -0
- snowflake/snowpark_connect/includes/python/pyspark/tests/test_worker.py +271 -0
- snowflake/snowpark_connect/includes/python/pyspark/traceback_utils.py +81 -0
- snowflake/snowpark_connect/includes/python/pyspark/util.py +416 -0
- snowflake/snowpark_connect/includes/python/pyspark/version.py +19 -0
- snowflake/snowpark_connect/includes/python/pyspark/worker.py +1307 -0
- snowflake/snowpark_connect/includes/python/pyspark/worker_util.py +46 -0
- snowflake/snowpark_connect/proto/__init__.py +10 -0
- snowflake/snowpark_connect/proto/control_pb2.py +35 -0
- snowflake/snowpark_connect/proto/control_pb2.pyi +38 -0
- snowflake/snowpark_connect/proto/control_pb2_grpc.py +183 -0
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.py +35 -0
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2.pyi +53 -0
- snowflake/snowpark_connect/proto/snowflake_rdd_pb2.pyi +39 -0
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.py +47 -0
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2.pyi +111 -0
- snowflake/snowpark_connect/relation/__init__.py +3 -0
- snowflake/snowpark_connect/relation/catalogs/__init__.py +12 -0
- snowflake/snowpark_connect/relation/catalogs/abstract_spark_catalog.py +287 -0
- snowflake/snowpark_connect/relation/catalogs/snowflake_catalog.py +467 -0
- snowflake/snowpark_connect/relation/catalogs/utils.py +51 -0
- snowflake/snowpark_connect/relation/io_utils.py +76 -0
- snowflake/snowpark_connect/relation/map_aggregate.py +322 -0
- snowflake/snowpark_connect/relation/map_catalog.py +151 -0
- snowflake/snowpark_connect/relation/map_column_ops.py +1068 -0
- snowflake/snowpark_connect/relation/map_crosstab.py +48 -0
- snowflake/snowpark_connect/relation/map_extension.py +412 -0
- snowflake/snowpark_connect/relation/map_join.py +341 -0
- snowflake/snowpark_connect/relation/map_local_relation.py +326 -0
- snowflake/snowpark_connect/relation/map_map_partitions.py +146 -0
- snowflake/snowpark_connect/relation/map_relation.py +253 -0
- snowflake/snowpark_connect/relation/map_row_ops.py +716 -0
- snowflake/snowpark_connect/relation/map_sample_by.py +35 -0
- snowflake/snowpark_connect/relation/map_show_string.py +50 -0
- snowflake/snowpark_connect/relation/map_sql.py +1874 -0
- snowflake/snowpark_connect/relation/map_stats.py +324 -0
- snowflake/snowpark_connect/relation/map_subquery_alias.py +32 -0
- snowflake/snowpark_connect/relation/map_udtf.py +288 -0
- snowflake/snowpark_connect/relation/read/__init__.py +7 -0
- snowflake/snowpark_connect/relation/read/jdbc_read_dbapi.py +668 -0
- snowflake/snowpark_connect/relation/read/map_read.py +367 -0
- snowflake/snowpark_connect/relation/read/map_read_csv.py +142 -0
- snowflake/snowpark_connect/relation/read/map_read_jdbc.py +108 -0
- snowflake/snowpark_connect/relation/read/map_read_json.py +344 -0
- snowflake/snowpark_connect/relation/read/map_read_parquet.py +194 -0
- snowflake/snowpark_connect/relation/read/map_read_socket.py +59 -0
- snowflake/snowpark_connect/relation/read/map_read_table.py +109 -0
- snowflake/snowpark_connect/relation/read/map_read_text.py +106 -0
- snowflake/snowpark_connect/relation/read/reader_config.py +399 -0
- snowflake/snowpark_connect/relation/read/utils.py +155 -0
- snowflake/snowpark_connect/relation/stage_locator.py +161 -0
- snowflake/snowpark_connect/relation/utils.py +219 -0
- snowflake/snowpark_connect/relation/write/__init__.py +3 -0
- snowflake/snowpark_connect/relation/write/jdbc_write_dbapi.py +339 -0
- snowflake/snowpark_connect/relation/write/map_write.py +436 -0
- snowflake/snowpark_connect/relation/write/map_write_jdbc.py +48 -0
- snowflake/snowpark_connect/resources/java_udfs-1.0-SNAPSHOT.jar +0 -0
- snowflake/snowpark_connect/resources_initializer.py +75 -0
- snowflake/snowpark_connect/server.py +1136 -0
- snowflake/snowpark_connect/start_server.py +32 -0
- snowflake/snowpark_connect/tcm.py +8 -0
- snowflake/snowpark_connect/type_mapping.py +1003 -0
- snowflake/snowpark_connect/typed_column.py +94 -0
- snowflake/snowpark_connect/utils/__init__.py +3 -0
- snowflake/snowpark_connect/utils/artifacts.py +48 -0
- snowflake/snowpark_connect/utils/attribute_handling.py +72 -0
- snowflake/snowpark_connect/utils/cache.py +84 -0
- snowflake/snowpark_connect/utils/concurrent.py +124 -0
- snowflake/snowpark_connect/utils/context.py +390 -0
- snowflake/snowpark_connect/utils/describe_query_cache.py +231 -0
- snowflake/snowpark_connect/utils/interrupt.py +85 -0
- snowflake/snowpark_connect/utils/io_utils.py +35 -0
- snowflake/snowpark_connect/utils/pandas_udtf_utils.py +117 -0
- snowflake/snowpark_connect/utils/profiling.py +47 -0
- snowflake/snowpark_connect/utils/session.py +180 -0
- snowflake/snowpark_connect/utils/snowpark_connect_logging.py +38 -0
- snowflake/snowpark_connect/utils/telemetry.py +513 -0
- snowflake/snowpark_connect/utils/udf_cache.py +392 -0
- snowflake/snowpark_connect/utils/udf_helper.py +328 -0
- snowflake/snowpark_connect/utils/udf_utils.py +310 -0
- snowflake/snowpark_connect/utils/udtf_helper.py +420 -0
- snowflake/snowpark_connect/utils/udtf_utils.py +799 -0
- snowflake/snowpark_connect/utils/xxhash64.py +247 -0
- snowflake/snowpark_connect/version.py +6 -0
- snowpark_connect-0.20.2.data/scripts/snowpark-connect +71 -0
- snowpark_connect-0.20.2.data/scripts/snowpark-session +11 -0
- snowpark_connect-0.20.2.data/scripts/snowpark-submit +354 -0
- snowpark_connect-0.20.2.dist-info/METADATA +37 -0
- snowpark_connect-0.20.2.dist-info/RECORD +879 -0
- snowpark_connect-0.20.2.dist-info/WHEEL +5 -0
- snowpark_connect-0.20.2.dist-info/licenses/LICENSE.txt +202 -0
- snowpark_connect-0.20.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1133 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
3
|
+
# contributor license agreements. See the NOTICE file distributed with
|
|
4
|
+
# this work for additional information regarding copyright ownership.
|
|
5
|
+
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
6
|
+
# (the "License"); you may not use this file except in compliance with
|
|
7
|
+
# the License. You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
import json
|
|
18
|
+
from contextlib import contextmanager
|
|
19
|
+
import collections
|
|
20
|
+
import logging
|
|
21
|
+
import math
|
|
22
|
+
import os
|
|
23
|
+
import random
|
|
24
|
+
import re
|
|
25
|
+
import shutil
|
|
26
|
+
import subprocess
|
|
27
|
+
import sys
|
|
28
|
+
import tempfile
|
|
29
|
+
import textwrap
|
|
30
|
+
import time
|
|
31
|
+
from typing import (
|
|
32
|
+
Union,
|
|
33
|
+
Callable,
|
|
34
|
+
List,
|
|
35
|
+
Dict,
|
|
36
|
+
Optional,
|
|
37
|
+
Any,
|
|
38
|
+
Tuple,
|
|
39
|
+
Generator,
|
|
40
|
+
Iterator,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
from pyspark import cloudpickle
|
|
44
|
+
from pyspark.resource.information import ResourceInformation
|
|
45
|
+
from pyspark.sql import DataFrame, SparkSession
|
|
46
|
+
from pyspark.taskcontext import BarrierTaskContext
|
|
47
|
+
from pyspark.ml.torch.log_communication import ( # type: ignore
|
|
48
|
+
LogStreamingClient,
|
|
49
|
+
LogStreamingServer,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _get_resources(session: SparkSession) -> Dict[str, ResourceInformation]:
|
|
54
|
+
resources: Dict[str, ResourceInformation] = {}
|
|
55
|
+
try:
|
|
56
|
+
resources = session.sparkContext.resources
|
|
57
|
+
except Exception:
|
|
58
|
+
resources = session._client._resources() # type: ignore[attr-defined]
|
|
59
|
+
return resources
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _get_conf(spark: SparkSession, key: str, default_value: str) -> str:
|
|
63
|
+
"""Get the conf "key" from the given spark session,
|
|
64
|
+
or return the default value if the conf is not set.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
spark : :class:`SparkSession`
|
|
69
|
+
The :class:`SparkSession` for the distributor.
|
|
70
|
+
key : str
|
|
71
|
+
string for conf name
|
|
72
|
+
default_value : str
|
|
73
|
+
default value for the conf value for the given key
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
str
|
|
78
|
+
Returns the string value that corresponds to the conf
|
|
79
|
+
"""
|
|
80
|
+
value = spark.conf.get(key, default_value)
|
|
81
|
+
assert value is not None
|
|
82
|
+
return value
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# TODO(SPARK-41589): will move the functions and tests to an external file
|
|
86
|
+
# once we are in agreement about which functions should be in utils.py
|
|
87
|
+
def _get_conf_boolean(spark: SparkSession, key: str, default_value: str) -> bool:
|
|
88
|
+
value = _get_conf(spark=spark, key=key, default_value=default_value)
|
|
89
|
+
value = value.lower()
|
|
90
|
+
assert value in ["true", "false"]
|
|
91
|
+
return value == "true"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _get_logger(name: str) -> logging.Logger:
|
|
95
|
+
"""
|
|
96
|
+
Gets a logger by name, or creates and configures it for the first time.
|
|
97
|
+
"""
|
|
98
|
+
logger = logging.getLogger(name)
|
|
99
|
+
logger.setLevel(logging.INFO)
|
|
100
|
+
# If the logger is configured, skip the configure
|
|
101
|
+
if not logger.handlers and not logging.getLogger().handlers:
|
|
102
|
+
handler = logging.StreamHandler(sys.stderr)
|
|
103
|
+
logger.addHandler(handler)
|
|
104
|
+
return logger
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _get_gpus_owned(context: Union[SparkSession, BarrierTaskContext]) -> List[str]:
|
|
108
|
+
"""Gets the number of GPUs that Spark scheduled to the calling task.
|
|
109
|
+
|
|
110
|
+
Parameters
|
|
111
|
+
----------
|
|
112
|
+
context : :class:`SparkSession` or :class:`BarrierTaskContext`
|
|
113
|
+
The :class:`SparkSession` or :class:`BarrierTaskContext` that has GPUs available.
|
|
114
|
+
|
|
115
|
+
Returns
|
|
116
|
+
-------
|
|
117
|
+
list
|
|
118
|
+
The correct mapping of addresses to workers.
|
|
119
|
+
|
|
120
|
+
Raises
|
|
121
|
+
------
|
|
122
|
+
ValueError
|
|
123
|
+
Raised if the input addresses were not found.
|
|
124
|
+
"""
|
|
125
|
+
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
|
|
126
|
+
pattern = re.compile("^[1-9][0-9]*|0$")
|
|
127
|
+
if isinstance(context, BarrierTaskContext):
|
|
128
|
+
addresses = context.resources()["gpu"].addresses
|
|
129
|
+
else:
|
|
130
|
+
addresses = _get_resources(context)["gpu"].addresses
|
|
131
|
+
|
|
132
|
+
if any(not pattern.match(address) for address in addresses):
|
|
133
|
+
raise ValueError(
|
|
134
|
+
f"Found GPU addresses {addresses} which "
|
|
135
|
+
"are not all in the correct format "
|
|
136
|
+
"for CUDA_VISIBLE_DEVICES, which requires "
|
|
137
|
+
"integers with no zero padding."
|
|
138
|
+
)
|
|
139
|
+
if CUDA_VISIBLE_DEVICES in os.environ:
|
|
140
|
+
gpu_indices = list(map(int, addresses))
|
|
141
|
+
gpu_list = os.environ[CUDA_VISIBLE_DEVICES].split(",")
|
|
142
|
+
gpu_owned = [gpu_list[i] for i in gpu_indices]
|
|
143
|
+
return gpu_owned
|
|
144
|
+
return addresses
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
SPARK_PARTITION_ARROW_DATA_FILE = "SPARK_PARTITION_ARROW_DATA_FILE"
|
|
148
|
+
SPARK_DATAFRAME_SCHEMA_FILE = "SPARK_DATAFRAME_SCHEMA_FILE"
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class Distributor:
|
|
152
|
+
"""
|
|
153
|
+
The parent class for TorchDistributor. This class shouldn't be instantiated directly.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
num_processes: int = 1,
|
|
159
|
+
local_mode: bool = True,
|
|
160
|
+
use_gpu: bool = True,
|
|
161
|
+
ssl_conf: Optional[str] = None,
|
|
162
|
+
):
|
|
163
|
+
from pyspark.sql.utils import is_remote
|
|
164
|
+
|
|
165
|
+
self.is_remote = is_remote()
|
|
166
|
+
self.spark = SparkSession.active()
|
|
167
|
+
|
|
168
|
+
# indicate whether the server side is local mode
|
|
169
|
+
self.is_spark_local_master = False
|
|
170
|
+
# Refer to 'org.apache.spark.util.Utils#isLocalMaster'
|
|
171
|
+
master = _get_conf(self.spark, "spark.master", "")
|
|
172
|
+
if master == "local" or master.startswith("local["):
|
|
173
|
+
self.is_spark_local_master = True
|
|
174
|
+
|
|
175
|
+
self.logger = _get_logger(self.__class__.__name__)
|
|
176
|
+
self.num_processes = num_processes
|
|
177
|
+
self.local_mode = local_mode
|
|
178
|
+
self.use_gpu = use_gpu
|
|
179
|
+
self.num_tasks = self._get_num_tasks()
|
|
180
|
+
self.ssl_conf = ssl_conf
|
|
181
|
+
|
|
182
|
+
def _create_input_params(self) -> Dict[str, Any]:
|
|
183
|
+
input_params = self.__dict__.copy()
|
|
184
|
+
for unneeded_param in [
|
|
185
|
+
"spark",
|
|
186
|
+
"ssl_conf",
|
|
187
|
+
"logger",
|
|
188
|
+
"is_remote",
|
|
189
|
+
"is_spark_local_master",
|
|
190
|
+
]:
|
|
191
|
+
del input_params[unneeded_param]
|
|
192
|
+
return input_params
|
|
193
|
+
|
|
194
|
+
def _get_num_tasks(self) -> int:
|
|
195
|
+
"""
|
|
196
|
+
Returns the number of Spark tasks to use for distributed training
|
|
197
|
+
|
|
198
|
+
Returns
|
|
199
|
+
-------
|
|
200
|
+
int
|
|
201
|
+
The number of Spark tasks to use for distributed training
|
|
202
|
+
|
|
203
|
+
Raises
|
|
204
|
+
------
|
|
205
|
+
RuntimeError
|
|
206
|
+
Raised when the SparkConf was misconfigured.
|
|
207
|
+
"""
|
|
208
|
+
if self.use_gpu:
|
|
209
|
+
if not self.local_mode:
|
|
210
|
+
key = "spark.task.resource.gpu.amount"
|
|
211
|
+
task_gpu_amount = int(_get_conf(self.spark, key, "0"))
|
|
212
|
+
if task_gpu_amount < 1:
|
|
213
|
+
raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
|
|
214
|
+
# TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
|
|
215
|
+
return math.ceil(self.num_processes / task_gpu_amount)
|
|
216
|
+
else:
|
|
217
|
+
key = "spark.driver.resource.gpu.amount"
|
|
218
|
+
if "gpu" not in _get_resources(self.spark):
|
|
219
|
+
raise RuntimeError("GPUs were unable to be found on the driver.")
|
|
220
|
+
num_available_gpus = int(_get_conf(self.spark, key, "0"))
|
|
221
|
+
if num_available_gpus == 0:
|
|
222
|
+
raise RuntimeError("GPU resources were not configured properly on the driver.")
|
|
223
|
+
if self.num_processes > num_available_gpus:
|
|
224
|
+
self.logger.warning(
|
|
225
|
+
"'num_processes' cannot be set to a value greater than the number of "
|
|
226
|
+
f"available GPUs on the driver, which is {num_available_gpus}. "
|
|
227
|
+
"'num_processes' was reset to be equal to the number of available GPUs.",
|
|
228
|
+
)
|
|
229
|
+
self.num_processes = num_available_gpus
|
|
230
|
+
return self.num_processes
|
|
231
|
+
|
|
232
|
+
def _validate_input_params(self) -> None:
|
|
233
|
+
if self.num_processes <= 0:
|
|
234
|
+
raise ValueError("num_proccesses has to be a positive integer")
|
|
235
|
+
|
|
236
|
+
def _check_encryption(self) -> None:
|
|
237
|
+
"""Checks to see if the user requires encrpytion of data.
|
|
238
|
+
If required, throw an exception since we don't support that.
|
|
239
|
+
|
|
240
|
+
Raises
|
|
241
|
+
------
|
|
242
|
+
RuntimeError
|
|
243
|
+
Thrown when the user requires ssl encryption or when the user initializes
|
|
244
|
+
the Distributor parent class.
|
|
245
|
+
"""
|
|
246
|
+
if not hasattr(self, "ssl_conf"):
|
|
247
|
+
raise RuntimeError(
|
|
248
|
+
"Distributor doesn't have this functionality. Use TorchDistributor instead."
|
|
249
|
+
)
|
|
250
|
+
is_ssl_enabled = _get_conf_boolean(self.spark, "spark.ssl.enabled", "false")
|
|
251
|
+
ignore_ssl = _get_conf_boolean(self.spark, self.ssl_conf, "false") # type: ignore
|
|
252
|
+
if is_ssl_enabled:
|
|
253
|
+
name = self.__class__.__name__
|
|
254
|
+
if ignore_ssl:
|
|
255
|
+
self.logger.warning(
|
|
256
|
+
textwrap.dedent(
|
|
257
|
+
f"""
|
|
258
|
+
This cluster has TLS encryption enabled;
|
|
259
|
+
however, {name} does not
|
|
260
|
+
support data encryption in transit.
|
|
261
|
+
The Spark configuration
|
|
262
|
+
'{self.ssl_conf}' has been set to
|
|
263
|
+
'true' to override this
|
|
264
|
+
configuration and use {name} anyway. Please
|
|
265
|
+
note this will cause model
|
|
266
|
+
parameters and possibly training data to
|
|
267
|
+
be sent between nodes unencrypted.
|
|
268
|
+
""",
|
|
269
|
+
)
|
|
270
|
+
)
|
|
271
|
+
return
|
|
272
|
+
raise RuntimeError(
|
|
273
|
+
textwrap.dedent(
|
|
274
|
+
f"""
|
|
275
|
+
This cluster has TLS encryption enabled;
|
|
276
|
+
however, {name} does not support
|
|
277
|
+
data encryption in transit. To override
|
|
278
|
+
this configuration and use {name}
|
|
279
|
+
anyway, you may set '{self.ssl_conf}'
|
|
280
|
+
to 'true' in the Spark configuration. Please note this
|
|
281
|
+
will cause model parameters and possibly training
|
|
282
|
+
data to be sent between nodes unencrypted.
|
|
283
|
+
"""
|
|
284
|
+
)
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class TorchDistributor(Distributor):
|
|
289
|
+
"""
|
|
290
|
+
A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
|
|
291
|
+
|
|
292
|
+
.. versionadded:: 3.4.0
|
|
293
|
+
|
|
294
|
+
.. versionchanged:: 3.5.0
|
|
295
|
+
Supports Spark Connect.
|
|
296
|
+
|
|
297
|
+
Parameters
|
|
298
|
+
----------
|
|
299
|
+
num_processes : int, optional
|
|
300
|
+
An integer that determines how many different concurrent
|
|
301
|
+
tasks are allowed. We expect spark.task.gpus = 1 for GPU-enabled training. Default
|
|
302
|
+
should be 1; we don't want to invoke multiple cores/gpus without explicit mention.
|
|
303
|
+
local_mode : bool, optional
|
|
304
|
+
A boolean that determines whether we are using the driver
|
|
305
|
+
node for training. Default should be false; we don't want to invoke executors without
|
|
306
|
+
explicit mention.
|
|
307
|
+
use_gpu : bool, optional
|
|
308
|
+
A boolean that indicates whether or not we are doing training
|
|
309
|
+
on the GPU. Note that there are differences in how GPU-enabled code looks like and
|
|
310
|
+
how CPU-specific code looks like.
|
|
311
|
+
|
|
312
|
+
Examples
|
|
313
|
+
--------
|
|
314
|
+
Run PyTorch Training locally on GPU (using a PyTorch native function)
|
|
315
|
+
|
|
316
|
+
>>> def train(learning_rate):
|
|
317
|
+
... import torch.distributed
|
|
318
|
+
... torch.distributed.init_process_group(backend="nccl")
|
|
319
|
+
... # ...
|
|
320
|
+
... torch.destroy_process_group()
|
|
321
|
+
... return model # or anything else
|
|
322
|
+
...
|
|
323
|
+
>>> distributor = TorchDistributor(
|
|
324
|
+
... num_processes=2,
|
|
325
|
+
... local_mode=True,
|
|
326
|
+
... use_gpu=True)
|
|
327
|
+
>>> model = distributor.run(train, 1e-3)
|
|
328
|
+
|
|
329
|
+
Run PyTorch Training on GPU (using a file with PyTorch code)
|
|
330
|
+
|
|
331
|
+
>>> distributor = TorchDistributor(
|
|
332
|
+
... num_processes=2,
|
|
333
|
+
... local_mode=False,
|
|
334
|
+
... use_gpu=True)
|
|
335
|
+
>>> distributor.run("/path/to/train.py", "--learning-rate=1e-3")
|
|
336
|
+
|
|
337
|
+
Run PyTorch Lightning Training on GPU
|
|
338
|
+
|
|
339
|
+
>>> num_proc = 2
|
|
340
|
+
>>> def train():
|
|
341
|
+
... from pytorch_lightning import Trainer
|
|
342
|
+
... # ...
|
|
343
|
+
... # required to set devices = 1 and num_nodes = num_processes for multi node
|
|
344
|
+
... # required to set devices = num_processes and num_nodes = 1 for single node multi GPU
|
|
345
|
+
... trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp")
|
|
346
|
+
... trainer.fit()
|
|
347
|
+
... # ...
|
|
348
|
+
... return trainer
|
|
349
|
+
...
|
|
350
|
+
>>> distributor = TorchDistributor(
|
|
351
|
+
... num_processes=num_proc,
|
|
352
|
+
... local_mode=True,
|
|
353
|
+
... use_gpu=True)
|
|
354
|
+
>>> trainer = distributor.run(train)
|
|
355
|
+
"""
|
|
356
|
+
|
|
357
|
+
_PICKLED_FUNC_FILE = "func.pickle"
|
|
358
|
+
_TRAIN_FILE = "train.py"
|
|
359
|
+
_PICKLED_OUTPUT_FILE = "output.pickle"
|
|
360
|
+
_TORCH_SSL_CONF = "pytorch.spark.distributor.ignoreSsl"
|
|
361
|
+
|
|
362
|
+
def __init__(
|
|
363
|
+
self,
|
|
364
|
+
num_processes: int = 1,
|
|
365
|
+
local_mode: bool = True,
|
|
366
|
+
use_gpu: bool = True,
|
|
367
|
+
_ssl_conf: str = _TORCH_SSL_CONF,
|
|
368
|
+
):
|
|
369
|
+
"""Initializes the distributor.
|
|
370
|
+
|
|
371
|
+
Parameters
|
|
372
|
+
----------
|
|
373
|
+
num_processes : int, optional
|
|
374
|
+
An integer that determines how many different concurrent
|
|
375
|
+
tasks are allowed. We expect spark.task.gpus = 1 for GPU-enabled training. Default
|
|
376
|
+
should be 1; we don't want to invoke multiple cores/gpus without explicit mention.
|
|
377
|
+
local_mode : bool, optional
|
|
378
|
+
A boolean that determines whether we are using the driver
|
|
379
|
+
node for training. Default should be false; we don't want to invoke executors without
|
|
380
|
+
explicit mention.
|
|
381
|
+
use_gpu : bool, optional
|
|
382
|
+
A boolean that indicates whether or not we are doing training
|
|
383
|
+
on the GPU. Note that there are differences in how GPU-enabled code looks like and
|
|
384
|
+
how CPU-specific code looks like.
|
|
385
|
+
|
|
386
|
+
Raises
|
|
387
|
+
------
|
|
388
|
+
ValueError
|
|
389
|
+
If any of the parameters are incorrect.
|
|
390
|
+
RuntimeError
|
|
391
|
+
If an active SparkSession is unavailable.
|
|
392
|
+
"""
|
|
393
|
+
super().__init__(num_processes, local_mode, use_gpu, ssl_conf=_ssl_conf)
|
|
394
|
+
self._validate_input_params()
|
|
395
|
+
self.input_params = self._create_input_params()
|
|
396
|
+
|
|
397
|
+
@staticmethod
|
|
398
|
+
def _get_torchrun_args(local_mode: bool, num_processes: int) -> Tuple[List[Any], int]:
|
|
399
|
+
"""
|
|
400
|
+
Given the mode and the number of processes, create the arguments to be given to for torch
|
|
401
|
+
|
|
402
|
+
Parameters
|
|
403
|
+
---------
|
|
404
|
+
local_mode: bool
|
|
405
|
+
Whether or not we are running training locally or in a distributed fashion
|
|
406
|
+
|
|
407
|
+
num_processes: int
|
|
408
|
+
The number of processes that we are going to use
|
|
409
|
+
|
|
410
|
+
Returns
|
|
411
|
+
------
|
|
412
|
+
Tuple[List[Any], int]
|
|
413
|
+
A tuple containing a list of arguments to pass as pytorch args,
|
|
414
|
+
as well as the number of processes per node
|
|
415
|
+
"""
|
|
416
|
+
if local_mode:
|
|
417
|
+
torchrun_args = ["--standalone", "--nnodes=1"]
|
|
418
|
+
processes_per_node = num_processes
|
|
419
|
+
return torchrun_args, processes_per_node
|
|
420
|
+
|
|
421
|
+
master_addr = os.environ["MASTER_ADDR"]
|
|
422
|
+
master_port = os.environ["MASTER_PORT"]
|
|
423
|
+
node_rank = os.environ["RANK"]
|
|
424
|
+
torchrun_args = [
|
|
425
|
+
f"--nnodes={num_processes}",
|
|
426
|
+
f"--node_rank={node_rank}",
|
|
427
|
+
f"--rdzv_endpoint={master_addr}:{master_port}",
|
|
428
|
+
"--rdzv_id=0", # TODO: setup random ID that is gleaned from env variables
|
|
429
|
+
]
|
|
430
|
+
processes_per_node = 1
|
|
431
|
+
return torchrun_args, processes_per_node
|
|
432
|
+
|
|
433
|
+
@staticmethod
|
|
434
|
+
def _create_torchrun_command(
|
|
435
|
+
input_params: Dict[str, Any], path_to_train_file: str, *args: Any
|
|
436
|
+
) -> List[str]:
|
|
437
|
+
local_mode = input_params["local_mode"]
|
|
438
|
+
num_processes = input_params["num_processes"]
|
|
439
|
+
|
|
440
|
+
torchrun_args, processes_per_node = TorchDistributor._get_torchrun_args(
|
|
441
|
+
local_mode=local_mode, num_processes=num_processes
|
|
442
|
+
)
|
|
443
|
+
args_string = list(map(str, args)) # converting all args to strings
|
|
444
|
+
|
|
445
|
+
return [
|
|
446
|
+
sys.executable,
|
|
447
|
+
"-m",
|
|
448
|
+
"pyspark.ml.torch.torch_run_process_wrapper",
|
|
449
|
+
*torchrun_args,
|
|
450
|
+
f"--nproc_per_node={processes_per_node}",
|
|
451
|
+
path_to_train_file,
|
|
452
|
+
*args_string,
|
|
453
|
+
]
|
|
454
|
+
|
|
455
|
+
@staticmethod
|
|
456
|
+
def _execute_command(
|
|
457
|
+
cmd: List[str],
|
|
458
|
+
_prctl: bool = True,
|
|
459
|
+
redirect_to_stdout: bool = True,
|
|
460
|
+
log_streaming_client: Optional[LogStreamingClient] = None,
|
|
461
|
+
) -> None:
|
|
462
|
+
_TAIL_LINES_TO_KEEP = 100
|
|
463
|
+
|
|
464
|
+
task = subprocess.Popen(
|
|
465
|
+
cmd,
|
|
466
|
+
stdout=subprocess.PIPE,
|
|
467
|
+
stderr=subprocess.STDOUT,
|
|
468
|
+
stdin=subprocess.PIPE,
|
|
469
|
+
env=os.environ,
|
|
470
|
+
)
|
|
471
|
+
task.stdin.close() # type: ignore
|
|
472
|
+
tail: collections.deque = collections.deque(maxlen=_TAIL_LINES_TO_KEEP)
|
|
473
|
+
try:
|
|
474
|
+
for line in task.stdout: # type: ignore
|
|
475
|
+
decoded = line.decode()
|
|
476
|
+
tail.append(decoded)
|
|
477
|
+
if redirect_to_stdout:
|
|
478
|
+
if (
|
|
479
|
+
log_streaming_client
|
|
480
|
+
and not log_streaming_client.failed
|
|
481
|
+
and (
|
|
482
|
+
log_streaming_client.sock.getsockname()[0]
|
|
483
|
+
== log_streaming_client.sock.getpeername()[0]
|
|
484
|
+
)
|
|
485
|
+
):
|
|
486
|
+
# If log_streaming_client and log_stream_server are in the same
|
|
487
|
+
# node (typical case is spark local mode),
|
|
488
|
+
# server side will redirect the log to STDOUT,
|
|
489
|
+
# to avoid STDOUT outputs duplication, skip redirecting
|
|
490
|
+
# logs to STDOUT in client side.
|
|
491
|
+
pass
|
|
492
|
+
else:
|
|
493
|
+
sys.stdout.write(decoded)
|
|
494
|
+
if log_streaming_client:
|
|
495
|
+
log_streaming_client.send(decoded.rstrip())
|
|
496
|
+
task.wait()
|
|
497
|
+
finally:
|
|
498
|
+
if task.poll() is None:
|
|
499
|
+
try:
|
|
500
|
+
task.terminate() # SIGTERM
|
|
501
|
+
time.sleep(0.5)
|
|
502
|
+
if task.poll() is None:
|
|
503
|
+
task.kill() # SIGKILL
|
|
504
|
+
except OSError:
|
|
505
|
+
pass
|
|
506
|
+
if task.returncode != os.EX_OK:
|
|
507
|
+
if len(tail) == _TAIL_LINES_TO_KEEP:
|
|
508
|
+
last_n_msg = f"last {_TAIL_LINES_TO_KEEP} lines of the task output are"
|
|
509
|
+
else:
|
|
510
|
+
last_n_msg = "task output is"
|
|
511
|
+
task_output = "".join(tail)
|
|
512
|
+
raise RuntimeError(
|
|
513
|
+
f"Command {cmd} failed with return code {task.returncode}. "
|
|
514
|
+
f"The {last_n_msg} included below: {task_output}"
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
@staticmethod
|
|
518
|
+
def _get_output_from_framework_wrapper(
|
|
519
|
+
framework_wrapper: Optional[Callable],
|
|
520
|
+
input_params: Dict,
|
|
521
|
+
train_object: Union[Callable, str],
|
|
522
|
+
run_pytorch_file_fn: Optional[Callable],
|
|
523
|
+
*args: Any,
|
|
524
|
+
**kwargs: Any,
|
|
525
|
+
) -> Optional[Any]:
|
|
526
|
+
"""
|
|
527
|
+
This function is meant to get the output from framework wrapper function by passing in the
|
|
528
|
+
correct arguments, depending on the type of train_object.
|
|
529
|
+
|
|
530
|
+
Parameters
|
|
531
|
+
----------
|
|
532
|
+
framework_wrapper: Optional[Callable]
|
|
533
|
+
Function pointer that will be invoked. Can either be the function that runs distributed
|
|
534
|
+
training on files if train_object is a string. Otherwise, it will be the function that
|
|
535
|
+
runs distributed training for functions if the train_object is a Callable
|
|
536
|
+
input_params: Dict
|
|
537
|
+
A dictionary that maps parameter to arguments for the command to be created.
|
|
538
|
+
train_object: Union[Callable, str]
|
|
539
|
+
This input comes from the user. If the user inputs a string, then this means
|
|
540
|
+
it's a filepath. Otherwise, if the input is a function, then this means that
|
|
541
|
+
the user wants to run this function in a distributed manner.
|
|
542
|
+
run_pytorch_file_fn: Optional[Callable]
|
|
543
|
+
The function that will be used to run distributed training of a file;
|
|
544
|
+
mainly used for the distributed training using a function.
|
|
545
|
+
*args: Any
|
|
546
|
+
Extra arguments to be used by framework wrapper.
|
|
547
|
+
**kwargs: Any
|
|
548
|
+
Extra keyword args to be used. Not currently supported but kept for
|
|
549
|
+
future improvement.
|
|
550
|
+
|
|
551
|
+
Returns
|
|
552
|
+
-------
|
|
553
|
+
Optional[Any]
|
|
554
|
+
Returns the result of the framework_wrapper
|
|
555
|
+
"""
|
|
556
|
+
if not framework_wrapper:
|
|
557
|
+
raise RuntimeError("`framework_wrapper` is not set. ...")
|
|
558
|
+
# The object to train is a file path, so framework_wrapper is some
|
|
559
|
+
# run_training_on_pytorch_file function.
|
|
560
|
+
if type(train_object) is str:
|
|
561
|
+
return framework_wrapper(input_params, train_object, *args, **kwargs)
|
|
562
|
+
else:
|
|
563
|
+
# We are doing training with a function, will call run_training_on_pytorch_function
|
|
564
|
+
if not run_pytorch_file_fn:
|
|
565
|
+
run_pytorch_file_fn = TorchDistributor._run_training_on_pytorch_file
|
|
566
|
+
return framework_wrapper(
|
|
567
|
+
input_params, train_object, run_pytorch_file_fn, *args, **kwargs
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
def _run_local_training(
|
|
571
|
+
self,
|
|
572
|
+
framework_wrapper_fn: Callable,
|
|
573
|
+
train_object: Union[Callable, str],
|
|
574
|
+
run_pytorch_file_fn: Optional[Callable],
|
|
575
|
+
*args: Any,
|
|
576
|
+
**kwargs: Any,
|
|
577
|
+
) -> Optional[Any]:
|
|
578
|
+
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
|
|
579
|
+
cuda_state_was_set = CUDA_VISIBLE_DEVICES in os.environ
|
|
580
|
+
old_cuda_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES, "")
|
|
581
|
+
try:
|
|
582
|
+
# Only replace the GPUs with 'SparkContext.resources' in legacy mode.
|
|
583
|
+
# In connect mode, this replacement is skipped since only GPUs on the client side
|
|
584
|
+
# can be used.
|
|
585
|
+
if self.use_gpu and not self.is_remote:
|
|
586
|
+
gpus_owned = _get_gpus_owned(self.spark)
|
|
587
|
+
random.seed(hash(train_object))
|
|
588
|
+
selected_gpus = [str(e) for e in random.sample(gpus_owned, self.num_processes)]
|
|
589
|
+
os.environ[CUDA_VISIBLE_DEVICES] = ",".join(selected_gpus)
|
|
590
|
+
|
|
591
|
+
self.logger.info(f"Started local training with {self.num_processes} processes")
|
|
592
|
+
output = TorchDistributor._get_output_from_framework_wrapper(
|
|
593
|
+
framework_wrapper_fn,
|
|
594
|
+
self.input_params,
|
|
595
|
+
train_object,
|
|
596
|
+
run_pytorch_file_fn,
|
|
597
|
+
*args,
|
|
598
|
+
**kwargs,
|
|
599
|
+
)
|
|
600
|
+
self.logger.info(f"Finished local training with {self.num_processes} processes")
|
|
601
|
+
|
|
602
|
+
finally:
|
|
603
|
+
if cuda_state_was_set:
|
|
604
|
+
os.environ[CUDA_VISIBLE_DEVICES] = old_cuda_visible_devices
|
|
605
|
+
else:
|
|
606
|
+
if CUDA_VISIBLE_DEVICES in os.environ:
|
|
607
|
+
del os.environ[CUDA_VISIBLE_DEVICES]
|
|
608
|
+
|
|
609
|
+
return output
|
|
610
|
+
|
|
611
|
+
def _get_spark_task_function(
|
|
612
|
+
self,
|
|
613
|
+
framework_wrapper_fn: Optional[Callable],
|
|
614
|
+
train_object: Union[Callable, str],
|
|
615
|
+
run_pytorch_file_fn: Optional[Callable],
|
|
616
|
+
input_dataframe: Optional["DataFrame"],
|
|
617
|
+
*args: Any,
|
|
618
|
+
**kwargs: Any,
|
|
619
|
+
) -> Callable:
|
|
620
|
+
"""Creates a spark task function that is used inside `mapPartitions`.
|
|
621
|
+
|
|
622
|
+
Parameters
|
|
623
|
+
----------
|
|
624
|
+
framework_wrapper_fn : Optional[Callable]
|
|
625
|
+
The function that determines whether we are running training
|
|
626
|
+
on a PyTorch file or a PyTorch function.
|
|
627
|
+
train_object : Union[Callable, str]
|
|
628
|
+
The actual train function/file.
|
|
629
|
+
|
|
630
|
+
Returns
|
|
631
|
+
-------
|
|
632
|
+
Callable
|
|
633
|
+
The wrapped function ready for use with `mapPartitions`
|
|
634
|
+
"""
|
|
635
|
+
num_processes = self.num_processes
|
|
636
|
+
use_gpu = self.use_gpu
|
|
637
|
+
input_params = self.input_params
|
|
638
|
+
driver_address = self.driver_address
|
|
639
|
+
log_streaming_server_port = self.log_streaming_server_port
|
|
640
|
+
is_spark_local_master = self.is_spark_local_master
|
|
641
|
+
driver_owned_gpus: List[str] = []
|
|
642
|
+
if is_spark_local_master and use_gpu:
|
|
643
|
+
driver_owned_gpus = _get_gpus_owned(self.spark)
|
|
644
|
+
|
|
645
|
+
if input_dataframe is not None:
|
|
646
|
+
schema_json = input_dataframe.schema.jsonValue()
|
|
647
|
+
else:
|
|
648
|
+
schema_json = None
|
|
649
|
+
|
|
650
|
+
# Spark task program
|
|
651
|
+
def wrapped_train_fn(iterator): # type: ignore[no-untyped-def]
|
|
652
|
+
import os
|
|
653
|
+
import pandas as pd
|
|
654
|
+
import pyarrow
|
|
655
|
+
from pyspark import BarrierTaskContext
|
|
656
|
+
|
|
657
|
+
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
|
|
658
|
+
|
|
659
|
+
def get_free_port(address: str, context: "BarrierTaskContext") -> int:
|
|
660
|
+
port = ""
|
|
661
|
+
if context.partitionId() == 0:
|
|
662
|
+
try:
|
|
663
|
+
import socket
|
|
664
|
+
|
|
665
|
+
sock = socket.socket()
|
|
666
|
+
sock.bind((address, 0))
|
|
667
|
+
port = sock.getsockname()[1]
|
|
668
|
+
except socket.error:
|
|
669
|
+
pass
|
|
670
|
+
available_port = context.allGather(str(port))[0]
|
|
671
|
+
if not available_port:
|
|
672
|
+
raise RuntimeError("Failed to find free port for distributed training.")
|
|
673
|
+
return int(available_port)
|
|
674
|
+
|
|
675
|
+
def set_torch_config(context: "BarrierTaskContext") -> None:
|
|
676
|
+
addrs = [e.address.split(":")[0] for e in context.getTaskInfos()]
|
|
677
|
+
|
|
678
|
+
os.environ["MASTER_ADDR"] = str(addrs[0])
|
|
679
|
+
os.environ["MASTER_PORT"] = str(get_free_port(addrs[0], context))
|
|
680
|
+
os.environ["WORLD_SIZE"] = str(num_processes)
|
|
681
|
+
os.environ["NODE_RANK"] = str(context.partitionId())
|
|
682
|
+
os.environ["RANK"] = str(context.partitionId())
|
|
683
|
+
|
|
684
|
+
if context.partitionId() >= num_processes:
|
|
685
|
+
raise ValueError(
|
|
686
|
+
"TorchDistributor._train_on_dataframe requires setting num_processes "
|
|
687
|
+
"equal to input spark dataframe partition number."
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
if is_spark_local_master:
|
|
691
|
+
# distributed training on a local mode spark cluster
|
|
692
|
+
def set_gpus(context: "BarrierTaskContext") -> None:
|
|
693
|
+
if CUDA_VISIBLE_DEVICES in os.environ:
|
|
694
|
+
return
|
|
695
|
+
|
|
696
|
+
gpu_owned = driver_owned_gpus[context.partitionId()]
|
|
697
|
+
os.environ[CUDA_VISIBLE_DEVICES] = gpu_owned
|
|
698
|
+
|
|
699
|
+
else:
|
|
700
|
+
|
|
701
|
+
def set_gpus(context: "BarrierTaskContext") -> None:
|
|
702
|
+
if CUDA_VISIBLE_DEVICES in os.environ:
|
|
703
|
+
return
|
|
704
|
+
|
|
705
|
+
gpus_owned = _get_gpus_owned(context)
|
|
706
|
+
os.environ[CUDA_VISIBLE_DEVICES] = ",".join(gpus_owned)
|
|
707
|
+
|
|
708
|
+
context = BarrierTaskContext.get()
|
|
709
|
+
|
|
710
|
+
if use_gpu:
|
|
711
|
+
set_gpus(context)
|
|
712
|
+
else:
|
|
713
|
+
os.environ[CUDA_VISIBLE_DEVICES] = ""
|
|
714
|
+
set_torch_config(context)
|
|
715
|
+
|
|
716
|
+
log_streaming_client = LogStreamingClient(driver_address, log_streaming_server_port)
|
|
717
|
+
input_params["log_streaming_client"] = log_streaming_client
|
|
718
|
+
try:
|
|
719
|
+
with TorchDistributor._setup_spark_partition_data(iterator, schema_json):
|
|
720
|
+
output = TorchDistributor._get_output_from_framework_wrapper(
|
|
721
|
+
framework_wrapper_fn,
|
|
722
|
+
input_params,
|
|
723
|
+
train_object,
|
|
724
|
+
run_pytorch_file_fn,
|
|
725
|
+
*args,
|
|
726
|
+
**kwargs,
|
|
727
|
+
)
|
|
728
|
+
finally:
|
|
729
|
+
try:
|
|
730
|
+
LogStreamingClient._destroy()
|
|
731
|
+
except BaseException:
|
|
732
|
+
pass
|
|
733
|
+
|
|
734
|
+
if context.partitionId() == 0:
|
|
735
|
+
output_bytes = cloudpickle.dumps(output)
|
|
736
|
+
output_size = len(output_bytes)
|
|
737
|
+
|
|
738
|
+
# In Spark Connect, DataFrame.collect stacks rows to size
|
|
739
|
+
# 'spark.connect.grpc.arrow.maxBatchSize' (default 4MiB),
|
|
740
|
+
# here use 4KiB for each chunk, which mean each arrow batch
|
|
741
|
+
# may contain about 1000 chunks.
|
|
742
|
+
chunks = []
|
|
743
|
+
chunk_size = 4096
|
|
744
|
+
index = 0
|
|
745
|
+
while index < output_size:
|
|
746
|
+
chunks.append(output_bytes[index : index + chunk_size])
|
|
747
|
+
index += chunk_size
|
|
748
|
+
|
|
749
|
+
yield pyarrow.RecordBatch.from_pandas(pd.DataFrame(data={"chunk": chunks}))
|
|
750
|
+
|
|
751
|
+
return wrapped_train_fn
|
|
752
|
+
|
|
753
|
+
def _run_distributed_training(
|
|
754
|
+
self,
|
|
755
|
+
framework_wrapper_fn: Callable,
|
|
756
|
+
train_object: Union[Callable, str],
|
|
757
|
+
run_pytorch_file_fn: Optional[Callable],
|
|
758
|
+
spark_dataframe: Optional["DataFrame"],
|
|
759
|
+
*args: Any,
|
|
760
|
+
**kwargs: Any,
|
|
761
|
+
) -> Optional[Any]:
|
|
762
|
+
if not framework_wrapper_fn:
|
|
763
|
+
raise RuntimeError("Unknown combination of parameters")
|
|
764
|
+
|
|
765
|
+
log_streaming_server = LogStreamingServer()
|
|
766
|
+
self.driver_address = _get_conf(self.spark, "spark.driver.host", "")
|
|
767
|
+
assert self.driver_address != ""
|
|
768
|
+
try:
|
|
769
|
+
log_streaming_server.start(spark_host_address=self.driver_address)
|
|
770
|
+
time.sleep(1) # wait for the server to start
|
|
771
|
+
self.log_streaming_server_port = log_streaming_server.port
|
|
772
|
+
except Exception as e:
|
|
773
|
+
# If starting log streaming server failed, we don't need to break
|
|
774
|
+
# the distributor training but emit a warning instead.
|
|
775
|
+
self.log_streaming_server_port = -1
|
|
776
|
+
self.logger.warning(
|
|
777
|
+
"Start torch distributor log streaming server failed, "
|
|
778
|
+
"You cannot receive logs sent from distributor workers, ",
|
|
779
|
+
f"error: {repr(e)}.",
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
try:
|
|
783
|
+
spark_task_function = self._get_spark_task_function(
|
|
784
|
+
framework_wrapper_fn,
|
|
785
|
+
train_object,
|
|
786
|
+
run_pytorch_file_fn,
|
|
787
|
+
spark_dataframe,
|
|
788
|
+
*args,
|
|
789
|
+
**kwargs,
|
|
790
|
+
)
|
|
791
|
+
self._check_encryption()
|
|
792
|
+
self.logger.info(
|
|
793
|
+
f"Started distributed training with {self.num_processes} executor processes"
|
|
794
|
+
)
|
|
795
|
+
if spark_dataframe is not None:
|
|
796
|
+
input_df = spark_dataframe
|
|
797
|
+
else:
|
|
798
|
+
input_df = self.spark.range(
|
|
799
|
+
start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks
|
|
800
|
+
)
|
|
801
|
+
rows = input_df.mapInArrow(
|
|
802
|
+
func=spark_task_function, schema="chunk binary", barrier=True
|
|
803
|
+
).collect()
|
|
804
|
+
output_bytes = b"".join([row.chunk for row in rows])
|
|
805
|
+
result = cloudpickle.loads(output_bytes)
|
|
806
|
+
finally:
|
|
807
|
+
log_streaming_server.shutdown()
|
|
808
|
+
self.logger.info(
|
|
809
|
+
f"Finished distributed training with {self.num_processes} executor processes"
|
|
810
|
+
)
|
|
811
|
+
return result
|
|
812
|
+
|
|
813
|
+
@staticmethod
|
|
814
|
+
def _run_training_on_pytorch_file(
|
|
815
|
+
input_params: Dict[str, Any], train_path: str, *args: Any, **kwargs: Any
|
|
816
|
+
) -> None:
|
|
817
|
+
if kwargs:
|
|
818
|
+
raise ValueError("Running pytorch file does not support key-word type arguments.")
|
|
819
|
+
log_streaming_client = input_params.get("log_streaming_client", None)
|
|
820
|
+
training_command = TorchDistributor._create_torchrun_command(
|
|
821
|
+
input_params, train_path, *args
|
|
822
|
+
)
|
|
823
|
+
TorchDistributor._execute_command(
|
|
824
|
+
training_command, log_streaming_client=log_streaming_client
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
@staticmethod
|
|
828
|
+
@contextmanager
|
|
829
|
+
def _setup_files(
|
|
830
|
+
train_fn: Callable, *args: Any, **kwargs: Any
|
|
831
|
+
) -> Generator[Tuple[str, str], None, None]:
|
|
832
|
+
save_dir = TorchDistributor._create_save_dir()
|
|
833
|
+
pickle_file_path = TorchDistributor._save_pickled_function(
|
|
834
|
+
save_dir, train_fn, *args, **kwargs
|
|
835
|
+
)
|
|
836
|
+
output_file_path = os.path.join(save_dir, TorchDistributor._PICKLED_OUTPUT_FILE)
|
|
837
|
+
train_file_path = TorchDistributor._create_torchrun_train_file(
|
|
838
|
+
save_dir, pickle_file_path, output_file_path
|
|
839
|
+
)
|
|
840
|
+
try:
|
|
841
|
+
yield (train_file_path, output_file_path)
|
|
842
|
+
finally:
|
|
843
|
+
TorchDistributor._cleanup_files(save_dir)
|
|
844
|
+
|
|
845
|
+
@staticmethod
|
|
846
|
+
@contextmanager
|
|
847
|
+
def _setup_spark_partition_data(
|
|
848
|
+
partition_data_iterator: Iterator[Any], input_schema_json: Dict[str, Any]
|
|
849
|
+
) -> Iterator[Any]:
|
|
850
|
+
from pyspark.sql.pandas.serializers import ArrowStreamSerializer
|
|
851
|
+
from pyspark.files import SparkFiles
|
|
852
|
+
import json
|
|
853
|
+
|
|
854
|
+
if input_schema_json is None:
|
|
855
|
+
yield
|
|
856
|
+
return
|
|
857
|
+
|
|
858
|
+
# We need to temporarily write partition data into a temp dir,
|
|
859
|
+
# partition data might be huge, so we need to write it under
|
|
860
|
+
# configured `SPARK_LOCAL_DIRS`.
|
|
861
|
+
save_dir = TorchDistributor._create_save_dir(root_dir=SparkFiles.getRootDirectory())
|
|
862
|
+
|
|
863
|
+
try:
|
|
864
|
+
serializer = ArrowStreamSerializer()
|
|
865
|
+
arrow_file_path = os.path.join(save_dir, "data.arrow")
|
|
866
|
+
with open(arrow_file_path, "wb") as f:
|
|
867
|
+
serializer.dump_stream(partition_data_iterator, f)
|
|
868
|
+
if f.tell() == 0:
|
|
869
|
+
# Nothing is written to file, this partition is empty
|
|
870
|
+
raise ValueError(
|
|
871
|
+
"Empty Spark partition is not allowed in "
|
|
872
|
+
"TorchDistributor.train_on_dataframe."
|
|
873
|
+
)
|
|
874
|
+
|
|
875
|
+
schema_file_path = os.path.join(save_dir, "schema.json")
|
|
876
|
+
schema_json_string = json.dumps(input_schema_json)
|
|
877
|
+
|
|
878
|
+
with open(schema_file_path, "w") as f:
|
|
879
|
+
f.write(schema_json_string)
|
|
880
|
+
|
|
881
|
+
os.environ[SPARK_PARTITION_ARROW_DATA_FILE] = arrow_file_path
|
|
882
|
+
os.environ[SPARK_DATAFRAME_SCHEMA_FILE] = schema_file_path
|
|
883
|
+
yield
|
|
884
|
+
finally:
|
|
885
|
+
os.environ.pop(SPARK_PARTITION_ARROW_DATA_FILE)
|
|
886
|
+
os.environ.pop(SPARK_DATAFRAME_SCHEMA_FILE)
|
|
887
|
+
TorchDistributor._cleanup_files(save_dir)
|
|
888
|
+
|
|
889
|
+
@staticmethod
|
|
890
|
+
def _run_training_on_pytorch_function(
|
|
891
|
+
input_params: Dict[str, Any],
|
|
892
|
+
train_fn: Callable,
|
|
893
|
+
run_pytorch_file_fn: Optional[Callable],
|
|
894
|
+
*args: Any,
|
|
895
|
+
**kwargs: Any,
|
|
896
|
+
) -> Any:
|
|
897
|
+
|
|
898
|
+
if not run_pytorch_file_fn:
|
|
899
|
+
run_pytorch_file_fn = TorchDistributor._run_training_on_pytorch_file
|
|
900
|
+
|
|
901
|
+
with TorchDistributor._setup_files(train_fn, *args, **kwargs) as (
|
|
902
|
+
train_file_path,
|
|
903
|
+
output_file_path,
|
|
904
|
+
):
|
|
905
|
+
run_pytorch_file_fn(input_params, train_file_path)
|
|
906
|
+
if not os.path.exists(output_file_path):
|
|
907
|
+
raise RuntimeError(
|
|
908
|
+
"TorchDistributor failed during training."
|
|
909
|
+
"View stdout logs for detailed error message."
|
|
910
|
+
)
|
|
911
|
+
try:
|
|
912
|
+
output = TorchDistributor._get_pickled_output(output_file_path)
|
|
913
|
+
except Exception as e:
|
|
914
|
+
raise RuntimeError(
|
|
915
|
+
"TorchDistributor failed due to a pickling error. "
|
|
916
|
+
"View stdout logs for detailed error message."
|
|
917
|
+
) from e
|
|
918
|
+
return output
|
|
919
|
+
|
|
920
|
+
@staticmethod
|
|
921
|
+
def _create_save_dir(root_dir: Optional[str] = None) -> str:
|
|
922
|
+
# TODO: need to do this in a safe way to avoid issues during concurrent runs
|
|
923
|
+
return tempfile.mkdtemp(dir=root_dir)
|
|
924
|
+
|
|
925
|
+
@staticmethod
|
|
926
|
+
def _cleanup_files(save_dir: str) -> None:
|
|
927
|
+
shutil.rmtree(save_dir, ignore_errors=True)
|
|
928
|
+
|
|
929
|
+
@staticmethod
|
|
930
|
+
def _save_pickled_function(
|
|
931
|
+
save_dir: str, train_fn: Union[str, Callable], *args: Any, **kwargs: Any
|
|
932
|
+
) -> str:
|
|
933
|
+
saved_pickle_path = os.path.join(save_dir, TorchDistributor._PICKLED_FUNC_FILE)
|
|
934
|
+
with open(saved_pickle_path, "wb") as f:
|
|
935
|
+
cloudpickle.dump((train_fn, args, kwargs), f)
|
|
936
|
+
return saved_pickle_path
|
|
937
|
+
|
|
938
|
+
@staticmethod
|
|
939
|
+
def _create_torchrun_train_file(
|
|
940
|
+
save_dir_path: str, pickle_file_path: str, output_file_path: str
|
|
941
|
+
) -> str:
|
|
942
|
+
code = textwrap.dedent(
|
|
943
|
+
f"""
|
|
944
|
+
from pyspark import cloudpickle
|
|
945
|
+
import os
|
|
946
|
+
|
|
947
|
+
if __name__ == "__main__":
|
|
948
|
+
with open("{pickle_file_path}", "rb") as f:
|
|
949
|
+
train_fn, args, kwargs = cloudpickle.load(f)
|
|
950
|
+
output = train_fn(*args, **kwargs)
|
|
951
|
+
with open("{output_file_path}", "wb") as f:
|
|
952
|
+
cloudpickle.dump(output, f)
|
|
953
|
+
"""
|
|
954
|
+
)
|
|
955
|
+
saved_file_path = os.path.join(save_dir_path, TorchDistributor._TRAIN_FILE)
|
|
956
|
+
with open(saved_file_path, "w") as f:
|
|
957
|
+
f.write(code)
|
|
958
|
+
return saved_file_path
|
|
959
|
+
|
|
960
|
+
@staticmethod
|
|
961
|
+
def _get_pickled_output(output_file_path: str) -> Any:
|
|
962
|
+
with open(output_file_path, "rb") as f:
|
|
963
|
+
output = cloudpickle.load(f)
|
|
964
|
+
return output
|
|
965
|
+
|
|
966
|
+
def run(self, train_object: Union[Callable, str], *args: Any, **kwargs: Any) -> Optional[Any]:
|
|
967
|
+
"""Runs distributed training.
|
|
968
|
+
|
|
969
|
+
Parameters
|
|
970
|
+
----------
|
|
971
|
+
train_object : callable object or str
|
|
972
|
+
Either a PyTorch function, PyTorch Lightning function, or the path to a python file
|
|
973
|
+
that launches distributed training.
|
|
974
|
+
args :
|
|
975
|
+
If train_object is a python function and not a path to a python file, args need
|
|
976
|
+
to be the input parameters to that function. It would look like
|
|
977
|
+
|
|
978
|
+
>>> model = distributor.run(train, 1e-3, 64)
|
|
979
|
+
|
|
980
|
+
where train is a function and 1e-3 and 64 are regular numeric inputs to the function.
|
|
981
|
+
|
|
982
|
+
If train_object is a python file, then args would be the command-line arguments for
|
|
983
|
+
that python file which are all in the form of strings. An example would be
|
|
984
|
+
|
|
985
|
+
>>> distributor.run("/path/to/train.py", "--learning-rate=1e-3", "--batch-size=64")
|
|
986
|
+
|
|
987
|
+
where since the input is a path, all of the parameters are strings that can be
|
|
988
|
+
handled by argparse in that python file.
|
|
989
|
+
kwargs :
|
|
990
|
+
If train_object is a python function and not a path to a python file, kwargs need
|
|
991
|
+
to be the key-word input parameters to that function. It would look like
|
|
992
|
+
|
|
993
|
+
>>> model = distributor.run(train, tol=1e-3, max_iter=64)
|
|
994
|
+
|
|
995
|
+
where train is a function of 2 arguments `tol` and `max_iter`.
|
|
996
|
+
|
|
997
|
+
If train_object is a python file, then you should not set kwargs arguments.
|
|
998
|
+
|
|
999
|
+
Returns
|
|
1000
|
+
-------
|
|
1001
|
+
Returns the output of train_object called with args inside spark rank 0 task if the
|
|
1002
|
+
train_object is a Callable with an expected output. Returns None if train_object is
|
|
1003
|
+
a file.
|
|
1004
|
+
"""
|
|
1005
|
+
return self._run(
|
|
1006
|
+
train_object, TorchDistributor._run_training_on_pytorch_file, *args, **kwargs
|
|
1007
|
+
)
|
|
1008
|
+
|
|
1009
|
+
def _run(
|
|
1010
|
+
self,
|
|
1011
|
+
train_object: Union[Callable, str],
|
|
1012
|
+
run_pytorch_file_fn: Callable,
|
|
1013
|
+
*args: Any,
|
|
1014
|
+
**kwargs: Any,
|
|
1015
|
+
) -> Optional[Any]:
|
|
1016
|
+
if isinstance(train_object, str):
|
|
1017
|
+
framework_wrapper_fn = run_pytorch_file_fn
|
|
1018
|
+
else:
|
|
1019
|
+
framework_wrapper_fn = TorchDistributor._run_training_on_pytorch_function
|
|
1020
|
+
if self.local_mode:
|
|
1021
|
+
output = self._run_local_training(
|
|
1022
|
+
framework_wrapper_fn, train_object, run_pytorch_file_fn, *args, **kwargs
|
|
1023
|
+
)
|
|
1024
|
+
else:
|
|
1025
|
+
output = self._run_distributed_training(
|
|
1026
|
+
framework_wrapper_fn, train_object, run_pytorch_file_fn, None, *args, **kwargs
|
|
1027
|
+
)
|
|
1028
|
+
return output
|
|
1029
|
+
|
|
1030
|
+
def _train_on_dataframe(
|
|
1031
|
+
self,
|
|
1032
|
+
train_function: Callable,
|
|
1033
|
+
spark_dataframe: "DataFrame",
|
|
1034
|
+
*args: Any,
|
|
1035
|
+
**kwargs: Any,
|
|
1036
|
+
) -> Any:
|
|
1037
|
+
"""
|
|
1038
|
+
Runs distributed training using provided Spark DataFrame as input data.
|
|
1039
|
+
You should ensure the input Spark DataFrame have evenly distributed partitions,
|
|
1040
|
+
and this method starts a barrier Spark job that each Spark task in the job
|
|
1041
|
+
process one partition of the input Spark DataFrame.
|
|
1042
|
+
|
|
1043
|
+
Parameters
|
|
1044
|
+
----------
|
|
1045
|
+
train_function :
|
|
1046
|
+
Either a PyTorch function, PyTorch Lightning function that launches distributed
|
|
1047
|
+
training. Note that inside the function, you can call
|
|
1048
|
+
`pyspark.ml.torch.distributor.get_spark_partition_data_loader` API to get a torch
|
|
1049
|
+
data loader, the data loader loads data from the corresponding partition of the
|
|
1050
|
+
input Spark DataFrame.
|
|
1051
|
+
spark_dataframe :
|
|
1052
|
+
An input Spark DataFrame that can be used in PyTorch `train_function` function.
|
|
1053
|
+
See `train_function` argument doc for details.
|
|
1054
|
+
args :
|
|
1055
|
+
`args` need to be the input parameters to `train_function` function. It would look like
|
|
1056
|
+
|
|
1057
|
+
>>> model = distributor.run(train, 1e-3, 64)
|
|
1058
|
+
|
|
1059
|
+
where train is a function and 1e-3 and 64 are regular numeric inputs to the function.
|
|
1060
|
+
kwargs :
|
|
1061
|
+
`kwargs` need to be the key-word input parameters to `train_function` function.
|
|
1062
|
+
It would look like
|
|
1063
|
+
|
|
1064
|
+
>>> model = distributor.run(train, tol=1e-3, max_iter=64)
|
|
1065
|
+
|
|
1066
|
+
where train is a function of 2 arguments `tol` and `max_iter`.
|
|
1067
|
+
|
|
1068
|
+
Returns
|
|
1069
|
+
-------
|
|
1070
|
+
Returns the output of `train_function` called with args inside Spark rank 0 task.
|
|
1071
|
+
"""
|
|
1072
|
+
|
|
1073
|
+
if self.local_mode:
|
|
1074
|
+
raise ValueError(
|
|
1075
|
+
"TorchDistributor.train_on_dataframe requires setting "
|
|
1076
|
+
"TorchDistributor.local_mode to False."
|
|
1077
|
+
)
|
|
1078
|
+
|
|
1079
|
+
return self._run_distributed_training(
|
|
1080
|
+
TorchDistributor._run_training_on_pytorch_function,
|
|
1081
|
+
train_function,
|
|
1082
|
+
TorchDistributor._run_training_on_pytorch_file,
|
|
1083
|
+
spark_dataframe,
|
|
1084
|
+
*args,
|
|
1085
|
+
**kwargs,
|
|
1086
|
+
)
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
def _get_spark_partition_data_loader(
|
|
1090
|
+
num_samples: int, batch_size: int, num_workers: int = 1, prefetch_factor: int = 2
|
|
1091
|
+
) -> Any:
|
|
1092
|
+
"""
|
|
1093
|
+
This function must be called inside the `train_function` where `train_function`
|
|
1094
|
+
is the input argument of `TorchDistributor.train_on_dataframe`.
|
|
1095
|
+
The function returns a pytorch data loader that loads data from
|
|
1096
|
+
the corresponding spark partition data.
|
|
1097
|
+
|
|
1098
|
+
Parameters
|
|
1099
|
+
----------
|
|
1100
|
+
num_samples :
|
|
1101
|
+
Number of samples to generate per epoch. If `num_samples` is less than the number of
|
|
1102
|
+
rows in the spark partition, it generate the first `num_samples` rows of
|
|
1103
|
+
the spark partition, if `num_samples` is greater than the number of
|
|
1104
|
+
rows in the spark partition, then after the iterator loaded all rows from the partition,
|
|
1105
|
+
it wraps round back to the first row.
|
|
1106
|
+
batch_size:
|
|
1107
|
+
How many samples per batch to load.
|
|
1108
|
+
num_workers:
|
|
1109
|
+
How many subprocesses to use for data loading.
|
|
1110
|
+
0 means that the data will be loaded in the main process.
|
|
1111
|
+
prefetch_factor:
|
|
1112
|
+
Number of batches loaded in advance by each worker
|
|
1113
|
+
"""
|
|
1114
|
+
from pyspark.sql.types import StructType
|
|
1115
|
+
from pyspark.ml.torch.data import _SparkPartitionTorchDataset
|
|
1116
|
+
from torch.utils.data import DataLoader
|
|
1117
|
+
|
|
1118
|
+
arrow_file = os.environ[SPARK_PARTITION_ARROW_DATA_FILE]
|
|
1119
|
+
schema_file = os.environ[SPARK_DATAFRAME_SCHEMA_FILE]
|
|
1120
|
+
|
|
1121
|
+
with open(schema_file, "r") as fp:
|
|
1122
|
+
schema = StructType.fromJson(json.load(fp))
|
|
1123
|
+
|
|
1124
|
+
dataset = _SparkPartitionTorchDataset(arrow_file, schema, num_samples)
|
|
1125
|
+
|
|
1126
|
+
if num_workers > 0:
|
|
1127
|
+
return DataLoader(
|
|
1128
|
+
dataset, batch_size, num_workers=num_workers, prefetch_factor=prefetch_factor
|
|
1129
|
+
)
|
|
1130
|
+
else:
|
|
1131
|
+
# if num_workers is zero, we cannot set `prefetch_factor` otherwise
|
|
1132
|
+
# torch will raise error.
|
|
1133
|
+
return DataLoader(dataset, batch_size, num_workers=num_workers)
|