snowflake-ml-python 1.0.2__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +29 -7
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +24 -6
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +5 -2
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -9
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +3 -2
- snowflake/ml/model/_model_meta.py +12 -7
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +23 -4
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -26
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -26
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -26
- snowflake/ml/modeling/cluster/birch.py +51 -26
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -26
- snowflake/ml/modeling/cluster/dbscan.py +51 -26
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -26
- snowflake/ml/modeling/cluster/k_means.py +51 -26
- snowflake/ml/modeling/cluster/mean_shift.py +51 -26
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -26
- snowflake/ml/modeling/cluster/optics.py +51 -26
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -26
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -26
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -26
- snowflake/ml/modeling/compose/column_transformer.py +51 -26
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -26
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -26
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -26
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -26
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -26
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -26
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -26
- snowflake/ml/modeling/covariance/oas.py +51 -26
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -26
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -26
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -26
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -26
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -26
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -26
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -26
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -26
- snowflake/ml/modeling/decomposition/pca.py +51 -26
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -26
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -26
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -26
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -26
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -26
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -26
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -26
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -26
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -26
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -26
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -26
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -26
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -26
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -26
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -26
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -26
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -26
- snowflake/ml/modeling/impute/knn_imputer.py +51 -26
- snowflake/ml/modeling/impute/missing_indicator.py +51 -26
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -26
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -26
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -26
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -26
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -26
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -26
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -26
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -26
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -26
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -26
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -26
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/lars.py +51 -26
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -26
- snowflake/ml/modeling/linear_model/lasso.py +51 -26
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -26
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -26
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -26
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -26
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -26
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -26
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -26
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -26
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -26
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -26
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -26
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -26
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -26
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/perceptron.py +51 -26
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/ridge.py +51 -26
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -26
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -26
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -26
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -26
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -26
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -26
- snowflake/ml/modeling/manifold/isomap.py +51 -26
- snowflake/ml/modeling/manifold/mds.py +51 -26
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -26
- snowflake/ml/modeling/manifold/tsne.py +51 -26
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -26
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -26
- snowflake/ml/modeling/model_selection/grid_search_cv.py +51 -26
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +51 -26
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -26
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -26
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -26
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -26
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -26
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -26
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -26
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -26
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -26
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -26
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -26
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -26
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -26
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -26
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -26
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -26
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -26
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -26
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -26
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -26
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -26
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -26
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -26
- snowflake/ml/modeling/svm/linear_svc.py +51 -26
- snowflake/ml/modeling/svm/linear_svr.py +51 -26
- snowflake/ml/modeling/svm/nu_svc.py +51 -26
- snowflake/ml/modeling/svm/nu_svr.py +51 -26
- snowflake/ml/modeling/svm/svc.py +51 -26
- snowflake/ml/modeling/svm/svr.py +51 -26
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -26
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -26
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -26
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -26
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -26
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -26
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -26
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -26
- snowflake/ml/registry/model_registry.py +74 -56
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.2.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +27 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.2.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.2.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -54,12 +54,15 @@ sys.path.insert(0, os.path.join(extracted_model_dir_path, "{code_dir_name}"))
|
|
54
54
|
from snowflake.ml.model import _model
|
55
55
|
model, meta = _model._load_model_for_deploy(extracted_model_dir_path)
|
56
56
|
|
57
|
+
features = meta.signatures["{target_method}"].inputs
|
58
|
+
input_cols = [feature.name for feature in features]
|
59
|
+
dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
60
|
+
|
57
61
|
# TODO(halu): Wire `max_batch_size`.
|
58
62
|
# TODO(halu): Avoid per batch async detection branching.
|
59
63
|
@vectorized(input=pd.DataFrame, max_batch_size=10)
|
60
64
|
def infer(df):
|
61
|
-
|
62
|
-
input_df = pd.io.json.json_normalize(df[0])
|
65
|
+
input_df = pd.io.json.json_normalize(df[0]).astype(dtype=dtype_map)
|
63
66
|
if inspect.iscoroutinefunction(model.{target_method}):
|
64
67
|
predictions_df = anyio.run(model.{target_method}, input_df[input_cols])
|
65
68
|
else:
|
snowflake/ml/model/_deployer.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1
|
-
import json
|
2
1
|
import traceback
|
3
2
|
from enum import Enum
|
4
3
|
from typing import Optional, TypedDict, Union, overload
|
5
4
|
|
6
|
-
import numpy as np
|
7
5
|
import pandas as pd
|
8
6
|
from typing_extensions import Required
|
9
7
|
|
@@ -184,7 +182,6 @@ def predict(
|
|
184
182
|
|
185
183
|
Raises:
|
186
184
|
ValueError: Raised when the input is too large to use keep_order option.
|
187
|
-
NotImplementedError: FeatureGroupSpec is not supported.
|
188
185
|
|
189
186
|
Returns:
|
190
187
|
The output dataframe.
|
@@ -199,19 +196,19 @@ def predict(
|
|
199
196
|
# Validate and prepare input
|
200
197
|
if not isinstance(X, SnowparkDataFrame):
|
201
198
|
df = model_signature._convert_and_validate_local_data(X, sig.inputs)
|
202
|
-
s_df =
|
199
|
+
s_df = model_signature._SnowparkDataFrameHandler.convert_from_df(session, df, keep_order=keep_order)
|
203
200
|
else:
|
204
201
|
model_signature._validate_snowpark_data(X, sig.inputs)
|
205
202
|
s_df = X
|
206
203
|
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
204
|
+
if keep_order:
|
205
|
+
# ID is UINT64 type, this we should limit.
|
206
|
+
if s_df.count() > 2**64:
|
207
|
+
raise ValueError("Unable to keep order of a DataFrame with more than 2 ** 64 rows.")
|
208
|
+
s_df = s_df.with_column(
|
209
|
+
infer_template._KEEP_ORDER_COL_NAME,
|
210
|
+
F.monotonically_increasing_id(),
|
211
|
+
)
|
215
212
|
|
216
213
|
# Infer and get intermediate result
|
217
214
|
input_cols = []
|
@@ -223,7 +220,9 @@ def predict(
|
|
223
220
|
F.col(col_name),
|
224
221
|
]
|
225
222
|
)
|
226
|
-
output_obj = F.call_udf(
|
223
|
+
output_obj = F.call_udf(
|
224
|
+
identifier.get_inferred_name(deployment["name"]), F.object_construct(*input_cols) # type:ignore[arg-type]
|
225
|
+
)
|
227
226
|
if output_with_input_features:
|
228
227
|
df_res = s_df.with_column(INTERMEDIATE_OBJ_NAME, output_obj)
|
229
228
|
else:
|
@@ -243,24 +242,12 @@ def predict(
|
|
243
242
|
output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_feature.name].astype(output_feature.as_snowpark_type()))
|
244
243
|
|
245
244
|
df_res = df_res.with_columns(
|
246
|
-
[identifier.
|
245
|
+
[identifier.get_inferred_name(output_feature.name) for output_feature in sig.outputs],
|
247
246
|
output_cols,
|
248
247
|
).drop(INTERMEDIATE_OBJ_NAME)
|
249
248
|
|
250
249
|
# Get final result
|
251
250
|
if not isinstance(X, SnowparkDataFrame):
|
252
|
-
|
253
|
-
for feature in sig.outputs:
|
254
|
-
if isinstance(feature, model_signature.FeatureGroupSpec):
|
255
|
-
raise NotImplementedError("FeatureGroupSpec is not supported.")
|
256
|
-
assert isinstance(feature, model_signature.FeatureSpec), "Invalid feature kind."
|
257
|
-
dtype_map[feature.name] = feature.as_dtype()
|
258
|
-
df_local = df_res.to_pandas()
|
259
|
-
# This is because Array and object will generate variant type and requires an additional loads to
|
260
|
-
# get correct data otherwise it would be string.
|
261
|
-
for col_name in [col_name for col_name, col_dtype in dtype_map.items() if col_dtype == np.object0]:
|
262
|
-
df_local[col_name] = df_local[col_name].map(json.loads)
|
263
|
-
df_local = df_local.astype(dtype=dtype_map)
|
264
|
-
return pd.DataFrame(df_local)
|
251
|
+
return model_signature._SnowparkDataFrameHandler.convert_to_df(df_res, features=sig.outputs)
|
265
252
|
else:
|
266
253
|
return df_res
|
snowflake/ml/model/_env.py
CHANGED
@@ -36,7 +36,7 @@ def save_conda_env_file(
|
|
36
36
|
for chan, reqs in deps.items():
|
37
37
|
env["dependencies"].extend([f"{chan}::{str(req)}" if chan else str(req) for req in reqs])
|
38
38
|
|
39
|
-
with open(path, "w") as f:
|
39
|
+
with open(path, "w", encoding="utf-8") as f:
|
40
40
|
yaml.safe_dump(env, stream=f, default_flow_style=False)
|
41
41
|
|
42
42
|
return path
|
@@ -54,7 +54,7 @@ def save_requirements_file(dir_path: str, pip_deps: List[requirements.Requiremen
|
|
54
54
|
"""
|
55
55
|
requirements = "\n".join(map(str, pip_deps))
|
56
56
|
path = os.path.join(dir_path, _REQUIREMENTS_FILE_NAME)
|
57
|
-
with open(path, "w") as out:
|
57
|
+
with open(path, "w", encoding="utf-8") as out:
|
58
58
|
out.write(requirements)
|
59
59
|
|
60
60
|
return path
|
@@ -69,7 +69,7 @@ def load_conda_env_file(path: str) -> Tuple[DefaultDict[str, List[requirements.R
|
|
69
69
|
Returns:
|
70
70
|
A tuple of Dict of conda dependencies after validated and a string 'major.minor.patchlevel' of python version.
|
71
71
|
"""
|
72
|
-
with open(path) as f:
|
72
|
+
with open(path, encoding="utf-8") as f:
|
73
73
|
env = yaml.safe_load(stream=f)
|
74
74
|
|
75
75
|
assert isinstance(env, dict)
|
@@ -99,7 +99,7 @@ def load_requirements_file(path: str) -> List[requirements.Requirement]:
|
|
99
99
|
Returns:
|
100
100
|
List of dependencies string after validated.
|
101
101
|
"""
|
102
|
-
with open(path) as f:
|
102
|
+
with open(path, encoding="utf-8") as f:
|
103
103
|
reqs = f.readlines()
|
104
104
|
|
105
105
|
return env_utils.validate_pip_requirement_string_list(reqs)
|
@@ -1,16 +1,19 @@
|
|
1
1
|
import inspect
|
2
2
|
import os
|
3
|
+
import pathlib
|
3
4
|
import sys
|
4
5
|
from typing import TYPE_CHECKING, Dict, Optional
|
5
6
|
|
6
7
|
import anyio
|
7
8
|
import cloudpickle
|
9
|
+
import pandas as pd
|
8
10
|
from typing_extensions import TypeGuard, Unpack
|
9
11
|
|
10
12
|
from snowflake.ml._internal import file_utils, type_utils
|
11
13
|
from snowflake.ml.model import (
|
12
14
|
_model_handler,
|
13
15
|
_model_meta as model_meta_api,
|
16
|
+
model_signature,
|
14
17
|
type_hints as model_types,
|
15
18
|
)
|
16
19
|
from snowflake.ml.model._handlers import _base
|
@@ -55,6 +58,10 @@ class _CustomModelHandler(_base._ModelHandler["custom_model.CustomModel"]):
|
|
55
58
|
target_method = getattr(model, target_method_name, None)
|
56
59
|
assert callable(target_method) and inspect.ismethod(target_method)
|
57
60
|
target_method = target_method.__func__
|
61
|
+
|
62
|
+
if not isinstance(sample_input, pd.DataFrame):
|
63
|
+
sample_input = model_signature._convert_local_data_to_df(sample_input)
|
64
|
+
|
58
65
|
if inspect.iscoroutinefunction(target_method):
|
59
66
|
with anyio.start_blocking_portal() as portal:
|
60
67
|
predictions_df = portal.call(target_method, model, sample_input)
|
@@ -102,7 +109,9 @@ class _CustomModelHandler(_base._ModelHandler["custom_model.CustomModel"]):
|
|
102
109
|
model_type=_CustomModelHandler.handler_type,
|
103
110
|
path=_CustomModelHandler.MODEL_BLOB_FILE,
|
104
111
|
artifacts={
|
105
|
-
name:
|
112
|
+
name: pathlib.Path(
|
113
|
+
os.path.join(_CustomModelHandler.MODEL_ARTIFACTS_DIR, os.path.basename(os.path.normpath(path=uri)))
|
114
|
+
).as_posix()
|
106
115
|
for name, uri in model.context.artifacts.items()
|
107
116
|
},
|
108
117
|
)
|
@@ -129,7 +138,10 @@ class _CustomModelHandler(_base._ModelHandler["custom_model.CustomModel"]):
|
|
129
138
|
assert issubclass(ModelClass, custom_model.CustomModel)
|
130
139
|
|
131
140
|
artifacts_meta = model_blob_metadata.artifacts
|
132
|
-
artifacts = {
|
141
|
+
artifacts = {
|
142
|
+
name: str(pathlib.PurePath(model_blob_path) / pathlib.PurePosixPath(rel_path))
|
143
|
+
for name, rel_path in artifacts_meta.items()
|
144
|
+
}
|
133
145
|
models: Dict[str, model_types.SupportedModelType] = dict()
|
134
146
|
for sub_model_name, _ref in m.context.model_refs.items():
|
135
147
|
model_type = model_meta.models[sub_model_name].model_type
|
@@ -0,0 +1,186 @@
|
|
1
|
+
import os
|
2
|
+
import sys
|
3
|
+
from typing import TYPE_CHECKING, Callable, Optional, Type, cast
|
4
|
+
|
5
|
+
import cloudpickle
|
6
|
+
import pandas as pd
|
7
|
+
from typing_extensions import TypeGuard, Unpack
|
8
|
+
|
9
|
+
from snowflake.ml._internal import type_utils
|
10
|
+
from snowflake.ml.model import (
|
11
|
+
_model_meta as model_meta_api,
|
12
|
+
custom_model,
|
13
|
+
model_signature,
|
14
|
+
type_hints as model_types,
|
15
|
+
)
|
16
|
+
from snowflake.ml.model._handlers import _base
|
17
|
+
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
import torch
|
20
|
+
|
21
|
+
|
22
|
+
class _PyTorchHandler(_base._ModelHandler["torch.nn.Module"]):
|
23
|
+
"""Handler for PyTorch based model.
|
24
|
+
|
25
|
+
Currently torch.nn.Module based classes are supported.
|
26
|
+
"""
|
27
|
+
|
28
|
+
handler_type = "pytorch"
|
29
|
+
MODEL_BLOB_FILE = "model.pt"
|
30
|
+
DEFAULT_TARGET_METHODS = ["forward"]
|
31
|
+
|
32
|
+
@staticmethod
|
33
|
+
def can_handle(
|
34
|
+
model: model_types.SupportedModelType,
|
35
|
+
) -> TypeGuard["torch.nn.Module"]:
|
36
|
+
return type_utils.LazyType("torch.nn.Module").isinstance(model) and not type_utils.LazyType(
|
37
|
+
"torch.jit.ScriptModule"
|
38
|
+
).isinstance(model)
|
39
|
+
|
40
|
+
@staticmethod
|
41
|
+
def cast_model(
|
42
|
+
model: model_types.SupportedModelType,
|
43
|
+
) -> "torch.nn.Module":
|
44
|
+
import torch
|
45
|
+
|
46
|
+
assert isinstance(model, torch.nn.Module)
|
47
|
+
|
48
|
+
return cast(torch.nn.Module, model)
|
49
|
+
|
50
|
+
@staticmethod
|
51
|
+
def _save_model(
|
52
|
+
name: str,
|
53
|
+
model: "torch.nn.Module",
|
54
|
+
model_meta: model_meta_api.ModelMetadata,
|
55
|
+
model_blobs_dir_path: str,
|
56
|
+
sample_input: Optional[model_types.SupportedDataType] = None,
|
57
|
+
is_sub_model: Optional[bool] = False,
|
58
|
+
**kwargs: Unpack[model_types.PyTorchSaveOptions],
|
59
|
+
) -> None:
|
60
|
+
import torch
|
61
|
+
|
62
|
+
assert isinstance(model, torch.nn.Module)
|
63
|
+
|
64
|
+
if not is_sub_model:
|
65
|
+
target_methods = model_meta_api._get_target_methods(
|
66
|
+
model=model,
|
67
|
+
target_methods=kwargs.pop("target_methods", None),
|
68
|
+
default_target_methods=_PyTorchHandler.DEFAULT_TARGET_METHODS,
|
69
|
+
)
|
70
|
+
|
71
|
+
def get_prediction(
|
72
|
+
target_method_name: str, sample_input: "model_types.SupportedLocalDataType"
|
73
|
+
) -> model_types.SupportedLocalDataType:
|
74
|
+
if not model_signature._SeqOfPyTorchTensorHandler.can_handle(sample_input):
|
75
|
+
sample_input = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(
|
76
|
+
model_signature._convert_local_data_to_df(sample_input)
|
77
|
+
)
|
78
|
+
|
79
|
+
model.eval()
|
80
|
+
target_method = getattr(model, target_method_name, None)
|
81
|
+
assert callable(target_method)
|
82
|
+
with torch.no_grad():
|
83
|
+
predictions_df = target_method(sample_input)
|
84
|
+
return predictions_df
|
85
|
+
|
86
|
+
model_meta = model_meta_api._validate_signature(
|
87
|
+
model=model,
|
88
|
+
model_meta=model_meta,
|
89
|
+
target_methods=target_methods,
|
90
|
+
sample_input=sample_input,
|
91
|
+
get_prediction_fn=get_prediction,
|
92
|
+
)
|
93
|
+
|
94
|
+
# Torch.save using pickle will not pickle the model definition if defined in the top level of a module.
|
95
|
+
# Make sure that the module where the model is defined get pickled by value as well.
|
96
|
+
cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
|
97
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
98
|
+
os.makedirs(model_blob_path, exist_ok=True)
|
99
|
+
with open(os.path.join(model_blob_path, _PyTorchHandler.MODEL_BLOB_FILE), "wb") as f:
|
100
|
+
torch.save(model, f, pickle_module=cloudpickle)
|
101
|
+
base_meta = model_meta_api._ModelBlobMetadata(
|
102
|
+
name=name, model_type=_PyTorchHandler.handler_type, path=_PyTorchHandler.MODEL_BLOB_FILE
|
103
|
+
)
|
104
|
+
model_meta.models[name] = base_meta
|
105
|
+
model_meta._include_if_absent([model_meta_api.Dependency(conda_name="pytorch", pip_name="torch")])
|
106
|
+
|
107
|
+
@staticmethod
|
108
|
+
def _load_model(
|
109
|
+
name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
|
110
|
+
) -> "torch.nn.Module":
|
111
|
+
import torch
|
112
|
+
|
113
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
114
|
+
if not hasattr(model_meta, "models"):
|
115
|
+
raise ValueError("Ill model metadata found.")
|
116
|
+
model_blobs_metadata = model_meta.models
|
117
|
+
if name not in model_blobs_metadata:
|
118
|
+
raise ValueError(f"Blob of model {name} does not exist.")
|
119
|
+
model_blob_metadata = model_blobs_metadata[name]
|
120
|
+
model_blob_filename = model_blob_metadata.path
|
121
|
+
with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
|
122
|
+
m = torch.load(f)
|
123
|
+
assert isinstance(m, torch.nn.Module)
|
124
|
+
return m
|
125
|
+
|
126
|
+
@staticmethod
|
127
|
+
def _load_as_custom_model(
|
128
|
+
name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str
|
129
|
+
) -> custom_model.CustomModel:
|
130
|
+
"""Create a custom model class wrap for unified interface when being deployed. The predict method will be
|
131
|
+
re-targeted based on target_method metadata.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
name: Name of the model.
|
135
|
+
model_meta: The model metadata.
|
136
|
+
model_blobs_dir_path: Directory path to the whole model.
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
The model object as a custom model.
|
140
|
+
"""
|
141
|
+
import torch
|
142
|
+
|
143
|
+
from snowflake.ml.model import custom_model
|
144
|
+
|
145
|
+
def _create_custom_model(
|
146
|
+
raw_model: "torch.nn.Module",
|
147
|
+
model_meta: model_meta_api.ModelMetadata,
|
148
|
+
) -> Type[custom_model.CustomModel]:
|
149
|
+
def fn_factory(
|
150
|
+
raw_model: "torch.nn.Module",
|
151
|
+
signature: model_signature.ModelSignature,
|
152
|
+
target_method: str,
|
153
|
+
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
154
|
+
@custom_model.inference_api
|
155
|
+
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
156
|
+
if X.isnull().any(axis=None):
|
157
|
+
raise ValueError("Tensor cannot handle null values.")
|
158
|
+
|
159
|
+
raw_model.eval()
|
160
|
+
t = model_signature._SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
161
|
+
|
162
|
+
with torch.no_grad():
|
163
|
+
res = getattr(raw_model, target_method)(t)
|
164
|
+
return model_signature._rename_pandas_df(
|
165
|
+
data=model_signature._SeqOfPyTorchTensorHandler.convert_to_df(res), features=signature.outputs
|
166
|
+
)
|
167
|
+
|
168
|
+
return fn
|
169
|
+
|
170
|
+
type_method_dict = {}
|
171
|
+
for target_method_name, sig in model_meta.signatures.items():
|
172
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
173
|
+
|
174
|
+
_PyTorchModel = type(
|
175
|
+
"_PyTorchModel",
|
176
|
+
(custom_model.CustomModel,),
|
177
|
+
type_method_dict,
|
178
|
+
)
|
179
|
+
|
180
|
+
return _PyTorchModel
|
181
|
+
|
182
|
+
raw_model = _PyTorchHandler._load_model(name, model_meta, model_blobs_dir_path)
|
183
|
+
_PyTorchModel = _create_custom_model(raw_model, model_meta)
|
184
|
+
pytorch_model = _PyTorchModel(custom_model.ModelContext())
|
185
|
+
|
186
|
+
return pytorch_model
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import TYPE_CHECKING, Callable, Optional,
|
2
|
+
from typing import TYPE_CHECKING, Callable, Optional, Type, Union, cast
|
3
3
|
|
4
4
|
import cloudpickle
|
5
5
|
import numpy as np
|
@@ -10,6 +10,7 @@ from snowflake.ml._internal import type_utils
|
|
10
10
|
from snowflake.ml.model import (
|
11
11
|
_model_meta as model_meta_api,
|
12
12
|
custom_model,
|
13
|
+
model_signature,
|
13
14
|
type_hints as model_types,
|
14
15
|
)
|
15
16
|
from snowflake.ml.model._handlers import _base
|
@@ -80,6 +81,9 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
|
|
80
81
|
def get_prediction(
|
81
82
|
target_method_name: str, sample_input: model_types.SupportedLocalDataType
|
82
83
|
) -> model_types.SupportedLocalDataType:
|
84
|
+
if not isinstance(sample_input, (pd.DataFrame, np.ndarray)):
|
85
|
+
sample_input = model_signature._convert_local_data_to_df(sample_input)
|
86
|
+
|
83
87
|
target_method = getattr(model, target_method_name, None)
|
84
88
|
assert callable(target_method)
|
85
89
|
predictions_df = target_method(sample_input)
|
@@ -101,7 +105,7 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
|
|
101
105
|
name=name, model_type=_SKLModelHandler.handler_type, path=_SKLModelHandler.MODEL_BLOB_FILE
|
102
106
|
)
|
103
107
|
model_meta.models[name] = base_meta
|
104
|
-
model_meta._include_if_absent([("scikit-learn", "scikit-learn")])
|
108
|
+
model_meta._include_if_absent([model_meta_api.Dependency(conda_name="scikit-learn", pip_name="scikit-learn")])
|
105
109
|
|
106
110
|
@staticmethod
|
107
111
|
def _load_model(
|
@@ -147,7 +151,7 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
|
|
147
151
|
) -> Type[custom_model.CustomModel]:
|
148
152
|
def fn_factory(
|
149
153
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
150
|
-
|
154
|
+
signature: model_signature.ModelSignature,
|
151
155
|
target_method: str,
|
152
156
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
153
157
|
@custom_model.inference_api
|
@@ -156,17 +160,18 @@ class _SKLModelHandler(_base._ModelHandler[Union["sklearn.base.BaseEstimator", "
|
|
156
160
|
|
157
161
|
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
|
158
162
|
# In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
|
159
|
-
# return a list of ndarrays. We need to
|
160
|
-
|
161
|
-
|
163
|
+
# return a list of ndarrays. We need to deal them seperately
|
164
|
+
df = model_signature._SeqOfNumpyArrayHandler.convert_to_df(res)
|
165
|
+
else:
|
166
|
+
df = pd.DataFrame(res)
|
167
|
+
|
168
|
+
return model_signature._rename_pandas_df(df, signature.outputs)
|
162
169
|
|
163
170
|
return fn
|
164
171
|
|
165
172
|
type_method_dict = {}
|
166
173
|
for target_method_name, sig in model_meta.signatures.items():
|
167
|
-
type_method_dict[target_method_name] = fn_factory(
|
168
|
-
raw_model, [spec.name for spec in sig.outputs], target_method_name
|
169
|
-
)
|
174
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
170
175
|
|
171
176
|
_SKLModel = type(
|
172
177
|
"_SKLModel",
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import TYPE_CHECKING, Callable, Optional,
|
2
|
+
from typing import TYPE_CHECKING, Callable, Optional, Type, cast
|
3
3
|
|
4
4
|
import cloudpickle
|
5
5
|
import numpy as np
|
@@ -10,6 +10,7 @@ from snowflake.ml._internal import type_utils
|
|
10
10
|
from snowflake.ml.model import (
|
11
11
|
_model_meta as model_meta_api,
|
12
12
|
custom_model,
|
13
|
+
model_signature,
|
13
14
|
type_hints as model_types,
|
14
15
|
)
|
15
16
|
from snowflake.ml.model._handlers import _base
|
@@ -81,6 +82,9 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
|
|
81
82
|
def get_prediction(
|
82
83
|
target_method_name: str, sample_input: model_types.SupportedLocalDataType
|
83
84
|
) -> model_types.SupportedLocalDataType:
|
85
|
+
if not isinstance(sample_input, (pd.DataFrame,)):
|
86
|
+
sample_input = model_signature._convert_local_data_to_df(sample_input)
|
87
|
+
|
84
88
|
target_method = getattr(model, target_method_name, None)
|
85
89
|
assert callable(target_method)
|
86
90
|
predictions_df = target_method(sample_input)
|
@@ -106,7 +110,7 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
|
|
106
110
|
model_dependencies = model._get_dependencies()
|
107
111
|
for dep in model_dependencies:
|
108
112
|
pkg_name = dep.split("==")[0]
|
109
|
-
_include_if_absent_pkgs.append((pkg_name, pkg_name))
|
113
|
+
_include_if_absent_pkgs.append(model_meta_api.Dependency(conda_name=pkg_name, pip_name=pkg_name))
|
110
114
|
model_meta._include_if_absent(_include_if_absent_pkgs)
|
111
115
|
|
112
116
|
@staticmethod
|
@@ -150,7 +154,7 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
|
|
150
154
|
) -> Type[custom_model.CustomModel]:
|
151
155
|
def fn_factory(
|
152
156
|
raw_model: "BaseEstimator",
|
153
|
-
|
157
|
+
signature: model_signature.ModelSignature,
|
154
158
|
target_method: str,
|
155
159
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
156
160
|
@custom_model.inference_api
|
@@ -159,17 +163,18 @@ class _SnowMLModelHandler(_base._ModelHandler["BaseEstimator"]):
|
|
159
163
|
|
160
164
|
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
|
161
165
|
# In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
|
162
|
-
# return a list of ndarrays. We need to
|
163
|
-
|
164
|
-
|
166
|
+
# return a list of ndarrays. We need to deal them seperately
|
167
|
+
df = model_signature._SeqOfNumpyArrayHandler.convert_to_df(res)
|
168
|
+
else:
|
169
|
+
df = pd.DataFrame(res)
|
170
|
+
|
171
|
+
return model_signature._rename_pandas_df(df, signature.outputs)
|
165
172
|
|
166
173
|
return fn
|
167
174
|
|
168
175
|
type_method_dict = {}
|
169
176
|
for target_method_name, sig in model_meta.signatures.items():
|
170
|
-
type_method_dict[target_method_name] = fn_factory(
|
171
|
-
raw_model, [spec.name for spec in sig.outputs], target_method_name
|
172
|
-
)
|
177
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
173
178
|
|
174
179
|
_SnowMLModel = type(
|
175
180
|
"_SnowMLModel",
|