snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +26 -2
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +60 -19
- snowflake/ml/feature_store/feature_view.py +84 -30
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +96 -12
- snowflake/ml/model/_client/ops/model_ops.py +124 -6
- snowflake/ml/model/_client/ops/service_ops.py +309 -9
- snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +121 -20
- snowflake/ml/model/_model_composer/model_composer.py +11 -39
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
- snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
- snowflake/ml/model/_packager/model_packager.py +14 -8
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/type_hints.py +12 -145
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
- snowflake/ml/modeling/cluster/birch.py +61 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
- snowflake/ml/modeling/cluster/dbscan.py +61 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
- snowflake/ml/modeling/cluster/k_means.py +61 -21
- snowflake/ml/modeling/cluster/mean_shift.py +61 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
- snowflake/ml/modeling/cluster/optics.py +61 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
- snowflake/ml/modeling/compose/column_transformer.py +61 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
- snowflake/ml/modeling/covariance/oas.py +61 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/pca.py +61 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
- snowflake/ml/modeling/impute/knn_imputer.py +61 -21
- snowflake/ml/modeling/impute/missing_indicator.py +61 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/lars.py +61 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/perceptron.py +61 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ridge.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
- snowflake/ml/modeling/manifold/isomap.py +61 -21
- snowflake/ml/modeling/manifold/mds.py +61 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
- snowflake/ml/modeling/manifold/tsne.py +61 -21
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
- snowflake/ml/modeling/svm/linear_svc.py +61 -21
- snowflake/ml/modeling/svm/linear_svr.py +61 -21
- snowflake/ml/modeling/svm/nu_svc.py +61 -21
- snowflake/ml/modeling/svm/nu_svr.py +61 -21
- snowflake/ml/modeling/svm/svc.py +61 -21
- snowflake/ml/modeling/svm/svr.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/registry.py +166 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
- snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/data/torch_dataset.py +0 -33
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -104
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
|
|
1
1
|
import os
|
2
|
+
import warnings
|
2
3
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
|
3
4
|
|
4
5
|
import numpy as np
|
@@ -8,7 +9,11 @@ from typing_extensions import TypeGuard, Unpack
|
|
8
9
|
from snowflake.ml._internal import type_utils
|
9
10
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
10
11
|
from snowflake.ml.model._packager.model_env import model_env
|
11
|
-
from snowflake.ml.model._packager.model_handlers import
|
12
|
+
from snowflake.ml.model._packager.model_handlers import (
|
13
|
+
_base,
|
14
|
+
_utils as handlers_utils,
|
15
|
+
model_objective_utils,
|
16
|
+
)
|
12
17
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
13
18
|
from snowflake.ml.model._packager.model_meta import (
|
14
19
|
model_blob_meta,
|
@@ -32,22 +37,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
32
37
|
|
33
38
|
MODEL_BLOB_FILE_OR_DIR = "model.bin"
|
34
39
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
35
|
-
|
36
|
-
@classmethod
|
37
|
-
def get_model_objective(cls, model: "catboost.CatBoost") -> model_meta_schema.ModelObjective:
|
38
|
-
import catboost
|
39
|
-
|
40
|
-
if isinstance(model, catboost.CatBoostClassifier):
|
41
|
-
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
42
|
-
if num_classes == 2:
|
43
|
-
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
44
|
-
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
45
|
-
if isinstance(model, catboost.CatBoostRanker):
|
46
|
-
return model_meta_schema.ModelObjective.RANKING
|
47
|
-
if isinstance(model, catboost.CatBoostRegressor):
|
48
|
-
return model_meta_schema.ModelObjective.REGRESSION
|
49
|
-
# TODO: Find out model type from the generic Catboost Model
|
50
|
-
return model_meta_schema.ModelObjective.UNKNOWN
|
40
|
+
EXPLAIN_TARGET_METHODS = ["predict", "predict_proba"]
|
51
41
|
|
52
42
|
@classmethod
|
53
43
|
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
|
@@ -77,6 +67,8 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
77
67
|
is_sub_model: Optional[bool] = False,
|
78
68
|
**kwargs: Unpack[model_types.CatBoostModelSaveOptions],
|
79
69
|
) -> None:
|
70
|
+
enable_explainability = kwargs.get("enable_explainability", True)
|
71
|
+
|
80
72
|
import catboost
|
81
73
|
|
82
74
|
assert isinstance(model, catboost.CatBoost)
|
@@ -105,22 +97,34 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
105
97
|
sample_input_data=sample_input_data,
|
106
98
|
get_prediction_fn=get_prediction,
|
107
99
|
)
|
108
|
-
|
109
|
-
model_meta.
|
110
|
-
if
|
111
|
-
|
112
|
-
if model_objective == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
|
113
|
-
output_type = model_signature.DataType.STRING
|
100
|
+
model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
|
101
|
+
model_meta.task = model_task_and_output.task
|
102
|
+
if enable_explainability:
|
103
|
+
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
114
104
|
model_meta = handlers_utils.add_explain_method_signature(
|
115
105
|
model_meta=model_meta,
|
116
106
|
explain_method="explain",
|
117
|
-
target_method=
|
118
|
-
output_return_type=output_type,
|
107
|
+
target_method=explain_target_method,
|
108
|
+
output_return_type=model_task_and_output.output_type,
|
119
109
|
)
|
120
110
|
model_meta.function_properties = {
|
121
111
|
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
122
112
|
}
|
123
113
|
|
114
|
+
background_data = handlers_utils.get_explainability_supported_background(
|
115
|
+
sample_input_data, model_meta, explain_target_method
|
116
|
+
)
|
117
|
+
if background_data is not None:
|
118
|
+
handlers_utils.save_background_data(
|
119
|
+
model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
|
120
|
+
)
|
121
|
+
else:
|
122
|
+
warnings.warn(
|
123
|
+
"sample_input_data should be provided for better explainability results",
|
124
|
+
category=UserWarning,
|
125
|
+
stacklevel=1,
|
126
|
+
)
|
127
|
+
|
124
128
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
125
129
|
os.makedirs(model_blob_path, exist_ok=True)
|
126
130
|
model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
@@ -143,11 +147,8 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
143
147
|
],
|
144
148
|
check_local_version=True,
|
145
149
|
)
|
146
|
-
if
|
147
|
-
model_meta.env.include_if_absent(
|
148
|
-
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
149
|
-
check_local_version=True,
|
150
|
-
)
|
150
|
+
if enable_explainability:
|
151
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
151
152
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
152
153
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
153
154
|
|
@@ -242,10 +242,10 @@ class HuggingFacePipelineHandler(
|
|
242
242
|
task, spcs_only=(not type_utils.LazyType("transformers.Pipeline").isinstance(model))
|
243
243
|
)
|
244
244
|
if framework is None or framework == "pt":
|
245
|
-
# Since we set default cuda version to be 11.
|
246
|
-
# Pytorch version that works with CUDA 11.
|
245
|
+
# Since we set default cuda version to be 11.8, to make sure it works with GPU, we need to have a default
|
246
|
+
# Pytorch version that works with CUDA 11.8 as well. This is required for huggingface pipelines only as
|
247
247
|
# users are not required to install pytorch locally if they are using the wrapper.
|
248
|
-
pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch
|
248
|
+
pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch", pip_name="torch"))
|
249
249
|
elif framework == "tf":
|
250
250
|
pkgs_requirements.append(model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"))
|
251
251
|
model_meta.env.include_if_absent(
|
@@ -369,7 +369,9 @@ class HuggingFacePipelineHandler(
|
|
369
369
|
else:
|
370
370
|
# For others, we could offer the whole dataframe as a list.
|
371
371
|
# Some of them may need some conversion
|
372
|
-
if
|
372
|
+
if hasattr(transformers, "ConversationalPipeline") and isinstance(
|
373
|
+
raw_model, transformers.ConversationalPipeline
|
374
|
+
):
|
373
375
|
input_data = [
|
374
376
|
transformers.Conversation(
|
375
377
|
text=conv_data["user_inputs"][0],
|
@@ -391,27 +393,33 @@ class HuggingFacePipelineHandler(
|
|
391
393
|
# Making it not aligned with the auto-inferred signature.
|
392
394
|
# If the output is a dict, we could blindly create a list containing that.
|
393
395
|
# Otherwise, creating pandas DataFrame won't succeed.
|
394
|
-
if
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
396
|
+
if (
|
397
|
+
(hasattr(transformers, "Conversation") and isinstance(temp_res, transformers.Conversation))
|
398
|
+
or isinstance(temp_res, dict)
|
399
|
+
or (
|
400
|
+
# For some pipeline that is expected to generate a list of dict per input
|
401
|
+
# When it omit outer list, it becomes list of dict instead of list of list of dict.
|
402
|
+
# We need to distinguish them from those pipelines that designed to output a dict per input
|
403
|
+
# So we need to check the pipeline type.
|
404
|
+
isinstance(
|
405
|
+
raw_model,
|
406
|
+
(
|
407
|
+
transformers.FillMaskPipeline,
|
408
|
+
transformers.QuestionAnsweringPipeline,
|
409
|
+
),
|
410
|
+
)
|
411
|
+
and X.shape[0] == 1
|
412
|
+
and isinstance(temp_res[0], dict)
|
405
413
|
)
|
406
|
-
and X.shape[0] == 1
|
407
|
-
and isinstance(temp_res[0], dict)
|
408
414
|
):
|
409
415
|
temp_res = [temp_res]
|
410
416
|
|
411
417
|
if len(temp_res) == 0:
|
412
418
|
return pd.DataFrame()
|
413
419
|
|
414
|
-
if
|
420
|
+
if hasattr(transformers, "ConversationalPipeline") and isinstance(
|
421
|
+
raw_model, transformers.ConversationalPipeline
|
422
|
+
):
|
415
423
|
temp_res = [[conv.generated_responses] for conv in temp_res]
|
416
424
|
|
417
425
|
# To concat those who outputs a list with one input.
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import os
|
2
|
+
import warnings
|
2
3
|
from typing import (
|
3
4
|
TYPE_CHECKING,
|
4
5
|
Any,
|
@@ -19,7 +20,11 @@ from typing_extensions import TypeGuard, Unpack
|
|
19
20
|
from snowflake.ml._internal import type_utils
|
20
21
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
21
22
|
from snowflake.ml.model._packager.model_env import model_env
|
22
|
-
from snowflake.ml.model._packager.model_handlers import
|
23
|
+
from snowflake.ml.model._packager.model_handlers import (
|
24
|
+
_base,
|
25
|
+
_utils as handlers_utils,
|
26
|
+
model_objective_utils,
|
27
|
+
)
|
23
28
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
24
29
|
from snowflake.ml.model._packager.model_meta import (
|
25
30
|
model_blob_meta,
|
@@ -43,47 +48,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
43
48
|
|
44
49
|
MODEL_BLOB_FILE_OR_DIR = "model.pkl"
|
45
50
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
46
|
-
|
47
|
-
_MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
|
48
|
-
_RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
|
49
|
-
_REGRESSION_OBJECTIVES = [
|
50
|
-
"regression",
|
51
|
-
"regression_l1",
|
52
|
-
"huber",
|
53
|
-
"fair",
|
54
|
-
"poisson",
|
55
|
-
"quantile",
|
56
|
-
"tweedie",
|
57
|
-
"mape",
|
58
|
-
"gamma",
|
59
|
-
]
|
60
|
-
|
61
|
-
@classmethod
|
62
|
-
def get_model_objective(
|
63
|
-
cls, model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]
|
64
|
-
) -> model_meta_schema.ModelObjective:
|
65
|
-
import lightgbm
|
66
|
-
|
67
|
-
# does not account for cross-entropy and custom
|
68
|
-
if isinstance(model, lightgbm.LGBMClassifier):
|
69
|
-
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
70
|
-
if num_classes == 2:
|
71
|
-
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
72
|
-
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
73
|
-
if isinstance(model, lightgbm.LGBMRanker):
|
74
|
-
return model_meta_schema.ModelObjective.RANKING
|
75
|
-
if isinstance(model, lightgbm.LGBMRegressor):
|
76
|
-
return model_meta_schema.ModelObjective.REGRESSION
|
77
|
-
model_objective = model.params["objective"]
|
78
|
-
if model_objective in cls._BINARY_CLASSIFICATION_OBJECTIVES:
|
79
|
-
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
80
|
-
if model_objective in cls._MULTI_CLASSIFICATION_OBJECTIVES:
|
81
|
-
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
82
|
-
if model_objective in cls._RANKING_OBJECTIVES:
|
83
|
-
return model_meta_schema.ModelObjective.RANKING
|
84
|
-
if model_objective in cls._REGRESSION_OBJECTIVES:
|
85
|
-
return model_meta_schema.ModelObjective.REGRESSION
|
86
|
-
return model_meta_schema.ModelObjective.UNKNOWN
|
51
|
+
EXPLAIN_TARGET_METHODS = ["predict", "predict_proba"]
|
87
52
|
|
88
53
|
@classmethod
|
89
54
|
def can_handle(
|
@@ -118,6 +83,8 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
118
83
|
is_sub_model: Optional[bool] = False,
|
119
84
|
**kwargs: Unpack[model_types.LGBMModelSaveOptions],
|
120
85
|
) -> None:
|
86
|
+
enable_explainability = kwargs.get("enable_explainability", True)
|
87
|
+
|
121
88
|
import lightgbm
|
122
89
|
|
123
90
|
assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
|
@@ -146,25 +113,34 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
146
113
|
sample_input_data=sample_input_data,
|
147
114
|
get_prediction_fn=get_prediction,
|
148
115
|
)
|
149
|
-
|
150
|
-
model_meta.
|
151
|
-
if
|
152
|
-
|
153
|
-
if model_objective in [
|
154
|
-
model_meta_schema.ModelObjective.BINARY_CLASSIFICATION,
|
155
|
-
model_meta_schema.ModelObjective.MULTI_CLASSIFICATION,
|
156
|
-
]:
|
157
|
-
output_type = model_signature.DataType.STRING
|
116
|
+
model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
|
117
|
+
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
|
118
|
+
if enable_explainability:
|
119
|
+
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
158
120
|
model_meta = handlers_utils.add_explain_method_signature(
|
159
121
|
model_meta=model_meta,
|
160
122
|
explain_method="explain",
|
161
|
-
target_method=
|
162
|
-
output_return_type=output_type,
|
123
|
+
target_method=explain_target_method,
|
124
|
+
output_return_type=model_task_and_output.output_type,
|
163
125
|
)
|
164
126
|
model_meta.function_properties = {
|
165
127
|
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
166
128
|
}
|
167
129
|
|
130
|
+
background_data = handlers_utils.get_explainability_supported_background(
|
131
|
+
sample_input_data, model_meta, explain_target_method
|
132
|
+
)
|
133
|
+
if background_data is not None:
|
134
|
+
handlers_utils.save_background_data(
|
135
|
+
model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
|
136
|
+
)
|
137
|
+
else:
|
138
|
+
warnings.warn(
|
139
|
+
"sample_input_data should be provided for better explainability results",
|
140
|
+
category=UserWarning,
|
141
|
+
stacklevel=1,
|
142
|
+
)
|
143
|
+
|
168
144
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
169
145
|
os.makedirs(model_blob_path, exist_ok=True)
|
170
146
|
|
@@ -189,11 +165,8 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
189
165
|
],
|
190
166
|
check_local_version=True,
|
191
167
|
)
|
192
|
-
if
|
193
|
-
model_meta.env.include_if_absent(
|
194
|
-
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
195
|
-
check_local_version=True,
|
196
|
-
)
|
168
|
+
if enable_explainability:
|
169
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
197
170
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
198
171
|
|
199
172
|
return None
|
@@ -168,11 +168,6 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
168
168
|
) -> "mlflow.pyfunc.PyFuncModel":
|
169
169
|
import mlflow
|
170
170
|
|
171
|
-
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
172
|
-
# We need to redirect the mlruns folder to a writable location in the sandbox.
|
173
|
-
tmpdir = tempfile.TemporaryDirectory(dir="/tmp")
|
174
|
-
mlflow.set_tracking_uri(f"file://{tmpdir}")
|
175
|
-
|
176
171
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
177
172
|
model_blobs_metadata = model_meta.models
|
178
173
|
model_blob_metadata = model_blobs_metadata[name]
|
@@ -183,6 +178,9 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
183
178
|
model_artifact_path = model_blob_options["artifact_path"]
|
184
179
|
model_blob_filename = model_blob_metadata.path
|
185
180
|
|
181
|
+
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
182
|
+
return mlflow.pyfunc.load_model(os.path.join(model_blob_path, model_blob_filename, model_artifact_path))
|
183
|
+
|
186
184
|
# This is to make sure the loaded model can be saved again.
|
187
185
|
with mlflow.start_run() as run:
|
188
186
|
mlflow.log_artifacts(
|
@@ -0,0 +1,169 @@
|
|
1
|
+
import json
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import TYPE_CHECKING, Any, Union
|
4
|
+
|
5
|
+
from snowflake.ml._internal import type_utils
|
6
|
+
from snowflake.ml.model import model_signature, type_hints
|
7
|
+
from snowflake.ml.model._packager.model_handlers import _utils as handlers_utils
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
import catboost
|
11
|
+
import lightgbm
|
12
|
+
import sklearn
|
13
|
+
import sklearn.pipeline
|
14
|
+
import xgboost
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass
|
18
|
+
class ModelTaskAndOutputType:
|
19
|
+
task: type_hints.Task
|
20
|
+
output_type: model_signature.DataType
|
21
|
+
|
22
|
+
|
23
|
+
def get_task_skl(model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]) -> type_hints.Task:
|
24
|
+
from sklearn.base import is_classifier, is_regressor
|
25
|
+
|
26
|
+
if type_utils.LazyType("sklearn.pipeline.Pipeline").isinstance(model):
|
27
|
+
return type_hints.Task.UNKNOWN
|
28
|
+
if is_regressor(model):
|
29
|
+
return type_hints.Task.TABULAR_REGRESSION
|
30
|
+
if is_classifier(model):
|
31
|
+
classes_list = getattr(model, "classes_", [])
|
32
|
+
num_classes = getattr(model, "n_classes_", None) or len(classes_list)
|
33
|
+
if isinstance(num_classes, int):
|
34
|
+
if num_classes > 2:
|
35
|
+
return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
|
36
|
+
return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
|
37
|
+
return type_hints.Task.UNKNOWN
|
38
|
+
return type_hints.Task.UNKNOWN
|
39
|
+
|
40
|
+
|
41
|
+
def get_model_task_catboost(model: "catboost.CatBoost") -> type_hints.Task:
|
42
|
+
loss_function = None
|
43
|
+
if type_utils.LazyType("catboost.CatBoost").isinstance(model):
|
44
|
+
loss_function = model.get_all_params()["loss_function"] # type: ignore[attr-defined]
|
45
|
+
|
46
|
+
if (type_utils.LazyType("catboost.CatBoostClassifier").isinstance(model)) or model._is_classification_objective(
|
47
|
+
loss_function
|
48
|
+
):
|
49
|
+
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
50
|
+
if num_classes == 0:
|
51
|
+
return type_hints.Task.UNKNOWN
|
52
|
+
if num_classes <= 2:
|
53
|
+
return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
|
54
|
+
return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
|
55
|
+
if (type_utils.LazyType("catboost.CatBoostRanker").isinstance(model)) or model._is_ranking_objective(loss_function):
|
56
|
+
return type_hints.Task.TABULAR_RANKING
|
57
|
+
if (type_utils.LazyType("catboost.CatBoostRegressor").isinstance(model)) or model._is_regression_objective(
|
58
|
+
loss_function
|
59
|
+
):
|
60
|
+
return type_hints.Task.TABULAR_REGRESSION
|
61
|
+
|
62
|
+
return type_hints.Task.UNKNOWN
|
63
|
+
|
64
|
+
|
65
|
+
def get_model_task_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> type_hints.Task:
|
66
|
+
|
67
|
+
_BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
|
68
|
+
_MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
|
69
|
+
_RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
|
70
|
+
_REGRESSION_OBJECTIVES = [
|
71
|
+
"regression",
|
72
|
+
"regression_l1",
|
73
|
+
"huber",
|
74
|
+
"fair",
|
75
|
+
"poisson",
|
76
|
+
"quantile",
|
77
|
+
"tweedie",
|
78
|
+
"mape",
|
79
|
+
"gamma",
|
80
|
+
]
|
81
|
+
|
82
|
+
# does not account for cross-entropy and custom
|
83
|
+
model_task = ""
|
84
|
+
if type_utils.LazyType("lightgbm.Booster").isinstance(model):
|
85
|
+
model_task = model.params["objective"] # type: ignore[attr-defined]
|
86
|
+
elif hasattr(model, "objective_"):
|
87
|
+
model_task = model.objective_
|
88
|
+
if model_task in _BINARY_CLASSIFICATION_OBJECTIVES:
|
89
|
+
return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
|
90
|
+
if model_task in _MULTI_CLASSIFICATION_OBJECTIVES:
|
91
|
+
return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
|
92
|
+
if model_task in _RANKING_OBJECTIVES:
|
93
|
+
return type_hints.Task.TABULAR_RANKING
|
94
|
+
if model_task in _REGRESSION_OBJECTIVES:
|
95
|
+
return type_hints.Task.TABULAR_REGRESSION
|
96
|
+
return type_hints.Task.UNKNOWN
|
97
|
+
|
98
|
+
|
99
|
+
def get_model_task_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> type_hints.Task:
|
100
|
+
|
101
|
+
_BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
|
102
|
+
_MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
|
103
|
+
_RANKING_OBJECTIVE_PREFIX = ["rank:"]
|
104
|
+
_REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
|
105
|
+
|
106
|
+
model_task = ""
|
107
|
+
if type_utils.LazyType("xgboost.Booster").isinstance(model):
|
108
|
+
model_params = json.loads(model.save_config()) # type: ignore[attr-defined]
|
109
|
+
model_task = model_params.get("learner", {}).get("objective", "")
|
110
|
+
else:
|
111
|
+
if hasattr(model, "get_params"):
|
112
|
+
model_task = model.get_params().get("objective", "")
|
113
|
+
|
114
|
+
if isinstance(model_task, dict):
|
115
|
+
model_task = model_task.get("name", "")
|
116
|
+
for classification_objective in _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
|
117
|
+
if classification_objective in model_task:
|
118
|
+
return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
|
119
|
+
for classification_objective in _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
|
120
|
+
if classification_objective in model_task:
|
121
|
+
return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
|
122
|
+
for ranking_objective in _RANKING_OBJECTIVE_PREFIX:
|
123
|
+
if ranking_objective in model_task:
|
124
|
+
return type_hints.Task.TABULAR_RANKING
|
125
|
+
for regression_objective in _REGRESSION_OBJECTIVE_PREFIX:
|
126
|
+
if regression_objective in model_task:
|
127
|
+
return type_hints.Task.TABULAR_REGRESSION
|
128
|
+
return type_hints.Task.UNKNOWN
|
129
|
+
|
130
|
+
|
131
|
+
def get_model_task_and_output_type(model: Any) -> ModelTaskAndOutputType:
|
132
|
+
if type_utils.LazyType("xgboost.Booster").isinstance(model) or type_utils.LazyType("xgboost.XGBModel").isinstance(
|
133
|
+
model
|
134
|
+
):
|
135
|
+
task = get_model_task_xgb(model)
|
136
|
+
output_type = model_signature.DataType.DOUBLE
|
137
|
+
if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
|
138
|
+
output_type = model_signature.DataType.STRING
|
139
|
+
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
140
|
+
|
141
|
+
if type_utils.LazyType("lightgbm.Booster").isinstance(model) or type_utils.LazyType(
|
142
|
+
"lightgbm.LGBMModel"
|
143
|
+
).isinstance(model):
|
144
|
+
task = get_model_task_lightgbm(model)
|
145
|
+
output_type = model_signature.DataType.DOUBLE
|
146
|
+
if task in [
|
147
|
+
type_hints.Task.TABULAR_BINARY_CLASSIFICATION,
|
148
|
+
type_hints.Task.TABULAR_MULTI_CLASSIFICATION,
|
149
|
+
]:
|
150
|
+
output_type = model_signature.DataType.STRING
|
151
|
+
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
152
|
+
|
153
|
+
if type_utils.LazyType("catboost.CatBoost").isinstance(model):
|
154
|
+
task = get_model_task_catboost(model)
|
155
|
+
output_type = model_signature.DataType.DOUBLE
|
156
|
+
if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
|
157
|
+
output_type = model_signature.DataType.STRING
|
158
|
+
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
159
|
+
|
160
|
+
if type_utils.LazyType("sklearn.base.BaseEstimator").isinstance(model) or type_utils.LazyType(
|
161
|
+
"sklearn.pipeline.Pipeline"
|
162
|
+
).isinstance(model):
|
163
|
+
task = get_task_skl(model)
|
164
|
+
output_type = model_signature.DataType.DOUBLE
|
165
|
+
if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
|
166
|
+
output_type = model_signature.DataType.STRING
|
167
|
+
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
168
|
+
|
169
|
+
raise ValueError(f"Model type {type(model)} is not supported")
|
@@ -2,7 +2,6 @@ import logging
|
|
2
2
|
import os
|
3
3
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
4
4
|
|
5
|
-
import cloudpickle
|
6
5
|
import pandas as pd
|
7
6
|
from typing_extensions import TypeGuard, Unpack
|
8
7
|
|
@@ -120,9 +119,21 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
120
119
|
model_meta.env.include_if_absent(
|
121
120
|
[
|
122
121
|
model_env.ModelDependency(requirement="sentence-transformers", pip_name="sentence-transformers"),
|
122
|
+
model_env.ModelDependency(requirement="transformers", pip_name="transformers"),
|
123
|
+
model_env.ModelDependency(requirement="pytorch", pip_name="torch"),
|
123
124
|
],
|
124
125
|
check_local_version=True,
|
125
126
|
)
|
127
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
128
|
+
|
129
|
+
@staticmethod
|
130
|
+
def _get_device_config(**kwargs: Unpack[model_types.SentenceTransformersLoadOptions]) -> Optional[str]:
|
131
|
+
if kwargs.get("device", None) is not None:
|
132
|
+
return kwargs["device"]
|
133
|
+
elif kwargs.get("use_gpu", False):
|
134
|
+
return "cuda"
|
135
|
+
|
136
|
+
return None
|
126
137
|
|
127
138
|
@classmethod
|
128
139
|
def load_model(
|
@@ -144,13 +155,9 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
144
155
|
model_blob_filename = model_blob_metadata.path
|
145
156
|
model_blob_file_or_dir_path = os.path.join(model_blob_path, model_blob_filename)
|
146
157
|
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
assert os.path.isfile(model_blob_file_or_dir_path) # if the saved model is a file
|
151
|
-
with open(model_blob_file_or_dir_path, "rb") as f:
|
152
|
-
model = cloudpickle.load(f)
|
153
|
-
assert isinstance(model, sentence_transformers.SentenceTransformer)
|
158
|
+
model = sentence_transformers.SentenceTransformer(
|
159
|
+
model_blob_file_or_dir_path, device=cls._get_device_config(**kwargs)
|
160
|
+
)
|
154
161
|
return model
|
155
162
|
|
156
163
|
@classmethod
|