snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +18 -16
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +10 -10
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +45 -46
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +18 -29
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +10 -5
- snowflake/ml/jobs/job.py +87 -30
- snowflake/ml/jobs/manager.py +86 -56
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +217 -32
- snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +17 -26
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +58 -32
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
- snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +24 -11
- snowflake/ml/model/model_signature.py +12 -23
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/monitoring/explain_visualize.py +286 -0
- snowflake/ml/registry/_manager/model_manager.py +55 -32
- snowflake/ml/registry/registry.py +39 -31
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
- snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
snowflake/ml/fileset/sfcfs.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import collections
|
2
2
|
import logging
|
3
3
|
from functools import partial
|
4
|
-
from typing import Any, Callable,
|
4
|
+
from typing import Any, Callable, Optional, Union, cast
|
5
5
|
|
6
6
|
import fsspec
|
7
7
|
|
@@ -100,7 +100,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
100
100
|
raise ValueError("Either sf_connection or snowpark_session has to be non-empty!")
|
101
101
|
self._conn = self._session._conn._conn # Telemetry wrappers expect connection under `conn_attr_name="_conn"``
|
102
102
|
self._kwargs = kwargs
|
103
|
-
self._stage_fs_set:
|
103
|
+
self._stage_fs_set: dict[tuple[str, str, str], stage_fs.SFStageFileSystem] = {}
|
104
104
|
|
105
105
|
super().__init__(**kwargs)
|
106
106
|
|
@@ -133,7 +133,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
133
133
|
assert isinstance(session, snowpark.Session)
|
134
134
|
return session
|
135
135
|
|
136
|
-
def __reduce__(self) ->
|
136
|
+
def __reduce__(self) -> tuple[Callable[[], type["SFFileSystem"]], tuple[()], dict[str, Any]]:
|
137
137
|
"""Returns a state dictionary for use in serialization.
|
138
138
|
|
139
139
|
Returns:
|
@@ -145,7 +145,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
145
145
|
|
146
146
|
return partial(self.__class__, **{_RECREATE_FROM_SERIALIZED: True}), (), state_dictionary
|
147
147
|
|
148
|
-
def __setstate__(self, state_dict:
|
148
|
+
def __setstate__(self, state_dict: dict[str, Any]) -> None:
|
149
149
|
"""Sets the dictionary state at deserialization time, and rebuilds a snowflake connection.
|
150
150
|
|
151
151
|
Args:
|
@@ -191,7 +191,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
191
191
|
func_params_to_log=["detail"],
|
192
192
|
conn_attr_name="_conn",
|
193
193
|
)
|
194
|
-
def ls(self, path: str, detail: bool = False, **kwargs: Any) -> Union[
|
194
|
+
def ls(self, path: str, detail: bool = False, **kwargs: Any) -> Union[list[str], list[dict[str, Any]]]:
|
195
195
|
"""Override fsspec `ls` method. List single "directory" with or without details.
|
196
196
|
|
197
197
|
Args:
|
@@ -214,14 +214,14 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
214
214
|
file_path = self._parse_file_path(path)
|
215
215
|
stage_fs = self._get_stage_fs(file_path)
|
216
216
|
stage_path_list = stage_fs.ls(file_path.filepath, detail=True, **kwargs)
|
217
|
-
stage_path_list = cast(
|
217
|
+
stage_path_list = cast(list[dict[str, Any]], stage_path_list)
|
218
218
|
return self._decorate_ls_res(stage_fs, stage_path_list, detail)
|
219
219
|
|
220
220
|
@telemetry.send_api_usage_telemetry(
|
221
221
|
project=_PROJECT,
|
222
222
|
conn_attr_name="_conn",
|
223
223
|
)
|
224
|
-
def optimize_read(self, files: Optional[
|
224
|
+
def optimize_read(self, files: Optional[list[str]] = None) -> None:
|
225
225
|
"""Prefetch and cache the presigned urls for all the given files to speed up the file opening.
|
226
226
|
|
227
227
|
All the files introduced here will have their urls cached. Further open() on any of cached urls will lead to a
|
@@ -232,8 +232,8 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
232
232
|
"""
|
233
233
|
if not files:
|
234
234
|
return
|
235
|
-
stage_fs_dict:
|
236
|
-
stage_file_paths:
|
235
|
+
stage_fs_dict: dict[str, stage_fs.SFStageFileSystem] = {}
|
236
|
+
stage_file_paths: dict[str, list[str]] = collections.defaultdict(list)
|
237
237
|
for file in files:
|
238
238
|
path_info = self._parse_file_path(file)
|
239
239
|
fs = self._get_stage_fs(path_info)
|
@@ -271,11 +271,11 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
271
271
|
project=_PROJECT,
|
272
272
|
conn_attr_name="_conn",
|
273
273
|
)
|
274
|
-
def info(self, path: str, **kwargs: Any) ->
|
274
|
+
def info(self, path: str, **kwargs: Any) -> dict[str, Any]:
|
275
275
|
"""Override fsspec `info` method. Give details of entry at path."""
|
276
276
|
file_path = self._parse_file_path(path)
|
277
277
|
stage_fs = self._get_stage_fs(file_path)
|
278
|
-
res:
|
278
|
+
res: dict[str, Any] = stage_fs.info(file_path.filepath, **kwargs)
|
279
279
|
if res:
|
280
280
|
res["name"] = self._stage_path_to_absolute_path(stage_fs, res["name"])
|
281
281
|
return res
|
@@ -283,9 +283,9 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
283
283
|
def _decorate_ls_res(
|
284
284
|
self,
|
285
285
|
stage_fs: stage_fs.SFStageFileSystem,
|
286
|
-
stage_path_list:
|
286
|
+
stage_path_list: list[dict[str, Any]],
|
287
287
|
detail: bool,
|
288
|
-
) -> Union[
|
288
|
+
) -> Union[list[str], list[dict[str, Any]]]:
|
289
289
|
"""Add the stage location as the prefix of file names returned by ls() of stagefs"""
|
290
290
|
for path in stage_path_list:
|
291
291
|
path["name"] = self._stage_path_to_absolute_path(stage_fs, path["name"])
|
snowflake/ml/fileset/stage_fs.py
CHANGED
@@ -2,7 +2,7 @@ import inspect
|
|
2
2
|
import logging
|
3
3
|
import time
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from typing import Any,
|
5
|
+
from typing import Any, Optional, Union, cast
|
6
6
|
|
7
7
|
import fsspec
|
8
8
|
from fsspec.implementations import http as httpfs
|
@@ -44,7 +44,7 @@ class _PresignedUrl:
|
|
44
44
|
return not self.expire_at or time.time() > self.expire_at - headroom_sec
|
45
45
|
|
46
46
|
|
47
|
-
def _get_httpfs_kwargs(**kwargs: Any) ->
|
47
|
+
def _get_httpfs_kwargs(**kwargs: Any) -> dict[str, Any]:
|
48
48
|
"""Extract kwargs that are meaningful to HTTPFileSystem."""
|
49
49
|
httpfs_related_keys = [
|
50
50
|
"block_size",
|
@@ -124,7 +124,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
124
124
|
self._db = db
|
125
125
|
self._schema = schema
|
126
126
|
self._stage = stage
|
127
|
-
self._url_cache:
|
127
|
+
self._url_cache: dict[str, _PresignedUrl] = {}
|
128
128
|
|
129
129
|
httpfs_kwargs = _get_httpfs_kwargs(**kwargs)
|
130
130
|
self._fs = httpfs.HTTPFileSystem(**httpfs_kwargs)
|
@@ -145,7 +145,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
145
145
|
project=_PROJECT,
|
146
146
|
func_params_to_log=["detail"],
|
147
147
|
)
|
148
|
-
def ls(self, path: str, detail: bool = False) -> Union[
|
148
|
+
def ls(self, path: str, detail: bool = False) -> Union[list[str], list[dict[str, Any]]]:
|
149
149
|
"""Override fsspec `ls` method. List single "directory" with or without details.
|
150
150
|
|
151
151
|
Args:
|
@@ -169,7 +169,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
169
169
|
loc = self.stage_name
|
170
170
|
path = path.lstrip("/")
|
171
171
|
async_job: snowpark.AsyncJob = self._session.sql(f"LIST '{loc}/{path}'").collect(block=False)
|
172
|
-
objects:
|
172
|
+
objects: list[snowpark.Row] = _resolve_async_job(async_job)
|
173
173
|
except snowpark_exceptions.SnowparkSQLException as e:
|
174
174
|
if e.sql_error_code == fileset_errors.ERRNO_DOMAIN_NOT_EXIST:
|
175
175
|
raise snowml_exceptions.SnowflakeMLException(
|
@@ -192,7 +192,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
192
192
|
@telemetry.send_api_usage_telemetry(
|
193
193
|
project=_PROJECT,
|
194
194
|
)
|
195
|
-
def optimize_read(self, files: Optional[
|
195
|
+
def optimize_read(self, files: Optional[list[str]] = None) -> None:
|
196
196
|
"""Prefetch and cache the presigned urls for all the given files to speed up the read performance.
|
197
197
|
|
198
198
|
All the files introduced here will have their urls cached. Further open() on any of cached urls will lead to a
|
@@ -271,7 +271,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
271
271
|
original_exception=fileset_errors.StageFileNotFoundError(f"Stage file {path} doesn't exist."),
|
272
272
|
)
|
273
273
|
|
274
|
-
def _open_with_snowpark(self, path: str, **kwargs:
|
274
|
+
def _open_with_snowpark(self, path: str, **kwargs: dict[str, Any]) -> fsspec.spec.AbstractBufferedFile:
|
275
275
|
"""Open the a file for reading using snowflake.snowpark.file_operation
|
276
276
|
|
277
277
|
Args:
|
@@ -299,7 +299,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
299
299
|
original_exception=e,
|
300
300
|
)
|
301
301
|
|
302
|
-
def _parse_list_result(self, list_result:
|
302
|
+
def _parse_list_result(self, list_result: list[snowpark.Row], search_path: str) -> list[dict[str, Any]]:
|
303
303
|
"""Convert the result from LIST query to the expected format of fsspec ls() method.
|
304
304
|
|
305
305
|
Note that Snowflake LIST query has different behavior with ls(). LIST query will return all the stage files
|
@@ -318,7 +318,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
318
318
|
Returns:
|
319
319
|
A list of dict, where each dict contains key-value pairs as the properties of a file.
|
320
320
|
"""
|
321
|
-
files:
|
321
|
+
files: dict[str, dict[str, Any]] = {}
|
322
322
|
search_path = search_path.strip("/")
|
323
323
|
for row in list_result:
|
324
324
|
name, size, md5, last_modified = row["name"], row["size"], row["md5"], row["last_modified"]
|
@@ -360,7 +360,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
360
360
|
|
361
361
|
def _add_file_info_helper(
|
362
362
|
self,
|
363
|
-
files:
|
363
|
+
files: dict[str, dict[str, Any]],
|
364
364
|
object_path: str,
|
365
365
|
file_size: int,
|
366
366
|
file_type: str,
|
@@ -379,12 +379,12 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
379
379
|
)
|
380
380
|
|
381
381
|
def _fetch_presigned_urls(
|
382
|
-
self, files:
|
383
|
-
) ->
|
382
|
+
self, files: list[str], url_lifetime: float = _PRESIGNED_URL_LIFETIME_SEC
|
383
|
+
) -> list[tuple[str, str]]:
|
384
384
|
"""Fetch presigned urls for the given files."""
|
385
385
|
file_df = self._session.create_dataframe(files).to_df("name")
|
386
386
|
try:
|
387
|
-
presigned_urls:
|
387
|
+
presigned_urls: list[tuple[str, str]] = file_df.select_expr(
|
388
388
|
f"name, get_presigned_url('{self.stage_name}', name, {url_lifetime}) as url"
|
389
389
|
).collect(
|
390
390
|
statement_params=telemetry.get_function_usage_statement_params(
|
@@ -418,10 +418,10 @@ def _match_error_code(ex: snowpark_exceptions.SnowparkSQLException, error_code:
|
|
418
418
|
|
419
419
|
|
420
420
|
@snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
|
421
|
-
def _resolve_async_job(async_job: snowpark.AsyncJob) ->
|
421
|
+
def _resolve_async_job(async_job: snowpark.AsyncJob) -> list[snowpark.Row]:
|
422
422
|
# Make sure Snowpark exceptions are properly caught and converted by wrap_exception wrapper
|
423
423
|
try:
|
424
|
-
query_result = cast(
|
424
|
+
query_result = cast(list[snowpark.Row], async_job.result("row"))
|
425
425
|
return query_result
|
426
426
|
except snowpark_errors.DatabaseError as e:
|
427
427
|
# HACK: Snowpark surfaces a generic exception if query doesn't complete immediately
|
@@ -13,7 +13,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
|
13
13
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
14
14
|
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
15
15
|
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
16
|
-
DEFAULT_IMAGE_TAG = "1.
|
16
|
+
DEFAULT_IMAGE_TAG = "1.2.3"
|
17
17
|
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
18
18
|
|
19
19
|
# Percent of container memory to allocate for /dev/shm volume
|
@@ -10,7 +10,7 @@ import traceback
|
|
10
10
|
from collections import namedtuple
|
11
11
|
from dataclasses import dataclass
|
12
12
|
from types import TracebackType
|
13
|
-
from typing import Any, Callable,
|
13
|
+
from typing import Any, Callable, Optional, Union, cast
|
14
14
|
|
15
15
|
from snowflake import snowpark
|
16
16
|
from snowflake.snowpark import exceptions as sp_exceptions
|
@@ -33,7 +33,7 @@ class ExecutionResult:
|
|
33
33
|
def success(self) -> bool:
|
34
34
|
return self.exception is None
|
35
35
|
|
36
|
-
def to_dict(self) ->
|
36
|
+
def to_dict(self) -> dict[str, Any]:
|
37
37
|
"""Return the serializable dictionary."""
|
38
38
|
if isinstance(self.exception, BaseException):
|
39
39
|
exc_type = type(self.exception)
|
@@ -50,7 +50,7 @@ class ExecutionResult:
|
|
50
50
|
}
|
51
51
|
|
52
52
|
@classmethod
|
53
|
-
def from_dict(cls, result_dict:
|
53
|
+
def from_dict(cls, result_dict: dict[str, Any]) -> "ExecutionResult":
|
54
54
|
if not isinstance(result_dict.get("success"), bool):
|
55
55
|
raise ValueError("Invalid result dictionary")
|
56
56
|
|
@@ -242,11 +242,11 @@ def _install_sys_excepthook() -> None:
|
|
242
242
|
original_excepthook = sys.excepthook
|
243
243
|
|
244
244
|
def custom_excepthook(
|
245
|
-
exc_type:
|
245
|
+
exc_type: type[BaseException],
|
246
246
|
exc_value: BaseException,
|
247
247
|
exc_tb: Optional[TracebackType],
|
248
248
|
*,
|
249
|
-
seen_exc_ids: Optional[
|
249
|
+
seen_exc_ids: Optional[set[int]] = None,
|
250
250
|
) -> None:
|
251
251
|
if seen_exc_ids is None:
|
252
252
|
seen_exc_ids = set()
|
@@ -331,7 +331,7 @@ def _install_ipython_hook() -> bool:
|
|
331
331
|
except ImportError:
|
332
332
|
return False
|
333
333
|
|
334
|
-
def parse_traceback_str(traceback_str: str) ->
|
334
|
+
def parse_traceback_str(traceback_str: str) -> list[tuple[str, int, str, str]]:
|
335
335
|
return [
|
336
336
|
(m.group("filename"), int(m.group("lineno")), m.group("name"), m.group("line"))
|
337
337
|
for m in re.finditer(_TRACEBACK_ENTRY_PATTERN, traceback_str)
|
@@ -342,13 +342,13 @@ def _install_ipython_hook() -> bool:
|
|
342
342
|
|
343
343
|
def custom_format_exception_as_a_whole(
|
344
344
|
self: VerboseTB,
|
345
|
-
etype:
|
345
|
+
etype: type[BaseException],
|
346
346
|
evalue: Optional[BaseException],
|
347
347
|
etb: Optional[TracebackType],
|
348
348
|
number_of_lines_of_context: int,
|
349
349
|
tb_offset: Optional[int],
|
350
350
|
**kwargs: Any,
|
351
|
-
) ->
|
351
|
+
) -> list[list[str]]:
|
352
352
|
if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError):
|
353
353
|
# Implementation forked from IPython.core.ultratb.VerboseTB.format_exception_as_a_whole
|
354
354
|
head = self.prepare_header(remote_err.exc_type, long_version=False).replace(
|
@@ -388,7 +388,7 @@ def _install_ipython_hook() -> bool:
|
|
388
388
|
etb: Optional[TracebackType],
|
389
389
|
tb_offset: Optional[int] = None,
|
390
390
|
**kwargs: Any,
|
391
|
-
) ->
|
391
|
+
) -> list[str]:
|
392
392
|
if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError):
|
393
393
|
tb_list = [
|
394
394
|
(m.group("filename"), m.group("lineno"), m.group("name"), m.group("line"))
|
@@ -400,7 +400,7 @@ def _install_ipython_hook() -> bool:
|
|
400
400
|
"(most recent call last)",
|
401
401
|
"(from remote execution)",
|
402
402
|
)
|
403
|
-
return cast(
|
403
|
+
return cast(list[str], out_list)
|
404
404
|
return original_structured_traceback( # type: ignore[no-any-return]
|
405
405
|
self, etype, evalue, etb, tb_offset, **kwargs
|
406
406
|
)
|
@@ -6,19 +6,10 @@ import pickle
|
|
6
6
|
import sys
|
7
7
|
import textwrap
|
8
8
|
from pathlib import Path, PurePath
|
9
|
-
from typing import
|
10
|
-
Any,
|
11
|
-
Callable,
|
12
|
-
List,
|
13
|
-
Optional,
|
14
|
-
Type,
|
15
|
-
Union,
|
16
|
-
cast,
|
17
|
-
get_args,
|
18
|
-
get_origin,
|
19
|
-
)
|
9
|
+
from typing import Any, Callable, Optional, Union, cast, get_args, get_origin
|
20
10
|
|
21
11
|
import cloudpickle as cp
|
12
|
+
from packaging import version
|
22
13
|
|
23
14
|
from snowflake import snowpark
|
24
15
|
from snowflake.ml.jobs._utils import constants, types
|
@@ -107,11 +98,18 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
107
98
|
head_info=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --head)
|
108
99
|
if [ $? -eq 0 ]; then
|
109
100
|
# Parse the output using read
|
110
|
-
read head_index head_ip <<< "$head_info"
|
101
|
+
read head_index head_ip head_status<<< "$head_info"
|
111
102
|
|
112
103
|
# Use the parsed variables
|
113
104
|
echo "Head Instance Index: $head_index"
|
114
105
|
echo "Head Instance IP: $head_ip"
|
106
|
+
echo "Head Instance Status: $head_status"
|
107
|
+
|
108
|
+
# If the head status is not "READY" or "PENDING", exit early
|
109
|
+
if [ "$head_status" != "READY" ] && [ "$head_status" != "PENDING" ]; then
|
110
|
+
echo "Head instance status is not READY or PENDING. Exiting."
|
111
|
+
exit 0
|
112
|
+
fi
|
115
113
|
|
116
114
|
else
|
117
115
|
echo "Error: Failed to get head instance information."
|
@@ -277,7 +275,7 @@ class JobPayload:
|
|
277
275
|
source: Union[str, Path, Callable[..., Any]],
|
278
276
|
entrypoint: Optional[Union[str, Path]] = None,
|
279
277
|
*,
|
280
|
-
pip_requirements: Optional[
|
278
|
+
pip_requirements: Optional[list[str]] = None,
|
281
279
|
) -> None:
|
282
280
|
self.source = Path(source) if isinstance(source, str) else source
|
283
281
|
self.entrypoint = Path(entrypoint) if isinstance(entrypoint, str) else entrypoint
|
@@ -288,17 +286,19 @@ class JobPayload:
|
|
288
286
|
stage_path = PurePath(stage_path) if isinstance(stage_path, str) else stage_path
|
289
287
|
source = resolve_source(self.source)
|
290
288
|
entrypoint = resolve_entrypoint(source, self.entrypoint)
|
289
|
+
pip_requirements = self.pip_requirements or []
|
291
290
|
|
292
291
|
# Create stage if necessary
|
293
292
|
stage_name = stage_path.parts[0].lstrip("@")
|
294
293
|
# Explicitly check if stage exists first since we may not have CREATE STAGE privilege
|
295
294
|
try:
|
296
|
-
session.sql(
|
295
|
+
session.sql("describe stage identifier(?)", params=[stage_name]).collect()
|
297
296
|
except sp_exceptions.SnowparkSQLException:
|
298
297
|
session.sql(
|
299
|
-
|
298
|
+
"create stage if not exists identifier(?)"
|
300
299
|
" encryption = ( type = 'SNOWFLAKE_SSE' )"
|
301
|
-
" comment = 'Created by snowflake.ml.jobs Python API'"
|
300
|
+
" comment = 'Created by snowflake.ml.jobs Python API'",
|
301
|
+
params=[stage_name],
|
302
302
|
).collect()
|
303
303
|
|
304
304
|
# Upload payload to stage
|
@@ -311,6 +311,8 @@ class JobPayload:
|
|
311
311
|
overwrite=True,
|
312
312
|
)
|
313
313
|
source = Path(entrypoint.file_path.parent)
|
314
|
+
if not any(r.startswith("cloudpickle") for r in pip_requirements):
|
315
|
+
pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
|
314
316
|
elif source.is_dir():
|
315
317
|
# Manually traverse the directory and upload each file, since Snowflake PUT
|
316
318
|
# can't handle directories. Reduce the number of PUT operations by using
|
@@ -335,10 +337,10 @@ class JobPayload:
|
|
335
337
|
|
336
338
|
# Upload requirements
|
337
339
|
# TODO: Check if payload includes both a requirements.txt file and pip_requirements
|
338
|
-
if
|
340
|
+
if pip_requirements:
|
339
341
|
# Upload requirements.txt to stage
|
340
342
|
session.file.put_stream(
|
341
|
-
io.BytesIO("\n".join(
|
343
|
+
io.BytesIO("\n".join(pip_requirements).encode()),
|
342
344
|
stage_location=stage_path.joinpath("requirements.txt").as_posix(),
|
343
345
|
auto_compress=False,
|
344
346
|
overwrite=True,
|
@@ -364,7 +366,7 @@ class JobPayload:
|
|
364
366
|
auto_compress=False,
|
365
367
|
)
|
366
368
|
|
367
|
-
python_entrypoint:
|
369
|
+
python_entrypoint: list[Union[str, PurePath]] = [
|
368
370
|
PurePath("mljob_launcher.py"),
|
369
371
|
entrypoint.file_path.relative_to(source),
|
370
372
|
]
|
@@ -381,7 +383,7 @@ class JobPayload:
|
|
381
383
|
)
|
382
384
|
|
383
385
|
|
384
|
-
def _get_parameter_type(param: inspect.Parameter) -> Optional[
|
386
|
+
def _get_parameter_type(param: inspect.Parameter) -> Optional[type[object]]:
|
385
387
|
# Unwrap Optional type annotations
|
386
388
|
param_type = param.annotation
|
387
389
|
if get_origin(param_type) is Union and len(get_args(param_type)) == 2 and type(None) in get_args(param_type):
|
@@ -390,10 +392,10 @@ def _get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
|
|
390
392
|
# Return None for empty type annotations
|
391
393
|
if param_type == inspect.Parameter.empty:
|
392
394
|
return None
|
393
|
-
return cast(
|
395
|
+
return cast(type[object], param_type)
|
394
396
|
|
395
397
|
|
396
|
-
def _validate_parameter_type(param_type:
|
398
|
+
def _validate_parameter_type(param_type: type[object], param_name: str) -> None:
|
397
399
|
# Validate param_type is a supported type
|
398
400
|
if param_type not in _SUPPORTED_ARG_TYPES:
|
399
401
|
raise ValueError(
|
@@ -505,13 +507,6 @@ def generate_python_code(func: Callable[..., Any], source_code_display: bool = F
|
|
505
507
|
# https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
|
506
508
|
source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
|
507
509
|
|
508
|
-
func_code = f"""
|
509
|
-
{source_code_comment}
|
510
|
-
|
511
|
-
import pickle
|
512
|
-
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
|
513
|
-
"""
|
514
|
-
|
515
510
|
arg_dict_name = "kwargs"
|
516
511
|
if getattr(func, constants.IS_MLJOB_REMOTE_ATTR, None):
|
517
512
|
param_code = f"{arg_dict_name} = {{}}"
|
@@ -519,25 +514,29 @@ import pickle
|
|
519
514
|
param_code = _generate_param_handler_code(signature, arg_dict_name)
|
520
515
|
|
521
516
|
return f"""
|
522
|
-
### Version guard to check compatibility across Python versions ###
|
523
|
-
import os
|
524
517
|
import sys
|
525
|
-
import
|
526
|
-
|
527
|
-
if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
|
528
|
-
warnings.warn(
|
529
|
-
"Python version mismatch: job was created using"
|
530
|
-
" python{sys.version_info.major}.{sys.version_info.minor}"
|
531
|
-
f" but runtime environment uses python{{sys.version_info.major}}.{{sys.version_info.minor}}."
|
532
|
-
" Compatibility across Python versions is not guaranteed and may result in unexpected behavior."
|
533
|
-
" This will be fixed in a future release; for now, please use Python version"
|
534
|
-
f" {{sys.version_info.major}}.{{sys.version_info.minor}}.",
|
535
|
-
RuntimeWarning,
|
536
|
-
stacklevel=0,
|
537
|
-
)
|
538
|
-
### End version guard ###
|
518
|
+
import pickle
|
539
519
|
|
540
|
-
|
520
|
+
try:
|
521
|
+
{textwrap.indent(source_code_comment, ' ')}
|
522
|
+
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
|
523
|
+
except (TypeError, pickle.PickleError):
|
524
|
+
if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
|
525
|
+
raise RuntimeError(
|
526
|
+
"Failed to deserialize function due to Python version mismatch."
|
527
|
+
f" Runtime environment is Python {{sys.version_info.major}}.{{sys.version_info.minor}}"
|
528
|
+
" but function was serialized using Python {sys.version_info.major}.{sys.version_info.minor}."
|
529
|
+
) from None
|
530
|
+
raise
|
531
|
+
except AttributeError as e:
|
532
|
+
if 'cloudpickle' in str(e):
|
533
|
+
import cloudpickle as cp
|
534
|
+
raise RuntimeError(
|
535
|
+
"Failed to deserialize function due to cloudpickle version mismatch."
|
536
|
+
f" Runtime environment uses cloudpickle=={{cp.__version__}}"
|
537
|
+
" but job was serialized using cloudpickle=={cp.__version__}."
|
538
|
+
) from e
|
539
|
+
raise
|
541
540
|
|
542
541
|
if __name__ == '__main__':
|
543
542
|
{textwrap.indent(param_code, ' ')}
|
@@ -29,7 +29,7 @@ def get_self_ip() -> Optional[str]:
|
|
29
29
|
return None
|
30
30
|
|
31
31
|
|
32
|
-
def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
|
32
|
+
def get_first_instance(service_name: str) -> Optional[tuple[str, str, str]]:
|
33
33
|
"""Get the first instance of a batch job based on start time and instance ID.
|
34
34
|
|
35
35
|
Args:
|
@@ -42,7 +42,7 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
|
|
42
42
|
|
43
43
|
session = session_utils.get_session()
|
44
44
|
df = session.sql(f"show service instances in service {service_name}")
|
45
|
-
result = df.select('"instance_id"', '"ip_address"', '"start_time"').collect()
|
45
|
+
result = df.select('"instance_id"', '"ip_address"', '"start_time"', '"status"').collect()
|
46
46
|
|
47
47
|
if not result:
|
48
48
|
return None
|
@@ -57,7 +57,7 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
|
|
57
57
|
ip_address = head_instance["ip_address"]
|
58
58
|
try:
|
59
59
|
socket.inet_aton(ip_address) # Validate IPv4 address
|
60
|
-
return (head_instance["instance_id"], ip_address)
|
60
|
+
return (head_instance["instance_id"], ip_address, head_instance["status"])
|
61
61
|
except OSError:
|
62
62
|
logger.error(f"Error: Invalid IP address format: {ip_address}")
|
63
63
|
return None
|
@@ -110,7 +110,7 @@ def main():
|
|
110
110
|
head_info = get_first_instance(args.service_name)
|
111
111
|
if head_info:
|
112
112
|
# Print to stdout to allow capture but don't use logger
|
113
|
-
sys.stdout.write(
|
113
|
+
sys.stdout.write(" ".join(head_info) + "\n")
|
114
114
|
sys.exit(0)
|
115
115
|
time.sleep(args.retry_interval)
|
116
116
|
# If we get here, we've timed out
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import argparse
|
2
|
+
import copy
|
2
3
|
import importlib.util
|
3
4
|
import json
|
4
5
|
import os
|
@@ -7,7 +8,7 @@ import sys
|
|
7
8
|
import traceback
|
8
9
|
import warnings
|
9
10
|
from pathlib import Path
|
10
|
-
from typing import Any,
|
11
|
+
from typing import Any, Optional
|
11
12
|
|
12
13
|
import cloudpickle
|
13
14
|
|
@@ -27,7 +28,7 @@ except ImportError:
|
|
27
28
|
from dataclasses import dataclass
|
28
29
|
|
29
30
|
@dataclass(frozen=True)
|
30
|
-
class ExecutionResult:
|
31
|
+
class ExecutionResult: # type: ignore[no-redef]
|
31
32
|
result: Optional[Any] = None
|
32
33
|
exception: Optional[BaseException] = None
|
33
34
|
|
@@ -35,7 +36,7 @@ except ImportError:
|
|
35
36
|
def success(self) -> bool:
|
36
37
|
return self.exception is None
|
37
38
|
|
38
|
-
def to_dict(self) ->
|
39
|
+
def to_dict(self) -> dict[str, Any]:
|
39
40
|
"""Return the serializable dictionary."""
|
40
41
|
if isinstance(self.exception, BaseException):
|
41
42
|
exc_type = type(self.exception)
|
@@ -58,7 +59,7 @@ class SimpleJSONEncoder(json.JSONEncoder):
|
|
58
59
|
try:
|
59
60
|
return super().default(obj)
|
60
61
|
except TypeError:
|
61
|
-
return
|
62
|
+
return f"Unserializable object: {repr(obj)}"
|
62
63
|
|
63
64
|
|
64
65
|
def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = None) -> Any:
|
@@ -136,7 +137,9 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
136
137
|
while tb and tb.tb_frame.f_code.co_filename in skip_files:
|
137
138
|
# Skip any frames preceding user script execution
|
138
139
|
tb = tb.tb_next
|
139
|
-
|
140
|
+
cleaned_ex = copy.copy(e) # Need to create a mutable copy of exception to set __traceback__
|
141
|
+
cleaned_ex = cleaned_ex.with_traceback(tb)
|
142
|
+
result_obj = ExecutionResult(exception=cleaned_ex)
|
140
143
|
raise
|
141
144
|
finally:
|
142
145
|
result_dict = result_obj.to_dict()
|
@@ -9,7 +9,7 @@ import logging
|
|
9
9
|
import socket
|
10
10
|
import sys
|
11
11
|
import time
|
12
|
-
from typing import Any
|
12
|
+
from typing import Any
|
13
13
|
|
14
14
|
import ray
|
15
15
|
from constants import (
|
@@ -33,34 +33,34 @@ class ShutdownSignal:
|
|
33
33
|
self.acknowledged_workers = set()
|
34
34
|
logging.info(f"ShutdownSignal actor created on {self.hostname}")
|
35
35
|
|
36
|
-
def request_shutdown(self) ->
|
36
|
+
def request_shutdown(self) -> dict[str, Any]:
|
37
37
|
"""Signal workers to shut down"""
|
38
38
|
self.shutdown_requested = True
|
39
39
|
self.timestamp = time.time()
|
40
40
|
logging.info(f"Shutdown requested by head node at {self.timestamp}")
|
41
41
|
return {"status": "shutdown_requested", "timestamp": self.timestamp, "host": self.hostname}
|
42
42
|
|
43
|
-
def should_shutdown(self) ->
|
43
|
+
def should_shutdown(self) -> dict[str, Any]:
|
44
44
|
"""Check if shutdown has been requested"""
|
45
45
|
return {"shutdown": self.shutdown_requested, "timestamp": self.timestamp, "host": self.hostname}
|
46
46
|
|
47
|
-
def ping(self) ->
|
47
|
+
def ping(self) -> dict[str, Any]:
|
48
48
|
"""Simple method to test connectivity"""
|
49
49
|
return {"status": "alive", "host": self.hostname}
|
50
50
|
|
51
|
-
def acknowledge_shutdown(self, worker_id: str) ->
|
51
|
+
def acknowledge_shutdown(self, worker_id: str) -> dict[str, Any]:
|
52
52
|
"""Worker acknowledges it has received the shutdown signal and is terminating"""
|
53
53
|
self.acknowledged_workers.add(worker_id)
|
54
54
|
logging.info(f"Worker {worker_id} acknowledged shutdown. Total acknowledged: {len(self.acknowledged_workers)}")
|
55
55
|
|
56
56
|
return {"status": "acknowledged", "worker_id": worker_id, "acknowledged_count": len(self.acknowledged_workers)}
|
57
57
|
|
58
|
-
def get_acknowledgment_workers(self) ->
|
58
|
+
def get_acknowledgment_workers(self) -> set[str]:
|
59
59
|
"""Get the set of workers who have acknowledged shutdown"""
|
60
60
|
return self.acknowledged_workers
|
61
61
|
|
62
62
|
|
63
|
-
def get_worker_node_ids() ->
|
63
|
+
def get_worker_node_ids() -> list[str]:
|
64
64
|
"""Get the IDs of all active worker nodes.
|
65
65
|
|
66
66
|
Returns:
|
@@ -127,7 +127,7 @@ def verify_shutdown(shutdown_signal: ActorHandle) -> None:
|
|
127
127
|
logging.debug(f"Shutdown status check: {check}")
|
128
128
|
|
129
129
|
|
130
|
-
def wait_for_acknowledgments(shutdown_signal: ActorHandle, worker_node_ids:
|
130
|
+
def wait_for_acknowledgments(shutdown_signal: ActorHandle, worker_node_ids: list[str], wait_time: int) -> None:
|
131
131
|
"""Wait for workers to acknowledge shutdown.
|
132
132
|
|
133
133
|
Args:
|