snowflake-ml-python 1.8.5__py3-none-any.whl → 1.8.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +6 -9
- snowflake/ml/_internal/utils/connection_params.py +196 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +3 -2
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/payload_utils.py +83 -35
- snowflake/ml/jobs/_utils/scripts/constants.py +19 -3
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -26
- snowflake/ml/jobs/_utils/spec_utils.py +23 -1
- snowflake/ml/jobs/_utils/stage_utils.py +119 -0
- snowflake/ml/jobs/_utils/types.py +5 -1
- snowflake/ml/jobs/decorators.py +6 -7
- snowflake/ml/jobs/job.py +24 -9
- snowflake/ml/jobs/manager.py +102 -19
- snowflake/ml/model/_client/model/model_impl.py +58 -0
- snowflake/ml/model/_client/model/model_version_impl.py +90 -0
- snowflake/ml/model/_client/ops/model_ops.py +6 -3
- snowflake/ml/model/_client/ops/service_ops.py +19 -4
- snowflake/ml/model/_client/sql/service.py +68 -20
- snowflake/ml/model/_client/sql/stage.py +5 -2
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
- snowflake/ml/model/_signatures/core.py +24 -0
- snowflake/ml/monitoring/explain_visualize.py +2 -2
- snowflake/ml/monitoring/model_monitor.py +0 -4
- snowflake/ml/registry/registry.py +34 -14
- snowflake/ml/utils/connection_params.py +1 -1
- snowflake/ml/utils/html_utils.py +263 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/METADATA +14 -5
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/RECORD +33 -30
- snowflake/ml/monitoring/model_monitor_version.py +0 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/top_level.txt +0 -0
@@ -325,13 +325,14 @@ class ServiceOperator:
|
|
325
325
|
)
|
326
326
|
continue
|
327
327
|
|
328
|
-
|
328
|
+
statuses = self._service_client.get_service_container_statuses(
|
329
329
|
database_name=service_log_meta.service.database_name,
|
330
330
|
schema_name=service_log_meta.service.schema_name,
|
331
331
|
service_name=service_log_meta.service.service_name,
|
332
332
|
include_message=True,
|
333
333
|
statement_params=statement_params,
|
334
334
|
)
|
335
|
+
service_status = statuses[0].service_status
|
335
336
|
if (service_status != service_sql.ServiceStatus.RUNNING) or (
|
336
337
|
service_status != service_log_meta.service_status
|
337
338
|
):
|
@@ -341,7 +342,19 @@ class ServiceOperator:
|
|
341
342
|
f"{service_log_meta.service.display_service_name} is "
|
342
343
|
f"{service_log_meta.service_status.value}."
|
343
344
|
)
|
344
|
-
|
345
|
+
for status in statuses:
|
346
|
+
if status.instance_id is not None:
|
347
|
+
instance_status, container_status = None, None
|
348
|
+
if status.instance_status is not None:
|
349
|
+
instance_status = status.instance_status.value
|
350
|
+
if status.container_status is not None:
|
351
|
+
container_status = status.container_status.value
|
352
|
+
module_logger.info(
|
353
|
+
f"Instance[{status.instance_id}]: "
|
354
|
+
f"instance status: {instance_status}, "
|
355
|
+
f"container status: {container_status}, "
|
356
|
+
f"message: {status.message}"
|
357
|
+
)
|
345
358
|
|
346
359
|
new_logs, new_offset = fetch_logs(
|
347
360
|
service_log_meta.service,
|
@@ -353,13 +366,14 @@ class ServiceOperator:
|
|
353
366
|
|
354
367
|
# check if model build service is done
|
355
368
|
if not service_log_meta.is_model_build_service_done:
|
356
|
-
|
369
|
+
statuses = self._service_client.get_service_container_statuses(
|
357
370
|
database_name=model_build_service.database_name,
|
358
371
|
schema_name=model_build_service.schema_name,
|
359
372
|
service_name=model_build_service.service_name,
|
360
373
|
include_message=False,
|
361
374
|
statement_params=statement_params,
|
362
375
|
)
|
376
|
+
service_status = statuses[0].service_status
|
363
377
|
|
364
378
|
if service_status == service_sql.ServiceStatus.DONE:
|
365
379
|
set_service_log_metadata_to_model_inference(
|
@@ -436,13 +450,14 @@ class ServiceOperator:
|
|
436
450
|
service_sql.ServiceStatus.FAILED,
|
437
451
|
]
|
438
452
|
try:
|
439
|
-
|
453
|
+
statuses = self._service_client.get_service_container_statuses(
|
440
454
|
database_name=database_name,
|
441
455
|
schema_name=schema_name,
|
442
456
|
service_name=service_name,
|
443
457
|
include_message=False,
|
444
458
|
statement_params=statement_params,
|
445
459
|
)
|
460
|
+
service_status = statuses[0].service_status
|
446
461
|
return any(service_status == status for status in service_status_list_if_exists)
|
447
462
|
except exceptions.SnowparkSQLException:
|
448
463
|
return False
|
@@ -1,4 +1,6 @@
|
|
1
|
+
import dataclasses
|
1
2
|
import enum
|
3
|
+
import logging
|
2
4
|
import textwrap
|
3
5
|
from typing import Any, Optional, Union
|
4
6
|
|
@@ -13,26 +15,59 @@ from snowflake.ml.model._model_composer.model_method import constants
|
|
13
15
|
from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
14
16
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
15
17
|
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
16
20
|
|
17
|
-
# The enum comes from https://docs.snowflake.com/en/sql-reference/sql/show-service-containers-in-service#output
|
18
|
-
# except UNKNOWN
|
19
21
|
class ServiceStatus(enum.Enum):
|
20
|
-
|
21
|
-
PENDING = "PENDING" # resource set is being created, can't be used yet
|
22
|
-
SUSPENDING = "SUSPENDING" # the service is set to suspended but the resource set is still in deleting state
|
23
|
-
SUSPENDED = "SUSPENDED" # the service is suspended and the resource set is deleted
|
24
|
-
DELETING = "DELETING" # resource set is being deleted
|
25
|
-
FAILED = "FAILED" # resource set has failed and cannot be used anymore
|
26
|
-
DONE = "DONE" # resource set has finished running
|
27
|
-
INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error.
|
22
|
+
PENDING = "PENDING"
|
28
23
|
RUNNING = "RUNNING"
|
24
|
+
FAILED = "FAILED"
|
25
|
+
DONE = "DONE"
|
26
|
+
SUSPENDING = "SUSPENDING"
|
27
|
+
SUSPENDED = "SUSPENDED"
|
28
|
+
DELETING = "DELETING"
|
29
29
|
DELETED = "DELETED"
|
30
|
+
INTERNAL_ERROR = "INTERNAL_ERROR"
|
31
|
+
|
32
|
+
|
33
|
+
class InstanceStatus(enum.Enum):
|
34
|
+
PENDING = "PENDING"
|
35
|
+
READY = "READY"
|
36
|
+
FAILED = "FAILED"
|
37
|
+
TERMINATING = "TERMINATING"
|
38
|
+
SUCCEEDED = "SUCCEEDED"
|
39
|
+
|
40
|
+
|
41
|
+
class ContainerStatus(enum.Enum):
|
42
|
+
PENDING = "PENDING"
|
43
|
+
READY = "READY"
|
44
|
+
DONE = "DONE"
|
45
|
+
FAILED = "FAILED"
|
46
|
+
UNKNOWN = "UNKNOWN"
|
47
|
+
|
48
|
+
|
49
|
+
@dataclasses.dataclass
|
50
|
+
class ServiceStatusInfo:
|
51
|
+
"""
|
52
|
+
Class containing information about service container status.
|
53
|
+
Reference: https://docs.snowflake.com/en/sql-reference/sql/show-service-containers-in-service
|
54
|
+
"""
|
55
|
+
|
56
|
+
service_status: ServiceStatus
|
57
|
+
instance_id: Optional[int] = None
|
58
|
+
instance_status: Optional[InstanceStatus] = None
|
59
|
+
container_status: Optional[ContainerStatus] = None
|
60
|
+
message: Optional[str] = None
|
30
61
|
|
31
62
|
|
32
63
|
class ServiceSQLClient(_base._BaseSQLClient):
|
33
64
|
MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
|
34
65
|
MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
|
35
66
|
SERVICE_STATUS = "service_status"
|
67
|
+
INSTANCE_ID = "instance_id"
|
68
|
+
INSTANCE_STATUS = "instance_status"
|
69
|
+
CONTAINER_STATUS = "status"
|
70
|
+
MESSAGE = "message"
|
36
71
|
|
37
72
|
def build_model_container(
|
38
73
|
self,
|
@@ -81,6 +116,10 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
81
116
|
) -> tuple[str, snowpark.AsyncJob]:
|
82
117
|
assert model_deployment_spec_yaml_str or model_deployment_spec_file_rel_path
|
83
118
|
if model_deployment_spec_yaml_str:
|
119
|
+
model_deployment_spec_yaml_str = snowpark_utils.escape_single_quotes(
|
120
|
+
model_deployment_spec_yaml_str
|
121
|
+
) # type: ignore[no-untyped-call]
|
122
|
+
logger.info(f"Deploying model with spec={model_deployment_spec_yaml_str}")
|
84
123
|
sql_str = f"CALL SYSTEM$DEPLOY_MODEL('{model_deployment_spec_yaml_str}')"
|
85
124
|
else:
|
86
125
|
sql_str = f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
|
@@ -192,7 +231,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
192
231
|
)
|
193
232
|
return str(rows[0][system_func])
|
194
233
|
|
195
|
-
def
|
234
|
+
def get_service_container_statuses(
|
196
235
|
self,
|
197
236
|
*,
|
198
237
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
@@ -200,18 +239,27 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
200
239
|
service_name: sql_identifier.SqlIdentifier,
|
201
240
|
include_message: bool = False,
|
202
241
|
statement_params: Optional[dict[str, Any]] = None,
|
203
|
-
) ->
|
242
|
+
) -> list[ServiceStatusInfo]:
|
204
243
|
fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
|
205
244
|
query = f"SHOW SERVICE CONTAINERS IN SERVICE {fully_qualified_object_name}"
|
206
245
|
rows = self._session.sql(query).collect(statement_params=statement_params)
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
246
|
+
statuses = []
|
247
|
+
for r in rows:
|
248
|
+
instance_status, container_status = None, None
|
249
|
+
if r[ServiceSQLClient.INSTANCE_STATUS] is not None:
|
250
|
+
instance_status = InstanceStatus(r[ServiceSQLClient.INSTANCE_STATUS])
|
251
|
+
if r[ServiceSQLClient.CONTAINER_STATUS] is not None:
|
252
|
+
container_status = ContainerStatus(r[ServiceSQLClient.CONTAINER_STATUS])
|
253
|
+
statuses.append(
|
254
|
+
ServiceStatusInfo(
|
255
|
+
service_status=ServiceStatus(r[ServiceSQLClient.SERVICE_STATUS]),
|
256
|
+
instance_id=r[ServiceSQLClient.INSTANCE_ID],
|
257
|
+
instance_status=instance_status,
|
258
|
+
container_status=container_status,
|
259
|
+
message=r[ServiceSQLClient.MESSAGE] if include_message else None,
|
260
|
+
)
|
261
|
+
)
|
262
|
+
return statuses
|
215
263
|
|
216
264
|
def drop_service(
|
217
265
|
self,
|
@@ -12,9 +12,12 @@ class StageSQLClient(_base._BaseSQLClient):
|
|
12
12
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
13
13
|
stage_name: sql_identifier.SqlIdentifier,
|
14
14
|
statement_params: Optional[dict[str, Any]] = None,
|
15
|
-
) ->
|
15
|
+
) -> str:
|
16
|
+
fq_stage_name = self.fully_qualified_object_name(database_name, schema_name, stage_name)
|
16
17
|
query_result_checker.SqlResultValidator(
|
17
18
|
self._session,
|
18
|
-
f"CREATE SCOPED TEMPORARY STAGE {
|
19
|
+
f"CREATE SCOPED TEMPORARY STAGE {fq_stage_name}",
|
19
20
|
statement_params=statement_params,
|
20
21
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
22
|
+
|
23
|
+
return fq_stage_name
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
7
7
|
import pandas as pd
|
8
|
+
import shap
|
8
9
|
from typing_extensions import TypeGuard, Unpack
|
9
10
|
|
10
11
|
from snowflake.ml._internal import type_utils
|
@@ -25,6 +26,19 @@ if TYPE_CHECKING:
|
|
25
26
|
from snowflake.ml.modeling.framework.base import BaseEstimator
|
26
27
|
|
27
28
|
|
29
|
+
def _apply_transforms_up_to_last_step(
|
30
|
+
model: "BaseEstimator",
|
31
|
+
data: model_types.SupportedDataType,
|
32
|
+
) -> pd.DataFrame:
|
33
|
+
"""Apply all transformations in the snowml pipeline model up to the last step."""
|
34
|
+
if type_utils.LazyType("snowflake.ml.modeling.pipeline.Pipeline").isinstance(model):
|
35
|
+
for step_name, step in model.steps[:-1]: # type: ignore[attr-defined]
|
36
|
+
if not hasattr(step, "transform"):
|
37
|
+
raise ValueError(f"Step '{step_name}' does not have a 'transform' method.")
|
38
|
+
data = pd.DataFrame(step.transform(data))
|
39
|
+
return data
|
40
|
+
|
41
|
+
|
28
42
|
@final
|
29
43
|
class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
30
44
|
"""Handler for SnowML based model.
|
@@ -39,7 +53,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
39
53
|
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
40
54
|
|
41
55
|
DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
|
42
|
-
EXPLAIN_TARGET_METHODS = ["
|
56
|
+
EXPLAIN_TARGET_METHODS = ["predict_proba", "predict", "predict_log_proba"]
|
43
57
|
|
44
58
|
IS_AUTO_SIGNATURE = True
|
45
59
|
|
@@ -97,11 +111,6 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
97
111
|
return result
|
98
112
|
except exceptions.SnowflakeMLException:
|
99
113
|
pass # Do nothing and continue to the next method
|
100
|
-
|
101
|
-
if enable_explainability:
|
102
|
-
raise ValueError(
|
103
|
-
"Explain only supported for xgboost, lightgbm and sklearn (not pipeline) Snowpark ML models."
|
104
|
-
)
|
105
114
|
return None
|
106
115
|
|
107
116
|
@classmethod
|
@@ -189,23 +198,46 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
189
198
|
else:
|
190
199
|
enable_explainability = True
|
191
200
|
if enable_explainability:
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
target_method=explain_target_method,
|
200
|
-
output_return_type=model_task_and_output_type.output_type,
|
201
|
-
)
|
202
|
-
background_data = handlers_utils.get_explainability_supported_background(
|
203
|
-
sample_input_data, model_meta, explain_target_method
|
204
|
-
)
|
205
|
-
if background_data is not None:
|
206
|
-
handlers_utils.save_background_data(
|
207
|
-
model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
|
201
|
+
try:
|
202
|
+
model_task_and_output_type = model_task_utils.resolve_model_task_and_output_type(
|
203
|
+
python_base_obj, model_meta.task
|
204
|
+
)
|
205
|
+
model_meta.task = model_task_and_output_type.task
|
206
|
+
background_data = handlers_utils.get_explainability_supported_background(
|
207
|
+
sample_input_data, model_meta, explain_target_method
|
208
208
|
)
|
209
|
+
if type_utils.LazyType("snowflake.ml.modeling.pipeline.Pipeline").isinstance(model):
|
210
|
+
transformed_df = _apply_transforms_up_to_last_step(model, sample_input_data)
|
211
|
+
explain_fn = cls._build_explain_fn(model, background_data)
|
212
|
+
model_meta = handlers_utils.add_inferred_explain_method_signature(
|
213
|
+
model_meta=model_meta,
|
214
|
+
explain_method="explain",
|
215
|
+
target_method=explain_target_method, # type: ignore[arg-type]
|
216
|
+
background_data=background_data,
|
217
|
+
explain_fn=explain_fn,
|
218
|
+
output_feature_names=transformed_df.columns,
|
219
|
+
)
|
220
|
+
else:
|
221
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
222
|
+
model_meta=model_meta,
|
223
|
+
explain_method="explain",
|
224
|
+
target_method=explain_target_method,
|
225
|
+
output_return_type=model_task_and_output_type.output_type,
|
226
|
+
)
|
227
|
+
if background_data is not None:
|
228
|
+
handlers_utils.save_background_data(
|
229
|
+
model_blobs_dir_path,
|
230
|
+
cls.EXPLAIN_ARTIFACTS_DIR,
|
231
|
+
cls.BG_DATA_FILE_SUFFIX,
|
232
|
+
name,
|
233
|
+
background_data,
|
234
|
+
)
|
235
|
+
except Exception:
|
236
|
+
if kwargs.get("enable_explainability", None):
|
237
|
+
# user explicitly enabled explainability, so we should raise the error
|
238
|
+
raise ValueError(
|
239
|
+
"Explainability for this model is not supported. Please set `enable_explainability=False`"
|
240
|
+
)
|
209
241
|
|
210
242
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
211
243
|
os.makedirs(model_blob_path, exist_ok=True)
|
@@ -251,6 +283,53 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
251
283
|
assert isinstance(m, BaseEstimator)
|
252
284
|
return m
|
253
285
|
|
286
|
+
@classmethod
|
287
|
+
def _build_explain_fn(
|
288
|
+
cls, model: "BaseEstimator", background_data: model_types.SupportedDataType
|
289
|
+
) -> Callable[[model_types.SupportedDataType], pd.DataFrame]:
|
290
|
+
|
291
|
+
predictor = model
|
292
|
+
is_pipeline = type_utils.LazyType("snowflake.ml.modeling.pipeline.Pipeline").isinstance(model)
|
293
|
+
if is_pipeline:
|
294
|
+
background_data = _apply_transforms_up_to_last_step(model, background_data)
|
295
|
+
predictor = model.steps[-1][1] # type: ignore[attr-defined]
|
296
|
+
|
297
|
+
def explain_fn(data: model_types.SupportedDataType) -> pd.DataFrame:
|
298
|
+
data = _apply_transforms_up_to_last_step(model, data)
|
299
|
+
tree_methods = ["to_xgboost", "to_lightgbm"]
|
300
|
+
non_tree_methods = ["to_sklearn", None] # None just uses the predictor directly
|
301
|
+
for method_name in tree_methods:
|
302
|
+
try:
|
303
|
+
base_model = getattr(predictor, method_name)()
|
304
|
+
explainer = shap.TreeExplainer(base_model)
|
305
|
+
return handlers_utils.convert_explanations_to_2D_df(model, explainer.shap_values(data))
|
306
|
+
except exceptions.SnowflakeMLException:
|
307
|
+
pass # Do nothing and continue to the next method
|
308
|
+
for method_name in non_tree_methods: # type: ignore[assignment]
|
309
|
+
try:
|
310
|
+
base_model = getattr(predictor, method_name)() if method_name is not None else predictor
|
311
|
+
try:
|
312
|
+
explainer = shap.Explainer(base_model, masker=background_data)
|
313
|
+
return handlers_utils.convert_explanations_to_2D_df(base_model, explainer(data).values)
|
314
|
+
except TypeError:
|
315
|
+
for explain_target_method in cls.EXPLAIN_TARGET_METHODS:
|
316
|
+
if not hasattr(base_model, explain_target_method):
|
317
|
+
continue
|
318
|
+
explain_target_method_fn = getattr(base_model, explain_target_method)
|
319
|
+
if isinstance(data, np.ndarray):
|
320
|
+
explainer = shap.Explainer(
|
321
|
+
explain_target_method_fn,
|
322
|
+
background_data.values, # type: ignore[union-attr]
|
323
|
+
)
|
324
|
+
else:
|
325
|
+
explainer = shap.Explainer(explain_target_method_fn, background_data)
|
326
|
+
return handlers_utils.convert_explanations_to_2D_df(base_model, explainer(data).values)
|
327
|
+
except Exception:
|
328
|
+
pass # Do nothing and continue to the next method
|
329
|
+
raise ValueError("Explainability for this model is not supported.")
|
330
|
+
|
331
|
+
return explain_fn
|
332
|
+
|
254
333
|
@classmethod
|
255
334
|
def convert_as_custom_model(
|
256
335
|
cls,
|
@@ -286,57 +365,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
286
365
|
|
287
366
|
@custom_model.inference_api
|
288
367
|
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
289
|
-
|
290
|
-
|
291
|
-
tree_methods = ["to_xgboost", "to_lightgbm"]
|
292
|
-
non_tree_methods = ["to_sklearn"]
|
293
|
-
for method_name in tree_methods:
|
294
|
-
try:
|
295
|
-
base_model = getattr(raw_model, method_name)()
|
296
|
-
explainer = shap.TreeExplainer(base_model)
|
297
|
-
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
|
298
|
-
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
299
|
-
except exceptions.SnowflakeMLException:
|
300
|
-
pass # Do nothing and continue to the next method
|
301
|
-
for method_name in non_tree_methods:
|
302
|
-
try:
|
303
|
-
base_model = getattr(raw_model, method_name)()
|
304
|
-
try:
|
305
|
-
explainer = shap.Explainer(base_model, masker=background_data)
|
306
|
-
df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
|
307
|
-
except TypeError:
|
308
|
-
try:
|
309
|
-
dtype_map = {
|
310
|
-
spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs
|
311
|
-
}
|
312
|
-
|
313
|
-
if isinstance(X, pd.DataFrame):
|
314
|
-
X = X.astype(dtype_map, copy=False)
|
315
|
-
if hasattr(base_model, "predict_proba"):
|
316
|
-
if isinstance(X, np.ndarray):
|
317
|
-
explainer = shap.Explainer(
|
318
|
-
base_model.predict_proba,
|
319
|
-
background_data.values, # type: ignore[union-attr]
|
320
|
-
)
|
321
|
-
else:
|
322
|
-
explainer = shap.Explainer(base_model.predict_proba, background_data)
|
323
|
-
elif hasattr(base_model, "predict"):
|
324
|
-
if isinstance(X, np.ndarray):
|
325
|
-
explainer = shap.Explainer(
|
326
|
-
base_model.predict, background_data.values # type: ignore[union-attr]
|
327
|
-
)
|
328
|
-
else:
|
329
|
-
explainer = shap.Explainer(base_model.predict, background_data)
|
330
|
-
else:
|
331
|
-
raise ValueError("Missing any supported target method to explain.")
|
332
|
-
df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
|
333
|
-
except TypeError as e:
|
334
|
-
raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
|
335
|
-
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
336
|
-
|
337
|
-
except exceptions.SnowflakeMLException:
|
338
|
-
pass # Do nothing and continue to the next method
|
339
|
-
raise ValueError("The model must be an xgboost, lightgbm or sklearn (not pipeline) estimator.")
|
368
|
+
fn = cls._build_explain_fn(raw_model, background_data)
|
369
|
+
return model_signature_utils.rename_pandas_df(fn(X), signature.outputs)
|
340
370
|
|
341
371
|
if target_method == "explain":
|
342
372
|
return explain_fn
|
@@ -559,6 +559,30 @@ class ModelSignature:
|
|
559
559
|
)"""
|
560
560
|
)
|
561
561
|
|
562
|
+
def _repr_html_(self) -> str:
|
563
|
+
"""Generate an HTML representation of the model signature.
|
564
|
+
|
565
|
+
Returns:
|
566
|
+
str: HTML string containing formatted signature details.
|
567
|
+
"""
|
568
|
+
from snowflake.ml.utils import html_utils
|
569
|
+
|
570
|
+
# Create collapsible sections for inputs and outputs
|
571
|
+
inputs_content = html_utils.create_features_html(self.inputs, "Input")
|
572
|
+
outputs_content = html_utils.create_features_html(self.outputs, "Output")
|
573
|
+
|
574
|
+
inputs_section = html_utils.create_collapsible_section("Inputs", inputs_content, open_by_default=True)
|
575
|
+
outputs_section = html_utils.create_collapsible_section("Outputs", outputs_content, open_by_default=True)
|
576
|
+
|
577
|
+
content = f"""
|
578
|
+
<div style="margin-top: 10px;">
|
579
|
+
{inputs_section}
|
580
|
+
{outputs_section}
|
581
|
+
</div>
|
582
|
+
"""
|
583
|
+
|
584
|
+
return html_utils.create_base_container("Model Signature", content)
|
585
|
+
|
562
586
|
@classmethod
|
563
587
|
def from_mlflow_sig(cls, mlflow_sig: "mlflow.models.ModelSignature") -> "ModelSignature":
|
564
588
|
return ModelSignature(
|
@@ -272,8 +272,8 @@ def plot_influence_sensitivity(
|
|
272
272
|
If Streamlit is not available and a DataFrame is passed in, an ImportError will be raised.
|
273
273
|
|
274
274
|
Args:
|
275
|
-
|
276
|
-
|
275
|
+
shap_values: pandas Series or 2D array containing the SHAP values for a specific feature
|
276
|
+
feature_values: pandas Series or 2D array containing the feature values for the same feature
|
277
277
|
figsize: tuple of (width, height) for the plot
|
278
278
|
|
279
279
|
Returns:
|
@@ -1,7 +1,5 @@
|
|
1
|
-
from snowflake import snowpark
|
2
1
|
from snowflake.ml._internal import telemetry
|
3
2
|
from snowflake.ml._internal.utils import sql_identifier
|
4
|
-
from snowflake.ml.monitoring import model_monitor_version
|
5
3
|
from snowflake.ml.monitoring._client import model_monitor_sql_client
|
6
4
|
|
7
5
|
|
@@ -29,7 +27,6 @@ class ModelMonitor:
|
|
29
27
|
project=telemetry.TelemetryProject.MLOPS.value,
|
30
28
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
31
29
|
)
|
32
|
-
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
33
30
|
def suspend(self) -> None:
|
34
31
|
"""Suspend the Model Monitor"""
|
35
32
|
statement_params = telemetry.get_statement_params(
|
@@ -42,7 +39,6 @@ class ModelMonitor:
|
|
42
39
|
project=telemetry.TelemetryProject.MLOPS.value,
|
43
40
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
44
41
|
)
|
45
|
-
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
46
42
|
def resume(self) -> None:
|
47
43
|
"""Resume the Model Monitor"""
|
48
44
|
statement_params = telemetry.get_statement_params(
|
@@ -14,7 +14,7 @@ from snowflake.ml.model import (
|
|
14
14
|
type_hints as model_types,
|
15
15
|
)
|
16
16
|
from snowflake.ml.model._client.model import model_version_impl
|
17
|
-
from snowflake.ml.monitoring import model_monitor
|
17
|
+
from snowflake.ml.monitoring import model_monitor
|
18
18
|
from snowflake.ml.monitoring._manager import model_monitor_manager
|
19
19
|
from snowflake.ml.monitoring.entities import model_monitor_config
|
20
20
|
from snowflake.ml.registry._manager import model_manager
|
@@ -30,6 +30,7 @@ _MODEL_MONITORING_DISABLED_ERROR = (
|
|
30
30
|
|
31
31
|
|
32
32
|
class Registry:
|
33
|
+
@telemetry.send_api_usage_telemetry(project=_TELEMETRY_PROJECT, subproject=_MODEL_TELEMETRY_SUBPROJECT)
|
33
34
|
def __init__(
|
34
35
|
self,
|
35
36
|
session: session.Session,
|
@@ -74,6 +75,22 @@ class Registry:
|
|
74
75
|
else sql_identifier.SqlIdentifier("PUBLIC")
|
75
76
|
)
|
76
77
|
|
78
|
+
database_exists = session.sql(
|
79
|
+
f"""SELECT 1 FROM INFORMATION_SCHEMA.DATABASES WHERE DATABASE_NAME = '{self._database_name.resolved()}';"""
|
80
|
+
).collect()
|
81
|
+
|
82
|
+
if not database_exists:
|
83
|
+
raise ValueError(f"Database {self._database_name} does not exist.")
|
84
|
+
|
85
|
+
schema_exists = session.sql(
|
86
|
+
f"""
|
87
|
+
SELECT 1 FROM {self._database_name.identifier()}.INFORMATION_SCHEMA.SCHEMATA
|
88
|
+
WHERE SCHEMA_NAME = '{self._schema_name.resolved()}';"""
|
89
|
+
).collect()
|
90
|
+
|
91
|
+
if not schema_exists:
|
92
|
+
raise ValueError(f"Schema {self._schema_name} does not exist.")
|
93
|
+
|
77
94
|
self._model_manager = model_manager.ModelManager(
|
78
95
|
session,
|
79
96
|
database_name=self._database_name,
|
@@ -155,7 +172,11 @@ class Registry:
|
|
155
172
|
`snowflake.snowpark.pypi_shared_repository`.
|
156
173
|
resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
|
157
174
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
158
|
-
|
175
|
+
"WAREHOUSE" and "SNOWPARK_CONTAINER_SERVICES":
|
176
|
+
- ["WAREHOUSE"] (Warehouse only)
|
177
|
+
- ["SNOWPARK_CONTAINER_SERVICES"] (Snowpark Container Services only)
|
178
|
+
- ["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"] (Both)
|
179
|
+
Defaults to None. When None, the target platforms will be both.
|
159
180
|
python_version: Python version in which the model is run. Defaults to None.
|
160
181
|
signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
|
161
182
|
sample_input_data would be used to infer the signatures for those models that cannot automatically
|
@@ -295,8 +316,11 @@ class Registry:
|
|
295
316
|
`snowflake.snowpark.pypi_shared_repository`.
|
296
317
|
resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
|
297
318
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
298
|
-
|
299
|
-
|
319
|
+
"WAREHOUSE" and "SNOWPARK_CONTAINER_SERVICES":
|
320
|
+
- ["WAREHOUSE"] (Warehouse only)
|
321
|
+
- ["SNOWPARK_CONTAINER_SERVICES"] (Snowpark Container Services only)
|
322
|
+
- ["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"] (Both)
|
323
|
+
Defaults to None. When None, the target platforms will be both.
|
300
324
|
python_version: Python version in which the model is run. Defaults to None.
|
301
325
|
signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
|
302
326
|
sample_input_data would be used to infer the signatures for those models that cannot automatically
|
@@ -397,11 +421,11 @@ class Registry:
|
|
397
421
|
if task is not model_types.Task.UNKNOWN:
|
398
422
|
raise ValueError("`task` cannot be specified when calling log_model with a ModelVersion.")
|
399
423
|
|
400
|
-
if pip_requirements:
|
424
|
+
if pip_requirements and not artifact_repository_map:
|
401
425
|
warnings.warn(
|
402
|
-
"Models logged specifying `pip_requirements`
|
403
|
-
"
|
404
|
-
"
|
426
|
+
"Models logged specifying `pip_requirements` cannot be executed in a Snowflake Warehouse "
|
427
|
+
"without specifying `artifact_repository_map`. This model can be run in Snowpark Container "
|
428
|
+
"Services. See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container.",
|
405
429
|
category=UserWarning,
|
406
430
|
stacklevel=1,
|
407
431
|
)
|
@@ -500,7 +524,6 @@ class Registry:
|
|
500
524
|
project=telemetry.TelemetryProject.MLOPS.value,
|
501
525
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
502
526
|
)
|
503
|
-
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
504
527
|
def add_monitor(
|
505
528
|
self,
|
506
529
|
name: str,
|
@@ -525,7 +548,7 @@ class Registry:
|
|
525
548
|
return self._model_monitor_manager.add_monitor(name, source_config, model_monitor_config)
|
526
549
|
|
527
550
|
@overload
|
528
|
-
def get_monitor(self, model_version: model_version_impl.ModelVersion) -> model_monitor.ModelMonitor:
|
551
|
+
def get_monitor(self, *, model_version: model_version_impl.ModelVersion) -> model_monitor.ModelMonitor:
|
529
552
|
"""Get a Model Monitor on a Model Version from the Registry.
|
530
553
|
|
531
554
|
Args:
|
@@ -534,7 +557,7 @@ class Registry:
|
|
534
557
|
...
|
535
558
|
|
536
559
|
@overload
|
537
|
-
def get_monitor(self, name: str) -> model_monitor.ModelMonitor:
|
560
|
+
def get_monitor(self, *, name: str) -> model_monitor.ModelMonitor:
|
538
561
|
"""Get a Model Monitor by name from the Registry.
|
539
562
|
|
540
563
|
Args:
|
@@ -546,7 +569,6 @@ class Registry:
|
|
546
569
|
project=telemetry.TelemetryProject.MLOPS.value,
|
547
570
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
548
571
|
)
|
549
|
-
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
550
572
|
def get_monitor(
|
551
573
|
self, *, name: Optional[str] = None, model_version: Optional[model_version_impl.ModelVersion] = None
|
552
574
|
) -> model_monitor.ModelMonitor:
|
@@ -575,7 +597,6 @@ class Registry:
|
|
575
597
|
project=telemetry.TelemetryProject.MLOPS.value,
|
576
598
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
577
599
|
)
|
578
|
-
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
579
600
|
def show_model_monitors(self) -> list[snowpark.Row]:
|
580
601
|
"""Show all model monitors in the registry.
|
581
602
|
|
@@ -593,7 +614,6 @@ class Registry:
|
|
593
614
|
project=telemetry.TelemetryProject.MLOPS.value,
|
594
615
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
595
616
|
)
|
596
|
-
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
597
617
|
def delete_monitor(self, name: str) -> None:
|
598
618
|
"""Delete a Model Monitor by name from the Registry.
|
599
619
|
|
@@ -136,7 +136,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
|
|
136
136
|
return conn_params
|
137
137
|
|
138
138
|
|
139
|
-
@snowpark._internal.utils.
|
139
|
+
@snowpark._internal.utils.deprecated(version="1.8.5")
|
140
140
|
def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) -> dict[str, Union[str, bytes]]:
|
141
141
|
"""Returns a dict that can be used directly into snowflake python connector or Snowpark session config.
|
142
142
|
|