snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,206 @@
|
|
1
|
+
import os
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import pandas as pd
|
6
|
+
from typing_extensions import TypeGuard, Unpack
|
7
|
+
|
8
|
+
from snowflake.ml._internal import type_utils
|
9
|
+
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
10
|
+
from snowflake.ml.model._packager.model_env import model_env
|
11
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
12
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
13
|
+
from snowflake.ml.model._packager.model_meta import (
|
14
|
+
model_blob_meta,
|
15
|
+
model_meta as model_meta_api,
|
16
|
+
model_meta_schema,
|
17
|
+
)
|
18
|
+
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
import catboost
|
22
|
+
|
23
|
+
|
24
|
+
@final
|
25
|
+
class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
26
|
+
"""Handler for CatBoost based model."""
|
27
|
+
|
28
|
+
HANDLER_TYPE = "catboost"
|
29
|
+
HANDLER_VERSION = "2024-03-21"
|
30
|
+
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
31
|
+
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
32
|
+
|
33
|
+
MODELE_BLOB_FILE_OR_DIR = "model.bin"
|
34
|
+
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
35
|
+
|
36
|
+
@classmethod
|
37
|
+
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
|
38
|
+
return (type_utils.LazyType("catboost.CatBoost").isinstance(model)) and any(
|
39
|
+
(hasattr(model, method) and callable(getattr(model, method, None))) for method in cls.DEFAULT_TARGET_METHODS
|
40
|
+
)
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
def cast_model(
|
44
|
+
cls,
|
45
|
+
model: model_types.SupportedModelType,
|
46
|
+
) -> "catboost.CatBoost":
|
47
|
+
import catboost
|
48
|
+
|
49
|
+
assert isinstance(model, catboost.CatBoost)
|
50
|
+
|
51
|
+
return model
|
52
|
+
|
53
|
+
@classmethod
|
54
|
+
def save_model(
|
55
|
+
cls,
|
56
|
+
name: str,
|
57
|
+
model: "catboost.CatBoost",
|
58
|
+
model_meta: model_meta_api.ModelMetadata,
|
59
|
+
model_blobs_dir_path: str,
|
60
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
61
|
+
is_sub_model: Optional[bool] = False,
|
62
|
+
**kwargs: Unpack[model_types.CatBoostModelSaveOptions],
|
63
|
+
) -> None:
|
64
|
+
import catboost
|
65
|
+
|
66
|
+
assert isinstance(model, catboost.CatBoost)
|
67
|
+
|
68
|
+
if not is_sub_model:
|
69
|
+
target_methods = handlers_utils.get_target_methods(
|
70
|
+
model=model,
|
71
|
+
target_methods=kwargs.pop("target_methods", None),
|
72
|
+
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
73
|
+
)
|
74
|
+
|
75
|
+
def get_prediction(
|
76
|
+
target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
|
77
|
+
) -> model_types.SupportedLocalDataType:
|
78
|
+
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
|
79
|
+
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
80
|
+
target_method = getattr(model, target_method_name, None)
|
81
|
+
assert callable(target_method)
|
82
|
+
predictions_df = target_method(sample_input_data)
|
83
|
+
return predictions_df
|
84
|
+
|
85
|
+
model_meta = handlers_utils.validate_signature(
|
86
|
+
model=model,
|
87
|
+
model_meta=model_meta,
|
88
|
+
target_methods=target_methods,
|
89
|
+
sample_input_data=sample_input_data,
|
90
|
+
get_prediction_fn=get_prediction,
|
91
|
+
)
|
92
|
+
|
93
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
94
|
+
os.makedirs(model_blob_path, exist_ok=True)
|
95
|
+
model_save_path = os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
|
96
|
+
|
97
|
+
model.save_model(model_save_path)
|
98
|
+
|
99
|
+
base_meta = model_blob_meta.ModelBlobMeta(
|
100
|
+
name=name,
|
101
|
+
model_type=cls.HANDLER_TYPE,
|
102
|
+
handler_version=cls.HANDLER_VERSION,
|
103
|
+
path=cls.MODELE_BLOB_FILE_OR_DIR,
|
104
|
+
options=model_meta_schema.CatBoostModelBlobOptions({"catboost_estimator_type": model.__class__.__name__}),
|
105
|
+
)
|
106
|
+
model_meta.models[name] = base_meta
|
107
|
+
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
108
|
+
|
109
|
+
model_meta.env.include_if_absent(
|
110
|
+
[
|
111
|
+
model_env.ModelDependency(requirement="catboost", pip_name="catboost"),
|
112
|
+
],
|
113
|
+
check_local_version=True,
|
114
|
+
)
|
115
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
116
|
+
|
117
|
+
return None
|
118
|
+
|
119
|
+
@classmethod
|
120
|
+
def load_model(
|
121
|
+
cls,
|
122
|
+
name: str,
|
123
|
+
model_meta: model_meta_api.ModelMetadata,
|
124
|
+
model_blobs_dir_path: str,
|
125
|
+
**kwargs: Unpack[model_types.ModelLoadOption],
|
126
|
+
) -> "catboost.CatBoost":
|
127
|
+
import catboost
|
128
|
+
|
129
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
130
|
+
model_blobs_metadata = model_meta.models
|
131
|
+
model_blob_metadata = model_blobs_metadata[name]
|
132
|
+
model_blob_filename = model_blob_metadata.path
|
133
|
+
model_blob_file_path = os.path.join(model_blob_path, model_blob_filename)
|
134
|
+
|
135
|
+
model_blob_options = cast(model_meta_schema.CatBoostModelBlobOptions, model_blob_metadata.options)
|
136
|
+
if "catboost_estimator_type" not in model_blob_options:
|
137
|
+
raise ValueError("Missing field `catboost_estimator_type` in model blob metadata for type `catboost`")
|
138
|
+
|
139
|
+
catboost_estimator_type = model_blob_options["catboost_estimator_type"]
|
140
|
+
if not hasattr(catboost, catboost_estimator_type):
|
141
|
+
raise ValueError("Type of CatBoost estimator is not supported.")
|
142
|
+
|
143
|
+
assert os.path.isfile(model_blob_file_path) # saved model is a file
|
144
|
+
model = getattr(catboost, catboost_estimator_type)()
|
145
|
+
model.load_model(model_blob_file_path)
|
146
|
+
assert isinstance(model, getattr(catboost, catboost_estimator_type))
|
147
|
+
|
148
|
+
if kwargs.get("use_gpu", False):
|
149
|
+
assert type(kwargs.get("use_gpu", False)) == bool
|
150
|
+
gpu_params = {"task_type": "GPU"}
|
151
|
+
model.__dict__.update(gpu_params)
|
152
|
+
|
153
|
+
return model
|
154
|
+
|
155
|
+
@classmethod
|
156
|
+
def convert_as_custom_model(
|
157
|
+
cls,
|
158
|
+
raw_model: "catboost.CatBoost",
|
159
|
+
model_meta: model_meta_api.ModelMetadata,
|
160
|
+
**kwargs: Unpack[model_types.ModelLoadOption],
|
161
|
+
) -> custom_model.CustomModel:
|
162
|
+
import catboost
|
163
|
+
|
164
|
+
from snowflake.ml.model import custom_model
|
165
|
+
|
166
|
+
def _create_custom_model(
|
167
|
+
raw_model: "catboost.CatBoost",
|
168
|
+
model_meta: model_meta_api.ModelMetadata,
|
169
|
+
) -> Type[custom_model.CustomModel]:
|
170
|
+
def fn_factory(
|
171
|
+
raw_model: "catboost.CatBoost",
|
172
|
+
signature: model_signature.ModelSignature,
|
173
|
+
target_method: str,
|
174
|
+
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
175
|
+
@custom_model.inference_api
|
176
|
+
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
177
|
+
|
178
|
+
res = getattr(raw_model, target_method)(X)
|
179
|
+
|
180
|
+
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
|
181
|
+
# In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
|
182
|
+
# return a list of ndarrays. We need to deal them separately
|
183
|
+
df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
|
184
|
+
else:
|
185
|
+
df = pd.DataFrame(res)
|
186
|
+
|
187
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
188
|
+
|
189
|
+
return fn
|
190
|
+
|
191
|
+
type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
|
192
|
+
for target_method_name, sig in model_meta.signatures.items():
|
193
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
194
|
+
|
195
|
+
_CatBoostModel = type(
|
196
|
+
"_CatBoostModel",
|
197
|
+
(custom_model.CustomModel,),
|
198
|
+
type_method_dict,
|
199
|
+
)
|
200
|
+
|
201
|
+
return _CatBoostModel
|
202
|
+
|
203
|
+
_CatBoostModel = _create_custom_model(raw_model, model_meta)
|
204
|
+
catboost_model = _CatBoostModel(custom_model.ModelContext())
|
205
|
+
|
206
|
+
return catboost_model
|
@@ -0,0 +1,218 @@
|
|
1
|
+
import os
|
2
|
+
from typing import (
|
3
|
+
TYPE_CHECKING,
|
4
|
+
Any,
|
5
|
+
Callable,
|
6
|
+
Dict,
|
7
|
+
Optional,
|
8
|
+
Type,
|
9
|
+
Union,
|
10
|
+
cast,
|
11
|
+
final,
|
12
|
+
)
|
13
|
+
|
14
|
+
import cloudpickle
|
15
|
+
import numpy as np
|
16
|
+
import pandas as pd
|
17
|
+
from typing_extensions import TypeGuard, Unpack
|
18
|
+
|
19
|
+
from snowflake.ml._internal import type_utils
|
20
|
+
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
21
|
+
from snowflake.ml.model._packager.model_env import model_env
|
22
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
23
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
24
|
+
from snowflake.ml.model._packager.model_meta import (
|
25
|
+
model_blob_meta,
|
26
|
+
model_meta as model_meta_api,
|
27
|
+
model_meta_schema,
|
28
|
+
)
|
29
|
+
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
30
|
+
|
31
|
+
if TYPE_CHECKING:
|
32
|
+
import lightgbm
|
33
|
+
|
34
|
+
|
35
|
+
@final
|
36
|
+
class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgbm.LGBMModel"]]):
|
37
|
+
"""Handler for LightGBM based model."""
|
38
|
+
|
39
|
+
HANDLER_TYPE = "lightgbm"
|
40
|
+
HANDLER_VERSION = "2024-03-19"
|
41
|
+
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
42
|
+
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
43
|
+
|
44
|
+
MODELE_BLOB_FILE_OR_DIR = "model.pkl"
|
45
|
+
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def can_handle(
|
49
|
+
cls, model: model_types.SupportedModelType
|
50
|
+
) -> TypeGuard[Union["lightgbm.Booster", "lightgbm.LGBMModel"]]:
|
51
|
+
return (
|
52
|
+
type_utils.LazyType("lightgbm.Booster").isinstance(model)
|
53
|
+
or type_utils.LazyType("lightgbm.LGBMModel").isinstance(model)
|
54
|
+
) and any(
|
55
|
+
(hasattr(model, method) and callable(getattr(model, method, None))) for method in cls.DEFAULT_TARGET_METHODS
|
56
|
+
)
|
57
|
+
|
58
|
+
@classmethod
|
59
|
+
def cast_model(
|
60
|
+
cls,
|
61
|
+
model: model_types.SupportedModelType,
|
62
|
+
) -> Union["lightgbm.Booster", "lightgbm.LGBMModel"]:
|
63
|
+
import lightgbm
|
64
|
+
|
65
|
+
assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
|
66
|
+
|
67
|
+
return model
|
68
|
+
|
69
|
+
@classmethod
|
70
|
+
def save_model(
|
71
|
+
cls,
|
72
|
+
name: str,
|
73
|
+
model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
|
74
|
+
model_meta: model_meta_api.ModelMetadata,
|
75
|
+
model_blobs_dir_path: str,
|
76
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
77
|
+
is_sub_model: Optional[bool] = False,
|
78
|
+
**kwargs: Unpack[model_types.LGBMModelSaveOptions],
|
79
|
+
) -> None:
|
80
|
+
import lightgbm
|
81
|
+
|
82
|
+
assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
|
83
|
+
|
84
|
+
if not is_sub_model:
|
85
|
+
target_methods = handlers_utils.get_target_methods(
|
86
|
+
model=model,
|
87
|
+
target_methods=kwargs.pop("target_methods", None),
|
88
|
+
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
89
|
+
)
|
90
|
+
|
91
|
+
def get_prediction(
|
92
|
+
target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
|
93
|
+
) -> model_types.SupportedLocalDataType:
|
94
|
+
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
|
95
|
+
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
96
|
+
target_method = getattr(model, target_method_name, None)
|
97
|
+
assert callable(target_method)
|
98
|
+
predictions_df = target_method(sample_input_data)
|
99
|
+
return predictions_df
|
100
|
+
|
101
|
+
model_meta = handlers_utils.validate_signature(
|
102
|
+
model=model,
|
103
|
+
model_meta=model_meta,
|
104
|
+
target_methods=target_methods,
|
105
|
+
sample_input_data=sample_input_data,
|
106
|
+
get_prediction_fn=get_prediction,
|
107
|
+
)
|
108
|
+
|
109
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
110
|
+
os.makedirs(model_blob_path, exist_ok=True)
|
111
|
+
|
112
|
+
model_save_path = os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
|
113
|
+
with open(model_save_path, "wb") as f:
|
114
|
+
cloudpickle.dump(model, f)
|
115
|
+
|
116
|
+
base_meta = model_blob_meta.ModelBlobMeta(
|
117
|
+
name=name,
|
118
|
+
model_type=cls.HANDLER_TYPE,
|
119
|
+
handler_version=cls.HANDLER_VERSION,
|
120
|
+
path=cls.MODELE_BLOB_FILE_OR_DIR,
|
121
|
+
options=model_meta_schema.LightGBMModelBlobOptions({"lightgbm_estimator_type": model.__class__.__name__}),
|
122
|
+
)
|
123
|
+
model_meta.models[name] = base_meta
|
124
|
+
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
125
|
+
|
126
|
+
model_meta.env.include_if_absent(
|
127
|
+
[
|
128
|
+
model_env.ModelDependency(requirement="lightgbm", pip_name="lightgbm"),
|
129
|
+
model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
|
130
|
+
],
|
131
|
+
check_local_version=True,
|
132
|
+
)
|
133
|
+
|
134
|
+
return None
|
135
|
+
|
136
|
+
@classmethod
|
137
|
+
def load_model(
|
138
|
+
cls,
|
139
|
+
name: str,
|
140
|
+
model_meta: model_meta_api.ModelMetadata,
|
141
|
+
model_blobs_dir_path: str,
|
142
|
+
**kwargs: Unpack[model_types.ModelLoadOption],
|
143
|
+
) -> Union["lightgbm.Booster", "lightgbm.LGBMModel"]:
|
144
|
+
import lightgbm
|
145
|
+
|
146
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
147
|
+
model_blobs_metadata = model_meta.models
|
148
|
+
model_blob_metadata = model_blobs_metadata[name]
|
149
|
+
model_blob_filename = model_blob_metadata.path
|
150
|
+
model_blob_file_path = os.path.join(model_blob_path, model_blob_filename)
|
151
|
+
|
152
|
+
model_blob_options = cast(model_meta_schema.LightGBMModelBlobOptions, model_blob_metadata.options)
|
153
|
+
if "lightgbm_estimator_type" not in model_blob_options:
|
154
|
+
raise ValueError("Missing field `lightgbm_estimator_type` in model blob metadata for type `lightgbm`")
|
155
|
+
|
156
|
+
lightgbm_estimator_type = model_blob_options["lightgbm_estimator_type"]
|
157
|
+
if not hasattr(lightgbm, lightgbm_estimator_type):
|
158
|
+
raise ValueError("Type of LightGBM estimator is not supported.")
|
159
|
+
|
160
|
+
assert os.path.isfile(model_blob_file_path) # saved model is a file
|
161
|
+
with open(model_blob_file_path, "rb") as f:
|
162
|
+
model = cloudpickle.load(f)
|
163
|
+
assert isinstance(model, getattr(lightgbm, lightgbm_estimator_type))
|
164
|
+
|
165
|
+
return model
|
166
|
+
|
167
|
+
@classmethod
|
168
|
+
def convert_as_custom_model(
|
169
|
+
cls,
|
170
|
+
raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"],
|
171
|
+
model_meta: model_meta_api.ModelMetadata,
|
172
|
+
**kwargs: Unpack[model_types.ModelLoadOption],
|
173
|
+
) -> custom_model.CustomModel:
|
174
|
+
import lightgbm
|
175
|
+
|
176
|
+
from snowflake.ml.model import custom_model
|
177
|
+
|
178
|
+
def _create_custom_model(
|
179
|
+
raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
|
180
|
+
model_meta: model_meta_api.ModelMetadata,
|
181
|
+
) -> Type[custom_model.CustomModel]:
|
182
|
+
def fn_factory(
|
183
|
+
raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
|
184
|
+
signature: model_signature.ModelSignature,
|
185
|
+
target_method: str,
|
186
|
+
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
187
|
+
@custom_model.inference_api
|
188
|
+
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
189
|
+
|
190
|
+
res = getattr(raw_model, target_method)(X)
|
191
|
+
|
192
|
+
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
|
193
|
+
# In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
|
194
|
+
# return a list of ndarrays. We need to deal them separately
|
195
|
+
df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
|
196
|
+
else:
|
197
|
+
df = pd.DataFrame(res)
|
198
|
+
|
199
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
200
|
+
|
201
|
+
return fn
|
202
|
+
|
203
|
+
type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
|
204
|
+
for target_method_name, sig in model_meta.signatures.items():
|
205
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
206
|
+
|
207
|
+
_LightGBMModel = type(
|
208
|
+
"_LightGBMModel",
|
209
|
+
(custom_model.CustomModel,),
|
210
|
+
type_method_dict,
|
211
|
+
)
|
212
|
+
|
213
|
+
return _LightGBMModel
|
214
|
+
|
215
|
+
_LightGBMModel = _create_custom_model(raw_model, model_meta)
|
216
|
+
lightgbm_model = _LightGBMModel(custom_model.ModelContext())
|
217
|
+
|
218
|
+
return lightgbm_model
|
@@ -47,6 +47,9 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
47
47
|
or type_utils.LazyType("sklearn.pipeline.Pipeline").isinstance(model)
|
48
48
|
)
|
49
49
|
and (not type_utils.LazyType("xgboost.XGBModel").isinstance(model)) # XGBModel is actually a BaseEstimator
|
50
|
+
and (
|
51
|
+
not type_utils.LazyType("lightgbm.LGBMModel").isinstance(model)
|
52
|
+
) # LGBMModel is actually a BaseEstimator
|
50
53
|
and any(
|
51
54
|
(hasattr(model, method) and callable(getattr(model, method, None)))
|
52
55
|
for method in cls.DEFAULT_TARGET_METHODS
|
@@ -23,6 +23,7 @@ from snowflake.ml.model._packager.model_meta import (
|
|
23
23
|
model_meta_schema,
|
24
24
|
)
|
25
25
|
from snowflake.ml.model._packager.model_meta_migrator import migrator_plans
|
26
|
+
from snowflake.ml.model._packager.model_runtime import model_runtime
|
26
27
|
|
27
28
|
MODEL_METADATA_FILE = "model.yaml"
|
28
29
|
MODEL_CODE_DIR = "code"
|
@@ -115,7 +116,6 @@ def create_model_metadata(
|
|
115
116
|
python_version=python_version,
|
116
117
|
embed_local_ml_library=embed_local_ml_library,
|
117
118
|
legacy_save=legacy_save,
|
118
|
-
relax_version=relax_version,
|
119
119
|
)
|
120
120
|
|
121
121
|
if embed_local_ml_library:
|
@@ -156,6 +156,8 @@ def create_model_metadata(
|
|
156
156
|
cloudpickle.register_pickle_by_value(mod)
|
157
157
|
imported_modules.append(mod)
|
158
158
|
yield model_meta
|
159
|
+
if relax_version:
|
160
|
+
model_meta.env.relax_version()
|
159
161
|
model_meta.save(model_dir_path)
|
160
162
|
finally:
|
161
163
|
for mod in imported_modules:
|
@@ -169,7 +171,6 @@ def _create_env_for_model_metadata(
|
|
169
171
|
python_version: Optional[str] = None,
|
170
172
|
embed_local_ml_library: bool = False,
|
171
173
|
legacy_save: bool = False,
|
172
|
-
relax_version: bool = False,
|
173
174
|
) -> model_env.ModelEnv:
|
174
175
|
env = model_env.ModelEnv()
|
175
176
|
|
@@ -197,10 +198,6 @@ def _create_env_for_model_metadata(
|
|
197
198
|
],
|
198
199
|
check_local_version=True,
|
199
200
|
)
|
200
|
-
|
201
|
-
if relax_version:
|
202
|
-
env.relax_version()
|
203
|
-
|
204
201
|
return env
|
205
202
|
|
206
203
|
|
@@ -237,6 +234,7 @@ class ModelMetadata:
|
|
237
234
|
name: str,
|
238
235
|
env: model_env.ModelEnv,
|
239
236
|
model_type: model_types.SupportedModelHandlerType,
|
237
|
+
runtimes: Optional[Dict[str, model_runtime.ModelRuntime]] = None,
|
240
238
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
241
239
|
metadata: Optional[Dict[str, str]] = None,
|
242
240
|
creation_timestamp: Optional[str] = None,
|
@@ -262,6 +260,8 @@ class ModelMetadata:
|
|
262
260
|
if models:
|
263
261
|
self.models = models
|
264
262
|
|
263
|
+
self._runtimes = runtimes
|
264
|
+
|
265
265
|
self.original_metadata_version = original_metadata_version
|
266
266
|
|
267
267
|
@property
|
@@ -273,6 +273,19 @@ class ModelMetadata:
|
|
273
273
|
parsed_min_snowpark_ml_version = version.parse(min_snowpark_ml_version)
|
274
274
|
self._min_snowpark_ml_version = max(self._min_snowpark_ml_version, parsed_min_snowpark_ml_version)
|
275
275
|
|
276
|
+
@property
|
277
|
+
def runtimes(self) -> Dict[str, model_runtime.ModelRuntime]:
|
278
|
+
if self._runtimes and "cpu" in self._runtimes:
|
279
|
+
return self._runtimes
|
280
|
+
runtimes = {
|
281
|
+
"cpu": model_runtime.ModelRuntime("cpu", self.env),
|
282
|
+
}
|
283
|
+
if self.env.cuda_version:
|
284
|
+
runtimes.update(
|
285
|
+
{"gpu": model_runtime.ModelRuntime("gpu", self.env, is_gpu=True, server_availability_source="conda")}
|
286
|
+
)
|
287
|
+
return runtimes
|
288
|
+
|
276
289
|
def save(self, model_dir_path: str) -> None:
|
277
290
|
"""Save the model metadata
|
278
291
|
|
@@ -291,6 +304,10 @@ class ModelMetadata:
|
|
291
304
|
{
|
292
305
|
"creation_timestamp": self.creation_timestamp,
|
293
306
|
"env": self.env.save_as_dict(pathlib.Path(model_dir_path)),
|
307
|
+
"runtimes": {
|
308
|
+
runtime_name: runtime.save(pathlib.Path(model_dir_path))
|
309
|
+
for runtime_name, runtime in self.runtimes.items()
|
310
|
+
},
|
294
311
|
"metadata": self.metadata,
|
295
312
|
"model_type": self.model_type,
|
296
313
|
"models": {model_name: blob.to_dict() for model_name, blob in self.models.items()},
|
@@ -302,11 +319,8 @@ class ModelMetadata:
|
|
302
319
|
)
|
303
320
|
|
304
321
|
with open(model_yaml_path, "w", encoding="utf-8") as out:
|
305
|
-
yaml.
|
306
|
-
|
307
|
-
stream=out,
|
308
|
-
default_flow_style=False,
|
309
|
-
)
|
322
|
+
yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
|
323
|
+
yaml.safe_dump(model_dict, stream=out, default_flow_style=False)
|
310
324
|
|
311
325
|
@staticmethod
|
312
326
|
def _validate_model_metadata(loaded_meta: Any) -> model_meta_schema.ModelMetadataDict:
|
@@ -330,6 +344,7 @@ class ModelMetadata:
|
|
330
344
|
return model_meta_schema.ModelMetadataDict(
|
331
345
|
creation_timestamp=loaded_meta["creation_timestamp"],
|
332
346
|
env=loaded_meta["env"],
|
347
|
+
runtimes=loaded_meta.get("runtimes", None),
|
333
348
|
metadata=loaded_meta.get("metadata", None),
|
334
349
|
model_type=loaded_meta["model_type"],
|
335
350
|
models=loaded_meta["models"],
|
@@ -363,10 +378,21 @@ class ModelMetadata:
|
|
363
378
|
models = {name: model_blob_meta.ModelBlobMeta(**blob_meta) for name, blob_meta in model_dict["models"].items()}
|
364
379
|
env = model_env.ModelEnv()
|
365
380
|
env.load_from_dict(pathlib.Path(model_dir_path), model_dict["env"])
|
381
|
+
|
382
|
+
runtimes: Optional[Dict[str, model_runtime.ModelRuntime]]
|
383
|
+
if model_dict.get("runtimes", None):
|
384
|
+
runtimes = {
|
385
|
+
name: model_runtime.ModelRuntime.load(pathlib.Path(model_dir_path), name, env, runtime_dict)
|
386
|
+
for name, runtime_dict in model_dict["runtimes"].items()
|
387
|
+
}
|
388
|
+
else:
|
389
|
+
runtimes = None
|
390
|
+
|
366
391
|
return cls(
|
367
392
|
name=model_dict["name"],
|
368
393
|
model_type=model_dict["model_type"],
|
369
394
|
env=env,
|
395
|
+
runtimes=runtimes,
|
370
396
|
signatures=signatures,
|
371
397
|
metadata=model_dict.get("metadata", None),
|
372
398
|
creation_timestamp=model_dict["creation_timestamp"],
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# This files contains schema definition of what will be written into model.yml
|
2
2
|
# Changing this file should lead to a change of the schema version.
|
3
3
|
|
4
|
-
from typing import Any, Dict, Optional, TypedDict, Union
|
4
|
+
from typing import Any, Dict, List, Optional, TypedDict, Union
|
5
5
|
|
6
6
|
from typing_extensions import NotRequired, Required
|
7
7
|
|
@@ -11,6 +11,16 @@ MODEL_METADATA_VERSION = "2023-12-01"
|
|
11
11
|
MODEL_METADATA_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
12
12
|
|
13
13
|
|
14
|
+
class ModelRuntimeDependenciesDict(TypedDict):
|
15
|
+
conda: Required[str]
|
16
|
+
pip: Required[str]
|
17
|
+
|
18
|
+
|
19
|
+
class ModelRuntimeDict(TypedDict):
|
20
|
+
imports: Required[List[str]]
|
21
|
+
dependencies: Required[ModelRuntimeDependenciesDict]
|
22
|
+
|
23
|
+
|
14
24
|
class ModelEnvDict(TypedDict):
|
15
25
|
conda: Required[str]
|
16
26
|
pip: Required[str]
|
@@ -23,11 +33,19 @@ class BaseModelBlobOptions(TypedDict):
|
|
23
33
|
...
|
24
34
|
|
25
35
|
|
36
|
+
class CatBoostModelBlobOptions(BaseModelBlobOptions):
|
37
|
+
catboost_estimator_type: Required[str]
|
38
|
+
|
39
|
+
|
26
40
|
class HuggingFacePipelineModelBlobOptions(BaseModelBlobOptions):
|
27
41
|
task: Required[str]
|
28
42
|
batch_size: Required[int]
|
29
43
|
|
30
44
|
|
45
|
+
class LightGBMModelBlobOptions(BaseModelBlobOptions):
|
46
|
+
lightgbm_estimator_type: Required[str]
|
47
|
+
|
48
|
+
|
31
49
|
class LLMModelBlobOptions(BaseModelBlobOptions):
|
32
50
|
batch_size: Required[int]
|
33
51
|
|
@@ -61,6 +79,7 @@ class ModelBlobMetadataDict(TypedDict):
|
|
61
79
|
class ModelMetadataDict(TypedDict):
|
62
80
|
creation_timestamp: Required[str]
|
63
81
|
env: Required[ModelEnvDict]
|
82
|
+
runtimes: NotRequired[Dict[str, ModelRuntimeDict]]
|
64
83
|
metadata: NotRequired[Optional[Dict[str, str]]]
|
65
84
|
model_type: Required[type_hints.SupportedModelHandlerType]
|
66
85
|
models: Required[Dict[str, ModelBlobMetadataDict]]
|
@@ -3,7 +3,9 @@ from typing import Any, Dict, Type
|
|
3
3
|
from snowflake.ml.model._packager.model_meta import model_meta_schema
|
4
4
|
from snowflake.ml.model._packager.model_meta_migrator import base_migrator, migrator_v1
|
5
5
|
|
6
|
-
MODEL_META_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelMetaMigrator]] = {
|
6
|
+
MODEL_META_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelMetaMigrator]] = {
|
7
|
+
"1": migrator_v1.MetaMigrator_v1,
|
8
|
+
}
|
7
9
|
|
8
10
|
|
9
11
|
def migrate_metadata(loaded_meta: Dict[str, Any]) -> Dict[str, Any]:
|
@@ -4,7 +4,6 @@ from typing import Dict, List, Optional
|
|
4
4
|
|
5
5
|
from absl import logging
|
6
6
|
|
7
|
-
from snowflake.ml._internal import env_utils
|
8
7
|
from snowflake.ml._internal.exceptions import (
|
9
8
|
error_codes,
|
10
9
|
exceptions as snowml_exceptions,
|
@@ -102,8 +101,8 @@ class ModelPackager:
|
|
102
101
|
if signatures is None:
|
103
102
|
logging.info(f"Model signatures are auto inferred as:\n\n{meta.signatures}")
|
104
103
|
|
105
|
-
|
106
|
-
|
104
|
+
self.model = model
|
105
|
+
self.meta = meta
|
107
106
|
|
108
107
|
def load(
|
109
108
|
self,
|
@@ -129,8 +128,6 @@ class ModelPackager:
|
|
129
128
|
|
130
129
|
model_meta.load_code_path(self.local_dir_path)
|
131
130
|
|
132
|
-
env_utils.validate_py_runtime_version(self.meta.env.python_version)
|
133
|
-
|
134
131
|
handler = model_handler.load_handler(self.meta.model_type)
|
135
132
|
if handler is None:
|
136
133
|
raise snowml_exceptions.SnowflakeMLException(
|