snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +18 -16
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +10 -10
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +45 -46
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +18 -29
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +10 -5
- snowflake/ml/jobs/job.py +87 -30
- snowflake/ml/jobs/manager.py +86 -56
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +217 -32
- snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +17 -26
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +58 -32
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
- snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +24 -11
- snowflake/ml/model/model_signature.py +12 -23
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/monitoring/explain_visualize.py +286 -0
- snowflake/ml/registry/_manager/model_manager.py +55 -32
- snowflake/ml/registry/registry.py +39 -31
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
- snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Mapping, Optional
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
4
|
from snowflake.ml._internal.utils import (
|
@@ -15,7 +15,7 @@ MODEL_JSON_MODEL_NAME_FIELD = "model_name"
|
|
15
15
|
MODEL_JSON_VERSION_NAME_FIELD = "version_name"
|
16
16
|
|
17
17
|
|
18
|
-
def _build_sql_list_from_columns(columns:
|
18
|
+
def _build_sql_list_from_columns(columns: list[sql_identifier.SqlIdentifier]) -> str:
|
19
19
|
sql_list = ", ".join([f"'{column}'" for column in columns])
|
20
20
|
return f"({sql_list})"
|
21
21
|
|
@@ -60,17 +60,17 @@ class ModelMonitorSQLClient:
|
|
60
60
|
function_name: str,
|
61
61
|
warehouse_name: sql_identifier.SqlIdentifier,
|
62
62
|
timestamp_column: sql_identifier.SqlIdentifier,
|
63
|
-
id_columns:
|
64
|
-
prediction_score_columns:
|
65
|
-
prediction_class_columns:
|
66
|
-
actual_score_columns:
|
67
|
-
actual_class_columns:
|
63
|
+
id_columns: list[sql_identifier.SqlIdentifier],
|
64
|
+
prediction_score_columns: list[sql_identifier.SqlIdentifier],
|
65
|
+
prediction_class_columns: list[sql_identifier.SqlIdentifier],
|
66
|
+
actual_score_columns: list[sql_identifier.SqlIdentifier],
|
67
|
+
actual_class_columns: list[sql_identifier.SqlIdentifier],
|
68
68
|
refresh_interval: str,
|
69
69
|
aggregation_window: str,
|
70
70
|
baseline_database: Optional[sql_identifier.SqlIdentifier] = None,
|
71
71
|
baseline_schema: Optional[sql_identifier.SqlIdentifier] = None,
|
72
72
|
baseline: Optional[sql_identifier.SqlIdentifier] = None,
|
73
|
-
statement_params: Optional[
|
73
|
+
statement_params: Optional[dict[str, Any]] = None,
|
74
74
|
) -> None:
|
75
75
|
baseline_sql = ""
|
76
76
|
if baseline:
|
@@ -103,7 +103,7 @@ class ModelMonitorSQLClient:
|
|
103
103
|
database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
104
104
|
schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
105
105
|
monitor_name: sql_identifier.SqlIdentifier,
|
106
|
-
statement_params: Optional[
|
106
|
+
statement_params: Optional[dict[str, Any]] = None,
|
107
107
|
) -> None:
|
108
108
|
search_database_name = database_name or self._database_name
|
109
109
|
search_schema_name = schema_name or self._schema_name
|
@@ -116,8 +116,8 @@ class ModelMonitorSQLClient:
|
|
116
116
|
def show_model_monitors(
|
117
117
|
self,
|
118
118
|
*,
|
119
|
-
statement_params: Optional[
|
120
|
-
) ->
|
119
|
+
statement_params: Optional[dict[str, Any]] = None,
|
120
|
+
) -> list[snowpark.Row]:
|
121
121
|
fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
|
122
122
|
return (
|
123
123
|
query_result_checker.SqlResultValidator(
|
@@ -135,7 +135,7 @@ class ModelMonitorSQLClient:
|
|
135
135
|
database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
136
136
|
schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
137
137
|
monitor_name: sql_identifier.SqlIdentifier,
|
138
|
-
statement_params: Optional[
|
138
|
+
statement_params: Optional[dict[str, Any]] = None,
|
139
139
|
) -> bool:
|
140
140
|
search_database_name = database_name or self._database_name
|
141
141
|
search_schema_name = schema_name or self._schema_name
|
@@ -153,7 +153,7 @@ class ModelMonitorSQLClient:
|
|
153
153
|
def validate_monitor_warehouse(
|
154
154
|
self,
|
155
155
|
warehouse_name: sql_identifier.SqlIdentifier,
|
156
|
-
statement_params: Optional[
|
156
|
+
statement_params: Optional[dict[str, Any]] = None,
|
157
157
|
) -> None:
|
158
158
|
"""Validate warehouse provided for monitoring exists.
|
159
159
|
|
@@ -177,11 +177,11 @@ class ModelMonitorSQLClient:
|
|
177
177
|
*,
|
178
178
|
source_column_schema: Mapping[str, types.DataType],
|
179
179
|
timestamp_column: sql_identifier.SqlIdentifier,
|
180
|
-
prediction_score_columns:
|
181
|
-
prediction_class_columns:
|
182
|
-
actual_score_columns:
|
183
|
-
actual_class_columns:
|
184
|
-
id_columns:
|
180
|
+
prediction_score_columns: list[sql_identifier.SqlIdentifier],
|
181
|
+
prediction_class_columns: list[sql_identifier.SqlIdentifier],
|
182
|
+
actual_score_columns: list[sql_identifier.SqlIdentifier],
|
183
|
+
actual_class_columns: list[sql_identifier.SqlIdentifier],
|
184
|
+
id_columns: list[sql_identifier.SqlIdentifier],
|
185
185
|
) -> None:
|
186
186
|
"""Ensures all columns exist in the source table.
|
187
187
|
|
@@ -221,11 +221,11 @@ class ModelMonitorSQLClient:
|
|
221
221
|
source_schema: Optional[sql_identifier.SqlIdentifier],
|
222
222
|
source: sql_identifier.SqlIdentifier,
|
223
223
|
timestamp_column: sql_identifier.SqlIdentifier,
|
224
|
-
prediction_score_columns:
|
225
|
-
prediction_class_columns:
|
226
|
-
actual_score_columns:
|
227
|
-
actual_class_columns:
|
228
|
-
id_columns:
|
224
|
+
prediction_score_columns: list[sql_identifier.SqlIdentifier],
|
225
|
+
prediction_class_columns: list[sql_identifier.SqlIdentifier],
|
226
|
+
actual_score_columns: list[sql_identifier.SqlIdentifier],
|
227
|
+
actual_class_columns: list[sql_identifier.SqlIdentifier],
|
228
|
+
id_columns: list[sql_identifier.SqlIdentifier],
|
229
229
|
) -> None:
|
230
230
|
source_database = source_database or self._database_name
|
231
231
|
source_schema = source_schema or self._schema_name
|
@@ -250,7 +250,7 @@ class ModelMonitorSQLClient:
|
|
250
250
|
self,
|
251
251
|
operation: str,
|
252
252
|
monitor_name: sql_identifier.SqlIdentifier,
|
253
|
-
statement_params: Optional[
|
253
|
+
statement_params: Optional[dict[str, Any]] = None,
|
254
254
|
) -> None:
|
255
255
|
if operation not in {"SUSPEND", "RESUME"}:
|
256
256
|
raise ValueError(f"Operation {operation} not supported for altering Dynamic Tables")
|
@@ -263,7 +263,7 @@ class ModelMonitorSQLClient:
|
|
263
263
|
def suspend_monitor(
|
264
264
|
self,
|
265
265
|
monitor_name: sql_identifier.SqlIdentifier,
|
266
|
-
statement_params: Optional[
|
266
|
+
statement_params: Optional[dict[str, Any]] = None,
|
267
267
|
) -> None:
|
268
268
|
self._alter_monitor(
|
269
269
|
operation="SUSPEND",
|
@@ -274,7 +274,7 @@ class ModelMonitorSQLClient:
|
|
274
274
|
def resume_monitor(
|
275
275
|
self,
|
276
276
|
monitor_name: sql_identifier.SqlIdentifier,
|
277
|
-
statement_params: Optional[
|
277
|
+
statement_params: Optional[dict[str, Any]] = None,
|
278
278
|
) -> None:
|
279
279
|
self._alter_monitor(
|
280
280
|
operation="RESUME",
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional
|
3
3
|
|
4
4
|
from snowflake import snowpark
|
5
5
|
from snowflake.ml._internal.utils import sql_identifier
|
@@ -20,7 +20,7 @@ class ModelMonitorManager:
|
|
20
20
|
database_name: sql_identifier.SqlIdentifier,
|
21
21
|
schema_name: sql_identifier.SqlIdentifier,
|
22
22
|
*,
|
23
|
-
statement_params: Optional[
|
23
|
+
statement_params: Optional[dict[str, Any]] = None,
|
24
24
|
) -> None:
|
25
25
|
"""
|
26
26
|
Opens a ModelMonitorManager for a given database and schema.
|
@@ -64,7 +64,7 @@ class ModelMonitorManager:
|
|
64
64
|
f"Found: {existing_target_methods}."
|
65
65
|
)
|
66
66
|
|
67
|
-
def _build_column_list_from_input(self, columns: Optional[
|
67
|
+
def _build_column_list_from_input(self, columns: Optional[list[str]]) -> list[sql_identifier.SqlIdentifier]:
|
68
68
|
return [sql_identifier.SqlIdentifier(column_name) for column_name in columns] if columns else []
|
69
69
|
|
70
70
|
def add_monitor(
|
@@ -172,7 +172,7 @@ class ModelMonitorManager:
|
|
172
172
|
"""
|
173
173
|
rows = self._model_monitor_client.show_model_monitors(statement_params=self.statement_params)
|
174
174
|
|
175
|
-
def model_match_fn(model_details:
|
175
|
+
def model_match_fn(model_details: dict[str, str]) -> bool:
|
176
176
|
return (
|
177
177
|
model_details[model_monitor_sql_client.MODEL_JSON_MODEL_NAME_FIELD] == model_version.model_name
|
178
178
|
and model_details[model_monitor_sql_client.MODEL_JSON_VERSION_NAME_FIELD] == model_version.version_name
|
@@ -215,7 +215,7 @@ class ModelMonitorManager:
|
|
215
215
|
name=monitor_name_id,
|
216
216
|
)
|
217
217
|
|
218
|
-
def show_model_monitors(self) ->
|
218
|
+
def show_model_monitors(self) -> list[snowpark.Row]:
|
219
219
|
"""Show all model monitors in the registry.
|
220
220
|
|
221
221
|
Returns:
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
|
-
from typing import
|
2
|
+
from typing import Optional
|
3
3
|
|
4
4
|
from snowflake.ml.model._client.model import model_version_impl
|
5
5
|
|
@@ -14,20 +14,20 @@ class ModelMonitorSourceConfig:
|
|
14
14
|
timestamp_column: str
|
15
15
|
"""Name of column in the source containing timestamp."""
|
16
16
|
|
17
|
-
id_columns:
|
17
|
+
id_columns: list[str]
|
18
18
|
"""List of columns in the source containing unique identifiers."""
|
19
19
|
|
20
|
-
prediction_score_columns: Optional[
|
20
|
+
prediction_score_columns: Optional[list[str]] = None
|
21
21
|
"""List of columns in the source containing prediction scores.
|
22
22
|
Can be regression scores for regression models and probability scores for classification models."""
|
23
23
|
|
24
|
-
prediction_class_columns: Optional[
|
24
|
+
prediction_class_columns: Optional[list[str]] = None
|
25
25
|
"""List of columns in the source containing prediction classes for classification models."""
|
26
26
|
|
27
|
-
actual_score_columns: Optional[
|
27
|
+
actual_score_columns: Optional[list[str]] = None
|
28
28
|
"""List of columns in the source containing actual scores."""
|
29
29
|
|
30
|
-
actual_class_columns: Optional[
|
30
|
+
actual_class_columns: Optional[list[str]] = None
|
31
31
|
"""List of columns in the source containing actual classes for classification models."""
|
32
32
|
|
33
33
|
baseline: Optional[str] = None
|
@@ -0,0 +1,286 @@
|
|
1
|
+
from typing import Union, cast, overload
|
2
|
+
|
3
|
+
import altair as alt
|
4
|
+
import numpy as np
|
5
|
+
import pandas as pd
|
6
|
+
|
7
|
+
import snowflake.snowpark.dataframe as sp_df
|
8
|
+
from snowflake import snowpark
|
9
|
+
from snowflake.ml.model import model_signature, type_hints
|
10
|
+
from snowflake.ml.model._signatures import snowpark_handler
|
11
|
+
|
12
|
+
|
13
|
+
@overload
|
14
|
+
def plot_force(
|
15
|
+
shap_row: snowpark.Row,
|
16
|
+
features_row: snowpark.Row,
|
17
|
+
base_value: float = 0.0,
|
18
|
+
figsize: tuple[float, float] = (600, 200),
|
19
|
+
contribution_threshold: float = 0.05,
|
20
|
+
) -> alt.LayerChart:
|
21
|
+
...
|
22
|
+
|
23
|
+
|
24
|
+
@overload
|
25
|
+
def plot_force(
|
26
|
+
shap_row: pd.Series,
|
27
|
+
features_row: pd.Series,
|
28
|
+
base_value: float = 0.0,
|
29
|
+
figsize: tuple[float, float] = (600, 200),
|
30
|
+
contribution_threshold: float = 0.05,
|
31
|
+
) -> alt.LayerChart:
|
32
|
+
...
|
33
|
+
|
34
|
+
|
35
|
+
def plot_force(
|
36
|
+
shap_row: Union[pd.Series, snowpark.Row],
|
37
|
+
features_row: Union[pd.Series, snowpark.Row],
|
38
|
+
base_value: float = 0.0,
|
39
|
+
figsize: tuple[float, float] = (600, 200),
|
40
|
+
contribution_threshold: float = 0.05,
|
41
|
+
) -> alt.LayerChart:
|
42
|
+
"""
|
43
|
+
Create a force plot for SHAP values with stacked bars based on influence direction.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
shap_row: pandas Series or snowpark Row containing SHAP values for a specific instance
|
47
|
+
features_row: pandas Series or snowpark Row containing the feature values for the same instance
|
48
|
+
base_value: base value of the predictions. Defaults to 0, but is usually the model's average prediction
|
49
|
+
figsize: tuple of (width, height) for the plot
|
50
|
+
contribution_threshold:
|
51
|
+
Only features with magnitude greater than contribution_threshold as a percentage of the
|
52
|
+
total absolute SHAP values will be plotted. Defaults to 0.05 (5%)
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
Altair chart object
|
56
|
+
"""
|
57
|
+
if isinstance(shap_row, snowpark.Row):
|
58
|
+
shap_row = pd.Series(shap_row.as_dict())
|
59
|
+
if isinstance(features_row, snowpark.Row):
|
60
|
+
features_row = pd.Series(features_row.as_dict())
|
61
|
+
|
62
|
+
# Create a dataframe for plotting
|
63
|
+
positive_label = "Positive"
|
64
|
+
negative_label = "Negative"
|
65
|
+
plot_df = pd.DataFrame(
|
66
|
+
[
|
67
|
+
{
|
68
|
+
"feature": feature,
|
69
|
+
"feature_value": features_row.iloc[index],
|
70
|
+
"feature_annotated": f"{feature}: {features_row.iloc[index]}",
|
71
|
+
"influence_value": shap_row.iloc[index],
|
72
|
+
"bar_direction": positive_label if shap_row.iloc[index] >= 0 else negative_label,
|
73
|
+
}
|
74
|
+
for index, feature in enumerate(features_row.index)
|
75
|
+
]
|
76
|
+
)
|
77
|
+
|
78
|
+
# Calculate cumulative positions for the stacked bars
|
79
|
+
shap_sum = np.sum(shap_row)
|
80
|
+
current_position_pos = shap_sum
|
81
|
+
current_position_neg = shap_sum
|
82
|
+
positions = []
|
83
|
+
|
84
|
+
total_abs_value_sum = np.sum(plot_df["influence_value"].abs())
|
85
|
+
max_abs_value = plot_df["influence_value"].abs().max()
|
86
|
+
spacing = max_abs_value * 0.07 # Use 2% of max value as spacing between bars
|
87
|
+
|
88
|
+
# Sort by absolute value to have largest impacts first
|
89
|
+
plot_df = plot_df.reindex(plot_df["influence_value"].abs().sort_values(ascending=False).index)
|
90
|
+
for _, row in plot_df.iterrows():
|
91
|
+
# Skip features with small contributions
|
92
|
+
row_influence_value = row["influence_value"]
|
93
|
+
if abs(row_influence_value) / total_abs_value_sum < contribution_threshold:
|
94
|
+
continue
|
95
|
+
|
96
|
+
if row_influence_value >= 0:
|
97
|
+
start = current_position_pos - spacing
|
98
|
+
end = current_position_pos - row_influence_value
|
99
|
+
current_position_pos = end
|
100
|
+
else:
|
101
|
+
start = current_position_neg + spacing
|
102
|
+
end = current_position_neg + abs(row_influence_value)
|
103
|
+
current_position_neg = end
|
104
|
+
|
105
|
+
positions.append(
|
106
|
+
{
|
107
|
+
"start": start,
|
108
|
+
"end": end,
|
109
|
+
"avg": (start + end) / 2,
|
110
|
+
"influence_value": row_influence_value,
|
111
|
+
"influence_annotated": f"Influence: {row_influence_value}",
|
112
|
+
"feature_value": row["feature_value"],
|
113
|
+
"feature_annotated": row["feature_annotated"],
|
114
|
+
"bar_direction": row["bar_direction"],
|
115
|
+
}
|
116
|
+
)
|
117
|
+
|
118
|
+
position_df = pd.DataFrame(positions)
|
119
|
+
|
120
|
+
# Create force plot using Altair
|
121
|
+
blue_color = "#1f77b4"
|
122
|
+
red_color = "#d62728"
|
123
|
+
width, height = figsize
|
124
|
+
bars: alt.Chart = (
|
125
|
+
alt.Chart(position_df)
|
126
|
+
.mark_bar(size=10)
|
127
|
+
.encode(
|
128
|
+
x=alt.X("start:Q", title="Feature Impact"),
|
129
|
+
x2=alt.X2("end:Q"),
|
130
|
+
color=alt.Color(
|
131
|
+
"bar_direction:N",
|
132
|
+
scale=alt.Scale(domain=[positive_label, negative_label], range=[red_color, blue_color]),
|
133
|
+
legend=alt.Legend(title="Influence Direction"),
|
134
|
+
),
|
135
|
+
tooltip=["influence_value", "feature_value"],
|
136
|
+
)
|
137
|
+
.properties(title="Feature Influence (SHAP values)", width=width, height=height)
|
138
|
+
).interactive()
|
139
|
+
|
140
|
+
arrow: alt.Chart = (
|
141
|
+
alt.Chart(position_df)
|
142
|
+
.mark_point(shape="triangle", filled=True, fillOpacity=1)
|
143
|
+
.encode(
|
144
|
+
x=alt.X("start:Q"),
|
145
|
+
angle=alt.Angle("bar_direction:N", scale=alt.Scale(domain=["Positive", "Negative"], range=[90, -90])),
|
146
|
+
color=alt.Color(
|
147
|
+
"bar_direction:N", scale=alt.Scale(domain=["Positive", "Negative"], range=["#1f77b4", "#d62728"])
|
148
|
+
),
|
149
|
+
size=alt.SizeValue(300),
|
150
|
+
tooltip=alt.value(None),
|
151
|
+
)
|
152
|
+
)
|
153
|
+
|
154
|
+
# Add a vertical line at the base value
|
155
|
+
zero_line: alt.Chart = alt.Chart(pd.DataFrame({"x": [base_value]})).mark_rule(strokeDash=[3, 3]).encode(x="x:Q")
|
156
|
+
|
157
|
+
# Add text labels on each bar
|
158
|
+
feature_labels = (
|
159
|
+
alt.Chart(position_df)
|
160
|
+
.mark_text(align="center", baseline="line-bottom", dy=30, fontSize=11)
|
161
|
+
.encode(
|
162
|
+
x=alt.X("avg:Q"),
|
163
|
+
text=alt.Text("feature_annotated:N"), # Display with 2 decimal places
|
164
|
+
color=alt.value("grey"), # Label color for positive values
|
165
|
+
tooltip=["feature_value"],
|
166
|
+
)
|
167
|
+
)
|
168
|
+
|
169
|
+
return cast(alt.LayerChart, bars + feature_labels + zero_line + arrow)
|
170
|
+
|
171
|
+
|
172
|
+
def plot_influence_sensitivity(
|
173
|
+
feature_values: pd.Series, shap_values: pd.Series, figsize: tuple[float, float] = (600, 400)
|
174
|
+
) -> alt.Chart:
|
175
|
+
"""
|
176
|
+
Create a SHAP dependence scatter plot for a specific feature.
|
177
|
+
|
178
|
+
Args:
|
179
|
+
feature_values: pandas Series containing the feature values for a specific feature
|
180
|
+
shap_values: pandas Series containing the SHAP values for the same feature
|
181
|
+
figsize: tuple of (width, height) for the plot
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
Altair chart object
|
185
|
+
|
186
|
+
"""
|
187
|
+
|
188
|
+
unique_vals = np.sort(np.unique(feature_values.values))
|
189
|
+
max_points_per_unique_value = float(np.max(np.bincount(np.searchsorted(unique_vals, feature_values.values))))
|
190
|
+
points_per_value = len(feature_values.values) / len(unique_vals)
|
191
|
+
is_categorical = float(max(max_points_per_unique_value, points_per_value)) > 10
|
192
|
+
|
193
|
+
kwargs = (
|
194
|
+
{
|
195
|
+
"x": alt.X("feature_value:N", title="Feature Value"),
|
196
|
+
"color": alt.Color("feature_value:N").legend(None),
|
197
|
+
"xOffset": "jitter:Q",
|
198
|
+
}
|
199
|
+
if is_categorical
|
200
|
+
else {"x": alt.X("feature_value:Q", title="Feature Value")}
|
201
|
+
)
|
202
|
+
|
203
|
+
# Create a dataframe for plotting
|
204
|
+
plot_df = pd.DataFrame({"feature_value": feature_values, "shap_value": shap_values})
|
205
|
+
|
206
|
+
width, height = figsize
|
207
|
+
|
208
|
+
# Create scatter plot
|
209
|
+
scatter = (
|
210
|
+
alt.Chart(plot_df)
|
211
|
+
.transform_calculate(jitter="random()")
|
212
|
+
.mark_circle(size=60, opacity=0.7)
|
213
|
+
.encode(
|
214
|
+
y=alt.Y("shap_value:Q", title="SHAP Value"),
|
215
|
+
tooltip=["feature_value", "shap_value"],
|
216
|
+
**kwargs,
|
217
|
+
)
|
218
|
+
.properties(title="SHAP Dependence Scatter Plot", width=width, height=height)
|
219
|
+
)
|
220
|
+
|
221
|
+
return cast(alt.Chart, scatter)
|
222
|
+
|
223
|
+
|
224
|
+
def plot_violin(
|
225
|
+
shap_df: type_hints.SupportedDataType,
|
226
|
+
feature_df: type_hints.SupportedDataType,
|
227
|
+
figsize: tuple[float, float] = (600, 200),
|
228
|
+
) -> alt.Chart:
|
229
|
+
"""
|
230
|
+
Create a violin plot per feature showing the distribution of SHAP values.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
shap_df: 2D array containing SHAP values for multiple features
|
234
|
+
feature_df: 2D array containing the corresponding feature values
|
235
|
+
figsize: tuple of (width, height) for the plot
|
236
|
+
|
237
|
+
Returns:
|
238
|
+
Altair chart object
|
239
|
+
"""
|
240
|
+
|
241
|
+
shap_df_pd = _convert_to_pandas_df(shap_df)
|
242
|
+
feature_df_pd = _convert_to_pandas_df(feature_df)
|
243
|
+
|
244
|
+
# Assert that the input dataframes are 2D
|
245
|
+
assert len(shap_df_pd.shape) == 2, f"shap_df must be 2D, but got shape {shap_df_pd.shape}"
|
246
|
+
assert len(feature_df_pd.shape) == 2, f"feature_df must be 2D, but got shape {feature_df_pd.shape}"
|
247
|
+
|
248
|
+
# Prepare data for plotting
|
249
|
+
plot_data = pd.DataFrame(
|
250
|
+
{
|
251
|
+
"feature_name": feature_df_pd.columns.repeat(shap_df_pd.shape[0]),
|
252
|
+
"shap_value": shap_df_pd.transpose().values.flatten(),
|
253
|
+
}
|
254
|
+
)
|
255
|
+
|
256
|
+
# Order the rows by the absolute sum of SHAP values per feature
|
257
|
+
feature_abs_sum = shap_df_pd.abs().sum(axis=0)
|
258
|
+
sorted_features = feature_abs_sum.sort_values(ascending=False).index
|
259
|
+
column_sort_order = [feature_df_pd.columns[shap_df_pd.columns.get_loc(col)] for col in sorted_features]
|
260
|
+
|
261
|
+
# Create the violin plot
|
262
|
+
width, height = figsize
|
263
|
+
violin = (
|
264
|
+
alt.Chart(plot_data)
|
265
|
+
.transform_density(density="shap_value", groupby=["feature_name"], as_=["shap_value", "density"])
|
266
|
+
.mark_area(orient="vertical")
|
267
|
+
.encode(
|
268
|
+
y=alt.Y("density:Q", title=None).stack("center").impute(None).axis(labels=False, grid=False, ticks=True),
|
269
|
+
x=alt.X("shap_value:Q", title="SHAP Value"),
|
270
|
+
row=alt.Row("feature_name:N", sort=column_sort_order).spacing(0),
|
271
|
+
color=alt.Color("feature_name:N", legend=None),
|
272
|
+
tooltip=["feature_name", "shap_value"],
|
273
|
+
)
|
274
|
+
.properties(width=width, height=height)
|
275
|
+
).interactive()
|
276
|
+
|
277
|
+
return cast(alt.Chart, violin)
|
278
|
+
|
279
|
+
|
280
|
+
def _convert_to_pandas_df(
|
281
|
+
data: type_hints.SupportedDataType,
|
282
|
+
) -> pd.DataFrame:
|
283
|
+
if isinstance(data, sp_df.DataFrame):
|
284
|
+
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(data)
|
285
|
+
|
286
|
+
return model_signature._convert_local_data_to_df(data)
|