snowflake-ml-python 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +11 -1
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/feature_store/feature_store.py +166 -184
- snowflake/ml/feature_store/feature_view.py +12 -24
- snowflake/ml/fileset/sfcfs.py +56 -50
- snowflake/ml/fileset/stage_fs.py +48 -13
- snowflake/ml/model/_client/model/model_version_impl.py +6 -49
- snowflake/ml/model/_client/ops/model_ops.py +78 -29
- snowflake/ml/model/_client/sql/model.py +23 -2
- snowflake/ml/model/_client/sql/model_version.py +22 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -3
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +5 -2
- snowflake/ml/model/_model_composer/model_composer.py +7 -5
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +19 -54
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +8 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +13 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +1 -1
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +36 -6
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -2
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +545 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -5
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +195 -123
- snowflake/ml/modeling/cluster/affinity_propagation.py +195 -123
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +195 -123
- snowflake/ml/modeling/cluster/birch.py +195 -123
- snowflake/ml/modeling/cluster/bisecting_k_means.py +195 -123
- snowflake/ml/modeling/cluster/dbscan.py +195 -123
- snowflake/ml/modeling/cluster/feature_agglomeration.py +195 -123
- snowflake/ml/modeling/cluster/k_means.py +195 -123
- snowflake/ml/modeling/cluster/mean_shift.py +195 -123
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +195 -123
- snowflake/ml/modeling/cluster/optics.py +195 -123
- snowflake/ml/modeling/cluster/spectral_biclustering.py +195 -123
- snowflake/ml/modeling/cluster/spectral_clustering.py +195 -123
- snowflake/ml/modeling/cluster/spectral_coclustering.py +195 -123
- snowflake/ml/modeling/compose/column_transformer.py +195 -123
- snowflake/ml/modeling/compose/transformed_target_regressor.py +195 -123
- snowflake/ml/modeling/covariance/elliptic_envelope.py +195 -123
- snowflake/ml/modeling/covariance/empirical_covariance.py +195 -123
- snowflake/ml/modeling/covariance/graphical_lasso.py +195 -123
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +195 -123
- snowflake/ml/modeling/covariance/ledoit_wolf.py +195 -123
- snowflake/ml/modeling/covariance/min_cov_det.py +195 -123
- snowflake/ml/modeling/covariance/oas.py +195 -123
- snowflake/ml/modeling/covariance/shrunk_covariance.py +195 -123
- snowflake/ml/modeling/decomposition/dictionary_learning.py +195 -123
- snowflake/ml/modeling/decomposition/factor_analysis.py +195 -123
- snowflake/ml/modeling/decomposition/fast_ica.py +195 -123
- snowflake/ml/modeling/decomposition/incremental_pca.py +195 -123
- snowflake/ml/modeling/decomposition/kernel_pca.py +195 -123
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +195 -123
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +195 -123
- snowflake/ml/modeling/decomposition/pca.py +195 -123
- snowflake/ml/modeling/decomposition/sparse_pca.py +195 -123
- snowflake/ml/modeling/decomposition/truncated_svd.py +195 -123
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +195 -123
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +195 -123
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/bagging_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/bagging_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/isolation_forest.py +195 -123
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/stacking_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/voting_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/voting_regressor.py +195 -123
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +195 -123
- snowflake/ml/modeling/feature_selection/select_fdr.py +195 -123
- snowflake/ml/modeling/feature_selection/select_fpr.py +195 -123
- snowflake/ml/modeling/feature_selection/select_fwe.py +195 -123
- snowflake/ml/modeling/feature_selection/select_k_best.py +195 -123
- snowflake/ml/modeling/feature_selection/select_percentile.py +195 -123
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +195 -123
- snowflake/ml/modeling/feature_selection/variance_threshold.py +195 -123
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +24 -6
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +195 -123
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +195 -123
- snowflake/ml/modeling/impute/iterative_imputer.py +195 -123
- snowflake/ml/modeling/impute/knn_imputer.py +195 -123
- snowflake/ml/modeling/impute/missing_indicator.py +195 -123
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +195 -123
- snowflake/ml/modeling/kernel_approximation/nystroem.py +195 -123
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +195 -123
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +195 -123
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +195 -123
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +195 -123
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +198 -125
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +198 -125
- snowflake/ml/modeling/linear_model/ard_regression.py +195 -123
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +195 -123
- snowflake/ml/modeling/linear_model/elastic_net.py +195 -123
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +195 -123
- snowflake/ml/modeling/linear_model/gamma_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/huber_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/lars.py +195 -123
- snowflake/ml/modeling/linear_model/lars_cv.py +195 -123
- snowflake/ml/modeling/linear_model/lasso.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_cv.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_lars.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +195 -123
- snowflake/ml/modeling/linear_model/linear_regression.py +195 -123
- snowflake/ml/modeling/linear_model/logistic_regression.py +195 -123
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +195 -123
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +195 -123
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +195 -123
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/perceptron.py +195 -123
- snowflake/ml/modeling/linear_model/poisson_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/ransac_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/ridge.py +195 -123
- snowflake/ml/modeling/linear_model/ridge_classifier.py +195 -123
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +195 -123
- snowflake/ml/modeling/linear_model/ridge_cv.py +195 -123
- snowflake/ml/modeling/linear_model/sgd_classifier.py +195 -123
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +195 -123
- snowflake/ml/modeling/linear_model/sgd_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +195 -123
- snowflake/ml/modeling/manifold/isomap.py +195 -123
- snowflake/ml/modeling/manifold/mds.py +195 -123
- snowflake/ml/modeling/manifold/spectral_embedding.py +195 -123
- snowflake/ml/modeling/manifold/tsne.py +195 -123
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +195 -123
- snowflake/ml/modeling/mixture/gaussian_mixture.py +195 -123
- snowflake/ml/modeling/model_selection/grid_search_cv.py +42 -18
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +42 -18
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +195 -123
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +195 -123
- snowflake/ml/modeling/multiclass/output_code_classifier.py +195 -123
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/complement_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +195 -123
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +195 -123
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +195 -123
- snowflake/ml/modeling/neighbors/kernel_density.py +195 -123
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +195 -123
- snowflake/ml/modeling/neighbors/nearest_centroid.py +195 -123
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +195 -123
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +195 -123
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +195 -123
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +195 -123
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +195 -123
- snowflake/ml/modeling/neural_network/mlp_classifier.py +195 -123
- snowflake/ml/modeling/neural_network/mlp_regressor.py +195 -123
- snowflake/ml/modeling/pipeline/pipeline.py +4 -4
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +195 -123
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +195 -123
- snowflake/ml/modeling/semi_supervised/label_spreading.py +195 -123
- snowflake/ml/modeling/svm/linear_svc.py +195 -123
- snowflake/ml/modeling/svm/linear_svr.py +195 -123
- snowflake/ml/modeling/svm/nu_svc.py +195 -123
- snowflake/ml/modeling/svm/nu_svr.py +195 -123
- snowflake/ml/modeling/svm/svc.py +195 -123
- snowflake/ml/modeling/svm/svr.py +195 -123
- snowflake/ml/modeling/tree/decision_tree_classifier.py +195 -123
- snowflake/ml/modeling/tree/decision_tree_regressor.py +195 -123
- snowflake/ml/modeling/tree/extra_tree_classifier.py +195 -123
- snowflake/ml/modeling/tree/extra_tree_regressor.py +195 -123
- snowflake/ml/modeling/xgboost/xgb_classifier.py +195 -123
- snowflake/ml/modeling/xgboost/xgb_regressor.py +195 -123
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +195 -123
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +195 -123
- snowflake/ml/registry/_manager/model_manager.py +5 -1
- snowflake/ml/registry/model_registry.py +99 -26
- snowflake/ml/registry/registry.py +3 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/METADATA +94 -55
- {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/RECORD +218 -212
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
|
|
1
1
|
import glob
|
2
2
|
import pathlib
|
3
3
|
import tempfile
|
4
|
+
import uuid
|
4
5
|
import zipfile
|
5
6
|
from types import ModuleType
|
6
7
|
from typing import Any, Dict, List, Optional
|
@@ -31,7 +32,6 @@ class ModelComposer:
|
|
31
32
|
will zip it. This would not be required if we make directory import work.
|
32
33
|
"""
|
33
34
|
|
34
|
-
MODEL_FILE_REL_PATH = "model.zip"
|
35
35
|
MODEL_DIR_REL_PATH = "model"
|
36
36
|
|
37
37
|
def __init__(
|
@@ -50,6 +50,8 @@ class ModelComposer:
|
|
50
50
|
self.packager = model_packager.ModelPackager(local_dir_path=str(self._packager_workspace_path))
|
51
51
|
self.manifest = model_manifest.ModelManifest(workspace_path=self.workspace_path)
|
52
52
|
|
53
|
+
self.model_file_rel_path = f"model-{uuid.uuid4().hex}.zip"
|
54
|
+
|
53
55
|
self._statement_params = statement_params
|
54
56
|
|
55
57
|
def __del__(self) -> None:
|
@@ -66,11 +68,11 @@ class ModelComposer:
|
|
66
68
|
|
67
69
|
@property
|
68
70
|
def model_stage_path(self) -> str:
|
69
|
-
return (self.stage_path /
|
71
|
+
return (self.stage_path / self.model_file_rel_path).as_posix()
|
70
72
|
|
71
73
|
@property
|
72
74
|
def model_local_path(self) -> str:
|
73
|
-
return str(self.workspace_path /
|
75
|
+
return str(self.workspace_path / self.model_file_rel_path)
|
74
76
|
|
75
77
|
def save(
|
76
78
|
self,
|
@@ -130,7 +132,7 @@ class ModelComposer:
|
|
130
132
|
self.manifest.save(
|
131
133
|
session=self.session,
|
132
134
|
model_meta=self.packager.meta,
|
133
|
-
model_file_rel_path=pathlib.PurePosixPath(
|
135
|
+
model_file_rel_path=pathlib.PurePosixPath(self.model_file_rel_path),
|
134
136
|
options=options,
|
135
137
|
)
|
136
138
|
|
@@ -156,7 +158,7 @@ class ModelComposer:
|
|
156
158
|
|
157
159
|
# TODO (Server-side Model Rollout): Remove this section.
|
158
160
|
model_zip_path = pathlib.Path(glob.glob(str(self.workspace_path / "*.zip"))[0])
|
159
|
-
|
161
|
+
self.model_file_rel_path = str(model_zip_path.relative_to(self.workspace_path))
|
160
162
|
|
161
163
|
with zipfile.ZipFile(self.model_local_path, mode="r", compression=zipfile.ZIP_DEFLATED) as zf:
|
162
164
|
zf.extractall(path=self._packager_workspace_path)
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import collections
|
2
|
+
import copy
|
2
3
|
import pathlib
|
3
|
-
from typing import
|
4
|
+
from typing import List, Optional, cast
|
4
5
|
|
5
6
|
import yaml
|
6
7
|
|
@@ -10,7 +11,6 @@ from snowflake.ml.model._model_composer.model_method import (
|
|
10
11
|
function_generator,
|
11
12
|
model_method,
|
12
13
|
)
|
13
|
-
from snowflake.ml.model._model_composer.model_runtime import model_runtime
|
14
14
|
from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
|
15
15
|
from snowflake.snowpark import Session
|
16
16
|
|
@@ -39,21 +39,19 @@ class ModelManifest:
|
|
39
39
|
) -> None:
|
40
40
|
if options is None:
|
41
41
|
options = {}
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
)
|
49
|
-
]
|
42
|
+
|
43
|
+
runtime_to_use = copy.deepcopy(model_meta.runtimes["cpu"])
|
44
|
+
runtime_to_use.name = self._DEFAULT_RUNTIME_NAME
|
45
|
+
runtime_to_use.imports.append(model_file_rel_path)
|
46
|
+
runtime_dict = runtime_to_use.save(self.workspace_path)
|
47
|
+
|
50
48
|
self.function_generator = function_generator.FunctionGenerator(model_file_rel_path=model_file_rel_path)
|
51
49
|
self.methods: List[model_method.ModelMethod] = []
|
52
50
|
for target_method in model_meta.signatures.keys():
|
53
51
|
method = model_method.ModelMethod(
|
54
52
|
model_meta=model_meta,
|
55
53
|
target_method=target_method,
|
56
|
-
runtime_name=self.
|
54
|
+
runtime_name=self._DEFAULT_RUNTIME_NAME,
|
57
55
|
function_generator=self.function_generator,
|
58
56
|
options=model_method.get_model_method_options_from_options(options, target_method),
|
59
57
|
)
|
@@ -71,7 +69,16 @@ class ModelManifest:
|
|
71
69
|
|
72
70
|
manifest_dict = model_manifest_schema.ModelManifestDict(
|
73
71
|
manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
|
74
|
-
runtimes={
|
72
|
+
runtimes={
|
73
|
+
self._DEFAULT_RUNTIME_NAME: model_manifest_schema.ModelRuntimeDict(
|
74
|
+
language="PYTHON",
|
75
|
+
version=runtime_to_use.runtime_env.python_version,
|
76
|
+
imports=runtime_dict["imports"],
|
77
|
+
dependencies=model_manifest_schema.ModelRuntimeDependenciesDict(
|
78
|
+
conda=runtime_dict["dependencies"]["conda"]
|
79
|
+
),
|
80
|
+
)
|
81
|
+
},
|
75
82
|
methods=[
|
76
83
|
method.save(
|
77
84
|
self.workspace_path,
|
@@ -83,8 +90,6 @@ class ModelManifest:
|
|
83
90
|
],
|
84
91
|
)
|
85
92
|
|
86
|
-
manifest_dict["user_data"] = self.generate_user_data_with_client_data(model_meta)
|
87
|
-
|
88
93
|
with (self.workspace_path / ModelManifest.MANIFEST_FILE_REL_PATH).open("w", encoding="utf-8") as f:
|
89
94
|
# Anchors are not supported in the server, avoid that.
|
90
95
|
yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
|
@@ -103,43 +108,3 @@ class ModelManifest:
|
|
103
108
|
res = cast(model_manifest_schema.ModelManifestDict, raw_input)
|
104
109
|
|
105
110
|
return res
|
106
|
-
|
107
|
-
def generate_user_data_with_client_data(self, model_meta: model_meta_api.ModelMetadata) -> Dict[str, Any]:
|
108
|
-
client_data = model_manifest_schema.SnowparkMLDataDict(
|
109
|
-
schema_version=model_manifest_schema.MANIFEST_CLIENT_DATA_SCHEMA_VERSION,
|
110
|
-
functions=[
|
111
|
-
model_manifest_schema.ModelFunctionInfoDict(
|
112
|
-
name=method.method_name.identifier(),
|
113
|
-
target_method=method.target_method,
|
114
|
-
signature=model_meta.signatures[method.target_method].to_dict(),
|
115
|
-
)
|
116
|
-
for method in self.methods
|
117
|
-
],
|
118
|
-
)
|
119
|
-
return {model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME: client_data}
|
120
|
-
|
121
|
-
@staticmethod
|
122
|
-
def parse_client_data_from_user_data(raw_user_data: Dict[str, Any]) -> model_manifest_schema.SnowparkMLDataDict:
|
123
|
-
raw_client_data = raw_user_data.get(model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME, {})
|
124
|
-
if not isinstance(raw_client_data, dict) or "schema_version" not in raw_client_data:
|
125
|
-
raise ValueError(f"Ill-formatted client data {raw_client_data} in user data found.")
|
126
|
-
loaded_client_data_schema_version = raw_client_data["schema_version"]
|
127
|
-
if (
|
128
|
-
not isinstance(loaded_client_data_schema_version, str)
|
129
|
-
or loaded_client_data_schema_version != model_manifest_schema.MANIFEST_CLIENT_DATA_SCHEMA_VERSION
|
130
|
-
):
|
131
|
-
raise ValueError(f"Unsupported client data schema version {loaded_client_data_schema_version} confronted.")
|
132
|
-
|
133
|
-
return_functions_info: List[model_manifest_schema.ModelFunctionInfoDict] = []
|
134
|
-
loaded_functions_info = raw_client_data.get("functions", [])
|
135
|
-
for func in loaded_functions_info:
|
136
|
-
fi = model_manifest_schema.ModelFunctionInfoDict(
|
137
|
-
name=func["name"],
|
138
|
-
target_method=func["target_method"],
|
139
|
-
signature=func["signature"],
|
140
|
-
)
|
141
|
-
return_functions_info.append(fi)
|
142
|
-
|
143
|
-
return model_manifest_schema.SnowparkMLDataDict(
|
144
|
-
schema_version=loaded_client_data_schema_version, functions=return_functions_info
|
145
|
-
)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# This files contains schema definition of what will be written into MANIFEST.yml
|
2
|
-
|
2
|
+
import enum
|
3
3
|
from typing import Any, Dict, List, Literal, TypedDict, Union
|
4
4
|
|
5
5
|
from typing_extensions import NotRequired, Required
|
@@ -12,6 +12,11 @@ MANIFEST_CLIENT_DATA_KEY_NAME = "snowpark_ml_data"
|
|
12
12
|
MANIFEST_CLIENT_DATA_SCHEMA_VERSION = "2024-02-01"
|
13
13
|
|
14
14
|
|
15
|
+
class ModelMethodFunctionTypes(enum.Enum):
|
16
|
+
FUNCTION = "FUNCTION"
|
17
|
+
TABLE_FUNCTION = "TABLE_FUNCTION"
|
18
|
+
|
19
|
+
|
15
20
|
class ModelRuntimeDependenciesDict(TypedDict):
|
16
21
|
conda: Required[str]
|
17
22
|
|
@@ -49,11 +54,13 @@ class ModelFunctionInfo(TypedDict):
|
|
49
54
|
Attributes:
|
50
55
|
name: Name of the function to be called via SQL.
|
51
56
|
target_method: actual target method name to be called.
|
57
|
+
target_method_function_type: target method function type (FUNCTION or TABLE_FUNCTION).
|
52
58
|
signature: The signature of the model method.
|
53
59
|
"""
|
54
60
|
|
55
61
|
name: Required[str]
|
56
62
|
target_method: Required[str]
|
63
|
+
target_method_function_type: Required[str]
|
57
64
|
signature: Required[model_signature.ModelSignature]
|
58
65
|
|
59
66
|
|
@@ -73,5 +73,5 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
73
73
|
# Actual table function
|
74
74
|
class {function_name}:
|
75
75
|
@vectorized(input=pd.DataFrame)
|
76
|
-
def end_partition(df: pd.DataFrame) -> pd.DataFrame:
|
76
|
+
def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
|
77
77
|
return runner(df)
|
@@ -1,5 +1,4 @@
|
|
1
1
|
import collections
|
2
|
-
import enum
|
3
2
|
import pathlib
|
4
3
|
from typing import List, Optional, TypedDict, Union
|
5
4
|
|
@@ -13,11 +12,6 @@ from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
|
|
13
12
|
from snowflake.snowpark._internal import type_utils
|
14
13
|
|
15
14
|
|
16
|
-
class ModelMethodFunctionTypes(enum.Enum):
|
17
|
-
FUNCTION = "FUNCTION"
|
18
|
-
TABLE_FUNCTION = "TABLE_FUNCTION"
|
19
|
-
|
20
|
-
|
21
15
|
class ModelMethodOptions(TypedDict):
|
22
16
|
"""Options when creating model method.
|
23
17
|
|
@@ -33,9 +27,9 @@ def get_model_method_options_from_options(
|
|
33
27
|
options: type_hints.ModelSaveOption, target_method: str
|
34
28
|
) -> ModelMethodOptions:
|
35
29
|
method_option = options.get("method_options", {}).get(target_method, {})
|
36
|
-
global_function_type = options.get("function_type", ModelMethodFunctionTypes.FUNCTION.value)
|
30
|
+
global_function_type = options.get("function_type", model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value)
|
37
31
|
function_type = method_option.get("function_type", global_function_type)
|
38
|
-
if function_type not in [function_type.value for function_type in ModelMethodFunctionTypes]:
|
32
|
+
if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
|
39
33
|
raise NotImplementedError
|
40
34
|
|
41
35
|
# TODO(TH): enforce minimum snowflake version
|
@@ -89,7 +83,9 @@ class ModelMethod:
|
|
89
83
|
if self.target_method not in self.model_meta.signatures.keys():
|
90
84
|
raise ValueError(f"Target method {self.target_method} is not available in the signatures of the model.")
|
91
85
|
|
92
|
-
self.function_type = self.options.get(
|
86
|
+
self.function_type = self.options.get(
|
87
|
+
"function_type", model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
|
88
|
+
)
|
93
89
|
|
94
90
|
@staticmethod
|
95
91
|
def _get_method_arg_from_feature(
|
@@ -134,7 +130,7 @@ class ModelMethod:
|
|
134
130
|
List[model_manifest_schema.ModelMethodSignatureField],
|
135
131
|
List[model_manifest_schema.ModelMethodSignatureFieldWithName],
|
136
132
|
]
|
137
|
-
if self.function_type == ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
133
|
+
if self.function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
138
134
|
outputs = [
|
139
135
|
ModelMethod._get_method_arg_from_feature(ft, case_sensitive=self.options.get("case_sensitive", False))
|
140
136
|
for ft in self.model_meta.signatures[self.target_method].outputs
|
@@ -0,0 +1,206 @@
|
|
1
|
+
import os
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import pandas as pd
|
6
|
+
from typing_extensions import TypeGuard, Unpack
|
7
|
+
|
8
|
+
from snowflake.ml._internal import type_utils
|
9
|
+
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
10
|
+
from snowflake.ml.model._packager.model_env import model_env
|
11
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
12
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
13
|
+
from snowflake.ml.model._packager.model_meta import (
|
14
|
+
model_blob_meta,
|
15
|
+
model_meta as model_meta_api,
|
16
|
+
model_meta_schema,
|
17
|
+
)
|
18
|
+
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
import catboost
|
22
|
+
|
23
|
+
|
24
|
+
@final
|
25
|
+
class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
26
|
+
"""Handler for CatBoost based model."""
|
27
|
+
|
28
|
+
HANDLER_TYPE = "catboost"
|
29
|
+
HANDLER_VERSION = "2024-03-21"
|
30
|
+
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
31
|
+
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
32
|
+
|
33
|
+
MODELE_BLOB_FILE_OR_DIR = "model.bin"
|
34
|
+
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
35
|
+
|
36
|
+
@classmethod
|
37
|
+
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
|
38
|
+
return (type_utils.LazyType("catboost.CatBoost").isinstance(model)) and any(
|
39
|
+
(hasattr(model, method) and callable(getattr(model, method, None))) for method in cls.DEFAULT_TARGET_METHODS
|
40
|
+
)
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
def cast_model(
|
44
|
+
cls,
|
45
|
+
model: model_types.SupportedModelType,
|
46
|
+
) -> "catboost.CatBoost":
|
47
|
+
import catboost
|
48
|
+
|
49
|
+
assert isinstance(model, catboost.CatBoost)
|
50
|
+
|
51
|
+
return model
|
52
|
+
|
53
|
+
@classmethod
|
54
|
+
def save_model(
|
55
|
+
cls,
|
56
|
+
name: str,
|
57
|
+
model: "catboost.CatBoost",
|
58
|
+
model_meta: model_meta_api.ModelMetadata,
|
59
|
+
model_blobs_dir_path: str,
|
60
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
61
|
+
is_sub_model: Optional[bool] = False,
|
62
|
+
**kwargs: Unpack[model_types.CatBoostModelSaveOptions],
|
63
|
+
) -> None:
|
64
|
+
import catboost
|
65
|
+
|
66
|
+
assert isinstance(model, catboost.CatBoost)
|
67
|
+
|
68
|
+
if not is_sub_model:
|
69
|
+
target_methods = handlers_utils.get_target_methods(
|
70
|
+
model=model,
|
71
|
+
target_methods=kwargs.pop("target_methods", None),
|
72
|
+
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
73
|
+
)
|
74
|
+
|
75
|
+
def get_prediction(
|
76
|
+
target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
|
77
|
+
) -> model_types.SupportedLocalDataType:
|
78
|
+
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
|
79
|
+
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
80
|
+
target_method = getattr(model, target_method_name, None)
|
81
|
+
assert callable(target_method)
|
82
|
+
predictions_df = target_method(sample_input_data)
|
83
|
+
return predictions_df
|
84
|
+
|
85
|
+
model_meta = handlers_utils.validate_signature(
|
86
|
+
model=model,
|
87
|
+
model_meta=model_meta,
|
88
|
+
target_methods=target_methods,
|
89
|
+
sample_input_data=sample_input_data,
|
90
|
+
get_prediction_fn=get_prediction,
|
91
|
+
)
|
92
|
+
|
93
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
94
|
+
os.makedirs(model_blob_path, exist_ok=True)
|
95
|
+
model_save_path = os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
|
96
|
+
|
97
|
+
model.save_model(model_save_path)
|
98
|
+
|
99
|
+
base_meta = model_blob_meta.ModelBlobMeta(
|
100
|
+
name=name,
|
101
|
+
model_type=cls.HANDLER_TYPE,
|
102
|
+
handler_version=cls.HANDLER_VERSION,
|
103
|
+
path=cls.MODELE_BLOB_FILE_OR_DIR,
|
104
|
+
options=model_meta_schema.CatBoostModelBlobOptions({"catboost_estimator_type": model.__class__.__name__}),
|
105
|
+
)
|
106
|
+
model_meta.models[name] = base_meta
|
107
|
+
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
108
|
+
|
109
|
+
model_meta.env.include_if_absent(
|
110
|
+
[
|
111
|
+
model_env.ModelDependency(requirement="catboost", pip_name="catboost"),
|
112
|
+
],
|
113
|
+
check_local_version=True,
|
114
|
+
)
|
115
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
116
|
+
|
117
|
+
return None
|
118
|
+
|
119
|
+
@classmethod
|
120
|
+
def load_model(
|
121
|
+
cls,
|
122
|
+
name: str,
|
123
|
+
model_meta: model_meta_api.ModelMetadata,
|
124
|
+
model_blobs_dir_path: str,
|
125
|
+
**kwargs: Unpack[model_types.ModelLoadOption],
|
126
|
+
) -> "catboost.CatBoost":
|
127
|
+
import catboost
|
128
|
+
|
129
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
130
|
+
model_blobs_metadata = model_meta.models
|
131
|
+
model_blob_metadata = model_blobs_metadata[name]
|
132
|
+
model_blob_filename = model_blob_metadata.path
|
133
|
+
model_blob_file_path = os.path.join(model_blob_path, model_blob_filename)
|
134
|
+
|
135
|
+
model_blob_options = cast(model_meta_schema.CatBoostModelBlobOptions, model_blob_metadata.options)
|
136
|
+
if "catboost_estimator_type" not in model_blob_options:
|
137
|
+
raise ValueError("Missing field `catboost_estimator_type` in model blob metadata for type `catboost`")
|
138
|
+
|
139
|
+
catboost_estimator_type = model_blob_options["catboost_estimator_type"]
|
140
|
+
if not hasattr(catboost, catboost_estimator_type):
|
141
|
+
raise ValueError("Type of CatBoost estimator is not supported.")
|
142
|
+
|
143
|
+
assert os.path.isfile(model_blob_file_path) # saved model is a file
|
144
|
+
model = getattr(catboost, catboost_estimator_type)()
|
145
|
+
model.load_model(model_blob_file_path)
|
146
|
+
assert isinstance(model, getattr(catboost, catboost_estimator_type))
|
147
|
+
|
148
|
+
if kwargs.get("use_gpu", False):
|
149
|
+
assert type(kwargs.get("use_gpu", False)) == bool
|
150
|
+
gpu_params = {"task_type": "GPU"}
|
151
|
+
model.__dict__.update(gpu_params)
|
152
|
+
|
153
|
+
return model
|
154
|
+
|
155
|
+
@classmethod
|
156
|
+
def convert_as_custom_model(
|
157
|
+
cls,
|
158
|
+
raw_model: "catboost.CatBoost",
|
159
|
+
model_meta: model_meta_api.ModelMetadata,
|
160
|
+
**kwargs: Unpack[model_types.ModelLoadOption],
|
161
|
+
) -> custom_model.CustomModel:
|
162
|
+
import catboost
|
163
|
+
|
164
|
+
from snowflake.ml.model import custom_model
|
165
|
+
|
166
|
+
def _create_custom_model(
|
167
|
+
raw_model: "catboost.CatBoost",
|
168
|
+
model_meta: model_meta_api.ModelMetadata,
|
169
|
+
) -> Type[custom_model.CustomModel]:
|
170
|
+
def fn_factory(
|
171
|
+
raw_model: "catboost.CatBoost",
|
172
|
+
signature: model_signature.ModelSignature,
|
173
|
+
target_method: str,
|
174
|
+
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
175
|
+
@custom_model.inference_api
|
176
|
+
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
177
|
+
|
178
|
+
res = getattr(raw_model, target_method)(X)
|
179
|
+
|
180
|
+
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
|
181
|
+
# In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
|
182
|
+
# return a list of ndarrays. We need to deal them separately
|
183
|
+
df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
|
184
|
+
else:
|
185
|
+
df = pd.DataFrame(res)
|
186
|
+
|
187
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
188
|
+
|
189
|
+
return fn
|
190
|
+
|
191
|
+
type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
|
192
|
+
for target_method_name, sig in model_meta.signatures.items():
|
193
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
194
|
+
|
195
|
+
_CatBoostModel = type(
|
196
|
+
"_CatBoostModel",
|
197
|
+
(custom_model.CustomModel,),
|
198
|
+
type_method_dict,
|
199
|
+
)
|
200
|
+
|
201
|
+
return _CatBoostModel
|
202
|
+
|
203
|
+
_CatBoostModel = _create_custom_model(raw_model, model_meta)
|
204
|
+
catboost_model = _CatBoostModel(custom_model.ModelContext())
|
205
|
+
|
206
|
+
return catboost_model
|
@@ -0,0 +1,218 @@
|
|
1
|
+
import os
|
2
|
+
from typing import (
|
3
|
+
TYPE_CHECKING,
|
4
|
+
Any,
|
5
|
+
Callable,
|
6
|
+
Dict,
|
7
|
+
Optional,
|
8
|
+
Type,
|
9
|
+
Union,
|
10
|
+
cast,
|
11
|
+
final,
|
12
|
+
)
|
13
|
+
|
14
|
+
import cloudpickle
|
15
|
+
import numpy as np
|
16
|
+
import pandas as pd
|
17
|
+
from typing_extensions import TypeGuard, Unpack
|
18
|
+
|
19
|
+
from snowflake.ml._internal import type_utils
|
20
|
+
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
21
|
+
from snowflake.ml.model._packager.model_env import model_env
|
22
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
23
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
24
|
+
from snowflake.ml.model._packager.model_meta import (
|
25
|
+
model_blob_meta,
|
26
|
+
model_meta as model_meta_api,
|
27
|
+
model_meta_schema,
|
28
|
+
)
|
29
|
+
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
30
|
+
|
31
|
+
if TYPE_CHECKING:
|
32
|
+
import lightgbm
|
33
|
+
|
34
|
+
|
35
|
+
@final
|
36
|
+
class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgbm.LGBMModel"]]):
|
37
|
+
"""Handler for LightGBM based model."""
|
38
|
+
|
39
|
+
HANDLER_TYPE = "lightgbm"
|
40
|
+
HANDLER_VERSION = "2024-03-19"
|
41
|
+
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
42
|
+
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
43
|
+
|
44
|
+
MODELE_BLOB_FILE_OR_DIR = "model.pkl"
|
45
|
+
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def can_handle(
|
49
|
+
cls, model: model_types.SupportedModelType
|
50
|
+
) -> TypeGuard[Union["lightgbm.Booster", "lightgbm.LGBMModel"]]:
|
51
|
+
return (
|
52
|
+
type_utils.LazyType("lightgbm.Booster").isinstance(model)
|
53
|
+
or type_utils.LazyType("lightgbm.LGBMModel").isinstance(model)
|
54
|
+
) and any(
|
55
|
+
(hasattr(model, method) and callable(getattr(model, method, None))) for method in cls.DEFAULT_TARGET_METHODS
|
56
|
+
)
|
57
|
+
|
58
|
+
@classmethod
|
59
|
+
def cast_model(
|
60
|
+
cls,
|
61
|
+
model: model_types.SupportedModelType,
|
62
|
+
) -> Union["lightgbm.Booster", "lightgbm.LGBMModel"]:
|
63
|
+
import lightgbm
|
64
|
+
|
65
|
+
assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
|
66
|
+
|
67
|
+
return model
|
68
|
+
|
69
|
+
@classmethod
|
70
|
+
def save_model(
|
71
|
+
cls,
|
72
|
+
name: str,
|
73
|
+
model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
|
74
|
+
model_meta: model_meta_api.ModelMetadata,
|
75
|
+
model_blobs_dir_path: str,
|
76
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
77
|
+
is_sub_model: Optional[bool] = False,
|
78
|
+
**kwargs: Unpack[model_types.LGBMModelSaveOptions],
|
79
|
+
) -> None:
|
80
|
+
import lightgbm
|
81
|
+
|
82
|
+
assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
|
83
|
+
|
84
|
+
if not is_sub_model:
|
85
|
+
target_methods = handlers_utils.get_target_methods(
|
86
|
+
model=model,
|
87
|
+
target_methods=kwargs.pop("target_methods", None),
|
88
|
+
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
89
|
+
)
|
90
|
+
|
91
|
+
def get_prediction(
|
92
|
+
target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
|
93
|
+
) -> model_types.SupportedLocalDataType:
|
94
|
+
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
|
95
|
+
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
96
|
+
target_method = getattr(model, target_method_name, None)
|
97
|
+
assert callable(target_method)
|
98
|
+
predictions_df = target_method(sample_input_data)
|
99
|
+
return predictions_df
|
100
|
+
|
101
|
+
model_meta = handlers_utils.validate_signature(
|
102
|
+
model=model,
|
103
|
+
model_meta=model_meta,
|
104
|
+
target_methods=target_methods,
|
105
|
+
sample_input_data=sample_input_data,
|
106
|
+
get_prediction_fn=get_prediction,
|
107
|
+
)
|
108
|
+
|
109
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
110
|
+
os.makedirs(model_blob_path, exist_ok=True)
|
111
|
+
|
112
|
+
model_save_path = os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
|
113
|
+
with open(model_save_path, "wb") as f:
|
114
|
+
cloudpickle.dump(model, f)
|
115
|
+
|
116
|
+
base_meta = model_blob_meta.ModelBlobMeta(
|
117
|
+
name=name,
|
118
|
+
model_type=cls.HANDLER_TYPE,
|
119
|
+
handler_version=cls.HANDLER_VERSION,
|
120
|
+
path=cls.MODELE_BLOB_FILE_OR_DIR,
|
121
|
+
options=model_meta_schema.LightGBMModelBlobOptions({"lightgbm_estimator_type": model.__class__.__name__}),
|
122
|
+
)
|
123
|
+
model_meta.models[name] = base_meta
|
124
|
+
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
125
|
+
|
126
|
+
model_meta.env.include_if_absent(
|
127
|
+
[
|
128
|
+
model_env.ModelDependency(requirement="lightgbm", pip_name="lightgbm"),
|
129
|
+
model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
|
130
|
+
],
|
131
|
+
check_local_version=True,
|
132
|
+
)
|
133
|
+
|
134
|
+
return None
|
135
|
+
|
136
|
+
@classmethod
|
137
|
+
def load_model(
|
138
|
+
cls,
|
139
|
+
name: str,
|
140
|
+
model_meta: model_meta_api.ModelMetadata,
|
141
|
+
model_blobs_dir_path: str,
|
142
|
+
**kwargs: Unpack[model_types.ModelLoadOption],
|
143
|
+
) -> Union["lightgbm.Booster", "lightgbm.LGBMModel"]:
|
144
|
+
import lightgbm
|
145
|
+
|
146
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
147
|
+
model_blobs_metadata = model_meta.models
|
148
|
+
model_blob_metadata = model_blobs_metadata[name]
|
149
|
+
model_blob_filename = model_blob_metadata.path
|
150
|
+
model_blob_file_path = os.path.join(model_blob_path, model_blob_filename)
|
151
|
+
|
152
|
+
model_blob_options = cast(model_meta_schema.LightGBMModelBlobOptions, model_blob_metadata.options)
|
153
|
+
if "lightgbm_estimator_type" not in model_blob_options:
|
154
|
+
raise ValueError("Missing field `lightgbm_estimator_type` in model blob metadata for type `lightgbm`")
|
155
|
+
|
156
|
+
lightgbm_estimator_type = model_blob_options["lightgbm_estimator_type"]
|
157
|
+
if not hasattr(lightgbm, lightgbm_estimator_type):
|
158
|
+
raise ValueError("Type of LightGBM estimator is not supported.")
|
159
|
+
|
160
|
+
assert os.path.isfile(model_blob_file_path) # saved model is a file
|
161
|
+
with open(model_blob_file_path, "rb") as f:
|
162
|
+
model = cloudpickle.load(f)
|
163
|
+
assert isinstance(model, getattr(lightgbm, lightgbm_estimator_type))
|
164
|
+
|
165
|
+
return model
|
166
|
+
|
167
|
+
@classmethod
|
168
|
+
def convert_as_custom_model(
|
169
|
+
cls,
|
170
|
+
raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"],
|
171
|
+
model_meta: model_meta_api.ModelMetadata,
|
172
|
+
**kwargs: Unpack[model_types.ModelLoadOption],
|
173
|
+
) -> custom_model.CustomModel:
|
174
|
+
import lightgbm
|
175
|
+
|
176
|
+
from snowflake.ml.model import custom_model
|
177
|
+
|
178
|
+
def _create_custom_model(
|
179
|
+
raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
|
180
|
+
model_meta: model_meta_api.ModelMetadata,
|
181
|
+
) -> Type[custom_model.CustomModel]:
|
182
|
+
def fn_factory(
|
183
|
+
raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
|
184
|
+
signature: model_signature.ModelSignature,
|
185
|
+
target_method: str,
|
186
|
+
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
187
|
+
@custom_model.inference_api
|
188
|
+
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
189
|
+
|
190
|
+
res = getattr(raw_model, target_method)(X)
|
191
|
+
|
192
|
+
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
|
193
|
+
# In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
|
194
|
+
# return a list of ndarrays. We need to deal them separately
|
195
|
+
df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
|
196
|
+
else:
|
197
|
+
df = pd.DataFrame(res)
|
198
|
+
|
199
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
200
|
+
|
201
|
+
return fn
|
202
|
+
|
203
|
+
type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
|
204
|
+
for target_method_name, sig in model_meta.signatures.items():
|
205
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
206
|
+
|
207
|
+
_LightGBMModel = type(
|
208
|
+
"_LightGBMModel",
|
209
|
+
(custom_model.CustomModel,),
|
210
|
+
type_method_dict,
|
211
|
+
)
|
212
|
+
|
213
|
+
return _LightGBMModel
|
214
|
+
|
215
|
+
_LightGBMModel = _create_custom_model(raw_model, model_meta)
|
216
|
+
lightgbm_model = _LightGBMModel(custom_model.ModelContext())
|
217
|
+
|
218
|
+
return lightgbm_model
|