snowflake-ml-python 1.7.0__py3-none-any.whl → 1.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_complete.py +107 -64
- snowflake/cortex/_finetune.py +273 -0
- snowflake/cortex/_sse_client.py +91 -28
- snowflake/cortex/_util.py +30 -1
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/jwt_generator.py +141 -0
- snowflake/ml/data/__init__.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +26 -12
- snowflake/ml/model/_client/ops/model_ops.py +51 -30
- snowflake/ml/model/_client/ops/service_ops.py +25 -9
- snowflake/ml/model/_client/sql/model.py +0 -14
- snowflake/ml/model/_client/sql/service.py +25 -1
- snowflake/ml/model/_client/sql/stage.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_packager/model_env/model_env.py +12 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +1 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +1 -1
- snowflake/ml/model/_packager/model_handlers/custom.py +3 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +50 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -1
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
- snowflake/ml/model/_packager/model_task/model_task_utils.py +1 -1
- snowflake/ml/model/_signatures/core.py +63 -16
- snowflake/ml/model/_signatures/pandas_handler.py +71 -27
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
- snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +4 -1
- snowflake/ml/model/model_signature.py +38 -9
- snowflake/ml/model/type_hints.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +148 -1200
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +114 -238
- snowflake/ml/monitoring/entities/model_monitor_config.py +38 -12
- snowflake/ml/monitoring/model_monitor.py +12 -86
- snowflake/ml/registry/registry.py +28 -40
- snowflake/ml/utils/authentication.py +75 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/METADATA +116 -52
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/RECORD +51 -49
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
- snowflake/ml/monitoring/entities/output_score_type.py +0 -90
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ import re
|
|
6
6
|
import tempfile
|
7
7
|
import threading
|
8
8
|
import time
|
9
|
-
from typing import Any, Dict, List, Optional, Tuple, cast
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
10
10
|
|
11
11
|
from packaging import version
|
12
12
|
|
@@ -15,7 +15,7 @@ from snowflake.ml._internal import file_utils
|
|
15
15
|
from snowflake.ml._internal.utils import service_logger, snowflake_env, sql_identifier
|
16
16
|
from snowflake.ml.model._client.service import model_deployment_spec
|
17
17
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
18
|
-
from snowflake.snowpark import exceptions, row, session
|
18
|
+
from snowflake.snowpark import async_job, exceptions, row, session
|
19
19
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
20
20
|
|
21
21
|
module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY)
|
@@ -107,8 +107,20 @@ class ServiceOperator:
|
|
107
107
|
max_batch_rows: Optional[int],
|
108
108
|
force_rebuild: bool,
|
109
109
|
build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
|
110
|
+
block: bool,
|
110
111
|
statement_params: Optional[Dict[str, Any]] = None,
|
111
|
-
) -> str:
|
112
|
+
) -> Union[str, async_job.AsyncJob]:
|
113
|
+
|
114
|
+
# Fall back to the registry's database and schema if not provided
|
115
|
+
database_name = database_name or self._database_name
|
116
|
+
schema_name = schema_name or self._schema_name
|
117
|
+
|
118
|
+
# Fall back to the model's database and schema if not provided then to the registry's database and schema
|
119
|
+
service_database_name = service_database_name or database_name or self._database_name
|
120
|
+
service_schema_name = service_schema_name or schema_name or self._schema_name
|
121
|
+
|
122
|
+
image_repo_database_name = image_repo_database_name or database_name or self._database_name
|
123
|
+
image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
|
112
124
|
# create a temp stage
|
113
125
|
stage_name = sql_identifier.SqlIdentifier(
|
114
126
|
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
|
@@ -130,8 +142,8 @@ class ServiceOperator:
|
|
130
142
|
raise ValueError("External access integrations are required in Snowflake < 8.40.0.")
|
131
143
|
|
132
144
|
self._model_deployment_spec.save(
|
133
|
-
database_name=database_name
|
134
|
-
schema_name=schema_name
|
145
|
+
database_name=database_name,
|
146
|
+
schema_name=schema_name,
|
135
147
|
model_name=model_name,
|
136
148
|
version_name=version_name,
|
137
149
|
service_database_name=service_database_name,
|
@@ -193,11 +205,15 @@ class ServiceOperator:
|
|
193
205
|
log_thread = self._start_service_log_streaming(
|
194
206
|
async_job, services, model_inference_service_exists, force_rebuild, statement_params
|
195
207
|
)
|
196
|
-
log_thread.join()
|
197
208
|
|
198
|
-
|
199
|
-
|
200
|
-
|
209
|
+
if block:
|
210
|
+
log_thread.join()
|
211
|
+
|
212
|
+
res = cast(str, cast(List[row.Row], async_job.result())[0][0])
|
213
|
+
module_logger.info(f"Inference service {service_name} deployment complete: {res}")
|
214
|
+
return res
|
215
|
+
else:
|
216
|
+
return async_job
|
201
217
|
|
202
218
|
def _start_service_log_streaming(
|
203
219
|
self,
|
@@ -17,8 +17,6 @@ class ModelSQLClient(_base._BaseSQLClient):
|
|
17
17
|
MODEL_VERSION_ALIASES_COL_NAME = "aliases"
|
18
18
|
MODEL_VERSION_INFERENCE_SERVICES_COL_NAME = "inference_services"
|
19
19
|
|
20
|
-
MODEL_INFERENCE_SERVICE_ENDPOINT_COL_NAME = "name"
|
21
|
-
|
22
20
|
def show_models(
|
23
21
|
self,
|
24
22
|
*,
|
@@ -85,18 +83,6 @@ class ModelSQLClient(_base._BaseSQLClient):
|
|
85
83
|
|
86
84
|
return res.validate()
|
87
85
|
|
88
|
-
def show_endpoints(
|
89
|
-
self,
|
90
|
-
*,
|
91
|
-
service_name: str,
|
92
|
-
) -> List[row.Row]:
|
93
|
-
res = query_result_checker.SqlResultValidator(
|
94
|
-
self._session,
|
95
|
-
(f"SHOW ENDPOINTS IN SERVICE {service_name}"),
|
96
|
-
).has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
|
97
|
-
|
98
|
-
return res.validate()
|
99
|
-
|
100
86
|
def set_comment(
|
101
87
|
self,
|
102
88
|
*,
|
@@ -10,7 +10,7 @@ from snowflake.ml._internal.utils import (
|
|
10
10
|
sql_identifier,
|
11
11
|
)
|
12
12
|
from snowflake.ml.model._client.sql import _base
|
13
|
-
from snowflake.snowpark import dataframe, functions as F, types as spt
|
13
|
+
from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
14
14
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
15
15
|
|
16
16
|
|
@@ -26,6 +26,9 @@ class ServiceStatus(enum.Enum):
|
|
26
26
|
|
27
27
|
|
28
28
|
class ServiceSQLClient(_base._BaseSQLClient):
|
29
|
+
MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
|
30
|
+
MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
|
31
|
+
|
29
32
|
def build_model_container(
|
30
33
|
self,
|
31
34
|
*,
|
@@ -216,3 +219,24 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
216
219
|
f"DROP SERVICE {self.fully_qualified_object_name(database_name, schema_name, service_name)}",
|
217
220
|
statement_params=statement_params,
|
218
221
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
222
|
+
|
223
|
+
def show_endpoints(
|
224
|
+
self,
|
225
|
+
*,
|
226
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
227
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
228
|
+
service_name: sql_identifier.SqlIdentifier,
|
229
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
230
|
+
) -> List[row.Row]:
|
231
|
+
fully_qualified_service_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
|
232
|
+
res = (
|
233
|
+
query_result_checker.SqlResultValidator(
|
234
|
+
self._session,
|
235
|
+
(f"SHOW ENDPOINTS IN SERVICE {fully_qualified_service_name}"),
|
236
|
+
statement_params=statement_params,
|
237
|
+
)
|
238
|
+
.has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME, allow_empty=True)
|
239
|
+
.has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME, allow_empty=True)
|
240
|
+
)
|
241
|
+
|
242
|
+
return res.validate()
|
@@ -15,6 +15,6 @@ class StageSQLClient(_base._BaseSQLClient):
|
|
15
15
|
) -> None:
|
16
16
|
query_result_checker.SqlResultValidator(
|
17
17
|
self._session,
|
18
|
-
f"CREATE TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
|
18
|
+
f"CREATE SCOPED TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
|
19
19
|
statement_params=statement_params,
|
20
20
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -5,6 +5,7 @@ import sys
|
|
5
5
|
|
6
6
|
import anyio
|
7
7
|
import pandas as pd
|
8
|
+
import numpy as np
|
8
9
|
from _snowflake import vectorized
|
9
10
|
|
10
11
|
from snowflake.ml.model._packager import model_packager
|
@@ -47,4 +48,4 @@ def {function_name}(df: pd.DataFrame) -> dict:
|
|
47
48
|
df.columns = input_cols
|
48
49
|
input_df = df.astype(dtype=dtype_map)
|
49
50
|
predictions_df = runner(input_df[input_cols])
|
50
|
-
return predictions_df.to_dict("records")
|
51
|
+
return predictions_df.replace({{pd.NA: None, np.nan: None}}).to_dict("records")
|
@@ -174,6 +174,18 @@ class ModelEnv:
|
|
174
174
|
except env_utils.DuplicateDependencyError:
|
175
175
|
pass
|
176
176
|
|
177
|
+
def remove_if_present_conda(self, conda_pkgs: List[str]) -> None:
|
178
|
+
"""Remove conda requirements from model env if present.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
conda_pkgs: A list of package name to be removed from conda requirements.
|
182
|
+
"""
|
183
|
+
for pkg_name in conda_pkgs:
|
184
|
+
spec_conda = env_utils._find_conda_dep_spec(self._conda_dependencies, pkg_name)
|
185
|
+
if spec_conda:
|
186
|
+
channel, spec = spec_conda
|
187
|
+
self._conda_dependencies[channel].remove(spec)
|
188
|
+
|
177
189
|
def generate_env_for_cuda(self) -> None:
|
178
190
|
if self.cuda_version is None:
|
179
191
|
return
|
@@ -179,7 +179,7 @@ def convert_explanations_to_2D_df(
|
|
179
179
|
return pd.DataFrame(explanations)
|
180
180
|
|
181
181
|
if hasattr(model, "classes_"):
|
182
|
-
classes_list = [str(cl) for cl in model.classes_]
|
182
|
+
classes_list = [str(cl) for cl in model.classes_]
|
183
183
|
len_classes = len(classes_list)
|
184
184
|
if explanations.shape[2] != len_classes:
|
185
185
|
raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
|
@@ -95,7 +95,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
95
95
|
get_prediction_fn=get_prediction,
|
96
96
|
)
|
97
97
|
model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
|
98
|
-
model_meta.task = model_task_and_output.task
|
98
|
+
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
|
99
99
|
if enable_explainability:
|
100
100
|
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
101
101
|
model_meta = handlers_utils.add_explain_method_signature(
|
@@ -2,7 +2,7 @@ import inspect
|
|
2
2
|
import os
|
3
3
|
import pathlib
|
4
4
|
import sys
|
5
|
-
from typing import Dict, Optional, Type, final
|
5
|
+
from typing import Dict, Optional, Type, cast, final
|
6
6
|
|
7
7
|
import anyio
|
8
8
|
import cloudpickle
|
@@ -108,6 +108,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
108
108
|
model_meta=model_meta,
|
109
109
|
model_blobs_dir_path=model_blobs_dir_path,
|
110
110
|
is_sub_model=True,
|
111
|
+
**cast(model_types.BaseModelSaveOption, kwargs),
|
111
112
|
)
|
112
113
|
|
113
114
|
# Make sure that the module where the model is defined get pickled by value as well.
|
@@ -175,6 +176,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
175
176
|
name=sub_model_name,
|
176
177
|
model_meta=model_meta,
|
177
178
|
model_blobs_dir_path=model_blobs_dir_path,
|
179
|
+
**cast(model_types.BaseModelLoadOption, kwargs),
|
178
180
|
)
|
179
181
|
models[sub_model_name] = sub_model
|
180
182
|
reconstructed_context = custom_model.ModelContext(artifacts=artifacts, models=models)
|
@@ -196,13 +196,14 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
196
196
|
with open(model_blob_file_path, "rb") as f:
|
197
197
|
model = cloudpickle.load(f)
|
198
198
|
assert isinstance(model, getattr(lightgbm, lightgbm_estimator_type))
|
199
|
+
assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
|
199
200
|
|
200
201
|
return model
|
201
202
|
|
202
203
|
@classmethod
|
203
204
|
def convert_as_custom_model(
|
204
205
|
cls,
|
205
|
-
raw_model: Union["lightgbm.Booster", "lightgbm.
|
206
|
+
raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
|
206
207
|
model_meta: model_meta_api.ModelMetadata,
|
207
208
|
background_data: Optional[pd.DataFrame] = None,
|
208
209
|
**kwargs: Unpack[model_types.LGBMModelLoadOptions],
|
@@ -19,12 +19,26 @@ from snowflake.ml.model._packager.model_meta import (
|
|
19
19
|
)
|
20
20
|
from snowflake.ml.model._packager.model_task import model_task_utils
|
21
21
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
22
|
+
from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR
|
22
23
|
|
23
24
|
if TYPE_CHECKING:
|
24
25
|
import sklearn.base
|
25
26
|
import sklearn.pipeline
|
26
27
|
|
27
28
|
|
29
|
+
def _unpack_container_runtime_pipeline(model: "sklearn.pipeline.Pipeline") -> "sklearn.pipeline.Pipeline":
|
30
|
+
new_steps = []
|
31
|
+
for step_name, step in model.steps:
|
32
|
+
new_reg = step
|
33
|
+
if hasattr(step, "_sklearn_estimator") and step._sklearn_estimator is not None:
|
34
|
+
# Unpack estimator to open source.
|
35
|
+
new_reg = step._sklearn_estimator
|
36
|
+
new_steps.append((step_name, new_reg))
|
37
|
+
|
38
|
+
model.steps = new_steps
|
39
|
+
return model
|
40
|
+
|
41
|
+
|
28
42
|
@final
|
29
43
|
class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]]):
|
30
44
|
"""Handler for scikit-learn based model.
|
@@ -101,6 +115,10 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
101
115
|
if sample_input_data is None:
|
102
116
|
raise ValueError("Sample input data is required to enable explainability.")
|
103
117
|
|
118
|
+
# If this is a pipeline and we are in the container runtime, check for distributed estimator.
|
119
|
+
if os.getenv(IN_ML_RUNTIME_ENV_VAR) and isinstance(model, sklearn.pipeline.Pipeline):
|
120
|
+
model = _unpack_container_runtime_pipeline(model)
|
121
|
+
|
104
122
|
if not is_sub_model:
|
105
123
|
target_methods = handlers_utils.get_target_methods(
|
106
124
|
model=model,
|
@@ -135,7 +153,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
135
153
|
)
|
136
154
|
|
137
155
|
model_task_and_output_type = model_task_utils.get_model_task_and_output_type(model)
|
138
|
-
model_meta.task = model_task_and_output_type.task
|
156
|
+
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task)
|
139
157
|
|
140
158
|
# if users did not ask then we enable if we have background data
|
141
159
|
if enable_explainability is None:
|
@@ -146,6 +164,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
146
164
|
stacklevel=1,
|
147
165
|
)
|
148
166
|
enable_explainability = False
|
167
|
+
elif model_meta.task == model_types.Task.UNKNOWN:
|
168
|
+
enable_explainability = False
|
149
169
|
else:
|
150
170
|
enable_explainability = True
|
151
171
|
if enable_explainability:
|
@@ -177,6 +197,35 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
177
197
|
model_meta.models[name] = base_meta
|
178
198
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
179
199
|
|
200
|
+
# if model instance is a pipeline, check the pipeline steps
|
201
|
+
if isinstance(model, sklearn.pipeline.Pipeline):
|
202
|
+
for _, pipeline_step in model.steps:
|
203
|
+
if type_utils.LazyType("lightgbm.LGBMModel").isinstance(pipeline_step) or type_utils.LazyType(
|
204
|
+
"lightgbm.Booster"
|
205
|
+
).isinstance(pipeline_step):
|
206
|
+
model_meta.env.include_if_absent(
|
207
|
+
[
|
208
|
+
model_env.ModelDependency(requirement="lightgbm", pip_name="lightgbm"),
|
209
|
+
],
|
210
|
+
check_local_version=True,
|
211
|
+
)
|
212
|
+
elif type_utils.LazyType("xgboost.XGBModel").isinstance(pipeline_step) or type_utils.LazyType(
|
213
|
+
"xgboost.Booster"
|
214
|
+
).isinstance(pipeline_step):
|
215
|
+
model_meta.env.include_if_absent(
|
216
|
+
[
|
217
|
+
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
218
|
+
],
|
219
|
+
check_local_version=True,
|
220
|
+
)
|
221
|
+
elif type_utils.LazyType("catboost.CatBoost").isinstance(pipeline_step):
|
222
|
+
model_meta.env.include_if_absent(
|
223
|
+
[
|
224
|
+
model_env.ModelDependency(requirement="catboost", pip_name="catboost"),
|
225
|
+
],
|
226
|
+
check_local_version=True,
|
227
|
+
)
|
228
|
+
|
180
229
|
if enable_explainability:
|
181
230
|
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
182
231
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
@@ -138,7 +138,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
138
138
|
enable_explainability = False
|
139
139
|
else:
|
140
140
|
model_task_and_output_type = model_task_utils.get_model_task_and_output_type(python_base_obj)
|
141
|
-
model_meta.task = model_task_and_output_type.task
|
141
|
+
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task)
|
142
142
|
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
143
143
|
model_meta = handlers_utils.add_explain_method_signature(
|
144
144
|
model_meta=model_meta,
|
@@ -13,6 +13,7 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
13
13
|
from snowflake.ml.model._packager.model_meta import (
|
14
14
|
model_blob_meta,
|
15
15
|
model_meta as model_meta_api,
|
16
|
+
model_meta_schema,
|
16
17
|
)
|
17
18
|
from snowflake.ml.model._signatures import (
|
18
19
|
numpy_handler,
|
@@ -76,7 +77,11 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
76
77
|
|
77
78
|
assert isinstance(model, tensorflow.Module)
|
78
79
|
|
79
|
-
|
80
|
+
is_keras_model = type_utils.LazyType("tensorflow.keras.Model").isinstance(model) or type_utils.LazyType(
|
81
|
+
"tf_keras.Model"
|
82
|
+
).isinstance(model)
|
83
|
+
|
84
|
+
if is_keras_model:
|
80
85
|
default_target_methods = ["predict"]
|
81
86
|
else:
|
82
87
|
default_target_methods = cls.DEFAULT_TARGET_METHODS
|
@@ -117,8 +122,14 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
117
122
|
|
118
123
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
119
124
|
os.makedirs(model_blob_path, exist_ok=True)
|
120
|
-
if
|
125
|
+
if is_keras_model:
|
121
126
|
tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
127
|
+
model_meta.env.include_if_absent(
|
128
|
+
[
|
129
|
+
model_env.ModelDependency(requirement="keras<3", pip_name="keras"),
|
130
|
+
],
|
131
|
+
check_local_version=False,
|
132
|
+
)
|
122
133
|
else:
|
123
134
|
tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
124
135
|
|
@@ -127,12 +138,16 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
127
138
|
model_type=cls.HANDLER_TYPE,
|
128
139
|
handler_version=cls.HANDLER_VERSION,
|
129
140
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
141
|
+
options=model_meta_schema.TensorflowModelBlobOptions(is_keras_model=is_keras_model),
|
130
142
|
)
|
131
143
|
model_meta.models[name] = base_meta
|
132
144
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
133
145
|
|
134
146
|
model_meta.env.include_if_absent(
|
135
|
-
[
|
147
|
+
[
|
148
|
+
model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"),
|
149
|
+
],
|
150
|
+
check_local_version=True,
|
136
151
|
)
|
137
152
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
138
153
|
|
@@ -150,9 +165,11 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
150
165
|
model_blobs_metadata = model_meta.models
|
151
166
|
model_blob_metadata = model_blobs_metadata[name]
|
152
167
|
model_blob_filename = model_blob_metadata.path
|
153
|
-
|
154
|
-
if
|
155
|
-
|
168
|
+
model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
|
169
|
+
if model_blob_options.get("is_keras_model", False):
|
170
|
+
m = tensorflow.keras.models.load_model(os.path.join(model_blob_path, model_blob_filename), compile=False)
|
171
|
+
else:
|
172
|
+
m = tensorflow.saved_model.load(os.path.join(model_blob_path, model_blob_filename))
|
156
173
|
return cast(tensorflow.Module, m)
|
157
174
|
|
158
175
|
@classmethod
|
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|
23
23
|
|
24
24
|
|
25
25
|
@final
|
26
|
-
class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
26
|
+
class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
27
27
|
"""Handler for PyTorch JIT based model.
|
28
28
|
|
29
29
|
Currently torch.jit.ScriptModule based classes are supported.
|
@@ -41,25 +41,25 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
41
41
|
def can_handle(
|
42
42
|
cls,
|
43
43
|
model: model_types.SupportedModelType,
|
44
|
-
) -> TypeGuard["torch.jit.ScriptModule"]:
|
44
|
+
) -> TypeGuard["torch.jit.ScriptModule"]:
|
45
45
|
return type_utils.LazyType("torch.jit.ScriptModule").isinstance(model)
|
46
46
|
|
47
47
|
@classmethod
|
48
48
|
def cast_model(
|
49
49
|
cls,
|
50
50
|
model: model_types.SupportedModelType,
|
51
|
-
) -> "torch.jit.ScriptModule":
|
51
|
+
) -> "torch.jit.ScriptModule":
|
52
52
|
import torch
|
53
53
|
|
54
|
-
assert isinstance(model, torch.jit.ScriptModule)
|
54
|
+
assert isinstance(model, torch.jit.ScriptModule)
|
55
55
|
|
56
|
-
return cast(torch.jit.ScriptModule, model)
|
56
|
+
return cast(torch.jit.ScriptModule, model)
|
57
57
|
|
58
58
|
@classmethod
|
59
59
|
def save_model(
|
60
60
|
cls,
|
61
61
|
name: str,
|
62
|
-
model: "torch.jit.ScriptModule",
|
62
|
+
model: "torch.jit.ScriptModule",
|
63
63
|
model_meta: model_meta_api.ModelMetadata,
|
64
64
|
model_blobs_dir_path: str,
|
65
65
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
@@ -72,7 +72,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
72
72
|
|
73
73
|
import torch
|
74
74
|
|
75
|
-
assert isinstance(model, torch.jit.ScriptModule)
|
75
|
+
assert isinstance(model, torch.jit.ScriptModule)
|
76
76
|
|
77
77
|
if not is_sub_model:
|
78
78
|
target_methods = handlers_utils.get_target_methods(
|
@@ -111,7 +111,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
111
111
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
112
112
|
os.makedirs(model_blob_path, exist_ok=True)
|
113
113
|
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
114
|
-
torch.jit.save(model, f) # type:ignore[no-untyped-call
|
114
|
+
torch.jit.save(model, f) # type:ignore[no-untyped-call]
|
115
115
|
base_meta = model_blob_meta.ModelBlobMeta(
|
116
116
|
name=name,
|
117
117
|
model_type=cls.HANDLER_TYPE,
|
@@ -133,7 +133,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
133
133
|
model_meta: model_meta_api.ModelMetadata,
|
134
134
|
model_blobs_dir_path: str,
|
135
135
|
**kwargs: Unpack[model_types.TorchScriptLoadOptions],
|
136
|
-
) -> "torch.jit.ScriptModule":
|
136
|
+
) -> "torch.jit.ScriptModule":
|
137
137
|
import torch
|
138
138
|
|
139
139
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
@@ -141,10 +141,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
141
141
|
model_blob_metadata = model_blobs_metadata[name]
|
142
142
|
model_blob_filename = model_blob_metadata.path
|
143
143
|
with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
|
144
|
-
m = torch.jit.load( # type:ignore[no-untyped-call
|
144
|
+
m = torch.jit.load( # type:ignore[no-untyped-call]
|
145
145
|
f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu"
|
146
146
|
)
|
147
|
-
assert isinstance(m, torch.jit.ScriptModule)
|
147
|
+
assert isinstance(m, torch.jit.ScriptModule)
|
148
148
|
|
149
149
|
if kwargs.get("use_gpu", False):
|
150
150
|
m = m.cuda()
|
@@ -154,7 +154,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
154
154
|
@classmethod
|
155
155
|
def convert_as_custom_model(
|
156
156
|
cls,
|
157
|
-
raw_model: "torch.jit.ScriptModule",
|
157
|
+
raw_model: "torch.jit.ScriptModule",
|
158
158
|
model_meta: model_meta_api.ModelMetadata,
|
159
159
|
background_data: Optional[pd.DataFrame] = None,
|
160
160
|
**kwargs: Unpack[model_types.TorchScriptLoadOptions],
|
@@ -162,11 +162,11 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
162
162
|
from snowflake.ml.model import custom_model
|
163
163
|
|
164
164
|
def _create_custom_model(
|
165
|
-
raw_model: "torch.jit.ScriptModule",
|
165
|
+
raw_model: "torch.jit.ScriptModule",
|
166
166
|
model_meta: model_meta_api.ModelMetadata,
|
167
167
|
) -> Type[custom_model.CustomModel]:
|
168
168
|
def fn_factory(
|
169
|
-
raw_model: "torch.jit.ScriptModule",
|
169
|
+
raw_model: "torch.jit.ScriptModule",
|
170
170
|
signature: model_signature.ModelSignature,
|
171
171
|
target_method: str,
|
172
172
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
@@ -1,3 +1,2 @@
|
|
1
|
-
REQUIREMENTS = [
|
2
|
-
|
3
|
-
]
|
1
|
+
REQUIREMENTS = ['cloudpickle>=2.0.0']
|
2
|
+
ALL_REQUIREMENTS=['cloudpickle>=2.0.0']
|
@@ -58,11 +58,16 @@ class XgboostModelBlobOptions(BaseModelBlobOptions):
|
|
58
58
|
xgb_estimator_type: Required[str]
|
59
59
|
|
60
60
|
|
61
|
+
class TensorflowModelBlobOptions(BaseModelBlobOptions):
|
62
|
+
is_keras_model: Required[bool]
|
63
|
+
|
64
|
+
|
61
65
|
ModelBlobOptions = Union[
|
62
66
|
BaseModelBlobOptions,
|
63
67
|
HuggingFacePipelineModelBlobOptions,
|
64
68
|
MLFlowModelBlobOptions,
|
65
69
|
XgboostModelBlobOptions,
|
70
|
+
TensorflowModelBlobOptions,
|
66
71
|
]
|
67
72
|
|
68
73
|
|
@@ -1,10 +1,2 @@
|
|
1
|
-
REQUIREMENTS = [
|
2
|
-
|
3
|
-
"anyio>=3.5.0,<4",
|
4
|
-
"numpy>=1.23,<2",
|
5
|
-
"packaging>=20.9,<24",
|
6
|
-
"pandas>=1.0.0,<3",
|
7
|
-
"pyyaml>=6.0,<7",
|
8
|
-
"snowflake-snowpark-python>=1.17.0,<2",
|
9
|
-
"typing-extensions>=4.1.0,<5"
|
10
|
-
]
|
1
|
+
REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
2
|
+
ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.1.0,<2.4', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'tensorflow>=2.10,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
@@ -17,6 +17,8 @@ _SNOWML_INFERENCE_ALTERNATIVE_DEPENDENCIES = [
|
|
17
17
|
for r in _snowml_inference_alternative_requirements.REQUIREMENTS
|
18
18
|
]
|
19
19
|
|
20
|
+
PACKAGES_NOT_ALLOWED_IN_WAREHOUSE = ["snowflake-connector-python", "pyarrow"]
|
21
|
+
|
20
22
|
|
21
23
|
class ModelRuntime:
|
22
24
|
"""Class to represent runtime in a model, which controls the runtime and version, imports and dependencies.
|
@@ -61,15 +63,8 @@ class ModelRuntime:
|
|
61
63
|
],
|
62
64
|
)
|
63
65
|
|
64
|
-
if
|
65
|
-
self.runtime_env.
|
66
|
-
[
|
67
|
-
model_env.ModelDependency(
|
68
|
-
requirement="pyarrow",
|
69
|
-
pip_name="pyarrow",
|
70
|
-
)
|
71
|
-
],
|
72
|
-
)
|
66
|
+
if is_warehouse and self.embed_local_ml_library:
|
67
|
+
self.runtime_env.remove_if_present_conda(PACKAGES_NOT_ALLOWED_IN_WAREHOUSE)
|
73
68
|
|
74
69
|
if is_gpu:
|
75
70
|
self.runtime_env.generate_env_for_cuda()
|
@@ -84,7 +84,7 @@ def get_model_task_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel
|
|
84
84
|
if type_utils.LazyType("lightgbm.Booster").isinstance(model):
|
85
85
|
model_task = model.params["objective"] # type: ignore[attr-defined]
|
86
86
|
elif hasattr(model, "objective_"):
|
87
|
-
model_task = model.objective_
|
87
|
+
model_task = model.objective_ # type: ignore[assignment]
|
88
88
|
if model_task in _BINARY_CLASSIFICATION_OBJECTIVES:
|
89
89
|
return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
|
90
90
|
if model_task in _MULTI_CLASSIFICATION_OBJECTIVES:
|