snowflake-ml-python 1.0.2__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +29 -7
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +24 -6
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +5 -2
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -9
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +3 -2
- snowflake/ml/model/_model_meta.py +12 -7
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +23 -4
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -26
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -26
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -26
- snowflake/ml/modeling/cluster/birch.py +51 -26
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -26
- snowflake/ml/modeling/cluster/dbscan.py +51 -26
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -26
- snowflake/ml/modeling/cluster/k_means.py +51 -26
- snowflake/ml/modeling/cluster/mean_shift.py +51 -26
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -26
- snowflake/ml/modeling/cluster/optics.py +51 -26
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -26
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -26
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -26
- snowflake/ml/modeling/compose/column_transformer.py +51 -26
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -26
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -26
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -26
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -26
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -26
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -26
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -26
- snowflake/ml/modeling/covariance/oas.py +51 -26
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -26
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -26
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -26
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -26
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -26
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -26
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -26
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -26
- snowflake/ml/modeling/decomposition/pca.py +51 -26
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -26
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -26
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -26
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -26
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -26
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -26
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -26
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -26
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -26
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -26
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -26
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -26
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -26
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -26
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -26
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -26
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -26
- snowflake/ml/modeling/impute/knn_imputer.py +51 -26
- snowflake/ml/modeling/impute/missing_indicator.py +51 -26
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -26
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -26
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -26
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -26
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -26
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -26
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -26
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -26
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -26
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -26
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -26
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/lars.py +51 -26
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -26
- snowflake/ml/modeling/linear_model/lasso.py +51 -26
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -26
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -26
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -26
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -26
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -26
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -26
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -26
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -26
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -26
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -26
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -26
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -26
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -26
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/perceptron.py +51 -26
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/ridge.py +51 -26
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -26
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -26
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -26
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -26
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -26
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -26
- snowflake/ml/modeling/manifold/isomap.py +51 -26
- snowflake/ml/modeling/manifold/mds.py +51 -26
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -26
- snowflake/ml/modeling/manifold/tsne.py +51 -26
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -26
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -26
- snowflake/ml/modeling/model_selection/grid_search_cv.py +51 -26
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +51 -26
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -26
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -26
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -26
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -26
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -26
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -26
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -26
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -26
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -26
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -26
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -26
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -26
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -26
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -26
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -26
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -26
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -26
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -26
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -26
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -26
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -26
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -26
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -26
- snowflake/ml/modeling/svm/linear_svc.py +51 -26
- snowflake/ml/modeling/svm/linear_svr.py +51 -26
- snowflake/ml/modeling/svm/nu_svc.py +51 -26
- snowflake/ml/modeling/svm/nu_svr.py +51 -26
- snowflake/ml/modeling/svm/svc.py +51 -26
- snowflake/ml/modeling/svm/svr.py +51 -26
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -26
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -26
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -26
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -26
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -26
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -26
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -26
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -26
- snowflake/ml/registry/model_registry.py +74 -56
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.2.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +27 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.2.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.2.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,180 @@
|
|
1
|
+
import os
|
2
|
+
from typing import TYPE_CHECKING, Callable, Optional, Type, cast
|
3
|
+
|
4
|
+
import pandas as pd
|
5
|
+
from typing_extensions import TypeGuard, Unpack
|
6
|
+
|
7
|
+
from snowflake.ml._internal import type_utils
|
8
|
+
from snowflake.ml.model import (
|
9
|
+
_model_meta as model_meta_api,
|
10
|
+
custom_model,
|
11
|
+
model_signature,
|
12
|
+
type_hints as model_types,
|
13
|
+
)
|
14
|
+
from snowflake.ml.model._handlers import _base
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
import torch
|
18
|
+
|
19
|
+
|
20
|
+
class _TorchScriptHandler(_base._ModelHandler["torch.jit.ScriptModule"]): # type:ignore[name-defined]
|
21
|
+
"""Handler for PyTorch JIT based model.
|
22
|
+
|
23
|
+
Currently torch.jit.ScriptModule based classes are supported.
|
24
|
+
"""
|
25
|
+
|
26
|
+
handler_type = "torchscript"
|
27
|
+
MODEL_BLOB_FILE = "model.pt"
|
28
|
+
DEFAULT_TARGET_METHODS = ["forward"]
|
29
|
+
|
30
|
+
@staticmethod
|
31
|
+
def can_handle(
|
32
|
+
model: model_types.SupportedModelType,
|
33
|
+
) -> TypeGuard["torch.jit.ScriptModule"]: # type:ignore[name-defined]
|
34
|
+
return type_utils.LazyType("torch.jit.ScriptModule").isinstance(model)
|
35
|
+
|
36
|
+
@staticmethod
|
37
|
+
def cast_model(
|
38
|
+
model: model_types.SupportedModelType,
|
39
|
+
) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
|
40
|
+
import torch
|
41
|
+
|
42
|
+
assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
|
43
|
+
|
44
|
+
return cast(torch.jit.ScriptModule, model) # type:ignore[name-defined]
|
45
|
+
|
46
|
+
@staticmethod
|
47
|
+
def _save_model(
|
48
|
+
name: str,
|
49
|
+
model: "torch.jit.ScriptModule", # type:ignore[name-defined]
|
50
|
+
model_meta: model_meta_api.ModelMetadata,
|
51
|
+
model_blobs_dir_path: str,
|
52
|
+
sample_input: Optional[model_types.SupportedDataType] = None,
|
53
|
+
is_sub_model: Optional[bool] = False,
|
54
|
+
**kwargs: Unpack[model_types.TorchScriptSaveOptions],
|
55
|
+
) -> None:
|
56
|
+
import torch
|
57
|
+
|
58
|
+
assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
|
59
|
+
|
60
|
+
if not is_sub_model:
|
61
|
+
target_methods = model_meta_api._get_target_methods(
|
62
|
+
model=model,
|
63
|
+
target_methods=kwargs.pop("target_methods", None),
|
64
|
+
default_target_methods=_TorchScriptHandler.DEFAULT_TARGET_METHODS,
|
65
|
+
)
|
66
|
+
|
67
|
+
def get_prediction(
|
68
|
+
target_method_name: str, sample_input: "model_types.SupportedLocalDataType"
|
69
|
+
) -> model_types.SupportedLocalDataType:
|
70
|
+
if not model_signature._SeqOfPyTorchTensorHandler.can_handle(sample_input):
|
71
|
+
sample_input = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(
|
72
|
+
model_signature._convert_local_data_to_df(sample_input)
|
73
|
+
)
|
74
|
+
|
75
|
+
model.eval()
|
76
|
+
target_method = getattr(model, target_method_name, None)
|
77
|
+
assert callable(target_method)
|
78
|
+
with torch.no_grad():
|
79
|
+
predictions_df = target_method(sample_input)
|
80
|
+
return predictions_df
|
81
|
+
|
82
|
+
model_meta = model_meta_api._validate_signature(
|
83
|
+
model=model,
|
84
|
+
model_meta=model_meta,
|
85
|
+
target_methods=target_methods,
|
86
|
+
sample_input=sample_input,
|
87
|
+
get_prediction_fn=get_prediction,
|
88
|
+
)
|
89
|
+
|
90
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
91
|
+
os.makedirs(model_blob_path, exist_ok=True)
|
92
|
+
with open(os.path.join(model_blob_path, _TorchScriptHandler.MODEL_BLOB_FILE), "wb") as f:
|
93
|
+
torch.jit.save(model, f) # type:ignore[attr-defined]
|
94
|
+
base_meta = model_meta_api._ModelBlobMetadata(
|
95
|
+
name=name, model_type=_TorchScriptHandler.handler_type, path=_TorchScriptHandler.MODEL_BLOB_FILE
|
96
|
+
)
|
97
|
+
model_meta.models[name] = base_meta
|
98
|
+
model_meta._include_if_absent([model_meta_api.Dependency(conda_name="pytorch", pip_name="torch")])
|
99
|
+
|
100
|
+
@staticmethod
|
101
|
+
def _load_model(
|
102
|
+
name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
|
103
|
+
) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
|
104
|
+
import torch
|
105
|
+
|
106
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
107
|
+
if not hasattr(model_meta, "models"):
|
108
|
+
raise ValueError("Ill model metadata found.")
|
109
|
+
model_blobs_metadata = model_meta.models
|
110
|
+
if name not in model_blobs_metadata:
|
111
|
+
raise ValueError(f"Blob of model {name} does not exist.")
|
112
|
+
model_blob_metadata = model_blobs_metadata[name]
|
113
|
+
model_blob_filename = model_blob_metadata.path
|
114
|
+
with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
|
115
|
+
m = torch.jit.load(f) # type:ignore[attr-defined]
|
116
|
+
assert isinstance(m, torch.jit.ScriptModule) # type:ignore[attr-defined]
|
117
|
+
return m
|
118
|
+
|
119
|
+
@staticmethod
|
120
|
+
def _load_as_custom_model(
|
121
|
+
name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
|
122
|
+
) -> custom_model.CustomModel:
|
123
|
+
"""Create a custom model class wrap for unified interface when being deployed. The predict method will be
|
124
|
+
re-targeted based on target_method metadata.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
name: Name of the model.
|
128
|
+
model_meta: The model metadata.
|
129
|
+
model_blobs_dir_path: Directory path to the whole model.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
The model object as a custom model.
|
133
|
+
"""
|
134
|
+
from snowflake.ml.model import custom_model
|
135
|
+
|
136
|
+
def _create_custom_model(
|
137
|
+
raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
|
138
|
+
model_meta: model_meta_api.ModelMetadata,
|
139
|
+
) -> Type[custom_model.CustomModel]:
|
140
|
+
def fn_factory(
|
141
|
+
raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
|
142
|
+
signature: model_signature.ModelSignature,
|
143
|
+
target_method: str,
|
144
|
+
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
145
|
+
@custom_model.inference_api
|
146
|
+
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
147
|
+
if X.isnull().any(axis=None):
|
148
|
+
raise ValueError("Tensor cannot handle null values.")
|
149
|
+
|
150
|
+
import torch
|
151
|
+
|
152
|
+
raw_model.eval()
|
153
|
+
|
154
|
+
t = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
155
|
+
|
156
|
+
with torch.no_grad():
|
157
|
+
res = getattr(raw_model, target_method)(t)
|
158
|
+
return model_signature._rename_pandas_df(
|
159
|
+
data=model_signature._SeqOfPyTorchTensorHandler.convert_to_df(res), features=signature.outputs
|
160
|
+
)
|
161
|
+
|
162
|
+
return fn
|
163
|
+
|
164
|
+
type_method_dict = {}
|
165
|
+
for target_method_name, sig in model_meta.signatures.items():
|
166
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
167
|
+
|
168
|
+
_TorchScriptModel = type(
|
169
|
+
"_TorchScriptModel",
|
170
|
+
(custom_model.CustomModel,),
|
171
|
+
type_method_dict,
|
172
|
+
)
|
173
|
+
|
174
|
+
return _TorchScriptModel
|
175
|
+
|
176
|
+
raw_model = _TorchScriptHandler._load_model(name, model_meta, model_blobs_dir_path)
|
177
|
+
_TorchScriptModel = _create_custom_model(raw_model, model_meta)
|
178
|
+
torchscript_model = _TorchScriptModel(custom_model.ModelContext())
|
179
|
+
|
180
|
+
return torchscript_model
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
2
|
import os
|
3
|
-
from typing import TYPE_CHECKING, Callable, Optional,
|
3
|
+
from typing import TYPE_CHECKING, Callable, Optional, Type, Union
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
import pandas as pd
|
@@ -10,6 +10,7 @@ from snowflake.ml._internal import type_utils
|
|
10
10
|
from snowflake.ml.model import (
|
11
11
|
_model_meta as model_meta_api,
|
12
12
|
custom_model,
|
13
|
+
model_signature,
|
13
14
|
type_hints as model_types,
|
14
15
|
)
|
15
16
|
from snowflake.ml.model._handlers import _base
|
@@ -72,6 +73,9 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
|
|
72
73
|
def get_prediction(
|
73
74
|
target_method_name: str, sample_input: model_types.SupportedLocalDataType
|
74
75
|
) -> model_types.SupportedLocalDataType:
|
76
|
+
if not isinstance(sample_input, (pd.DataFrame, np.ndarray)):
|
77
|
+
sample_input = model_signature._convert_local_data_to_df(sample_input)
|
78
|
+
|
75
79
|
target_method = getattr(model, target_method_name, None)
|
76
80
|
assert callable(target_method)
|
77
81
|
predictions_df = target_method(sample_input)
|
@@ -95,7 +99,12 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
|
|
95
99
|
options={"xgb_estimator_type": model.__class__.__name__},
|
96
100
|
)
|
97
101
|
model_meta.models[name] = base_meta
|
98
|
-
model_meta._include_if_absent(
|
102
|
+
model_meta._include_if_absent(
|
103
|
+
[
|
104
|
+
model_meta_api.Dependency(conda_name="scikit-learn", pip_name="scikit-learn"),
|
105
|
+
model_meta_api.Dependency(conda_name="xgboost", pip_name="xgboost"),
|
106
|
+
]
|
107
|
+
)
|
99
108
|
|
100
109
|
@staticmethod
|
101
110
|
def _load_model(
|
@@ -143,7 +152,7 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
|
|
143
152
|
) -> Type[custom_model.CustomModel]:
|
144
153
|
def fn_factory(
|
145
154
|
raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
|
146
|
-
|
155
|
+
signature: model_signature.ModelSignature,
|
147
156
|
target_method: str,
|
148
157
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
149
158
|
@custom_model.inference_api
|
@@ -152,17 +161,18 @@ class _XGBModelHandler(_base._ModelHandler[Union["xgboost.Booster", "xgboost.XGB
|
|
152
161
|
|
153
162
|
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
|
154
163
|
# In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
|
155
|
-
# return a list of ndarrays. We need to
|
156
|
-
|
157
|
-
|
164
|
+
# return a list of ndarrays. We need to deal them seperately
|
165
|
+
df = model_signature._SeqOfNumpyArrayHandler.convert_to_df(res)
|
166
|
+
else:
|
167
|
+
df = pd.DataFrame(res)
|
168
|
+
|
169
|
+
return model_signature._rename_pandas_df(df, signature.outputs)
|
158
170
|
|
159
171
|
return fn
|
160
172
|
|
161
173
|
type_method_dict = {}
|
162
174
|
for target_method_name, sig in model_meta.signatures.items():
|
163
|
-
type_method_dict[target_method_name] = fn_factory(
|
164
|
-
raw_model, [spec.name for spec in sig.outputs], target_method_name
|
165
|
-
)
|
175
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
166
176
|
|
167
177
|
_XGBModel = type(
|
168
178
|
"_XGBModel",
|
snowflake/ml/model/_model.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import os
|
2
|
+
import posixpath
|
2
3
|
import tempfile
|
3
4
|
import warnings
|
4
5
|
from types import ModuleType
|
@@ -364,7 +365,7 @@ def save_model(
|
|
364
365
|
)
|
365
366
|
|
366
367
|
assert session and model_stage_file_path
|
367
|
-
if
|
368
|
+
if posixpath.splitext(model_stage_file_path)[1] != ".zip":
|
368
369
|
raise ValueError(f"Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
|
369
370
|
|
370
371
|
with tempfile.TemporaryDirectory() as temp_local_model_dir_path:
|
@@ -543,7 +544,7 @@ def load_model(
|
|
543
544
|
return _load(local_dir_path=model_dir_path, meta_only=meta_only)
|
544
545
|
|
545
546
|
assert session and model_stage_file_path
|
546
|
-
if
|
547
|
+
if posixpath.splitext(model_stage_file_path)[1] != ".zip":
|
547
548
|
raise ValueError(f"Provided model path in the stage {model_stage_file_path} must be a path to a zip file.")
|
548
549
|
|
549
550
|
fo = FileOperation(session=session)
|
@@ -3,10 +3,11 @@ import importlib
|
|
3
3
|
import os
|
4
4
|
import sys
|
5
5
|
import warnings
|
6
|
+
from collections import namedtuple
|
6
7
|
from contextlib import contextmanager
|
7
8
|
from datetime import datetime
|
8
9
|
from types import ModuleType
|
9
|
-
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence,
|
10
|
+
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, cast
|
10
11
|
|
11
12
|
import cloudpickle
|
12
13
|
import yaml
|
@@ -24,6 +25,8 @@ from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
|
24
25
|
MODEL_METADATA_VERSION = 1
|
25
26
|
_BASIC_DEPENDENCIES = _core_requirements.REQUIREMENTS
|
26
27
|
|
28
|
+
Dependency = namedtuple("Dependency", ["conda_name", "pip_name"])
|
29
|
+
|
27
30
|
|
28
31
|
@dataclasses.dataclass
|
29
32
|
class _ModelBlobMetadata:
|
@@ -214,9 +217,11 @@ class ModelMetadata:
|
|
214
217
|
pip_requirements if pip_requirements else []
|
215
218
|
)
|
216
219
|
if "local_ml_library_version" in kwargs:
|
217
|
-
self._include_if_absent([(dep, dep) for dep in _BASIC_DEPENDENCIES])
|
220
|
+
self._include_if_absent([Dependency(conda_name=dep, pip_name=dep) for dep in _BASIC_DEPENDENCIES])
|
218
221
|
else:
|
219
|
-
self._include_if_absent(
|
222
|
+
self._include_if_absent(
|
223
|
+
[Dependency(conda_name=dep, pip_name=dep) for dep in _BASIC_DEPENDENCIES + [env_utils._SNOWML_PKG_NAME]]
|
224
|
+
)
|
220
225
|
|
221
226
|
self.__dict__.update(kwargs)
|
222
227
|
|
@@ -234,7 +239,7 @@ class ModelMetadata:
|
|
234
239
|
for req in reqs
|
235
240
|
)
|
236
241
|
|
237
|
-
def _include_if_absent(self, pkgs: List[
|
242
|
+
def _include_if_absent(self, pkgs: List[Dependency]) -> None:
|
238
243
|
conda_reqs_str, pip_reqs_str = tuple(zip(*pkgs))
|
239
244
|
pip_reqs = env_utils.validate_pip_requirement_string_list(list(pip_reqs_str))
|
240
245
|
conda_reqs = env_utils.validate_conda_dependency_string_list(list(conda_reqs_str))
|
@@ -327,7 +332,7 @@ class ModelMetadata:
|
|
327
332
|
path: The path of the directory to write a yaml file in it.
|
328
333
|
"""
|
329
334
|
model_yaml_path = os.path.join(path, ModelMetadata.MODEL_METADATA_FILE)
|
330
|
-
with open(model_yaml_path, "w") as out:
|
335
|
+
with open(model_yaml_path, "w", encoding="utf-8") as out:
|
331
336
|
yaml.safe_dump({**self.to_dict(), "version": MODEL_METADATA_VERSION}, stream=out, default_flow_style=False)
|
332
337
|
|
333
338
|
env_dir_path = os.path.join(path, ModelMetadata.ENV_DIR)
|
@@ -350,7 +355,7 @@ class ModelMetadata:
|
|
350
355
|
Loaded model metadata object.
|
351
356
|
"""
|
352
357
|
model_yaml_path = os.path.join(path, ModelMetadata.MODEL_METADATA_FILE)
|
353
|
-
with open(model_yaml_path) as f:
|
358
|
+
with open(model_yaml_path, encoding="utf-8") as f:
|
354
359
|
loaded_mata = yaml.safe_load(f.read())
|
355
360
|
|
356
361
|
loaded_mata_version = loaded_mata.pop("version", None)
|
@@ -392,7 +397,7 @@ def _validate_signature(
|
|
392
397
|
if isinstance(sample_input, SnowparkDataFrame):
|
393
398
|
# Added because of Any from missing stubs.
|
394
399
|
trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
|
395
|
-
local_sample_input =
|
400
|
+
local_sample_input = model_signature._SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
|
396
401
|
else:
|
397
402
|
local_sample_input = trunc_sample_input
|
398
403
|
for target_method in target_methods:
|