snowflake-ml-python 1.6.4__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_complete.py +107 -64
- snowflake/cortex/_finetune.py +273 -0
- snowflake/cortex/_sse_client.py +91 -28
- snowflake/cortex/_util.py +30 -1
- snowflake/ml/_internal/telemetry.py +4 -2
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/import_utils.py +31 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
- snowflake/ml/data/__init__.py +5 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/data/torch_utils.py +33 -14
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
- snowflake/ml/feature_store/examples/example_helper.py +6 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
- snowflake/ml/feature_store/feature_store.py +1 -2
- snowflake/ml/feature_store/feature_view.py +5 -1
- snowflake/ml/model/_client/model/model_version_impl.py +145 -11
- snowflake/ml/model/_client/ops/model_ops.py +56 -16
- snowflake/ml/model/_client/ops/service_ops.py +46 -30
- snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
- snowflake/ml/model/_client/sql/service.py +25 -1
- snowflake/ml/model/_model_composer/model_composer.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
- snowflake/ml/model/_packager/model_env/model_env.py +12 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +6 -2
- snowflake/ml/model/_packager/model_handlers/catboost.py +4 -7
- snowflake/ml/model/_packager/model_handlers/custom.py +5 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -7
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +51 -7
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +8 -66
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
- snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
- snowflake/ml/model/_packager/model_packager.py +0 -11
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
- snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +14 -26
- snowflake/ml/model/_signatures/core.py +63 -16
- snowflake/ml/model/_signatures/pandas_handler.py +87 -27
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
- snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/custom_model.py +47 -7
- snowflake/ml/model/model_signature.py +40 -9
- snowflake/ml/model/type_hints.py +9 -1
- snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
- snowflake/ml/modeling/cluster/dbscan.py +5 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
- snowflake/ml/modeling/cluster/k_means.py +14 -19
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
- snowflake/ml/modeling/cluster/optics.py +6 -6
- snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
- snowflake/ml/modeling/compose/column_transformer.py +15 -5
- snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
- snowflake/ml/modeling/decomposition/pca.py +28 -15
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
- snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
- snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +0 -10
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
- snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +3 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
- snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
- snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +3 -3
- snowflake/ml/modeling/manifold/tsne.py +10 -4
- snowflake/ml/modeling/metrics/classification.py +12 -16
- snowflake/ml/modeling/metrics/ranking.py +3 -3
- snowflake/ml/modeling/metrics/regression.py +3 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
- snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
- snowflake/ml/modeling/pipeline/pipeline.py +16 -14
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
- snowflake/ml/modeling/svm/linear_svc.py +25 -16
- snowflake/ml/modeling/svm/linear_svr.py +23 -17
- snowflake/ml/modeling/svm/nu_svc.py +5 -3
- snowflake/ml/modeling/svm/nu_svr.py +3 -1
- snowflake/ml/modeling/svm/svc.py +9 -5
- snowflake/ml/modeling/svm/svr.py +3 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
- snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
- snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
- snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +448 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +238 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
- snowflake/ml/monitoring/model_monitor.py +37 -0
- snowflake/ml/registry/_manager/model_manager.py +15 -1
- snowflake/ml/registry/registry.py +32 -37
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +104 -12
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +172 -171
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/_client/model_monitor.py +0 -126
- snowflake/ml/monitoring/_client/model_monitor_manager.py +0 -361
- snowflake/ml/monitoring/_client/monitor_sql_client.py +0 -1335
- snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
- /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -27,7 +27,7 @@ def get_model_method_options_from_options(
|
|
27
27
|
options: type_hints.ModelSaveOption, target_method: str
|
28
28
|
) -> ModelMethodOptions:
|
29
29
|
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
|
30
|
-
if
|
30
|
+
if target_method == "explain":
|
31
31
|
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
32
32
|
method_option = options.get("method_options", {}).get(target_method, {})
|
33
33
|
global_function_type = options.get("function_type", default_function_type)
|
@@ -174,6 +174,18 @@ class ModelEnv:
|
|
174
174
|
except env_utils.DuplicateDependencyError:
|
175
175
|
pass
|
176
176
|
|
177
|
+
def remove_if_present_conda(self, conda_pkgs: List[str]) -> None:
|
178
|
+
"""Remove conda requirements from model env if present.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
conda_pkgs: A list of package name to be removed from conda requirements.
|
182
|
+
"""
|
183
|
+
for pkg_name in conda_pkgs:
|
184
|
+
spec_conda = env_utils._find_conda_dep_spec(self._conda_dependencies, pkg_name)
|
185
|
+
if spec_conda:
|
186
|
+
channel, spec = spec_conda
|
187
|
+
self._conda_dependencies[channel].remove(spec)
|
188
|
+
|
177
189
|
def generate_env_for_cuda(self) -> None:
|
178
190
|
if self.cuda_version is None:
|
179
191
|
return
|
@@ -179,7 +179,7 @@ def convert_explanations_to_2D_df(
|
|
179
179
|
return pd.DataFrame(explanations)
|
180
180
|
|
181
181
|
if hasattr(model, "classes_"):
|
182
|
-
classes_list = [str(cl) for cl in model.classes_]
|
182
|
+
classes_list = [str(cl) for cl in model.classes_]
|
183
183
|
len_classes = len(classes_list)
|
184
184
|
if explanations.shape[2] != len_classes:
|
185
185
|
raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
|
@@ -191,7 +191,11 @@ def convert_explanations_to_2D_df(
|
|
191
191
|
# convert to object or numpy creates strings of fixed length
|
192
192
|
return np.asarray(json.dumps(dict(zip(classes_list, row)), cls=NumpyEncoder), dtype=object)
|
193
193
|
|
194
|
-
|
194
|
+
# convert to dict only for multiclass
|
195
|
+
if len(classes_list) > 2:
|
196
|
+
exp_2d = np.apply_along_axis(row_to_dict, -1, explanations)
|
197
|
+
else: # assumes index 1 is positive class always
|
198
|
+
exp_2d = np.apply_along_axis(lambda arr: arr[1], -1, explanations)
|
195
199
|
|
196
200
|
return pd.DataFrame(exp_2d)
|
197
201
|
|
@@ -9,17 +9,14 @@ from typing_extensions import TypeGuard, Unpack
|
|
9
9
|
from snowflake.ml._internal import type_utils
|
10
10
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
11
11
|
from snowflake.ml.model._packager.model_env import model_env
|
12
|
-
from snowflake.ml.model._packager.model_handlers import
|
13
|
-
_base,
|
14
|
-
_utils as handlers_utils,
|
15
|
-
model_objective_utils,
|
16
|
-
)
|
12
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
17
13
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
18
14
|
from snowflake.ml.model._packager.model_meta import (
|
19
15
|
model_blob_meta,
|
20
16
|
model_meta as model_meta_api,
|
21
17
|
model_meta_schema,
|
22
18
|
)
|
19
|
+
from snowflake.ml.model._packager.model_task import model_task_utils
|
23
20
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
24
21
|
|
25
22
|
if TYPE_CHECKING:
|
@@ -97,8 +94,8 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
97
94
|
sample_input_data=sample_input_data,
|
98
95
|
get_prediction_fn=get_prediction,
|
99
96
|
)
|
100
|
-
model_task_and_output =
|
101
|
-
model_meta.task = model_task_and_output.task
|
97
|
+
model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
|
98
|
+
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
|
102
99
|
if enable_explainability:
|
103
100
|
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
104
101
|
model_meta = handlers_utils.add_explain_method_signature(
|
@@ -2,7 +2,7 @@ import inspect
|
|
2
2
|
import os
|
3
3
|
import pathlib
|
4
4
|
import sys
|
5
|
-
from typing import Dict, Optional, Type, final
|
5
|
+
from typing import Dict, Optional, Type, cast, final
|
6
6
|
|
7
7
|
import anyio
|
8
8
|
import cloudpickle
|
@@ -99,6 +99,8 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
99
99
|
for sub_name, model_ref in model.context.model_refs.items():
|
100
100
|
handler = model_handler.find_handler(model_ref.model)
|
101
101
|
assert handler is not None
|
102
|
+
if handler is None:
|
103
|
+
raise TypeError("Your input type to custom model is not currently supported")
|
102
104
|
sub_model = handler.cast_model(model_ref.model)
|
103
105
|
handler.save_model(
|
104
106
|
name=sub_name,
|
@@ -106,6 +108,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
106
108
|
model_meta=model_meta,
|
107
109
|
model_blobs_dir_path=model_blobs_dir_path,
|
108
110
|
is_sub_model=True,
|
111
|
+
**cast(model_types.BaseModelSaveOption, kwargs),
|
109
112
|
)
|
110
113
|
|
111
114
|
# Make sure that the module where the model is defined get pickled by value as well.
|
@@ -173,6 +176,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
173
176
|
name=sub_model_name,
|
174
177
|
model_meta=model_meta,
|
175
178
|
model_blobs_dir_path=model_blobs_dir_path,
|
179
|
+
**cast(model_types.BaseModelLoadOption, kwargs),
|
176
180
|
)
|
177
181
|
models[sub_model_name] = sub_model
|
178
182
|
reconstructed_context = custom_model.ModelContext(artifacts=artifacts, models=models)
|
@@ -256,12 +256,20 @@ class HuggingFacePipelineHandler(
|
|
256
256
|
@staticmethod
|
257
257
|
def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> Dict[str, str]:
|
258
258
|
device_config: Dict[str, Any] = {}
|
259
|
+
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
260
|
+
gpu_nums = 0
|
261
|
+
if cuda_visible_devices is not None:
|
262
|
+
gpu_nums = len(cuda_visible_devices.split(","))
|
259
263
|
if (
|
260
264
|
kwargs.get("use_gpu", False)
|
261
265
|
and kwargs.get("device_map", None) is None
|
262
266
|
and kwargs.get("device", None) is None
|
263
267
|
):
|
264
|
-
|
268
|
+
if gpu_nums == 0 or gpu_nums > 1:
|
269
|
+
# Use accelerator if there are multiple GPUs or no GPU
|
270
|
+
device_config["device_map"] = "auto"
|
271
|
+
else:
|
272
|
+
device_config["device"] = "cuda"
|
265
273
|
elif kwargs.get("device_map", None) is not None:
|
266
274
|
device_config["device_map"] = kwargs["device_map"]
|
267
275
|
elif kwargs.get("device", None) is not None:
|
@@ -310,6 +318,7 @@ class HuggingFacePipelineHandler(
|
|
310
318
|
m = transformers.pipeline(
|
311
319
|
model_blob_options["task"],
|
312
320
|
model=model_blob_file_or_dir_path,
|
321
|
+
trust_remote_code=True,
|
313
322
|
**device_config,
|
314
323
|
)
|
315
324
|
|
@@ -20,17 +20,14 @@ from typing_extensions import TypeGuard, Unpack
|
|
20
20
|
from snowflake.ml._internal import type_utils
|
21
21
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
22
22
|
from snowflake.ml.model._packager.model_env import model_env
|
23
|
-
from snowflake.ml.model._packager.model_handlers import
|
24
|
-
_base,
|
25
|
-
_utils as handlers_utils,
|
26
|
-
model_objective_utils,
|
27
|
-
)
|
23
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
28
24
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
29
25
|
from snowflake.ml.model._packager.model_meta import (
|
30
26
|
model_blob_meta,
|
31
27
|
model_meta as model_meta_api,
|
32
28
|
model_meta_schema,
|
33
29
|
)
|
30
|
+
from snowflake.ml.model._packager.model_task import model_task_utils
|
34
31
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
35
32
|
|
36
33
|
if TYPE_CHECKING:
|
@@ -113,7 +110,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
113
110
|
sample_input_data=sample_input_data,
|
114
111
|
get_prediction_fn=get_prediction,
|
115
112
|
)
|
116
|
-
model_task_and_output =
|
113
|
+
model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
|
117
114
|
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
|
118
115
|
if enable_explainability:
|
119
116
|
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
@@ -199,13 +196,14 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
199
196
|
with open(model_blob_file_path, "rb") as f:
|
200
197
|
model = cloudpickle.load(f)
|
201
198
|
assert isinstance(model, getattr(lightgbm, lightgbm_estimator_type))
|
199
|
+
assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
|
202
200
|
|
203
201
|
return model
|
204
202
|
|
205
203
|
@classmethod
|
206
204
|
def convert_as_custom_model(
|
207
205
|
cls,
|
208
|
-
raw_model: Union["lightgbm.Booster", "lightgbm.
|
206
|
+
raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
|
209
207
|
model_meta: model_meta_api.ModelMetadata,
|
210
208
|
background_data: Optional[pd.DataFrame] = None,
|
211
209
|
**kwargs: Unpack[model_types.LGBMModelLoadOptions],
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import inspect
|
1
2
|
import logging
|
2
3
|
import os
|
3
4
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
@@ -155,8 +156,14 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
155
156
|
model_blob_filename = model_blob_metadata.path
|
156
157
|
model_blob_file_or_dir_path = os.path.join(model_blob_path, model_blob_filename)
|
157
158
|
|
159
|
+
additional_kwargs = {}
|
160
|
+
if "trust_remote_code" in inspect.signature(sentence_transformers.SentenceTransformer).parameters:
|
161
|
+
additional_kwargs["trust_remote_code"] = True
|
162
|
+
|
158
163
|
model = sentence_transformers.SentenceTransformer(
|
159
|
-
model_blob_file_or_dir_path,
|
164
|
+
model_blob_file_or_dir_path,
|
165
|
+
device=cls._get_device_config(**kwargs),
|
166
|
+
**additional_kwargs,
|
160
167
|
)
|
161
168
|
return model
|
162
169
|
|
@@ -10,24 +10,35 @@ from typing_extensions import TypeGuard, Unpack
|
|
10
10
|
from snowflake.ml._internal import type_utils
|
11
11
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
12
12
|
from snowflake.ml.model._packager.model_env import model_env
|
13
|
-
from snowflake.ml.model._packager.model_handlers import
|
14
|
-
_base,
|
15
|
-
_utils as handlers_utils,
|
16
|
-
model_objective_utils,
|
17
|
-
)
|
13
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
18
14
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
19
15
|
from snowflake.ml.model._packager.model_meta import (
|
20
16
|
model_blob_meta,
|
21
17
|
model_meta as model_meta_api,
|
22
18
|
model_meta_schema,
|
23
19
|
)
|
20
|
+
from snowflake.ml.model._packager.model_task import model_task_utils
|
24
21
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
22
|
+
from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR
|
25
23
|
|
26
24
|
if TYPE_CHECKING:
|
27
25
|
import sklearn.base
|
28
26
|
import sklearn.pipeline
|
29
27
|
|
30
28
|
|
29
|
+
def _unpack_container_runtime_pipeline(model: "sklearn.pipeline.Pipeline") -> "sklearn.pipeline.Pipeline":
|
30
|
+
new_steps = []
|
31
|
+
for step_name, step in model.steps:
|
32
|
+
new_reg = step
|
33
|
+
if hasattr(step, "_sklearn_estimator") and step._sklearn_estimator is not None:
|
34
|
+
# Unpack estimator to open source.
|
35
|
+
new_reg = step._sklearn_estimator
|
36
|
+
new_steps.append((step_name, new_reg))
|
37
|
+
|
38
|
+
model.steps = new_steps
|
39
|
+
return model
|
40
|
+
|
41
|
+
|
31
42
|
@final
|
32
43
|
class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]]):
|
33
44
|
"""Handler for scikit-learn based model.
|
@@ -104,6 +115,10 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
104
115
|
if sample_input_data is None:
|
105
116
|
raise ValueError("Sample input data is required to enable explainability.")
|
106
117
|
|
118
|
+
# If this is a pipeline and we are in the container runtime, check for distributed estimator.
|
119
|
+
if os.getenv(IN_ML_RUNTIME_ENV_VAR) and isinstance(model, sklearn.pipeline.Pipeline):
|
120
|
+
model = _unpack_container_runtime_pipeline(model)
|
121
|
+
|
107
122
|
if not is_sub_model:
|
108
123
|
target_methods = handlers_utils.get_target_methods(
|
109
124
|
model=model,
|
@@ -137,8 +152,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
137
152
|
sample_input_data, model_meta, explain_target_method
|
138
153
|
)
|
139
154
|
|
140
|
-
model_task_and_output_type =
|
141
|
-
model_meta.task = model_task_and_output_type.task
|
155
|
+
model_task_and_output_type = model_task_utils.get_model_task_and_output_type(model)
|
156
|
+
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task)
|
142
157
|
|
143
158
|
# if users did not ask then we enable if we have background data
|
144
159
|
if enable_explainability is None:
|
@@ -180,6 +195,35 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
180
195
|
model_meta.models[name] = base_meta
|
181
196
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
182
197
|
|
198
|
+
# if model instance is a pipeline, check the pipeline steps
|
199
|
+
if isinstance(model, sklearn.pipeline.Pipeline):
|
200
|
+
for _, pipeline_step in model.steps:
|
201
|
+
if type_utils.LazyType("lightgbm.LGBMModel").isinstance(pipeline_step) or type_utils.LazyType(
|
202
|
+
"lightgbm.Booster"
|
203
|
+
).isinstance(pipeline_step):
|
204
|
+
model_meta.env.include_if_absent(
|
205
|
+
[
|
206
|
+
model_env.ModelDependency(requirement="lightgbm", pip_name="lightgbm"),
|
207
|
+
],
|
208
|
+
check_local_version=True,
|
209
|
+
)
|
210
|
+
elif type_utils.LazyType("xgboost.XGBModel").isinstance(pipeline_step) or type_utils.LazyType(
|
211
|
+
"xgboost.Booster"
|
212
|
+
).isinstance(pipeline_step):
|
213
|
+
model_meta.env.include_if_absent(
|
214
|
+
[
|
215
|
+
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
216
|
+
],
|
217
|
+
check_local_version=True,
|
218
|
+
)
|
219
|
+
elif type_utils.LazyType("catboost.CatBoost").isinstance(pipeline_step):
|
220
|
+
model_meta.env.include_if_absent(
|
221
|
+
[
|
222
|
+
model_env.ModelDependency(requirement="catboost", pip_name="catboost"),
|
223
|
+
],
|
224
|
+
check_local_version=True,
|
225
|
+
)
|
226
|
+
|
183
227
|
if enable_explainability:
|
184
228
|
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
185
229
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
@@ -5,24 +5,20 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, fin
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
7
7
|
import pandas as pd
|
8
|
-
from packaging import version
|
9
8
|
from typing_extensions import TypeGuard, Unpack
|
10
9
|
|
11
10
|
from snowflake.ml._internal import type_utils
|
12
11
|
from snowflake.ml._internal.exceptions import exceptions
|
13
12
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
14
13
|
from snowflake.ml.model._packager.model_env import model_env
|
15
|
-
from snowflake.ml.model._packager.model_handlers import
|
16
|
-
_base,
|
17
|
-
_utils as handlers_utils,
|
18
|
-
model_objective_utils,
|
19
|
-
)
|
14
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
20
15
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
21
16
|
from snowflake.ml.model._packager.model_meta import (
|
22
17
|
model_blob_meta,
|
23
18
|
model_meta as model_meta_api,
|
24
19
|
model_meta_schema,
|
25
20
|
)
|
21
|
+
from snowflake.ml.model._packager.model_task import model_task_utils
|
26
22
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
27
23
|
|
28
24
|
if TYPE_CHECKING:
|
@@ -72,41 +68,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
72
68
|
return cast("BaseEstimator", model)
|
73
69
|
|
74
70
|
@classmethod
|
75
|
-
def
|
76
|
-
from importlib import metadata as importlib_metadata
|
77
|
-
|
78
|
-
from packaging import version
|
79
|
-
|
80
|
-
local_version = None
|
81
|
-
|
82
|
-
try:
|
83
|
-
local_dist = importlib_metadata.distribution(pkg_name)
|
84
|
-
local_version = version.parse(local_dist.version)
|
85
|
-
except importlib_metadata.PackageNotFoundError:
|
86
|
-
pass
|
87
|
-
|
88
|
-
return local_version
|
89
|
-
|
90
|
-
@classmethod
|
91
|
-
def _can_support_xgb(cls, enable_explainability: Optional[bool]) -> bool:
|
92
|
-
|
93
|
-
local_xgb_version = cls._get_local_version_package("xgboost")
|
94
|
-
|
95
|
-
if local_xgb_version and local_xgb_version >= version.parse("2.1.0"):
|
96
|
-
if enable_explainability:
|
97
|
-
warnings.warn(
|
98
|
-
f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
|
99
|
-
+ "If you want model explanations, lower the xgboost version to <2.1.0.",
|
100
|
-
category=UserWarning,
|
101
|
-
stacklevel=1,
|
102
|
-
)
|
103
|
-
return False
|
104
|
-
return True
|
105
|
-
|
106
|
-
@classmethod
|
107
|
-
def _get_supported_object_for_explainability(
|
108
|
-
cls, estimator: "BaseEstimator", enable_explainability: Optional[bool]
|
109
|
-
) -> Any:
|
71
|
+
def _get_supported_object_for_explainability(cls, estimator: "BaseEstimator") -> Any:
|
110
72
|
from snowflake.ml.modeling import pipeline as snowml_pipeline
|
111
73
|
|
112
74
|
# handle pipeline objects separately
|
@@ -118,8 +80,6 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
118
80
|
if hasattr(estimator, method_name):
|
119
81
|
try:
|
120
82
|
result = getattr(estimator, method_name)()
|
121
|
-
if method_name == "to_xgboost" and not cls._can_support_xgb(enable_explainability):
|
122
|
-
return None
|
123
83
|
return result
|
124
84
|
except exceptions.SnowflakeMLException:
|
125
85
|
pass # Do nothing and continue to the next method
|
@@ -168,7 +128,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
168
128
|
model_meta.signatures = temp_model_signature_dict
|
169
129
|
|
170
130
|
if enable_explainability or enable_explainability is None:
|
171
|
-
python_base_obj = cls._get_supported_object_for_explainability(model
|
131
|
+
python_base_obj = cls._get_supported_object_for_explainability(model)
|
172
132
|
if python_base_obj is None:
|
173
133
|
if enable_explainability: # if user set enable_explainability to True, throw error else silently skip
|
174
134
|
raise ValueError(
|
@@ -177,8 +137,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
177
137
|
# set None to False so we don't include shap in the environment
|
178
138
|
enable_explainability = False
|
179
139
|
else:
|
180
|
-
model_task_and_output_type =
|
181
|
-
model_meta.task = model_task_and_output_type.task
|
140
|
+
model_task_and_output_type = model_task_utils.get_model_task_and_output_type(python_base_obj)
|
141
|
+
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task)
|
182
142
|
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
183
143
|
model_meta = handlers_utils.add_explain_method_signature(
|
184
144
|
model_meta=model_meta,
|
@@ -213,28 +173,10 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
213
173
|
model_dependencies = model._get_dependencies()
|
214
174
|
for dep in model_dependencies:
|
215
175
|
pkg_name = dep.split("==")[0]
|
216
|
-
|
217
|
-
_include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
|
218
|
-
continue
|
219
|
-
|
220
|
-
local_xgb_version = cls._get_local_version_package("xgboost")
|
221
|
-
if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
|
222
|
-
model_meta.env.include_if_absent(
|
223
|
-
[
|
224
|
-
model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
|
225
|
-
],
|
226
|
-
check_local_version=False,
|
227
|
-
)
|
228
|
-
else:
|
229
|
-
model_meta.env.include_if_absent(
|
230
|
-
[
|
231
|
-
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
232
|
-
],
|
233
|
-
check_local_version=True,
|
234
|
-
)
|
176
|
+
_include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
|
235
177
|
|
236
178
|
if enable_explainability:
|
237
|
-
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
179
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
|
238
180
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
239
181
|
model_meta.env.include_if_absent(_include_if_absent_pkgs, check_local_version=True)
|
240
182
|
|
@@ -13,6 +13,7 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
13
13
|
from snowflake.ml.model._packager.model_meta import (
|
14
14
|
model_blob_meta,
|
15
15
|
model_meta as model_meta_api,
|
16
|
+
model_meta_schema,
|
16
17
|
)
|
17
18
|
from snowflake.ml.model._signatures import (
|
18
19
|
numpy_handler,
|
@@ -76,7 +77,11 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
76
77
|
|
77
78
|
assert isinstance(model, tensorflow.Module)
|
78
79
|
|
79
|
-
|
80
|
+
is_keras_model = type_utils.LazyType("tensorflow.keras.Model").isinstance(model) or type_utils.LazyType(
|
81
|
+
"tf_keras.Model"
|
82
|
+
).isinstance(model)
|
83
|
+
|
84
|
+
if is_keras_model:
|
80
85
|
default_target_methods = ["predict"]
|
81
86
|
else:
|
82
87
|
default_target_methods = cls.DEFAULT_TARGET_METHODS
|
@@ -117,8 +122,14 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
117
122
|
|
118
123
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
119
124
|
os.makedirs(model_blob_path, exist_ok=True)
|
120
|
-
if
|
125
|
+
if is_keras_model:
|
121
126
|
tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
127
|
+
model_meta.env.include_if_absent(
|
128
|
+
[
|
129
|
+
model_env.ModelDependency(requirement="keras<3", pip_name="keras"),
|
130
|
+
],
|
131
|
+
check_local_version=False,
|
132
|
+
)
|
122
133
|
else:
|
123
134
|
tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
124
135
|
|
@@ -127,12 +138,16 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
127
138
|
model_type=cls.HANDLER_TYPE,
|
128
139
|
handler_version=cls.HANDLER_VERSION,
|
129
140
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
141
|
+
options=model_meta_schema.TensorflowModelBlobOptions(is_keras_model=is_keras_model),
|
130
142
|
)
|
131
143
|
model_meta.models[name] = base_meta
|
132
144
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
133
145
|
|
134
146
|
model_meta.env.include_if_absent(
|
135
|
-
[
|
147
|
+
[
|
148
|
+
model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"),
|
149
|
+
],
|
150
|
+
check_local_version=True,
|
136
151
|
)
|
137
152
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
138
153
|
|
@@ -150,9 +165,11 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
150
165
|
model_blobs_metadata = model_meta.models
|
151
166
|
model_blob_metadata = model_blobs_metadata[name]
|
152
167
|
model_blob_filename = model_blob_metadata.path
|
153
|
-
|
154
|
-
if
|
155
|
-
|
168
|
+
model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
|
169
|
+
if model_blob_options.get("is_keras_model", False):
|
170
|
+
m = tensorflow.keras.models.load_model(os.path.join(model_blob_path, model_blob_filename), compile=False)
|
171
|
+
else:
|
172
|
+
m = tensorflow.saved_model.load(os.path.join(model_blob_path, model_blob_filename))
|
156
173
|
return cast(tensorflow.Module, m)
|
157
174
|
|
158
175
|
@classmethod
|
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|
23
23
|
|
24
24
|
|
25
25
|
@final
|
26
|
-
class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
26
|
+
class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
27
27
|
"""Handler for PyTorch JIT based model.
|
28
28
|
|
29
29
|
Currently torch.jit.ScriptModule based classes are supported.
|
@@ -41,25 +41,25 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
41
41
|
def can_handle(
|
42
42
|
cls,
|
43
43
|
model: model_types.SupportedModelType,
|
44
|
-
) -> TypeGuard["torch.jit.ScriptModule"]:
|
44
|
+
) -> TypeGuard["torch.jit.ScriptModule"]:
|
45
45
|
return type_utils.LazyType("torch.jit.ScriptModule").isinstance(model)
|
46
46
|
|
47
47
|
@classmethod
|
48
48
|
def cast_model(
|
49
49
|
cls,
|
50
50
|
model: model_types.SupportedModelType,
|
51
|
-
) -> "torch.jit.ScriptModule":
|
51
|
+
) -> "torch.jit.ScriptModule":
|
52
52
|
import torch
|
53
53
|
|
54
|
-
assert isinstance(model, torch.jit.ScriptModule)
|
54
|
+
assert isinstance(model, torch.jit.ScriptModule)
|
55
55
|
|
56
|
-
return cast(torch.jit.ScriptModule, model)
|
56
|
+
return cast(torch.jit.ScriptModule, model)
|
57
57
|
|
58
58
|
@classmethod
|
59
59
|
def save_model(
|
60
60
|
cls,
|
61
61
|
name: str,
|
62
|
-
model: "torch.jit.ScriptModule",
|
62
|
+
model: "torch.jit.ScriptModule",
|
63
63
|
model_meta: model_meta_api.ModelMetadata,
|
64
64
|
model_blobs_dir_path: str,
|
65
65
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
@@ -72,7 +72,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
72
72
|
|
73
73
|
import torch
|
74
74
|
|
75
|
-
assert isinstance(model, torch.jit.ScriptModule)
|
75
|
+
assert isinstance(model, torch.jit.ScriptModule)
|
76
76
|
|
77
77
|
if not is_sub_model:
|
78
78
|
target_methods = handlers_utils.get_target_methods(
|
@@ -111,7 +111,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
111
111
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
112
112
|
os.makedirs(model_blob_path, exist_ok=True)
|
113
113
|
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
114
|
-
torch.jit.save(model, f) # type:ignore[no-untyped-call
|
114
|
+
torch.jit.save(model, f) # type:ignore[no-untyped-call]
|
115
115
|
base_meta = model_blob_meta.ModelBlobMeta(
|
116
116
|
name=name,
|
117
117
|
model_type=cls.HANDLER_TYPE,
|
@@ -133,7 +133,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
133
133
|
model_meta: model_meta_api.ModelMetadata,
|
134
134
|
model_blobs_dir_path: str,
|
135
135
|
**kwargs: Unpack[model_types.TorchScriptLoadOptions],
|
136
|
-
) -> "torch.jit.ScriptModule":
|
136
|
+
) -> "torch.jit.ScriptModule":
|
137
137
|
import torch
|
138
138
|
|
139
139
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
@@ -141,10 +141,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
141
141
|
model_blob_metadata = model_blobs_metadata[name]
|
142
142
|
model_blob_filename = model_blob_metadata.path
|
143
143
|
with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
|
144
|
-
m = torch.jit.load( # type:ignore[no-untyped-call
|
144
|
+
m = torch.jit.load( # type:ignore[no-untyped-call]
|
145
145
|
f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu"
|
146
146
|
)
|
147
|
-
assert isinstance(m, torch.jit.ScriptModule)
|
147
|
+
assert isinstance(m, torch.jit.ScriptModule)
|
148
148
|
|
149
149
|
if kwargs.get("use_gpu", False):
|
150
150
|
m = m.cuda()
|
@@ -154,7 +154,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
154
154
|
@classmethod
|
155
155
|
def convert_as_custom_model(
|
156
156
|
cls,
|
157
|
-
raw_model: "torch.jit.ScriptModule",
|
157
|
+
raw_model: "torch.jit.ScriptModule",
|
158
158
|
model_meta: model_meta_api.ModelMetadata,
|
159
159
|
background_data: Optional[pd.DataFrame] = None,
|
160
160
|
**kwargs: Unpack[model_types.TorchScriptLoadOptions],
|
@@ -162,11 +162,11 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
162
162
|
from snowflake.ml.model import custom_model
|
163
163
|
|
164
164
|
def _create_custom_model(
|
165
|
-
raw_model: "torch.jit.ScriptModule",
|
165
|
+
raw_model: "torch.jit.ScriptModule",
|
166
166
|
model_meta: model_meta_api.ModelMetadata,
|
167
167
|
) -> Type[custom_model.CustomModel]:
|
168
168
|
def fn_factory(
|
169
|
-
raw_model: "torch.jit.ScriptModule",
|
169
|
+
raw_model: "torch.jit.ScriptModule",
|
170
170
|
signature: model_signature.ModelSignature,
|
171
171
|
target_method: str,
|
172
172
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|