snowflake-ml-python 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +35 -40
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/identifier.py +74 -7
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_core_requirements.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +25 -28
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +7 -4
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/_base.py +3 -1
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -8
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +27 -21
- snowflake/ml/model/_model_meta.py +33 -19
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +28 -15
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +79 -43
- snowflake/ml/modeling/cluster/affinity_propagation.py +79 -43
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +79 -43
- snowflake/ml/modeling/cluster/birch.py +79 -43
- snowflake/ml/modeling/cluster/bisecting_k_means.py +79 -43
- snowflake/ml/modeling/cluster/dbscan.py +79 -43
- snowflake/ml/modeling/cluster/feature_agglomeration.py +79 -43
- snowflake/ml/modeling/cluster/k_means.py +79 -43
- snowflake/ml/modeling/cluster/mean_shift.py +79 -43
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +79 -43
- snowflake/ml/modeling/cluster/optics.py +79 -43
- snowflake/ml/modeling/cluster/spectral_biclustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_clustering.py +79 -43
- snowflake/ml/modeling/cluster/spectral_coclustering.py +79 -43
- snowflake/ml/modeling/compose/column_transformer.py +79 -43
- snowflake/ml/modeling/compose/transformed_target_regressor.py +79 -43
- snowflake/ml/modeling/covariance/elliptic_envelope.py +79 -43
- snowflake/ml/modeling/covariance/empirical_covariance.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso.py +79 -43
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +79 -43
- snowflake/ml/modeling/covariance/ledoit_wolf.py +79 -43
- snowflake/ml/modeling/covariance/min_cov_det.py +79 -43
- snowflake/ml/modeling/covariance/oas.py +79 -43
- snowflake/ml/modeling/covariance/shrunk_covariance.py +79 -43
- snowflake/ml/modeling/decomposition/dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/factor_analysis.py +79 -43
- snowflake/ml/modeling/decomposition/fast_ica.py +79 -43
- snowflake/ml/modeling/decomposition/incremental_pca.py +79 -43
- snowflake/ml/modeling/decomposition/kernel_pca.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +79 -43
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/pca.py +79 -43
- snowflake/ml/modeling/decomposition/sparse_pca.py +79 -43
- snowflake/ml/modeling/decomposition/truncated_svd.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/bagging_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/isolation_forest.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/stacking_regressor.py +79 -43
- snowflake/ml/modeling/ensemble/voting_classifier.py +79 -43
- snowflake/ml/modeling/ensemble/voting_regressor.py +79 -43
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fdr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fpr.py +79 -43
- snowflake/ml/modeling/feature_selection/select_fwe.py +79 -43
- snowflake/ml/modeling/feature_selection/select_k_best.py +79 -43
- snowflake/ml/modeling/feature_selection/select_percentile.py +79 -43
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +79 -43
- snowflake/ml/modeling/feature_selection/variance_threshold.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +79 -43
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +79 -43
- snowflake/ml/modeling/impute/iterative_imputer.py +79 -43
- snowflake/ml/modeling/impute/knn_imputer.py +79 -43
- snowflake/ml/modeling/impute/missing_indicator.py +79 -43
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/nystroem.py +79 -43
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +79 -43
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +79 -43
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +79 -43
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +79 -43
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ard_regression.py +79 -43
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/gamma_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/huber_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/lars.py +79 -43
- snowflake/ml/modeling/linear_model/lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +79 -43
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +79 -43
- snowflake/ml/modeling/linear_model/linear_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression.py +79 -43
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +79 -43
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +79 -43
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/perceptron.py +79 -43
- snowflake/ml/modeling/linear_model/poisson_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ransac_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/ridge.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +79 -43
- snowflake/ml/modeling/linear_model/ridge_cv.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_classifier.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +79 -43
- snowflake/ml/modeling/linear_model/sgd_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +79 -43
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +79 -43
- snowflake/ml/modeling/manifold/isomap.py +79 -43
- snowflake/ml/modeling/manifold/mds.py +79 -43
- snowflake/ml/modeling/manifold/spectral_embedding.py +79 -43
- snowflake/ml/modeling/manifold/tsne.py +79 -43
- snowflake/ml/modeling/metrics/classification.py +6 -1
- snowflake/ml/modeling/metrics/regression.py +517 -9
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +79 -43
- snowflake/ml/modeling/mixture/gaussian_mixture.py +79 -43
- snowflake/ml/modeling/model_selection/grid_search_cv.py +79 -43
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +79 -43
- snowflake/ml/modeling/multiclass/output_code_classifier.py +79 -43
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/complement_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -43
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neighbors/kernel_density.py +79 -43
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_centroid.py +79 -43
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +79 -43
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +79 -43
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +79 -43
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_classifier.py +79 -43
- snowflake/ml/modeling/neural_network/mlp_regressor.py +79 -43
- snowflake/ml/modeling/pipeline/pipeline.py +24 -0
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +18 -19
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_propagation.py +79 -43
- snowflake/ml/modeling/semi_supervised/label_spreading.py +79 -43
- snowflake/ml/modeling/svm/linear_svc.py +79 -43
- snowflake/ml/modeling/svm/linear_svr.py +79 -43
- snowflake/ml/modeling/svm/nu_svc.py +79 -43
- snowflake/ml/modeling/svm/nu_svr.py +79 -43
- snowflake/ml/modeling/svm/svc.py +79 -43
- snowflake/ml/modeling/svm/svr.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/decision_tree_regressor.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_classifier.py +79 -43
- snowflake/ml/modeling/tree/extra_tree_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgb_regressor.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +79 -43
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +79 -43
- snowflake/ml/registry/model_registry.py +123 -121
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +50 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.1.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.1.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,186 @@
|
|
1
|
+
import os
|
2
|
+
import sys
|
3
|
+
from typing import TYPE_CHECKING, Callable, Optional, Type, cast
|
4
|
+
|
5
|
+
import cloudpickle
|
6
|
+
import pandas as pd
|
7
|
+
from typing_extensions import TypeGuard, Unpack
|
8
|
+
|
9
|
+
from snowflake.ml._internal import type_utils
|
10
|
+
from snowflake.ml.model import (
|
11
|
+
_model_meta as model_meta_api,
|
12
|
+
custom_model,
|
13
|
+
model_signature,
|
14
|
+
type_hints as model_types,
|
15
|
+
)
|
16
|
+
from snowflake.ml.model._handlers import _base
|
17
|
+
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
import torch
|
20
|
+
|
21
|
+
|
22
|
+
class _PyTorchHandler(_base._ModelHandler["torch.nn.Module"]):
|
23
|
+
"""Handler for PyTorch based model.
|
24
|
+
|
25
|
+
Currently torch.nn.Module based classes are supported.
|
26
|
+
"""
|
27
|
+
|
28
|
+
handler_type = "pytorch"
|
29
|
+
MODEL_BLOB_FILE = "model.pt"
|
30
|
+
DEFAULT_TARGET_METHODS = ["forward"]
|
31
|
+
|
32
|
+
@staticmethod
|
33
|
+
def can_handle(
|
34
|
+
model: model_types.SupportedModelType,
|
35
|
+
) -> TypeGuard["torch.nn.Module"]:
|
36
|
+
return type_utils.LazyType("torch.nn.Module").isinstance(model) and not type_utils.LazyType(
|
37
|
+
"torch.jit.ScriptModule"
|
38
|
+
).isinstance(model)
|
39
|
+
|
40
|
+
@staticmethod
|
41
|
+
def cast_model(
|
42
|
+
model: model_types.SupportedModelType,
|
43
|
+
) -> "torch.nn.Module":
|
44
|
+
import torch
|
45
|
+
|
46
|
+
assert isinstance(model, torch.nn.Module)
|
47
|
+
|
48
|
+
return cast(torch.nn.Module, model)
|
49
|
+
|
50
|
+
@staticmethod
|
51
|
+
def _save_model(
|
52
|
+
name: str,
|
53
|
+
model: "torch.nn.Module",
|
54
|
+
model_meta: model_meta_api.ModelMetadata,
|
55
|
+
model_blobs_dir_path: str,
|
56
|
+
sample_input: Optional[model_types.SupportedDataType] = None,
|
57
|
+
is_sub_model: Optional[bool] = False,
|
58
|
+
**kwargs: Unpack[model_types.PyTorchSaveOptions],
|
59
|
+
) -> None:
|
60
|
+
import torch
|
61
|
+
|
62
|
+
assert isinstance(model, torch.nn.Module)
|
63
|
+
|
64
|
+
if not is_sub_model:
|
65
|
+
target_methods = model_meta_api._get_target_methods(
|
66
|
+
model=model,
|
67
|
+
target_methods=kwargs.pop("target_methods", None),
|
68
|
+
default_target_methods=_PyTorchHandler.DEFAULT_TARGET_METHODS,
|
69
|
+
)
|
70
|
+
|
71
|
+
def get_prediction(
|
72
|
+
target_method_name: str, sample_input: "model_types.SupportedLocalDataType"
|
73
|
+
) -> model_types.SupportedLocalDataType:
|
74
|
+
if not model_signature._SeqOfPyTorchTensorHandler.can_handle(sample_input):
|
75
|
+
sample_input = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(
|
76
|
+
model_signature._convert_local_data_to_df(sample_input)
|
77
|
+
)
|
78
|
+
|
79
|
+
model.eval()
|
80
|
+
target_method = getattr(model, target_method_name, None)
|
81
|
+
assert callable(target_method)
|
82
|
+
with torch.no_grad():
|
83
|
+
predictions_df = target_method(sample_input)
|
84
|
+
return predictions_df
|
85
|
+
|
86
|
+
model_meta = model_meta_api._validate_signature(
|
87
|
+
model=model,
|
88
|
+
model_meta=model_meta,
|
89
|
+
target_methods=target_methods,
|
90
|
+
sample_input=sample_input,
|
91
|
+
get_prediction_fn=get_prediction,
|
92
|
+
)
|
93
|
+
|
94
|
+
# Torch.save using pickle will not pickle the model definition if defined in the top level of a module.
|
95
|
+
# Make sure that the module where the model is defined get pickled by value as well.
|
96
|
+
cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
|
97
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
98
|
+
os.makedirs(model_blob_path, exist_ok=True)
|
99
|
+
with open(os.path.join(model_blob_path, _PyTorchHandler.MODEL_BLOB_FILE), "wb") as f:
|
100
|
+
torch.save(model, f, pickle_module=cloudpickle)
|
101
|
+
base_meta = model_meta_api._ModelBlobMetadata(
|
102
|
+
name=name, model_type=_PyTorchHandler.handler_type, path=_PyTorchHandler.MODEL_BLOB_FILE
|
103
|
+
)
|
104
|
+
model_meta.models[name] = base_meta
|
105
|
+
model_meta._include_if_absent([model_meta_api.Dependency(conda_name="pytorch", pip_name="torch")])
|
106
|
+
|
107
|
+
@staticmethod
|
108
|
+
def _load_model(
|
109
|
+
name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
|
110
|
+
) -> "torch.nn.Module":
|
111
|
+
import torch
|
112
|
+
|
113
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
114
|
+
if not hasattr(model_meta, "models"):
|
115
|
+
raise ValueError("Ill model metadata found.")
|
116
|
+
model_blobs_metadata = model_meta.models
|
117
|
+
if name not in model_blobs_metadata:
|
118
|
+
raise ValueError(f"Blob of model {name} does not exist.")
|
119
|
+
model_blob_metadata = model_blobs_metadata[name]
|
120
|
+
model_blob_filename = model_blob_metadata.path
|
121
|
+
with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
|
122
|
+
m = torch.load(f)
|
123
|
+
assert isinstance(m, torch.nn.Module)
|
124
|
+
return m
|
125
|
+
|
126
|
+
@staticmethod
|
127
|
+
def _load_as_custom_model(
|
128
|
+
name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
|
129
|
+
) -> custom_model.CustomModel:
|
130
|
+
"""Create a custom model class wrap for unified interface when being deployed. The predict method will be
|
131
|
+
re-targeted based on target_method metadata.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
name: Name of the model.
|
135
|
+
model_meta: The model metadata.
|
136
|
+
model_blobs_dir_path: Directory path to the whole model.
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
The model object as a custom model.
|
140
|
+
"""
|
141
|
+
import torch
|
142
|
+
|
143
|
+
from snowflake.ml.model import custom_model
|
144
|
+
|
145
|
+
def _create_custom_model(
|
146
|
+
raw_model: "torch.nn.Module",
|
147
|
+
model_meta: model_meta_api.ModelMetadata,
|
148
|
+
) -> Type[custom_model.CustomModel]:
|
149
|
+
def fn_factory(
|
150
|
+
raw_model: "torch.nn.Module",
|
151
|
+
signature: model_signature.ModelSignature,
|
152
|
+
target_method: str,
|
153
|
+
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
154
|
+
@custom_model.inference_api
|
155
|
+
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
156
|
+
if X.isnull().any(axis=None):
|
157
|
+
raise ValueError("Tensor cannot handle null values.")
|
158
|
+
|
159
|
+
raw_model.eval()
|
160
|
+
t = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
161
|
+
|
162
|
+
with torch.no_grad():
|
163
|
+
res = getattr(raw_model, target_method)(t)
|
164
|
+
return model_signature._rename_pandas_df(
|
165
|
+
data=model_signature._SeqOfPyTorchTensorHandler.convert_to_df(res), features=signature.outputs
|
166
|
+
)
|
167
|
+
|
168
|
+
return fn
|
169
|
+
|
170
|
+
type_method_dict = {}
|
171
|
+
for target_method_name, sig in model_meta.signatures.items():
|
172
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
173
|
+
|
174
|
+
_PyTorchModel = type(
|
175
|
+
"_PyTorchModel",
|
176
|
+
(custom_model.CustomModel,),
|
177
|
+
type_method_dict,
|
178
|
+
)
|
179
|
+
|
180
|
+
return _PyTorchModel
|
181
|
+
|
182
|
+
raw_model = _PyTorchHandler._load_model(name, model_meta, model_blobs_dir_path)
|
183
|
+
_PyTorchModel = _create_custom_model(raw_model, model_meta)
|
184
|
+
pytorch_model = _PyTorchModel(custom_model.ModelContext())
|
185
|
+
|
186
|
+
return pytorch_model
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import TYPE_CHECKING, Callable, Optional,
|
2
|
+
from typing import TYPE_CHECKING, Callable, Optional, Type, Union, cast
|
3
3
|
|
4
4
|
import cloudpickle
|
5
5
|
import numpy as np
|
@@ -10,6 +10,7 @@ from snowflake.ml._internal import type_utils
|
|
10
10
|
from snowflake.ml.model import (
|
11
11
|
_model_meta as model_meta_api,
|
12
12
|
custom_model,
|
13
|
+
model_signature,
|
13
14
|
type_hints as model_types,
|
14
15
|
)
|
15
16
|
from snowflake.ml.model._handlers import _base
|
@@ -80,6 +81,9 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
|
|
80
81
|
def get_prediction(
|
81
82
|
target_method_name: str, sample_input: model_types.SupportedLocalDataType
|
82
83
|
) -> model_types.SupportedLocalDataType:
|
84
|
+
if not isinstance(sample_input, (pd.DataFrame, np.ndarray)):
|
85
|
+
sample_input = model_signature._convert_local_data_to_df(sample_input)
|
86
|
+
|
83
87
|
target_method = getattr(model, target_method_name, None)
|
84
88
|
assert callable(target_method)
|
85
89
|
predictions_df = target_method(sample_input)
|
@@ -101,6 +105,7 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
|
|
101
105
|
name=name, model_type=_SKLModelHandler.handler_type, path=_SKLModelHandler.MODEL_BLOB_FILE
|
102
106
|
)
|
103
107
|
model_meta.models[name] = base_meta
|
108
|
+
model_meta._include_if_absent([model_meta_api.Dependency(conda_name="scikit-learn", pip_name="scikit-learn")])
|
104
109
|
|
105
110
|
@staticmethod
|
106
111
|
def _load_model(
|
@@ -146,7 +151,7 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
|
|
146
151
|
) -> Type[custom_model.CustomModel]:
|
147
152
|
def fn_factory(
|
148
153
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
149
|
-
|
154
|
+
signature: model_signature.ModelSignature,
|
150
155
|
target_method: str,
|
151
156
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
152
157
|
@custom_model.inference_api
|
@@ -155,17 +160,18 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
|
|
155
160
|
|
156
161
|
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
|
157
162
|
# In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
|
158
|
-
# return a list of ndarrays. We need to
|
159
|
-
|
160
|
-
|
163
|
+
# return a list of ndarrays. We need to deal them seperately
|
164
|
+
df = model_signature._SeqOfNumpyArrayHandler.convert_to_df(res)
|
165
|
+
else:
|
166
|
+
df = pd.DataFrame(res)
|
167
|
+
|
168
|
+
return model_signature._rename_pandas_df(df, signature.outputs)
|
161
169
|
|
162
170
|
return fn
|
163
171
|
|
164
172
|
type_method_dict = {}
|
165
173
|
for target_method_name, sig in model_meta.signatures.items():
|
166
|
-
type_method_dict[target_method_name] = fn_factory(
|
167
|
-
raw_model, [spec.name for spec in sig.outputs], target_method_name
|
168
|
-
)
|
174
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
169
175
|
|
170
176
|
_SKLModel = type(
|
171
177
|
"_SKLModel",
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import TYPE_CHECKING, Callable, Optional,
|
2
|
+
from typing import TYPE_CHECKING, Callable, Optional, Type, cast
|
3
3
|
|
4
4
|
import cloudpickle
|
5
5
|
import numpy as np
|
@@ -10,6 +10,7 @@ from snowflake.ml._internal import type_utils
|
|
10
10
|
from snowflake.ml.model import (
|
11
11
|
_model_meta as model_meta_api,
|
12
12
|
custom_model,
|
13
|
+
model_signature,
|
13
14
|
type_hints as model_types,
|
14
15
|
)
|
15
16
|
from snowflake.ml.model._handlers import _base
|
@@ -81,6 +82,9 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
|
|
81
82
|
def get_prediction(
|
82
83
|
target_method_name: str, sample_input: model_types.SupportedLocalDataType
|
83
84
|
) -> model_types.SupportedLocalDataType:
|
85
|
+
if not isinstance(sample_input, (pd.DataFrame,)):
|
86
|
+
sample_input = model_signature._convert_local_data_to_df(sample_input)
|
87
|
+
|
84
88
|
target_method = getattr(model, target_method_name, None)
|
85
89
|
assert callable(target_method)
|
86
90
|
predictions_df = target_method(sample_input)
|
@@ -106,7 +110,7 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
|
|
106
110
|
model_dependencies = model._get_dependencies()
|
107
111
|
for dep in model_dependencies:
|
108
112
|
pkg_name = dep.split("==")[0]
|
109
|
-
_include_if_absent_pkgs.append((pkg_name, pkg_name))
|
113
|
+
_include_if_absent_pkgs.append(model_meta_api.Dependency(conda_name=pkg_name, pip_name=pkg_name))
|
110
114
|
model_meta._include_if_absent(_include_if_absent_pkgs)
|
111
115
|
|
112
116
|
@staticmethod
|
@@ -150,7 +154,7 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
|
|
150
154
|
) -> Type[custom_model.CustomModel]:
|
151
155
|
def fn_factory(
|
152
156
|
raw_model: "BaseEstimator",
|
153
|
-
|
157
|
+
signature: model_signature.ModelSignature,
|
154
158
|
target_method: str,
|
155
159
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
156
160
|
@custom_model.inference_api
|
@@ -159,17 +163,18 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
|
|
159
163
|
|
160
164
|
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
|
161
165
|
# In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
|
162
|
-
# return a list of ndarrays. We need to
|
163
|
-
|
164
|
-
|
166
|
+
# return a list of ndarrays. We need to deal them seperately
|
167
|
+
df = model_signature._SeqOfNumpyArrayHandler.convert_to_df(res)
|
168
|
+
else:
|
169
|
+
df = pd.DataFrame(res)
|
170
|
+
|
171
|
+
return model_signature._rename_pandas_df(df, signature.outputs)
|
165
172
|
|
166
173
|
return fn
|
167
174
|
|
168
175
|
type_method_dict = {}
|
169
176
|
for target_method_name, sig in model_meta.signatures.items():
|
170
|
-
type_method_dict[target_method_name] = fn_factory(
|
171
|
-
raw_model, [spec.name for spec in sig.outputs], target_method_name
|
172
|
-
)
|
177
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
173
178
|
|
174
179
|
_SnowMLModel = type(
|
175
180
|
"_SnowMLModel",
|
@@ -0,0 +1,180 @@
|
|
1
|
+
import os
|
2
|
+
from typing import TYPE_CHECKING, Callable, Optional, Type, cast
|
3
|
+
|
4
|
+
import pandas as pd
|
5
|
+
from typing_extensions import TypeGuard, Unpack
|
6
|
+
|
7
|
+
from snowflake.ml._internal import type_utils
|
8
|
+
from snowflake.ml.model import (
|
9
|
+
_model_meta as model_meta_api,
|
10
|
+
custom_model,
|
11
|
+
model_signature,
|
12
|
+
type_hints as model_types,
|
13
|
+
)
|
14
|
+
from snowflake.ml.model._handlers import _base
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
import torch
|
18
|
+
|
19
|
+
|
20
|
+
class _TorchScriptHandler(_base._ModelHandler["torch.jit.ScriptModule"]): # type:ignore[name-defined]
|
21
|
+
"""Handler for PyTorch JIT based model.
|
22
|
+
|
23
|
+
Currently torch.jit.ScriptModule based classes are supported.
|
24
|
+
"""
|
25
|
+
|
26
|
+
handler_type = "torchscript"
|
27
|
+
MODEL_BLOB_FILE = "model.pt"
|
28
|
+
DEFAULT_TARGET_METHODS = ["forward"]
|
29
|
+
|
30
|
+
@staticmethod
|
31
|
+
def can_handle(
|
32
|
+
model: model_types.SupportedModelType,
|
33
|
+
) -> TypeGuard["torch.jit.ScriptModule"]: # type:ignore[name-defined]
|
34
|
+
return type_utils.LazyType("torch.jit.ScriptModule").isinstance(model)
|
35
|
+
|
36
|
+
@staticmethod
|
37
|
+
def cast_model(
|
38
|
+
model: model_types.SupportedModelType,
|
39
|
+
) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
|
40
|
+
import torch
|
41
|
+
|
42
|
+
assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
|
43
|
+
|
44
|
+
return cast(torch.jit.ScriptModule, model) # type:ignore[name-defined]
|
45
|
+
|
46
|
+
@staticmethod
|
47
|
+
def _save_model(
|
48
|
+
name: str,
|
49
|
+
model: "torch.jit.ScriptModule", # type:ignore[name-defined]
|
50
|
+
model_meta: model_meta_api.ModelMetadata,
|
51
|
+
model_blobs_dir_path: str,
|
52
|
+
sample_input: Optional[model_types.SupportedDataType] = None,
|
53
|
+
is_sub_model: Optional[bool] = False,
|
54
|
+
**kwargs: Unpack[model_types.TorchScriptSaveOptions],
|
55
|
+
) -> None:
|
56
|
+
import torch
|
57
|
+
|
58
|
+
assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
|
59
|
+
|
60
|
+
if not is_sub_model:
|
61
|
+
target_methods = model_meta_api._get_target_methods(
|
62
|
+
model=model,
|
63
|
+
target_methods=kwargs.pop("target_methods", None),
|
64
|
+
default_target_methods=_TorchScriptHandler.DEFAULT_TARGET_METHODS,
|
65
|
+
)
|
66
|
+
|
67
|
+
def get_prediction(
|
68
|
+
target_method_name: str, sample_input: "model_types.SupportedLocalDataType"
|
69
|
+
) -> model_types.SupportedLocalDataType:
|
70
|
+
if not model_signature._SeqOfPyTorchTensorHandler.can_handle(sample_input):
|
71
|
+
sample_input = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(
|
72
|
+
model_signature._convert_local_data_to_df(sample_input)
|
73
|
+
)
|
74
|
+
|
75
|
+
model.eval()
|
76
|
+
target_method = getattr(model, target_method_name, None)
|
77
|
+
assert callable(target_method)
|
78
|
+
with torch.no_grad():
|
79
|
+
predictions_df = target_method(sample_input)
|
80
|
+
return predictions_df
|
81
|
+
|
82
|
+
model_meta = model_meta_api._validate_signature(
|
83
|
+
model=model,
|
84
|
+
model_meta=model_meta,
|
85
|
+
target_methods=target_methods,
|
86
|
+
sample_input=sample_input,
|
87
|
+
get_prediction_fn=get_prediction,
|
88
|
+
)
|
89
|
+
|
90
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
91
|
+
os.makedirs(model_blob_path, exist_ok=True)
|
92
|
+
with open(os.path.join(model_blob_path, _TorchScriptHandler.MODEL_BLOB_FILE), "wb") as f:
|
93
|
+
torch.jit.save(model, f) # type:ignore[attr-defined]
|
94
|
+
base_meta = model_meta_api._ModelBlobMetadata(
|
95
|
+
name=name, model_type=_TorchScriptHandler.handler_type, path=_TorchScriptHandler.MODEL_BLOB_FILE
|
96
|
+
)
|
97
|
+
model_meta.models[name] = base_meta
|
98
|
+
model_meta._include_if_absent([model_meta_api.Dependency(conda_name="pytorch", pip_name="torch")])
|
99
|
+
|
100
|
+
@staticmethod
|
101
|
+
def _load_model(
|
102
|
+
name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
|
103
|
+
) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
|
104
|
+
import torch
|
105
|
+
|
106
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
107
|
+
if not hasattr(model_meta, "models"):
|
108
|
+
raise ValueError("Ill model metadata found.")
|
109
|
+
model_blobs_metadata = model_meta.models
|
110
|
+
if name not in model_blobs_metadata:
|
111
|
+
raise ValueError(f"Blob of model {name} does not exist.")
|
112
|
+
model_blob_metadata = model_blobs_metadata[name]
|
113
|
+
model_blob_filename = model_blob_metadata.path
|
114
|
+
with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
|
115
|
+
m = torch.jit.load(f) # type:ignore[attr-defined]
|
116
|
+
assert isinstance(m, torch.jit.ScriptModule) # type:ignore[attr-defined]
|
117
|
+
return m
|
118
|
+
|
119
|
+
@staticmethod
|
120
|
+
def _load_as_custom_model(
|
121
|
+
name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
|
122
|
+
) -> custom_model.CustomModel:
|
123
|
+
"""Create a custom model class wrap for unified interface when being deployed. The predict method will be
|
124
|
+
re-targeted based on target_method metadata.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
name: Name of the model.
|
128
|
+
model_meta: The model metadata.
|
129
|
+
model_blobs_dir_path: Directory path to the whole model.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
The model object as a custom model.
|
133
|
+
"""
|
134
|
+
from snowflake.ml.model import custom_model
|
135
|
+
|
136
|
+
def _create_custom_model(
|
137
|
+
raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
|
138
|
+
model_meta: model_meta_api.ModelMetadata,
|
139
|
+
) -> Type[custom_model.CustomModel]:
|
140
|
+
def fn_factory(
|
141
|
+
raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
|
142
|
+
signature: model_signature.ModelSignature,
|
143
|
+
target_method: str,
|
144
|
+
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
145
|
+
@custom_model.inference_api
|
146
|
+
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
147
|
+
if X.isnull().any(axis=None):
|
148
|
+
raise ValueError("Tensor cannot handle null values.")
|
149
|
+
|
150
|
+
import torch
|
151
|
+
|
152
|
+
raw_model.eval()
|
153
|
+
|
154
|
+
t = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
155
|
+
|
156
|
+
with torch.no_grad():
|
157
|
+
res = getattr(raw_model, target_method)(t)
|
158
|
+
return model_signature._rename_pandas_df(
|
159
|
+
data=model_signature._SeqOfPyTorchTensorHandler.convert_to_df(res), features=signature.outputs
|
160
|
+
)
|
161
|
+
|
162
|
+
return fn
|
163
|
+
|
164
|
+
type_method_dict = {}
|
165
|
+
for target_method_name, sig in model_meta.signatures.items():
|
166
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
167
|
+
|
168
|
+
_TorchScriptModel = type(
|
169
|
+
"_TorchScriptModel",
|
170
|
+
(custom_model.CustomModel,),
|
171
|
+
type_method_dict,
|
172
|
+
)
|
173
|
+
|
174
|
+
return _TorchScriptModel
|
175
|
+
|
176
|
+
raw_model = _TorchScriptHandler._load_model(name, model_meta, model_blobs_dir_path)
|
177
|
+
_TorchScriptModel = _create_custom_model(raw_model, model_meta)
|
178
|
+
torchscript_model = _TorchScriptModel(custom_model.ModelContext())
|
179
|
+
|
180
|
+
return torchscript_model
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
2
|
import os
|
3
|
-
from typing import TYPE_CHECKING, Callable, Optional,
|
3
|
+
from typing import TYPE_CHECKING, Callable, Optional, Type, Union
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
import pandas as pd
|
@@ -10,6 +10,7 @@ from snowflake.ml._internal import type_utils
|
|
10
10
|
from snowflake.ml.model import (
|
11
11
|
_model_meta as model_meta_api,
|
12
12
|
custom_model,
|
13
|
+
model_signature,
|
13
14
|
type_hints as model_types,
|
14
15
|
)
|
15
16
|
from snowflake.ml.model._handlers import _base
|
@@ -72,6 +73,9 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
|
|
72
73
|
def get_prediction(
|
73
74
|
target_method_name: str, sample_input: model_types.SupportedLocalDataType
|
74
75
|
) -> model_types.SupportedLocalDataType:
|
76
|
+
if not isinstance(sample_input, (pd.DataFrame, np.ndarray)):
|
77
|
+
sample_input = model_signature._convert_local_data_to_df(sample_input)
|
78
|
+
|
75
79
|
target_method = getattr(model, target_method_name, None)
|
76
80
|
assert callable(target_method)
|
77
81
|
predictions_df = target_method(sample_input)
|
@@ -95,7 +99,12 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
|
|
95
99
|
options={"xgb_estimator_type": model.__class__.__name__},
|
96
100
|
)
|
97
101
|
model_meta.models[name] = base_meta
|
98
|
-
model_meta._include_if_absent(
|
102
|
+
model_meta._include_if_absent(
|
103
|
+
[
|
104
|
+
model_meta_api.Dependency(conda_name="scikit-learn", pip_name="scikit-learn"),
|
105
|
+
model_meta_api.Dependency(conda_name="xgboost", pip_name="xgboost"),
|
106
|
+
]
|
107
|
+
)
|
99
108
|
|
100
109
|
@staticmethod
|
101
110
|
def _load_model(
|
@@ -143,7 +152,7 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
|
|
143
152
|
) -> Type[custom_model.CustomModel]:
|
144
153
|
def fn_factory(
|
145
154
|
raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
|
146
|
-
|
155
|
+
signature: model_signature.ModelSignature,
|
147
156
|
target_method: str,
|
148
157
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
149
158
|
@custom_model.inference_api
|
@@ -152,17 +161,18 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
|
|
152
161
|
|
153
162
|
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
|
154
163
|
# In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
|
155
|
-
# return a list of ndarrays. We need to
|
156
|
-
|
157
|
-
|
164
|
+
# return a list of ndarrays. We need to deal them seperately
|
165
|
+
df = model_signature._SeqOfNumpyArrayHandler.convert_to_df(res)
|
166
|
+
else:
|
167
|
+
df = pd.DataFrame(res)
|
168
|
+
|
169
|
+
return model_signature._rename_pandas_df(df, signature.outputs)
|
158
170
|
|
159
171
|
return fn
|
160
172
|
|
161
173
|
type_method_dict = {}
|
162
174
|
for target_method_name, sig in model_meta.signatures.items():
|
163
|
-
type_method_dict[target_method_name] = fn_factory(
|
164
|
-
raw_model, [spec.name for spec in sig.outputs], target_method_name
|
165
|
-
)
|
175
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
166
176
|
|
167
177
|
_XGBModel = type(
|
168
178
|
"_XGBModel",
|
snowflake/ml/model/_model.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
1
|
import os
|
2
|
+
import posixpath
|
2
3
|
import tempfile
|
3
4
|
import warnings
|
4
5
|
from types import ModuleType
|
5
|
-
from typing import Dict, List, Literal, Optional, Tuple, Union, overload
|
6
|
+
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union, overload
|
6
7
|
|
7
|
-
from snowflake.ml._internal import file_utils
|
8
|
+
from snowflake.ml._internal import file_utils, type_utils
|
8
9
|
from snowflake.ml.model import (
|
9
10
|
_env,
|
10
11
|
_model_handler,
|
@@ -13,9 +14,11 @@ from snowflake.ml.model import (
|
|
13
14
|
model_signature,
|
14
15
|
type_hints as model_types,
|
15
16
|
)
|
16
|
-
from snowflake.ml.modeling.framework import base
|
17
17
|
from snowflake.snowpark import FileOperation, Session
|
18
18
|
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
from snowflake.ml.modeling.framework import base
|
21
|
+
|
19
22
|
MODEL_BLOBS_DIR = "models"
|
20
23
|
|
21
24
|
|
@@ -23,7 +26,7 @@ MODEL_BLOBS_DIR = "models"
|
|
23
26
|
def save_model(
|
24
27
|
*,
|
25
28
|
name: str,
|
26
|
-
model: base.BaseEstimator,
|
29
|
+
model: "base.BaseEstimator",
|
27
30
|
model_dir_path: str,
|
28
31
|
metadata: Optional[Dict[str, str]] = None,
|
29
32
|
conda_dependencies: Optional[List[str]] = None,
|
@@ -135,7 +138,7 @@ def save_model(
|
|
135
138
|
def save_model(
|
136
139
|
*,
|
137
140
|
name: str,
|
138
|
-
model: base.BaseEstimator,
|
141
|
+
model: "base.BaseEstimator",
|
139
142
|
session: Session,
|
140
143
|
model_stage_file_path: str,
|
141
144
|
metadata: Optional[Dict[str, str]] = None,
|
@@ -322,9 +325,11 @@ def save_model(
|
|
322
325
|
+ f"{'None' if model_stage_file_path is None else 'specified'} at the same time."
|
323
326
|
)
|
324
327
|
|
325
|
-
if (
|
326
|
-
(signatures is
|
327
|
-
|
328
|
+
if (
|
329
|
+
(signatures is None)
|
330
|
+
and (sample_input is None)
|
331
|
+
and not type_utils.LazyType("snowflake.ml.modeling.framework.base.BaseEstimator").isinstance(model)
|
332
|
+
) or ((signatures is not None) and (sample_input is not None)):
|
328
333
|
raise ValueError(
|
329
334
|
"Signatures and sample_input both cannot be "
|
330
335
|
+ f"{'None for local model' if signatures is None else 'specified'} at the same time."
|
@@ -360,8 +365,8 @@ def save_model(
|
|
360
365
|
)
|
361
366
|
|
362
367
|
assert session and model_stage_file_path
|
363
|
-
if
|
364
|
-
raise ValueError("Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
|
368
|
+
if posixpath.splitext(model_stage_file_path)[1] != ".zip":
|
369
|
+
raise ValueError(f"Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
|
365
370
|
|
366
371
|
with tempfile.TemporaryDirectory() as temp_local_model_dir_path:
|
367
372
|
meta = _save(
|
@@ -397,15 +402,15 @@ def _save(
|
|
397
402
|
name: str,
|
398
403
|
model: model_types.SupportedModelType,
|
399
404
|
local_dir_path: str,
|
400
|
-
signatures: Optional[Dict[str, model_signature.ModelSignature]]
|
401
|
-
sample_input: Optional[model_types.SupportedDataType]
|
402
|
-
metadata: Optional[Dict[str, str]]
|
403
|
-
conda_dependencies: Optional[List[str]]
|
404
|
-
pip_requirements: Optional[List[str]]
|
405
|
-
python_version: Optional[str]
|
406
|
-
ext_modules: Optional[List[ModuleType]]
|
407
|
-
code_paths: Optional[List[str]]
|
408
|
-
options:
|
405
|
+
signatures: Optional[Dict[str, model_signature.ModelSignature]],
|
406
|
+
sample_input: Optional[model_types.SupportedDataType],
|
407
|
+
metadata: Optional[Dict[str, str]],
|
408
|
+
conda_dependencies: Optional[List[str]],
|
409
|
+
pip_requirements: Optional[List[str]],
|
410
|
+
python_version: Optional[str],
|
411
|
+
ext_modules: Optional[List[ModuleType]],
|
412
|
+
code_paths: Optional[List[str]],
|
413
|
+
options: model_types.ModelSaveOption,
|
409
414
|
) -> _model_meta.ModelMetadata:
|
410
415
|
local_dir_path = os.path.normpath(local_dir_path)
|
411
416
|
|
@@ -423,6 +428,7 @@ def _save(
|
|
423
428
|
conda_dependencies=conda_dependencies,
|
424
429
|
pip_requirements=pip_requirements,
|
425
430
|
python_version=python_version,
|
431
|
+
**options,
|
426
432
|
) as meta:
|
427
433
|
model_blobs_path = os.path.join(local_dir_path, MODEL_BLOBS_DIR)
|
428
434
|
os.makedirs(model_blobs_path, exist_ok=True)
|
@@ -538,8 +544,8 @@ def load_model(
|
|
538
544
|
return _load(local_dir_path=model_dir_path, meta_only=meta_only)
|
539
545
|
|
540
546
|
assert session and model_stage_file_path
|
541
|
-
if
|
542
|
-
raise ValueError("Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
|
547
|
+
if posixpath.splitext(model_stage_file_path)[1] != ".zip":
|
548
|
+
raise ValueError(f"Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
|
543
549
|
|
544
550
|
fo = FileOperation(session=session)
|
545
551
|
zf = fo.get_stream(model_stage_file_path)
|