snowflake-ml-python 1.7.3__py3-none-any.whl → 1.7.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +19 -0
- snowflake/ml/_internal/env_utils.py +64 -21
- snowflake/ml/_internal/platform_capabilities.py +87 -0
- snowflake/ml/_internal/relax_version_strategy.py +16 -0
- snowflake/ml/_internal/telemetry.py +21 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
- snowflake/ml/dataset/dataset.py +0 -1
- snowflake/ml/feature_store/feature_store.py +18 -0
- snowflake/ml/feature_store/feature_view.py +46 -1
- snowflake/ml/fileset/fileset.py +6 -0
- snowflake/ml/jobs/__init__.py +21 -0
- snowflake/ml/jobs/_utils/constants.py +57 -0
- snowflake/ml/jobs/_utils/payload_utils.py +438 -0
- snowflake/ml/jobs/_utils/spec_utils.py +296 -0
- snowflake/ml/jobs/_utils/types.py +39 -0
- snowflake/ml/jobs/decorators.py +71 -0
- snowflake/ml/jobs/job.py +113 -0
- snowflake/ml/jobs/manager.py +298 -0
- snowflake/ml/model/_client/ops/model_ops.py +11 -2
- snowflake/ml/model/_client/ops/service_ops.py +1 -11
- snowflake/ml/model/_client/sql/service.py +13 -6
- snowflake/ml/model/_packager/model_env/model_env.py +45 -28
- snowflake/ml/model/_packager/model_handlers/_utils.py +19 -6
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +17 -0
- snowflake/ml/model/_packager/model_handlers/keras.py +230 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +28 -3
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +74 -21
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +27 -49
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +3 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -1
- snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
- snowflake/ml/model/_signatures/base_handler.py +1 -2
- snowflake/ml/model/_signatures/builtins_handler.py +2 -2
- snowflake/ml/model/_signatures/core.py +2 -2
- snowflake/ml/model/_signatures/numpy_handler.py +11 -12
- snowflake/ml/model/_signatures/pandas_handler.py +11 -9
- snowflake/ml/model/_signatures/pytorch_handler.py +3 -6
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
- snowflake/ml/model/model_signature.py +25 -4
- snowflake/ml/model/type_hints.py +15 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
- snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
- snowflake/ml/modeling/cluster/birch.py +6 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
- snowflake/ml/modeling/cluster/dbscan.py +6 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
- snowflake/ml/modeling/cluster/k_means.py +6 -3
- snowflake/ml/modeling/cluster/mean_shift.py +6 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
- snowflake/ml/modeling/cluster/optics.py +6 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
- snowflake/ml/modeling/compose/column_transformer.py +6 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
- snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
- snowflake/ml/modeling/covariance/oas.py +6 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/pca.py +6 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
- snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
- snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
- snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
- snowflake/ml/modeling/impute/knn_imputer.py +6 -3
- snowflake/ml/modeling/impute/missing_indicator.py +6 -3
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/lars.py +6 -3
- snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
- snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/perceptron.py +6 -3
- snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ridge.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
- snowflake/ml/modeling/manifold/isomap.py +6 -3
- snowflake/ml/modeling/manifold/mds.py +6 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
- snowflake/ml/modeling/manifold/tsne.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
- snowflake/ml/modeling/pipeline/pipeline.py +28 -3
- snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -5
- snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
- snowflake/ml/modeling/svm/linear_svc.py +6 -3
- snowflake/ml/modeling/svm/linear_svr.py +6 -3
- snowflake/ml/modeling/svm/nu_svc.py +6 -3
- snowflake/ml/modeling/svm/nu_svr.py +6 -3
- snowflake/ml/modeling/svm/svc.py +6 -3
- snowflake/ml/modeling/svm/svr.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_classifier.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +6 -3
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +6 -3
- snowflake/ml/registry/registry.py +34 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/METADATA +81 -33
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/RECORD +208 -196
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,230 @@
|
|
1
|
+
import os
|
2
|
+
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
3
|
+
|
4
|
+
import cloudpickle
|
5
|
+
import numpy as np
|
6
|
+
import pandas as pd
|
7
|
+
from packaging import version
|
8
|
+
from typing_extensions import TypeGuard, Unpack
|
9
|
+
|
10
|
+
from snowflake.ml._internal import type_utils
|
11
|
+
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
12
|
+
from snowflake.ml.model._packager.model_env import model_env
|
13
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
14
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
15
|
+
from snowflake.ml.model._packager.model_meta import (
|
16
|
+
model_blob_meta,
|
17
|
+
model_meta as model_meta_api,
|
18
|
+
)
|
19
|
+
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
import keras
|
23
|
+
|
24
|
+
|
25
|
+
@final
|
26
|
+
class KerasHandler(_base.BaseModelHandler["keras.Model"]):
|
27
|
+
"""Handler for Keras v3 model.
|
28
|
+
|
29
|
+
Currently keras.Model based classes are supported.
|
30
|
+
"""
|
31
|
+
|
32
|
+
HANDLER_TYPE = "keras"
|
33
|
+
HANDLER_VERSION = "2025-01-01"
|
34
|
+
_MIN_SNOWPARK_ML_VERSION = "1.7.5"
|
35
|
+
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
36
|
+
|
37
|
+
MODEL_BLOB_FILE_OR_DIR = "model.keras"
|
38
|
+
CUSTOM_OBJECT_SAVE_PATH = "custom_objects.pkl"
|
39
|
+
DEFAULT_TARGET_METHODS = ["predict"]
|
40
|
+
|
41
|
+
@classmethod
|
42
|
+
def can_handle(
|
43
|
+
cls,
|
44
|
+
model: model_types.SupportedModelType,
|
45
|
+
) -> TypeGuard["keras.Model"]:
|
46
|
+
if not type_utils.LazyType("keras.Model").isinstance(model):
|
47
|
+
return False
|
48
|
+
import keras
|
49
|
+
|
50
|
+
return version.parse(keras.__version__) >= version.parse("3.0.0")
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def cast_model(
|
54
|
+
cls,
|
55
|
+
model: model_types.SupportedModelType,
|
56
|
+
) -> "keras.Model":
|
57
|
+
import keras
|
58
|
+
|
59
|
+
assert isinstance(model, keras.Model)
|
60
|
+
|
61
|
+
return cast(keras.Model, model)
|
62
|
+
|
63
|
+
@classmethod
|
64
|
+
def save_model(
|
65
|
+
cls,
|
66
|
+
name: str,
|
67
|
+
model: "keras.Model",
|
68
|
+
model_meta: model_meta_api.ModelMetadata,
|
69
|
+
model_blobs_dir_path: str,
|
70
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
71
|
+
is_sub_model: Optional[bool] = False,
|
72
|
+
**kwargs: Unpack[model_types.TensorflowSaveOptions],
|
73
|
+
) -> None:
|
74
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
75
|
+
if enable_explainability:
|
76
|
+
raise NotImplementedError("Explainability is not supported for Tensorflow model.")
|
77
|
+
|
78
|
+
import keras
|
79
|
+
|
80
|
+
assert isinstance(model, keras.Model)
|
81
|
+
|
82
|
+
if not is_sub_model:
|
83
|
+
target_methods = handlers_utils.get_target_methods(
|
84
|
+
model=model,
|
85
|
+
target_methods=kwargs.pop("target_methods", None),
|
86
|
+
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
87
|
+
)
|
88
|
+
|
89
|
+
def get_prediction(
|
90
|
+
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
91
|
+
) -> model_types.SupportedLocalDataType:
|
92
|
+
target_method = getattr(model, target_method_name, None)
|
93
|
+
assert callable(target_method)
|
94
|
+
predictions_df = target_method(sample_input_data)
|
95
|
+
|
96
|
+
if (
|
97
|
+
type_utils.LazyType("tensorflow.Tensor").isinstance(predictions_df)
|
98
|
+
or type_utils.LazyType("tensorflow.Variable").isinstance(predictions_df)
|
99
|
+
or type_utils.LazyType("torch.Tensor").isinstance(predictions_df)
|
100
|
+
):
|
101
|
+
predictions_df = [predictions_df]
|
102
|
+
|
103
|
+
return predictions_df
|
104
|
+
|
105
|
+
model_meta = handlers_utils.validate_signature(
|
106
|
+
model=model,
|
107
|
+
model_meta=model_meta,
|
108
|
+
target_methods=target_methods,
|
109
|
+
sample_input_data=sample_input_data,
|
110
|
+
get_prediction_fn=get_prediction,
|
111
|
+
)
|
112
|
+
|
113
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
114
|
+
os.makedirs(model_blob_path, exist_ok=True)
|
115
|
+
save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
116
|
+
model.save(save_path)
|
117
|
+
|
118
|
+
custom_object_save_path = os.path.join(model_blob_path, cls.CUSTOM_OBJECT_SAVE_PATH)
|
119
|
+
custom_objects = keras.saving.get_custom_objects()
|
120
|
+
with open(custom_object_save_path, "wb") as f:
|
121
|
+
cloudpickle.dump(custom_objects, f)
|
122
|
+
|
123
|
+
base_meta = model_blob_meta.ModelBlobMeta(
|
124
|
+
name=name,
|
125
|
+
model_type=cls.HANDLER_TYPE,
|
126
|
+
handler_version=cls.HANDLER_VERSION,
|
127
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
128
|
+
)
|
129
|
+
model_meta.models[name] = base_meta
|
130
|
+
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
131
|
+
|
132
|
+
dependencies = [
|
133
|
+
model_env.ModelDependency(requirement="keras>=3", pip_name="keras"),
|
134
|
+
]
|
135
|
+
keras_backend = keras.backend.backend()
|
136
|
+
if keras_backend == "tensorflow":
|
137
|
+
dependencies.append(model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"))
|
138
|
+
elif keras_backend == "torch":
|
139
|
+
dependencies.append(model_env.ModelDependency(requirement="pytorch", pip_name="torch"))
|
140
|
+
elif keras_backend == "jax":
|
141
|
+
dependencies.append(model_env.ModelDependency(requirement="jax", pip_name="jax"))
|
142
|
+
else:
|
143
|
+
raise ValueError(f"Unsupported backend {keras_backend}")
|
144
|
+
|
145
|
+
model_meta.env.include_if_absent(
|
146
|
+
dependencies,
|
147
|
+
check_local_version=True,
|
148
|
+
)
|
149
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
150
|
+
|
151
|
+
@classmethod
|
152
|
+
def load_model(
|
153
|
+
cls,
|
154
|
+
name: str,
|
155
|
+
model_meta: model_meta_api.ModelMetadata,
|
156
|
+
model_blobs_dir_path: str,
|
157
|
+
**kwargs: Unpack[model_types.TensorflowLoadOptions],
|
158
|
+
) -> "keras.Model":
|
159
|
+
import keras
|
160
|
+
|
161
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
162
|
+
model_blobs_metadata = model_meta.models
|
163
|
+
model_blob_metadata = model_blobs_metadata[name]
|
164
|
+
model_blob_filename = model_blob_metadata.path
|
165
|
+
|
166
|
+
custom_object_save_path = os.path.join(model_blob_path, cls.CUSTOM_OBJECT_SAVE_PATH)
|
167
|
+
with open(custom_object_save_path, "rb") as f:
|
168
|
+
custom_objects = cloudpickle.load(f)
|
169
|
+
load_path = os.path.join(model_blob_path, model_blob_filename)
|
170
|
+
m = keras.models.load_model(load_path, custom_objects=custom_objects, safe_mode=False)
|
171
|
+
|
172
|
+
return cast(keras.Model, m)
|
173
|
+
|
174
|
+
@classmethod
|
175
|
+
def convert_as_custom_model(
|
176
|
+
cls,
|
177
|
+
raw_model: "keras.Model",
|
178
|
+
model_meta: model_meta_api.ModelMetadata,
|
179
|
+
background_data: Optional[pd.DataFrame] = None,
|
180
|
+
**kwargs: Unpack[model_types.TensorflowLoadOptions],
|
181
|
+
) -> custom_model.CustomModel:
|
182
|
+
|
183
|
+
from snowflake.ml.model import custom_model
|
184
|
+
|
185
|
+
def _create_custom_model(
|
186
|
+
raw_model: "keras.Model",
|
187
|
+
model_meta: model_meta_api.ModelMetadata,
|
188
|
+
) -> Type[custom_model.CustomModel]:
|
189
|
+
def fn_factory(
|
190
|
+
raw_model: "keras.Model",
|
191
|
+
signature: model_signature.ModelSignature,
|
192
|
+
target_method: str,
|
193
|
+
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
194
|
+
dtype_map = {
|
195
|
+
spec.name: spec.as_dtype(force_numpy_dtype=True)
|
196
|
+
for spec in signature.inputs
|
197
|
+
if isinstance(spec, model_signature.FeatureSpec)
|
198
|
+
}
|
199
|
+
|
200
|
+
@custom_model.inference_api
|
201
|
+
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
202
|
+
res = getattr(raw_model, target_method)(X.astype(dtype_map), verbose=0)
|
203
|
+
|
204
|
+
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
|
205
|
+
# In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
|
206
|
+
# return a list of ndarrays. We need to deal them separately
|
207
|
+
df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
|
208
|
+
else:
|
209
|
+
df = pd.DataFrame(res)
|
210
|
+
|
211
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
212
|
+
|
213
|
+
return fn
|
214
|
+
|
215
|
+
type_method_dict = {}
|
216
|
+
for target_method_name, sig in model_meta.signatures.items():
|
217
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
218
|
+
|
219
|
+
_KerasModel = type(
|
220
|
+
"_KerasModel",
|
221
|
+
(custom_model.CustomModel,),
|
222
|
+
type_method_dict,
|
223
|
+
)
|
224
|
+
|
225
|
+
return _KerasModel
|
226
|
+
|
227
|
+
_KerasModel = _create_custom_model(raw_model, model_meta)
|
228
|
+
keras_model = _KerasModel(custom_model.ModelContext())
|
229
|
+
|
230
|
+
return keras_model
|
@@ -49,6 +49,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
49
49
|
type_utils.LazyType("torch.nn.Module").isinstance(model)
|
50
50
|
and not type_utils.LazyType("torch.jit.ScriptModule").isinstance(model)
|
51
51
|
and not type_utils.LazyType("sentence_transformers.SentenceTransformer").isinstance(model)
|
52
|
+
and not type_utils.LazyType("keras.Model").isinstance(model)
|
52
53
|
)
|
53
54
|
|
54
55
|
@classmethod
|
@@ -292,12 +292,37 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
292
292
|
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
293
293
|
import shap
|
294
294
|
|
295
|
-
# TODO: if not resolved by explainer, we need to pass the callable function
|
296
295
|
try:
|
297
296
|
explainer = shap.Explainer(raw_model, background_data)
|
298
297
|
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
299
|
-
except TypeError
|
300
|
-
|
298
|
+
except TypeError:
|
299
|
+
try:
|
300
|
+
dtype_map = {
|
301
|
+
spec.name: spec.as_dtype(force_numpy_dtype=True) # type: ignore[attr-defined]
|
302
|
+
for spec in signature.inputs
|
303
|
+
}
|
304
|
+
|
305
|
+
if isinstance(X, pd.DataFrame):
|
306
|
+
X = X.astype(dtype_map, copy=False)
|
307
|
+
if hasattr(raw_model, "predict_proba"):
|
308
|
+
if isinstance(X, np.ndarray):
|
309
|
+
explanations = shap.Explainer(
|
310
|
+
raw_model.predict_proba, background_data.values # type: ignore[union-attr]
|
311
|
+
)(X).values
|
312
|
+
else:
|
313
|
+
explanations = shap.Explainer(raw_model.predict_proba, background_data)(X).values
|
314
|
+
elif hasattr(raw_model, "predict"):
|
315
|
+
if isinstance(X, np.ndarray):
|
316
|
+
explanations = shap.Explainer(
|
317
|
+
raw_model.predict, background_data.values # type: ignore[union-attr]
|
318
|
+
)(X).values
|
319
|
+
else:
|
320
|
+
explanations = shap.Explainer(raw_model.predict, background_data)(X).values
|
321
|
+
else:
|
322
|
+
raise ValueError("Missing any supported target method to explain.")
|
323
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explanations)
|
324
|
+
except TypeError as e:
|
325
|
+
raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
|
301
326
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
302
327
|
|
303
328
|
if target_method == "explain":
|
@@ -74,11 +74,6 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
74
74
|
background_data: Optional[model_types.SupportedDataType],
|
75
75
|
enable_explainability: Optional[bool],
|
76
76
|
) -> Any:
|
77
|
-
from snowflake.ml.modeling import pipeline as snowml_pipeline
|
78
|
-
|
79
|
-
# handle pipeline objects separately
|
80
|
-
if isinstance(estimator, snowml_pipeline.Pipeline): # type: ignore[attr-defined]
|
81
|
-
return None
|
82
77
|
|
83
78
|
tree_methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
|
84
79
|
non_tree_methods = ["to_sklearn"]
|
@@ -129,27 +124,54 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
129
124
|
# Pipeline is inherited from BaseEstimator, so no need to add one more check
|
130
125
|
|
131
126
|
if not is_sub_model:
|
132
|
-
if model_meta.signatures:
|
127
|
+
if model_meta.signatures or sample_input_data is not None:
|
133
128
|
warnings.warn(
|
134
129
|
"Providing model signature for Snowpark ML "
|
135
130
|
+ "Modeling model is not required. Model signature will automatically be inferred during fitting. ",
|
136
131
|
UserWarning,
|
137
132
|
stacklevel=2,
|
138
133
|
)
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
134
|
+
target_methods = handlers_utils.get_target_methods(
|
135
|
+
model=model,
|
136
|
+
target_methods=kwargs.pop("target_methods", None),
|
137
|
+
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
138
|
+
)
|
139
|
+
|
140
|
+
def get_prediction(
|
141
|
+
target_method_name: str,
|
142
|
+
sample_input_data: model_types.SupportedLocalDataType,
|
143
|
+
) -> model_types.SupportedLocalDataType:
|
144
|
+
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
|
145
|
+
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
146
|
+
|
147
|
+
target_method = getattr(model, target_method_name, None)
|
148
|
+
assert callable(target_method)
|
149
|
+
predictions_df = target_method(sample_input_data)
|
150
|
+
return predictions_df
|
151
|
+
|
152
|
+
model_meta = handlers_utils.validate_signature(
|
153
|
+
model=model,
|
154
|
+
model_meta=model_meta,
|
155
|
+
target_methods=target_methods,
|
156
|
+
sample_input_data=sample_input_data,
|
157
|
+
get_prediction_fn=get_prediction,
|
158
|
+
is_for_modeling_model=True,
|
159
|
+
)
|
144
160
|
else:
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
161
|
+
assert hasattr(model, "model_signatures"), "Model does not have model signatures as expected."
|
162
|
+
model_signature_dict = getattr(model, "model_signatures", {})
|
163
|
+
optional_target_methods = kwargs.pop("target_methods", None)
|
164
|
+
if not optional_target_methods:
|
165
|
+
model_meta.signatures = model_signature_dict
|
166
|
+
else:
|
167
|
+
temp_model_signature_dict = {}
|
168
|
+
for method_name in optional_target_methods:
|
169
|
+
method_model_signature = model_signature_dict.get(method_name, None)
|
170
|
+
if method_model_signature is not None:
|
171
|
+
temp_model_signature_dict[method_name] = method_model_signature
|
172
|
+
else:
|
173
|
+
raise ValueError(f"Target method {method_name} does not exist in the model.")
|
174
|
+
model_meta.signatures = temp_model_signature_dict
|
153
175
|
|
154
176
|
python_base_obj = cls._get_supported_object_for_explainability(model, sample_input_data, enable_explainability)
|
155
177
|
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
@@ -279,9 +301,40 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
279
301
|
for method_name in non_tree_methods:
|
280
302
|
try:
|
281
303
|
base_model = getattr(raw_model, method_name)()
|
282
|
-
|
283
|
-
|
304
|
+
try:
|
305
|
+
explainer = shap.Explainer(base_model, masker=background_data)
|
306
|
+
df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
|
307
|
+
except TypeError:
|
308
|
+
try:
|
309
|
+
dtype_map = {
|
310
|
+
spec.name: spec.as_dtype(force_numpy_dtype=True) # type: ignore[attr-defined]
|
311
|
+
for spec in signature.inputs
|
312
|
+
}
|
313
|
+
|
314
|
+
if isinstance(X, pd.DataFrame):
|
315
|
+
X = X.astype(dtype_map, copy=False)
|
316
|
+
if hasattr(base_model, "predict_proba"):
|
317
|
+
if isinstance(X, np.ndarray):
|
318
|
+
explainer = shap.Explainer(
|
319
|
+
base_model.predict_proba,
|
320
|
+
background_data.values, # type: ignore[union-attr]
|
321
|
+
)
|
322
|
+
else:
|
323
|
+
explainer = shap.Explainer(base_model.predict_proba, background_data)
|
324
|
+
elif hasattr(base_model, "predict"):
|
325
|
+
if isinstance(X, np.ndarray):
|
326
|
+
explainer = shap.Explainer(
|
327
|
+
base_model.predict, background_data.values # type: ignore[union-attr]
|
328
|
+
)
|
329
|
+
else:
|
330
|
+
explainer = shap.Explainer(base_model.predict, background_data)
|
331
|
+
else:
|
332
|
+
raise ValueError("Missing any supported target method to explain.")
|
333
|
+
df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
|
334
|
+
except TypeError as e:
|
335
|
+
raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
|
284
336
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
337
|
+
|
285
338
|
except exceptions.SnowflakeMLException:
|
286
339
|
pass # Do nothing and continue to the next method
|
287
340
|
raise ValueError("The model must be an xgboost, lightgbm or sklearn (not pipeline) estimator.")
|
@@ -10,7 +10,10 @@ from snowflake.ml._internal import type_utils
|
|
10
10
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
11
11
|
from snowflake.ml.model._packager.model_env import model_env
|
12
12
|
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
13
|
-
from snowflake.ml.model._packager.model_handlers_migrator import
|
13
|
+
from snowflake.ml.model._packager.model_handlers_migrator import (
|
14
|
+
base_migrator,
|
15
|
+
tensorflow_migrator_2023_12_01,
|
16
|
+
)
|
14
17
|
from snowflake.ml.model._packager.model_meta import (
|
15
18
|
model_blob_meta,
|
16
19
|
model_meta as model_meta_api,
|
@@ -28,15 +31,17 @@ if TYPE_CHECKING:
|
|
28
31
|
|
29
32
|
@final
|
30
33
|
class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
31
|
-
"""Handler for TensorFlow based model.
|
34
|
+
"""Handler for TensorFlow based model or keras v2 model.
|
32
35
|
|
33
36
|
Currently tensorflow.Module based classes are supported.
|
34
37
|
"""
|
35
38
|
|
36
39
|
HANDLER_TYPE = "tensorflow"
|
37
|
-
HANDLER_VERSION = "
|
38
|
-
_MIN_SNOWPARK_ML_VERSION = "1.
|
39
|
-
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
40
|
+
HANDLER_VERSION = "2025-01-01"
|
41
|
+
_MIN_SNOWPARK_ML_VERSION = "1.7.5"
|
42
|
+
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
43
|
+
"2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201
|
44
|
+
}
|
40
45
|
|
41
46
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
42
47
|
DEFAULT_TARGET_METHODS = ["__call__"]
|
@@ -46,7 +51,13 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
46
51
|
cls,
|
47
52
|
model: model_types.SupportedModelType,
|
48
53
|
) -> TypeGuard["tensorflow.nn.Module"]:
|
49
|
-
|
54
|
+
if not type_utils.LazyType("tensorflow.Module").isinstance(model):
|
55
|
+
return False
|
56
|
+
if type_utils.LazyType("keras.Model").isinstance(model):
|
57
|
+
import keras
|
58
|
+
|
59
|
+
return version.parse(keras.__version__) < version.parse("3.0.0")
|
60
|
+
return True
|
50
61
|
|
51
62
|
@classmethod
|
52
63
|
def cast_model(
|
@@ -74,44 +85,22 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
74
85
|
if enable_explainability:
|
75
86
|
raise NotImplementedError("Explainability is not supported for Tensorflow model.")
|
76
87
|
|
77
|
-
# When tensorflow is installed, keras is also installed.
|
78
|
-
import keras
|
79
88
|
import tensorflow
|
80
89
|
|
81
90
|
assert isinstance(model, tensorflow.Module)
|
82
91
|
|
83
|
-
is_keras_model = type_utils.LazyType("
|
84
|
-
"keras.Model"
|
85
|
-
).isinstance(model)
|
92
|
+
is_keras_model = type_utils.LazyType("keras.Model").isinstance(model)
|
86
93
|
is_tf_keras_model = type_utils.LazyType("tf_keras.Model").isinstance(model)
|
87
|
-
is_keras_functional_or_sequential_model = (
|
88
|
-
getattr(model, "_is_graph_network", False)
|
89
|
-
or type_utils.LazyType("tensorflow.keras.engine.sequential.Sequential").isinstance(model)
|
90
|
-
or type_utils.LazyType("keras.engine.sequential.Sequential").isinstance(model)
|
91
|
-
or type_utils.LazyType("tf_keras.engine.sequential.Sequential").isinstance(model)
|
92
|
-
)
|
93
|
-
|
94
|
-
assert isinstance(model, tensorflow.Module)
|
95
|
-
|
96
|
-
keras_version = version.parse(keras.__version__)
|
97
|
-
|
98
94
|
# Tensorflow and keras model save format is different.
|
99
|
-
# Keras
|
100
|
-
#
|
101
|
-
# Keras v2 other models are saved using tensorflow saved model format
|
102
|
-
# Tensorflow models are saved using tensorflow saved model format
|
95
|
+
# Keras v2 models are saved using keras api
|
96
|
+
# Tensorflow models are saved using tensorflow api
|
103
97
|
|
104
98
|
if is_keras_model or is_tf_keras_model:
|
105
|
-
|
106
|
-
save_format = "keras"
|
107
|
-
elif keras_version.major == 2 or is_tf_keras_model:
|
108
|
-
save_format = "keras_tf"
|
109
|
-
else:
|
110
|
-
save_format = "cloudpickle"
|
99
|
+
save_format = "keras_tf"
|
111
100
|
else:
|
112
101
|
save_format = "tf"
|
113
102
|
|
114
|
-
if is_keras_model:
|
103
|
+
if is_keras_model or is_tf_keras_model:
|
115
104
|
default_target_methods = ["predict"]
|
116
105
|
else:
|
117
106
|
default_target_methods = cls.DEFAULT_TARGET_METHODS
|
@@ -156,15 +145,8 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
156
145
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
157
146
|
os.makedirs(model_blob_path, exist_ok=True)
|
158
147
|
save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
159
|
-
if save_format == "
|
160
|
-
model.save(save_path, save_format="keras")
|
161
|
-
elif save_format == "keras_tf":
|
148
|
+
if save_format == "keras_tf":
|
162
149
|
model.save(save_path, save_format="tf")
|
163
|
-
elif save_format == "cloudpickle":
|
164
|
-
import cloudpickle
|
165
|
-
|
166
|
-
with open(save_path, "wb") as f:
|
167
|
-
cloudpickle.dump(model, f)
|
168
150
|
else:
|
169
151
|
tensorflow.saved_model.save(
|
170
152
|
model,
|
@@ -186,7 +168,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
186
168
|
model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"),
|
187
169
|
]
|
188
170
|
if is_keras_model:
|
189
|
-
dependencies.append(model_env.ModelDependency(requirement="keras", pip_name="keras"))
|
171
|
+
dependencies.append(model_env.ModelDependency(requirement="keras<=3", pip_name="keras"))
|
190
172
|
elif is_tf_keras_model:
|
191
173
|
dependencies.append(model_env.ModelDependency(requirement="tf-keras", pip_name="tf-keras"))
|
192
174
|
|
@@ -204,6 +186,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
204
186
|
model_blobs_dir_path: str,
|
205
187
|
**kwargs: Unpack[model_types.TensorflowLoadOptions],
|
206
188
|
) -> "tensorflow.Module":
|
189
|
+
os.environ["TF_USE_LEGACY_KERAS"] = "1"
|
207
190
|
import tensorflow
|
208
191
|
|
209
192
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
@@ -212,14 +195,9 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
212
195
|
model_blob_filename = model_blob_metadata.path
|
213
196
|
model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
|
214
197
|
load_path = os.path.join(model_blob_path, model_blob_filename)
|
215
|
-
save_format = model_blob_options.get("save_format", "
|
216
|
-
if save_format == "
|
198
|
+
save_format = model_blob_options.get("save_format", "keras_tf")
|
199
|
+
if save_format == "keras_tf":
|
217
200
|
m = tensorflow.keras.models.load_model(load_path)
|
218
|
-
elif save_format == "cloudpickle":
|
219
|
-
import cloudpickle
|
220
|
-
|
221
|
-
with open(load_path, "rb") as f:
|
222
|
-
m = cloudpickle.load(f)
|
223
201
|
else:
|
224
202
|
m = tensorflow.saved_model.load(load_path)
|
225
203
|
|
@@ -0,0 +1,48 @@
|
|
1
|
+
from typing import cast
|
2
|
+
|
3
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
4
|
+
from snowflake.ml.model._packager.model_meta import (
|
5
|
+
model_meta as model_meta_api,
|
6
|
+
model_meta_schema,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
class TensorflowHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
|
11
|
+
source_version = "2023-12-01"
|
12
|
+
target_version = "2025-01-01"
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
|
16
|
+
|
17
|
+
model_blob_metadata = model_meta.models[name]
|
18
|
+
model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
|
19
|
+
# To migrate code <= 1.7.0, default to keras model
|
20
|
+
is_old_model = "save_format" not in model_blob_options and "is_keras_model" not in model_blob_options
|
21
|
+
# To migrate code form 1.7.1, default to False.
|
22
|
+
is_keras_model = model_blob_options.get("is_keras_model", False)
|
23
|
+
# To migrate code from 1.7.2, default to tf, has options keras, keras_tf, cloudpickle, tf
|
24
|
+
#
|
25
|
+
# if is_keras_model or is_tf_keras_model:
|
26
|
+
# if is_keras_functional_or_sequential_model:
|
27
|
+
# save_format = "keras"
|
28
|
+
# elif keras_version.major == 2 or is_tf_keras_model:
|
29
|
+
# save_format = "keras_tf"
|
30
|
+
# else:
|
31
|
+
# save_format = "cloudpickle"
|
32
|
+
# else:
|
33
|
+
# save_format = "tf"
|
34
|
+
#
|
35
|
+
save_format = model_blob_options.get("save_format", "tf")
|
36
|
+
|
37
|
+
if save_format == "keras" or is_keras_model or is_old_model:
|
38
|
+
save_format = "keras_tf"
|
39
|
+
elif save_format == "cloudpickle":
|
40
|
+
# Given the old logic, this could only happen if the original model is a keras model, and keras is 3.x
|
41
|
+
# However, in this case, keras.Model does not extends from tensorflow.Module
|
42
|
+
# So actually TensorflowHandler will not be triggered, we could safely error this out.
|
43
|
+
raise NotImplementedError(
|
44
|
+
"Unable to upgrade keras 3.x model saved by old handler. This is not supposed to happen"
|
45
|
+
)
|
46
|
+
|
47
|
+
model_blob_options["save_format"] = save_format
|
48
|
+
model_meta.models[name].options = model_blob_options
|
@@ -352,7 +352,7 @@ class ModelMetadata:
|
|
352
352
|
version.parse(loaded_meta_min_snowpark_ml_version) > version.parse(snowml_env.VERSION)
|
353
353
|
):
|
354
354
|
raise RuntimeError(
|
355
|
-
f"The minimal version required to load the model is {loaded_meta_min_snowpark_ml_version},"
|
355
|
+
f"The minimal version required to load the model is {loaded_meta_min_snowpark_ml_version}, "
|
356
356
|
f"while current version of Snowpark ML library is {snowml_env.VERSION}."
|
357
357
|
)
|
358
358
|
return model_meta_schema.ModelMetadataDict(
|
@@ -44,6 +44,9 @@ class CatBoostModelBlobOptions(BaseModelBlobOptions):
|
|
44
44
|
class HuggingFacePipelineModelBlobOptions(BaseModelBlobOptions):
|
45
45
|
task: Required[str]
|
46
46
|
batch_size: Required[int]
|
47
|
+
has_tokenizer: NotRequired[bool]
|
48
|
+
has_feature_extractor: NotRequired[bool]
|
49
|
+
has_image_preprocessor: NotRequired[bool]
|
47
50
|
|
48
51
|
|
49
52
|
class LightGBMModelBlobOptions(BaseModelBlobOptions):
|
@@ -1,2 +1,2 @@
|
|
1
|
-
REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<
|
2
|
-
ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<
|
1
|
+
REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
2
|
+
ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'keras>=2.0.0,<4', 'lightgbm>=4.1.0, <5', 'mlflow>=2.16.0, <3', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<3', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.7.0,<3', 'sentencepiece>=0.1.95,<0.2.0', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'tensorflow>=2.17.0,<3', 'tokenizers>=0.15.1,<1', 'torchdata>=0.4,<1', 'transformers>=4.37.2,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
@@ -45,6 +45,7 @@ class ModelRuntime:
|
|
45
45
|
self.name = name
|
46
46
|
self.runtime_env = copy.deepcopy(env)
|
47
47
|
self.imports = imports or []
|
48
|
+
self.is_gpu = is_gpu
|
48
49
|
|
49
50
|
if loading_from_file:
|
50
51
|
return
|
@@ -88,7 +89,9 @@ class ModelRuntime:
|
|
88
89
|
self.runtime_env.conda_env_rel_path = self.runtime_rel_path / self.runtime_env.conda_env_rel_path
|
89
90
|
self.runtime_env.pip_requirements_rel_path = self.runtime_rel_path / self.runtime_env.pip_requirements_rel_path
|
90
91
|
|
91
|
-
env_dict = self.runtime_env.save_as_dict(
|
92
|
+
env_dict = self.runtime_env.save_as_dict(
|
93
|
+
packager_path, default_channel_override=default_channel_override, is_gpu=self.is_gpu
|
94
|
+
)
|
92
95
|
|
93
96
|
return model_meta_schema.ModelRuntimeDict(
|
94
97
|
imports=list(map(str, self.imports)),
|