snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +18 -16
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +10 -10
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +45 -46
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +18 -29
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +10 -5
- snowflake/ml/jobs/job.py +87 -30
- snowflake/ml/jobs/manager.py +86 -56
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +217 -32
- snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +17 -26
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +58 -32
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
- snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +24 -11
- snowflake/ml/model/model_signature.py +12 -23
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/monitoring/explain_visualize.py +286 -0
- snowflake/ml/registry/_manager/model_manager.py +55 -32
- snowflake/ml/registry/registry.py +39 -31
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
- snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
3
|
import os
|
4
|
-
from typing import TYPE_CHECKING, Callable,
|
4
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
5
5
|
|
6
6
|
import pandas as pd
|
7
7
|
from typing_extensions import TypeGuard, Unpack
|
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
|
24
24
|
logger = logging.getLogger(__name__)
|
25
25
|
|
26
26
|
|
27
|
-
def _validate_sentence_transformers_signatures(sigs:
|
27
|
+
def _validate_sentence_transformers_signatures(sigs: dict[str, model_signature.ModelSignature]) -> None:
|
28
28
|
if list(sigs.keys()) != ["encode"]:
|
29
29
|
raise ValueError("target_methods can only be ['encode']")
|
30
30
|
|
@@ -48,7 +48,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
48
48
|
HANDLER_TYPE = "sentence_transformers"
|
49
49
|
HANDLER_VERSION = "2024-03-15"
|
50
50
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
51
|
-
_HANDLER_MIGRATOR_PLANS:
|
51
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
52
52
|
|
53
53
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
54
54
|
DEFAULT_TARGET_METHODS = ["encode"]
|
@@ -166,7 +166,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
166
166
|
],
|
167
167
|
check_local_version=True,
|
168
168
|
)
|
169
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
169
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
170
170
|
|
171
171
|
@staticmethod
|
172
172
|
def _get_device_config(**kwargs: Unpack[model_types.SentenceTransformersLoadOptions]) -> Optional[str]:
|
@@ -224,7 +224,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
224
224
|
def _create_custom_model(
|
225
225
|
raw_model: "sentence_transformers.SentenceTransformer",
|
226
226
|
model_meta: model_meta_api.ModelMetadata,
|
227
|
-
) ->
|
227
|
+
) -> type[custom_model.CustomModel]:
|
228
228
|
batch_size = cast(
|
229
229
|
model_meta_schema.SentenceTransformersModelBlobOptions, model_meta.models[model_meta.name].options
|
230
230
|
).get("batch_size", None)
|
@@ -1,13 +1,13 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
-
from typing import TYPE_CHECKING, Callable,
|
3
|
+
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union, cast, final
|
4
4
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
7
7
|
import pandas as pd
|
8
8
|
from typing_extensions import TypeGuard, Unpack
|
9
9
|
|
10
|
-
from snowflake.ml._internal import type_utils
|
10
|
+
from snowflake.ml._internal import env, type_utils
|
11
11
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
12
12
|
from snowflake.ml.model._packager.model_env import model_env
|
13
13
|
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
@@ -19,7 +19,6 @@ from snowflake.ml.model._packager.model_meta import (
|
|
19
19
|
)
|
20
20
|
from snowflake.ml.model._packager.model_task import model_task_utils
|
21
21
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
22
|
-
from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR
|
23
22
|
|
24
23
|
if TYPE_CHECKING:
|
25
24
|
import sklearn.base
|
@@ -39,6 +38,35 @@ def _unpack_container_runtime_pipeline(model: "sklearn.pipeline.Pipeline") -> "s
|
|
39
38
|
return model
|
40
39
|
|
41
40
|
|
41
|
+
def _apply_transforms_up_to_last_step(
|
42
|
+
model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
43
|
+
data: model_types.SupportedDataType,
|
44
|
+
input_feature_names: Optional[list[str]] = None,
|
45
|
+
) -> pd.DataFrame:
|
46
|
+
"""Apply all transformations in the sklearn pipeline model up to the last step."""
|
47
|
+
transformed_data = data
|
48
|
+
output_features_names = input_feature_names
|
49
|
+
|
50
|
+
if type_utils.LazyType("sklearn.pipeline.Pipeline").isinstance(model):
|
51
|
+
for step_name, step in model.steps[:-1]: # type: ignore[attr-defined]
|
52
|
+
if not hasattr(step, "transform"):
|
53
|
+
raise ValueError(f"Step '{step_name}' does not have a 'transform' method.")
|
54
|
+
transformed_data = step.transform(transformed_data)
|
55
|
+
if output_features_names is None:
|
56
|
+
continue
|
57
|
+
elif hasattr(step, "get_feature_names_out"):
|
58
|
+
output_features_names = step.get_feature_names_out(output_features_names)
|
59
|
+
else:
|
60
|
+
raise ValueError(
|
61
|
+
f"Step '{step_name}' in the pipeline does not have a 'get_feature_names_out' method. "
|
62
|
+
"Feature names cannot be propagated."
|
63
|
+
)
|
64
|
+
if type_utils.LazyType("scipy.sparse.csr_matrix").isinstance(transformed_data):
|
65
|
+
# Convert to dense array if it's a sparse matrix
|
66
|
+
transformed_data = transformed_data.toarray() # type: ignore[attr-defined]
|
67
|
+
return pd.DataFrame(transformed_data, columns=output_features_names)
|
68
|
+
|
69
|
+
|
42
70
|
@final
|
43
71
|
class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]]):
|
44
72
|
"""Handler for scikit-learn based model.
|
@@ -49,7 +77,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
49
77
|
HANDLER_TYPE = "sklearn"
|
50
78
|
HANDLER_VERSION = "2023-12-01"
|
51
79
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
52
|
-
_HANDLER_MIGRATOR_PLANS:
|
80
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
53
81
|
|
54
82
|
DEFAULT_TARGET_METHODS = [
|
55
83
|
"predict",
|
@@ -59,7 +87,9 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
59
87
|
"decision_function",
|
60
88
|
"score_samples",
|
61
89
|
]
|
62
|
-
|
90
|
+
|
91
|
+
# Prioritize predict_proba as it gives multi-class probabilities
|
92
|
+
EXPLAIN_TARGET_METHODS = ["predict_proba", "predict", "predict_log_proba"]
|
63
93
|
|
64
94
|
@classmethod
|
65
95
|
def can_handle(
|
@@ -113,7 +143,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
113
143
|
raise ValueError("Sample input data is required to enable explainability.")
|
114
144
|
|
115
145
|
# If this is a pipeline and we are in the container runtime, check for distributed estimator.
|
116
|
-
if
|
146
|
+
if env.IN_ML_RUNTIME and isinstance(model, sklearn.pipeline.Pipeline):
|
117
147
|
model = _unpack_container_runtime_pipeline(model)
|
118
148
|
|
119
149
|
if not is_sub_model:
|
@@ -161,17 +191,38 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
161
191
|
stacklevel=1,
|
162
192
|
)
|
163
193
|
enable_explainability = False
|
164
|
-
elif model_meta.task == model_types.Task.UNKNOWN
|
194
|
+
elif model_meta.task == model_types.Task.UNKNOWN:
|
195
|
+
enable_explainability = False
|
196
|
+
elif explain_target_method is None:
|
165
197
|
enable_explainability = False
|
166
198
|
else:
|
167
199
|
enable_explainability = True
|
168
200
|
if enable_explainability:
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
201
|
+
explain_target_method = str(explain_target_method) # mypy complains if we don't cast to str here
|
202
|
+
|
203
|
+
input_signature = handlers_utils.get_input_signature(model_meta, explain_target_method)
|
204
|
+
transformed_background_data = _apply_transforms_up_to_last_step(
|
205
|
+
model=model,
|
206
|
+
data=background_data,
|
207
|
+
input_feature_names=[spec.name for spec in input_signature],
|
174
208
|
)
|
209
|
+
|
210
|
+
try:
|
211
|
+
model_meta = handlers_utils.add_inferred_explain_method_signature(
|
212
|
+
model_meta=model_meta,
|
213
|
+
explain_method="explain",
|
214
|
+
target_method=explain_target_method,
|
215
|
+
background_data=background_data,
|
216
|
+
explain_fn=cls._build_explain_fn(model, background_data, input_signature),
|
217
|
+
output_feature_names=transformed_background_data.columns,
|
218
|
+
)
|
219
|
+
except ValueError:
|
220
|
+
if kwargs.get("enable_explainability", None):
|
221
|
+
# user explicitly enabled explainability, so we should raise the error
|
222
|
+
raise ValueError(
|
223
|
+
"Explainability for this model is not supported. Please set `enable_explainability=False`"
|
224
|
+
)
|
225
|
+
|
175
226
|
handlers_utils.save_background_data(
|
176
227
|
model_blobs_dir_path,
|
177
228
|
cls.EXPLAIN_ARTIFACTS_DIR,
|
@@ -223,11 +274,13 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
223
274
|
)
|
224
275
|
|
225
276
|
if enable_explainability:
|
226
|
-
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
277
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
|
227
278
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
228
279
|
|
229
280
|
model_meta.env.include_if_absent(
|
230
|
-
[
|
281
|
+
[
|
282
|
+
model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
|
283
|
+
],
|
231
284
|
check_local_version=True,
|
232
285
|
)
|
233
286
|
|
@@ -265,7 +318,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
265
318
|
def _create_custom_model(
|
266
319
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
267
320
|
model_meta: model_meta_api.ModelMetadata,
|
268
|
-
) ->
|
321
|
+
) -> type[custom_model.CustomModel]:
|
269
322
|
def fn_factory(
|
270
323
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
271
324
|
signature: model_signature.ModelSignature,
|
@@ -287,37 +340,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
287
340
|
|
288
341
|
@custom_model.inference_api
|
289
342
|
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
290
|
-
|
291
|
-
|
292
|
-
try:
|
293
|
-
explainer = shap.Explainer(raw_model, background_data)
|
294
|
-
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
295
|
-
except TypeError:
|
296
|
-
try:
|
297
|
-
dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
|
298
|
-
|
299
|
-
if isinstance(X, pd.DataFrame):
|
300
|
-
X = X.astype(dtype_map, copy=False)
|
301
|
-
if hasattr(raw_model, "predict_proba"):
|
302
|
-
if isinstance(X, np.ndarray):
|
303
|
-
explanations = shap.Explainer(
|
304
|
-
raw_model.predict_proba, background_data.values # type: ignore[union-attr]
|
305
|
-
)(X).values
|
306
|
-
else:
|
307
|
-
explanations = shap.Explainer(raw_model.predict_proba, background_data)(X).values
|
308
|
-
elif hasattr(raw_model, "predict"):
|
309
|
-
if isinstance(X, np.ndarray):
|
310
|
-
explanations = shap.Explainer(
|
311
|
-
raw_model.predict, background_data.values # type: ignore[union-attr]
|
312
|
-
)(X).values
|
313
|
-
else:
|
314
|
-
explanations = shap.Explainer(raw_model.predict, background_data)(X).values
|
315
|
-
else:
|
316
|
-
raise ValueError("Missing any supported target method to explain.")
|
317
|
-
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explanations)
|
318
|
-
except TypeError as e:
|
319
|
-
raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
|
320
|
-
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
343
|
+
fn = cls._build_explain_fn(raw_model, background_data, signature.inputs)
|
344
|
+
return model_signature_utils.rename_pandas_df(fn(X), signature.outputs)
|
321
345
|
|
322
346
|
if target_method == "explain":
|
323
347
|
return explain_fn
|
@@ -340,3 +364,37 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
340
364
|
skl_model = _SKLModel(custom_model.ModelContext())
|
341
365
|
|
342
366
|
return skl_model
|
367
|
+
|
368
|
+
@classmethod
|
369
|
+
def _build_explain_fn(
|
370
|
+
cls,
|
371
|
+
model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
372
|
+
background_data: model_types.SupportedDataType,
|
373
|
+
input_specs: Sequence[model_signature.BaseFeatureSpec],
|
374
|
+
) -> Callable[[model_types.SupportedDataType], pd.DataFrame]:
|
375
|
+
import shap
|
376
|
+
import sklearn.pipeline
|
377
|
+
|
378
|
+
transformed_bg_data = _apply_transforms_up_to_last_step(model, background_data)
|
379
|
+
|
380
|
+
def explain_fn(data: model_types.SupportedDataType) -> pd.DataFrame:
|
381
|
+
transformed_data = _apply_transforms_up_to_last_step(model, data)
|
382
|
+
predictor = model[-1] if isinstance(model, sklearn.pipeline.Pipeline) else model
|
383
|
+
try:
|
384
|
+
explainer = shap.Explainer(predictor, transformed_bg_data)
|
385
|
+
return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values)
|
386
|
+
except TypeError:
|
387
|
+
if isinstance(data, pd.DataFrame):
|
388
|
+
dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in input_specs}
|
389
|
+
transformed_data = _apply_transforms_up_to_last_step(model, data.astype(dtype_map))
|
390
|
+
for explain_target_method in cls.EXPLAIN_TARGET_METHODS:
|
391
|
+
if not hasattr(predictor, explain_target_method):
|
392
|
+
continue
|
393
|
+
explain_target_method_fn = getattr(predictor, explain_target_method)
|
394
|
+
explanations = shap.Explainer(explain_target_method_fn, transformed_bg_data.values)(
|
395
|
+
transformed_data.to_numpy()
|
396
|
+
).values
|
397
|
+
return handlers_utils.convert_explanations_to_2D_df(model, explanations)
|
398
|
+
raise ValueError("Missing any supported target method to explain.")
|
399
|
+
|
400
|
+
return explain_fn
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
|
4
4
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
@@ -36,7 +36,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
36
36
|
HANDLER_TYPE = "snowml"
|
37
37
|
HANDLER_VERSION = "2023-12-01"
|
38
38
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
39
|
-
_HANDLER_MIGRATOR_PLANS:
|
39
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
40
40
|
|
41
41
|
DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
|
42
42
|
EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
|
@@ -264,7 +264,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
264
264
|
def _create_custom_model(
|
265
265
|
raw_model: "BaseEstimator",
|
266
266
|
model_meta: model_meta_api.ModelMetadata,
|
267
|
-
) ->
|
267
|
+
) -> type[custom_model.CustomModel]:
|
268
268
|
def fn_factory(
|
269
269
|
raw_model: "BaseEstimator",
|
270
270
|
signature: model_signature.ModelSignature,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import TYPE_CHECKING, Callable,
|
2
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
from packaging import version
|
@@ -38,7 +38,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
38
38
|
HANDLER_TYPE = "tensorflow"
|
39
39
|
HANDLER_VERSION = "2025-03-01"
|
40
40
|
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
41
|
-
_HANDLER_MIGRATOR_PLANS:
|
41
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
|
42
42
|
"2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201,
|
43
43
|
"2025-01-01": tensorflow_migrator_2025_01_01.TensorflowHandlerMigrator20250101,
|
44
44
|
}
|
@@ -88,6 +88,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
88
88
|
import tensorflow
|
89
89
|
|
90
90
|
assert isinstance(model, tensorflow.Module)
|
91
|
+
multiple_inputs = kwargs.get("multiple_inputs", False)
|
91
92
|
|
92
93
|
is_keras_model = type_utils.LazyType("keras.Model").isinstance(model)
|
93
94
|
is_tf_keras_model = type_utils.LazyType("tf_keras.Model").isinstance(model)
|
@@ -112,8 +113,6 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
112
113
|
default_target_methods=default_target_methods,
|
113
114
|
)
|
114
115
|
|
115
|
-
multiple_inputs = kwargs.get("multiple_inputs", False)
|
116
|
-
|
117
116
|
if is_keras_model and len(target_methods) > 1:
|
118
117
|
raise ValueError("Keras model can only have one target method.")
|
119
118
|
|
@@ -188,7 +187,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
188
187
|
dependencies,
|
189
188
|
check_local_version=True,
|
190
189
|
)
|
191
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
190
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
192
191
|
|
193
192
|
@classmethod
|
194
193
|
def load_model(
|
@@ -198,7 +197,6 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
198
197
|
model_blobs_dir_path: str,
|
199
198
|
**kwargs: Unpack[model_types.TensorflowLoadOptions],
|
200
199
|
) -> "tensorflow.Module":
|
201
|
-
os.environ["TF_USE_LEGACY_KERAS"] = "1"
|
202
200
|
import tensorflow
|
203
201
|
|
204
202
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
@@ -209,7 +207,12 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
209
207
|
load_path = os.path.join(model_blob_path, model_blob_filename)
|
210
208
|
save_format = model_blob_options.get("save_format", "keras_tf")
|
211
209
|
if save_format == "keras_tf":
|
212
|
-
|
210
|
+
if version.parse(tensorflow.keras.__version__) >= version.parse("3.0.0"):
|
211
|
+
import tf_keras
|
212
|
+
|
213
|
+
m = tf_keras.models.load_model(load_path)
|
214
|
+
else:
|
215
|
+
m = tensorflow.keras.models.load_model(load_path)
|
213
216
|
else:
|
214
217
|
m = tensorflow.saved_model.load(load_path)
|
215
218
|
|
@@ -230,7 +233,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
230
233
|
def _create_custom_model(
|
231
234
|
raw_model: "tensorflow.Module",
|
232
235
|
model_meta: model_meta_api.ModelMetadata,
|
233
|
-
) ->
|
236
|
+
) -> type[custom_model.CustomModel]:
|
234
237
|
multiple_inputs = cast(
|
235
238
|
model_meta_schema.TensorflowModelBlobOptions, model_meta.models[model_meta.name].options
|
236
239
|
)["multiple_inputs"]
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import TYPE_CHECKING, Callable,
|
2
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
from typing_extensions import TypeGuard, Unpack
|
@@ -36,7 +36,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
36
36
|
HANDLER_TYPE = "torchscript"
|
37
37
|
HANDLER_VERSION = "2025-03-01"
|
38
38
|
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
39
|
-
_HANDLER_MIGRATOR_PLANS:
|
39
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
|
40
40
|
"2023-12-01": torchscript_migrator_2023_12_01.TorchScriptHandlerMigrator20231201
|
41
41
|
}
|
42
42
|
|
@@ -76,6 +76,8 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
76
76
|
if enable_explainability:
|
77
77
|
raise NotImplementedError("Explainability is not supported for Torch Script model.")
|
78
78
|
|
79
|
+
multiple_inputs = kwargs.get("multiple_inputs", False)
|
80
|
+
|
79
81
|
import torch
|
80
82
|
|
81
83
|
assert isinstance(model, torch.jit.ScriptModule)
|
@@ -87,8 +89,6 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
87
89
|
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
88
90
|
)
|
89
91
|
|
90
|
-
multiple_inputs = kwargs.get("multiple_inputs", False)
|
91
|
-
|
92
92
|
def get_prediction(
|
93
93
|
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
94
94
|
) -> model_types.SupportedLocalDataType:
|
@@ -141,7 +141,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
141
141
|
model_meta.env.include_if_absent(
|
142
142
|
[model_env.ModelDependency(requirement="pytorch", pip_name="torch")], check_local_version=True
|
143
143
|
)
|
144
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
144
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
145
145
|
|
146
146
|
@classmethod
|
147
147
|
def load_model(
|
@@ -181,7 +181,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
181
181
|
def _create_custom_model(
|
182
182
|
raw_model: "torch.jit.ScriptModule",
|
183
183
|
model_meta: model_meta_api.ModelMetadata,
|
184
|
-
) ->
|
184
|
+
) -> type[custom_model.CustomModel]:
|
185
185
|
def fn_factory(
|
186
186
|
raw_model: "torch.jit.ScriptModule",
|
187
187
|
signature: model_signature.ModelSignature,
|
@@ -1,17 +1,7 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
2
|
import os
|
3
3
|
import warnings
|
4
|
-
from typing import
|
5
|
-
TYPE_CHECKING,
|
6
|
-
Any,
|
7
|
-
Callable,
|
8
|
-
Dict,
|
9
|
-
Optional,
|
10
|
-
Type,
|
11
|
-
Union,
|
12
|
-
cast,
|
13
|
-
final,
|
14
|
-
)
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
|
15
5
|
|
16
6
|
import numpy as np
|
17
7
|
import pandas as pd
|
@@ -44,7 +34,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
44
34
|
HANDLER_TYPE = "xgboost"
|
45
35
|
HANDLER_VERSION = "2023-12-01"
|
46
36
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
47
|
-
_HANDLER_MIGRATOR_PLANS:
|
37
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
48
38
|
|
49
39
|
MODEL_BLOB_FILE_OR_DIR = "model.ubj"
|
50
40
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
@@ -154,7 +144,12 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
154
144
|
model_type=cls.HANDLER_TYPE,
|
155
145
|
handler_version=cls.HANDLER_VERSION,
|
156
146
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
157
|
-
options=model_meta_schema.XgboostModelBlobOptions(
|
147
|
+
options=model_meta_schema.XgboostModelBlobOptions(
|
148
|
+
{
|
149
|
+
"xgb_estimator_type": model.__class__.__name__,
|
150
|
+
"enable_categorical": getattr(model, "enable_categorical", False),
|
151
|
+
}
|
152
|
+
),
|
158
153
|
)
|
159
154
|
model_meta.models[name] = base_meta
|
160
155
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -162,11 +157,6 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
162
157
|
model_meta.env.include_if_absent(
|
163
158
|
[
|
164
159
|
model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
|
165
|
-
],
|
166
|
-
check_local_version=True,
|
167
|
-
)
|
168
|
-
model_meta.env.include_if_absent(
|
169
|
-
[
|
170
160
|
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
171
161
|
],
|
172
162
|
check_local_version=True,
|
@@ -175,7 +165,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
175
165
|
if enable_explainability:
|
176
166
|
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
|
177
167
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
178
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
168
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
179
169
|
|
180
170
|
@classmethod
|
181
171
|
def load_model(
|
@@ -200,6 +190,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
200
190
|
raise ValueError("Type of XGB estimator is illegal.")
|
201
191
|
m = getattr(xgboost, xgb_estimator_type)()
|
202
192
|
m.load_model(os.path.join(model_blob_path, model_blob_filename))
|
193
|
+
m.enable_categorical = model_blob_options.get("enable_categorical", False)
|
203
194
|
|
204
195
|
if kwargs.get("use_gpu", False):
|
205
196
|
assert type(kwargs.get("use_gpu", False)) == bool
|
@@ -227,7 +218,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
227
218
|
def _create_custom_model(
|
228
219
|
raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
|
229
220
|
model_meta: model_meta_api.ModelMetadata,
|
230
|
-
) ->
|
221
|
+
) -> type[custom_model.CustomModel]:
|
231
222
|
def fn_factory(
|
232
223
|
raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
|
233
224
|
signature: model_signature.ModelSignature,
|
@@ -235,8 +226,16 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
235
226
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
236
227
|
@custom_model.inference_api
|
237
228
|
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
229
|
+
enable_categorical = False
|
230
|
+
for col, d_type in X.dtypes.items():
|
231
|
+
if pd.api.extensions.ExtensionDtype.is_dtype(d_type):
|
232
|
+
continue
|
233
|
+
if not np.issubdtype(d_type, np.number):
|
234
|
+
# categorical columns are converted to numpy's str dtype
|
235
|
+
X[col] = X[col].astype("category")
|
236
|
+
enable_categorical = True
|
238
237
|
if isinstance(raw_model, xgboost.Booster):
|
239
|
-
X = xgboost.DMatrix(X)
|
238
|
+
X = xgboost.DMatrix(X, enable_categorical=enable_categorical)
|
240
239
|
|
241
240
|
res = getattr(raw_model, target_method)(X)
|
242
241
|
|
@@ -261,7 +260,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
261
260
|
return explain_fn
|
262
261
|
return fn
|
263
262
|
|
264
|
-
type_method_dict:
|
263
|
+
type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
|
265
264
|
for target_method_name, sig in model_meta.signatures.items():
|
266
265
|
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
267
266
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import cast
|
2
2
|
|
3
3
|
from typing_extensions import Unpack
|
4
4
|
|
@@ -25,7 +25,7 @@ class ModelBlobMeta:
|
|
25
25
|
self.handler_version = kwargs["handler_version"]
|
26
26
|
self.function_properties = kwargs.get("function_properties", {})
|
27
27
|
|
28
|
-
self.artifacts:
|
28
|
+
self.artifacts: dict[str, str] = {}
|
29
29
|
artifacts = kwargs.get("artifacts", None)
|
30
30
|
if artifacts:
|
31
31
|
self.artifacts = artifacts
|