snowflake-ml-python 1.8.4__py3-none-any.whl → 1.8.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +42 -16
- snowflake/ml/_internal/utils/connection_params.py +196 -0
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +12 -2
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/interop_utils.py +1 -1
- snowflake/ml/jobs/_utils/payload_utils.py +95 -39
- snowflake/ml/jobs/_utils/scripts/constants.py +22 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +67 -2
- snowflake/ml/jobs/_utils/spec_utils.py +30 -6
- snowflake/ml/jobs/_utils/stage_utils.py +119 -0
- snowflake/ml/jobs/_utils/types.py +5 -1
- snowflake/ml/jobs/decorators.py +10 -7
- snowflake/ml/jobs/job.py +176 -28
- snowflake/ml/jobs/manager.py +119 -26
- snowflake/ml/model/_client/model/model_impl.py +58 -0
- snowflake/ml/model/_client/model/model_version_impl.py +90 -0
- snowflake/ml/model/_client/ops/model_ops.py +6 -3
- snowflake/ml/model/_client/ops/service_ops.py +24 -7
- snowflake/ml/model/_client/service/model_deployment_spec.py +11 -0
- snowflake/ml/model/_client/sql/model_version.py +1 -1
- snowflake/ml/model/_client/sql/service.py +73 -28
- snowflake/ml/model/_client/sql/stage.py +5 -2
- snowflake/ml/model/_model_composer/model_composer.py +3 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -2
- snowflake/ml/model/_signatures/core.py +24 -0
- snowflake/ml/monitoring/explain_visualize.py +160 -22
- snowflake/ml/monitoring/model_monitor.py +0 -4
- snowflake/ml/registry/registry.py +34 -14
- snowflake/ml/utils/connection_params.py +9 -3
- snowflake/ml/utils/html_utils.py +263 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/METADATA +40 -13
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/RECORD +40 -37
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/model_monitor_version.py +0 -1
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,96 @@ class ModelVersion(lineage_node.LineageNode):
|
|
38
38
|
def __init__(self) -> None:
|
39
39
|
raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
|
40
40
|
|
41
|
+
def _repr_html_(self) -> str:
|
42
|
+
"""Generate an HTML representation of the model version.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
str: HTML string containing formatted model version details.
|
46
|
+
"""
|
47
|
+
from snowflake.ml.utils import html_utils
|
48
|
+
|
49
|
+
# Get task
|
50
|
+
try:
|
51
|
+
task = self.get_model_task().value
|
52
|
+
except Exception:
|
53
|
+
task = (
|
54
|
+
html_utils.create_error_message("Not available")
|
55
|
+
.replace('<em style="color: #888; font-style: italic;">', "")
|
56
|
+
.replace("</em>", "")
|
57
|
+
)
|
58
|
+
|
59
|
+
# Get functions info for display
|
60
|
+
try:
|
61
|
+
functions = self.show_functions()
|
62
|
+
if not functions:
|
63
|
+
functions_html = html_utils.create_error_message("No functions available")
|
64
|
+
else:
|
65
|
+
functions_list = []
|
66
|
+
for func in functions:
|
67
|
+
try:
|
68
|
+
sig_html = func["signature"]._repr_html_()
|
69
|
+
except Exception:
|
70
|
+
# Fallback to simple display if can't display signature
|
71
|
+
sig_html = f"<pre style='margin: 5px 0;'>{func['signature']}</pre>"
|
72
|
+
|
73
|
+
function_content = f"""
|
74
|
+
<div style="margin: 5px 0;">
|
75
|
+
<strong>Target Method:</strong> {func['target_method']}
|
76
|
+
</div>
|
77
|
+
<div style="margin: 5px 0;">
|
78
|
+
<strong>Function Type:</strong> {func.get('target_method_function_type', 'N/A')}
|
79
|
+
</div>
|
80
|
+
<div style="margin: 5px 0;">
|
81
|
+
<strong>Partitioned:</strong> {func.get('is_partitioned', False)}
|
82
|
+
</div>
|
83
|
+
<div style="margin: 10px 0;">
|
84
|
+
<strong>Signature:</strong>
|
85
|
+
{sig_html}
|
86
|
+
</div>
|
87
|
+
"""
|
88
|
+
|
89
|
+
functions_list.append(
|
90
|
+
html_utils.create_collapsible_section(
|
91
|
+
title=func["name"], content=function_content, open_by_default=False
|
92
|
+
)
|
93
|
+
)
|
94
|
+
functions_html = "".join(functions_list)
|
95
|
+
except Exception:
|
96
|
+
functions_html = html_utils.create_error_message("Error retrieving functions")
|
97
|
+
|
98
|
+
# Get metrics for display
|
99
|
+
try:
|
100
|
+
metrics = self.show_metrics()
|
101
|
+
if not metrics:
|
102
|
+
metrics_html = html_utils.create_error_message("No metrics available")
|
103
|
+
else:
|
104
|
+
metrics_html = ""
|
105
|
+
for metric_name, value in metrics.items():
|
106
|
+
metrics_html += html_utils.create_metric_item(metric_name, value)
|
107
|
+
except Exception:
|
108
|
+
metrics_html = html_utils.create_error_message("Error retrieving metrics")
|
109
|
+
|
110
|
+
# Create main content sections
|
111
|
+
main_info = html_utils.create_grid_section(
|
112
|
+
[
|
113
|
+
("Model Name", self.model_name),
|
114
|
+
("Version", f'<strong style="color: #28a745;">{self.version_name}</strong>'),
|
115
|
+
("Full Name", self.fully_qualified_model_name),
|
116
|
+
("Description", self.description),
|
117
|
+
("Task", task),
|
118
|
+
]
|
119
|
+
)
|
120
|
+
|
121
|
+
functions_section = html_utils.create_section_header("Functions") + html_utils.create_content_section(
|
122
|
+
functions_html
|
123
|
+
)
|
124
|
+
|
125
|
+
metrics_section = html_utils.create_section_header("Metrics") + html_utils.create_content_section(metrics_html)
|
126
|
+
|
127
|
+
content = main_info + functions_section + metrics_section
|
128
|
+
|
129
|
+
return html_utils.create_base_container("Model Version Details", content)
|
130
|
+
|
41
131
|
@classmethod
|
42
132
|
def _ref(
|
43
133
|
cls,
|
@@ -643,14 +643,17 @@ class ModelOperator:
|
|
643
643
|
# TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
|
644
644
|
fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
|
645
645
|
|
646
|
-
result = []
|
647
|
-
|
646
|
+
result: list[ServiceInfo] = []
|
648
647
|
for fully_qualified_service_name in fully_qualified_service_names:
|
649
648
|
ingress_url: Optional[str] = None
|
650
649
|
db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
|
651
|
-
|
650
|
+
statuses = self._service_client.get_service_container_statuses(
|
652
651
|
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
653
652
|
)
|
653
|
+
if len(statuses) == 0:
|
654
|
+
return result
|
655
|
+
|
656
|
+
service_status = statuses[0].service_status
|
654
657
|
for res_row in self._service_client.show_endpoints(
|
655
658
|
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
656
659
|
):
|
@@ -125,6 +125,7 @@ class ServiceOperator:
|
|
125
125
|
stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
|
126
126
|
else:
|
127
127
|
stage_path = None
|
128
|
+
self._model_deployment_spec.clear()
|
128
129
|
self._model_deployment_spec.add_model_spec(
|
129
130
|
database_name=database_name,
|
130
131
|
schema_name=schema_name,
|
@@ -168,7 +169,7 @@ class ServiceOperator:
|
|
168
169
|
schema_name=service_schema_name,
|
169
170
|
service_name=service_name,
|
170
171
|
service_status_list_if_exists=[
|
171
|
-
service_sql.ServiceStatus.
|
172
|
+
service_sql.ServiceStatus.RUNNING,
|
172
173
|
service_sql.ServiceStatus.SUSPENDING,
|
173
174
|
service_sql.ServiceStatus.SUSPENDED,
|
174
175
|
],
|
@@ -324,14 +325,15 @@ class ServiceOperator:
|
|
324
325
|
)
|
325
326
|
continue
|
326
327
|
|
327
|
-
|
328
|
+
statuses = self._service_client.get_service_container_statuses(
|
328
329
|
database_name=service_log_meta.service.database_name,
|
329
330
|
schema_name=service_log_meta.service.schema_name,
|
330
331
|
service_name=service_log_meta.service.service_name,
|
331
332
|
include_message=True,
|
332
333
|
statement_params=statement_params,
|
333
334
|
)
|
334
|
-
|
335
|
+
service_status = statuses[0].service_status
|
336
|
+
if (service_status != service_sql.ServiceStatus.RUNNING) or (
|
335
337
|
service_status != service_log_meta.service_status
|
336
338
|
):
|
337
339
|
service_log_meta.service_status = service_status
|
@@ -340,7 +342,19 @@ class ServiceOperator:
|
|
340
342
|
f"{service_log_meta.service.display_service_name} is "
|
341
343
|
f"{service_log_meta.service_status.value}."
|
342
344
|
)
|
343
|
-
|
345
|
+
for status in statuses:
|
346
|
+
if status.instance_id is not None:
|
347
|
+
instance_status, container_status = None, None
|
348
|
+
if status.instance_status is not None:
|
349
|
+
instance_status = status.instance_status.value
|
350
|
+
if status.container_status is not None:
|
351
|
+
container_status = status.container_status.value
|
352
|
+
module_logger.info(
|
353
|
+
f"Instance[{status.instance_id}]: "
|
354
|
+
f"instance status: {instance_status}, "
|
355
|
+
f"container status: {container_status}, "
|
356
|
+
f"message: {status.message}"
|
357
|
+
)
|
344
358
|
|
345
359
|
new_logs, new_offset = fetch_logs(
|
346
360
|
service_log_meta.service,
|
@@ -352,13 +366,14 @@ class ServiceOperator:
|
|
352
366
|
|
353
367
|
# check if model build service is done
|
354
368
|
if not service_log_meta.is_model_build_service_done:
|
355
|
-
|
369
|
+
statuses = self._service_client.get_service_container_statuses(
|
356
370
|
database_name=model_build_service.database_name,
|
357
371
|
schema_name=model_build_service.schema_name,
|
358
372
|
service_name=model_build_service.service_name,
|
359
373
|
include_message=False,
|
360
374
|
statement_params=statement_params,
|
361
375
|
)
|
376
|
+
service_status = statuses[0].service_status
|
362
377
|
|
363
378
|
if service_status == service_sql.ServiceStatus.DONE:
|
364
379
|
set_service_log_metadata_to_model_inference(
|
@@ -428,20 +443,21 @@ class ServiceOperator:
|
|
428
443
|
if service_status_list_if_exists is None:
|
429
444
|
service_status_list_if_exists = [
|
430
445
|
service_sql.ServiceStatus.PENDING,
|
431
|
-
service_sql.ServiceStatus.
|
446
|
+
service_sql.ServiceStatus.RUNNING,
|
432
447
|
service_sql.ServiceStatus.SUSPENDING,
|
433
448
|
service_sql.ServiceStatus.SUSPENDED,
|
434
449
|
service_sql.ServiceStatus.DONE,
|
435
450
|
service_sql.ServiceStatus.FAILED,
|
436
451
|
]
|
437
452
|
try:
|
438
|
-
|
453
|
+
statuses = self._service_client.get_service_container_statuses(
|
439
454
|
database_name=database_name,
|
440
455
|
schema_name=schema_name,
|
441
456
|
service_name=service_name,
|
442
457
|
include_message=False,
|
443
458
|
statement_params=statement_params,
|
444
459
|
)
|
460
|
+
service_status = statuses[0].service_status
|
445
461
|
return any(service_status == status for status in service_status_list_if_exists)
|
446
462
|
except exceptions.SnowparkSQLException:
|
447
463
|
return False
|
@@ -538,6 +554,7 @@ class ServiceOperator:
|
|
538
554
|
)
|
539
555
|
|
540
556
|
try:
|
557
|
+
self._model_deployment_spec.clear()
|
541
558
|
# save the spec
|
542
559
|
self._model_deployment_spec.add_model_spec(
|
543
560
|
database_name=database_name,
|
@@ -29,6 +29,17 @@ class ModelDeploymentSpec:
|
|
29
29
|
self.database: Optional[sql_identifier.SqlIdentifier] = None
|
30
30
|
self.schema: Optional[sql_identifier.SqlIdentifier] = None
|
31
31
|
|
32
|
+
def clear(self) -> None:
|
33
|
+
"""Reset the deployment spec to its initial state."""
|
34
|
+
self._models = []
|
35
|
+
self._image_build = None
|
36
|
+
self._service = None
|
37
|
+
self._job = None
|
38
|
+
self._model_loggings = None
|
39
|
+
self._inference_spec = {}
|
40
|
+
self.database = None
|
41
|
+
self.schema = None
|
42
|
+
|
32
43
|
def add_model_spec(
|
33
44
|
self,
|
34
45
|
database_name: sql_identifier.SqlIdentifier,
|
@@ -293,7 +293,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
293
293
|
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
294
294
|
options = {"parallel": 10}
|
295
295
|
cursor = self._session._conn._cursor
|
296
|
-
cursor._download(stage_location_url, str(target_path), options)
|
296
|
+
cursor._download(stage_location_url, str(target_path), options)
|
297
297
|
cursor.fetchall()
|
298
298
|
else:
|
299
299
|
query_result_checker.SqlResultValidator(
|
@@ -1,5 +1,6 @@
|
|
1
|
+
import dataclasses
|
1
2
|
import enum
|
2
|
-
import
|
3
|
+
import logging
|
3
4
|
import textwrap
|
4
5
|
from typing import Any, Optional, Union
|
5
6
|
|
@@ -14,23 +15,59 @@ from snowflake.ml.model._model_composer.model_method import constants
|
|
14
15
|
from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
15
16
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
16
17
|
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
17
20
|
|
18
21
|
class ServiceStatus(enum.Enum):
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
22
|
+
PENDING = "PENDING"
|
23
|
+
RUNNING = "RUNNING"
|
24
|
+
FAILED = "FAILED"
|
25
|
+
DONE = "DONE"
|
26
|
+
SUSPENDING = "SUSPENDING"
|
27
|
+
SUSPENDED = "SUSPENDED"
|
28
|
+
DELETING = "DELETING"
|
29
|
+
DELETED = "DELETED"
|
30
|
+
INTERNAL_ERROR = "INTERNAL_ERROR"
|
31
|
+
|
32
|
+
|
33
|
+
class InstanceStatus(enum.Enum):
|
34
|
+
PENDING = "PENDING"
|
35
|
+
READY = "READY"
|
36
|
+
FAILED = "FAILED"
|
37
|
+
TERMINATING = "TERMINATING"
|
38
|
+
SUCCEEDED = "SUCCEEDED"
|
39
|
+
|
40
|
+
|
41
|
+
class ContainerStatus(enum.Enum):
|
42
|
+
PENDING = "PENDING"
|
43
|
+
READY = "READY"
|
44
|
+
DONE = "DONE"
|
45
|
+
FAILED = "FAILED"
|
46
|
+
UNKNOWN = "UNKNOWN"
|
47
|
+
|
48
|
+
|
49
|
+
@dataclasses.dataclass
|
50
|
+
class ServiceStatusInfo:
|
51
|
+
"""
|
52
|
+
Class containing information about service container status.
|
53
|
+
Reference: https://docs.snowflake.com/en/sql-reference/sql/show-service-containers-in-service
|
54
|
+
"""
|
55
|
+
|
56
|
+
service_status: ServiceStatus
|
57
|
+
instance_id: Optional[int] = None
|
58
|
+
instance_status: Optional[InstanceStatus] = None
|
59
|
+
container_status: Optional[ContainerStatus] = None
|
60
|
+
message: Optional[str] = None
|
29
61
|
|
30
62
|
|
31
63
|
class ServiceSQLClient(_base._BaseSQLClient):
|
32
64
|
MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
|
33
65
|
MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
|
66
|
+
SERVICE_STATUS = "service_status"
|
67
|
+
INSTANCE_ID = "instance_id"
|
68
|
+
INSTANCE_STATUS = "instance_status"
|
69
|
+
CONTAINER_STATUS = "status"
|
70
|
+
MESSAGE = "message"
|
34
71
|
|
35
72
|
def build_model_container(
|
36
73
|
self,
|
@@ -79,6 +116,10 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
79
116
|
) -> tuple[str, snowpark.AsyncJob]:
|
80
117
|
assert model_deployment_spec_yaml_str or model_deployment_spec_file_rel_path
|
81
118
|
if model_deployment_spec_yaml_str:
|
119
|
+
model_deployment_spec_yaml_str = snowpark_utils.escape_single_quotes(
|
120
|
+
model_deployment_spec_yaml_str
|
121
|
+
) # type: ignore[no-untyped-call]
|
122
|
+
logger.info(f"Deploying model with spec={model_deployment_spec_yaml_str}")
|
82
123
|
sql_str = f"CALL SYSTEM$DEPLOY_MODEL('{model_deployment_spec_yaml_str}')"
|
83
124
|
else:
|
84
125
|
sql_str = f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
|
@@ -190,7 +231,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
190
231
|
)
|
191
232
|
return str(rows[0][system_func])
|
192
233
|
|
193
|
-
def
|
234
|
+
def get_service_container_statuses(
|
194
235
|
self,
|
195
236
|
*,
|
196
237
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
@@ -198,23 +239,27 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
198
239
|
service_name: sql_identifier.SqlIdentifier,
|
199
240
|
include_message: bool = False,
|
200
241
|
statement_params: Optional[dict[str, Any]] = None,
|
201
|
-
) ->
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
242
|
+
) -> list[ServiceStatusInfo]:
|
243
|
+
fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
|
244
|
+
query = f"SHOW SERVICE CONTAINERS IN SERVICE {fully_qualified_object_name}"
|
245
|
+
rows = self._session.sql(query).collect(statement_params=statement_params)
|
246
|
+
statuses = []
|
247
|
+
for r in rows:
|
248
|
+
instance_status, container_status = None, None
|
249
|
+
if r[ServiceSQLClient.INSTANCE_STATUS] is not None:
|
250
|
+
instance_status = InstanceStatus(r[ServiceSQLClient.INSTANCE_STATUS])
|
251
|
+
if r[ServiceSQLClient.CONTAINER_STATUS] is not None:
|
252
|
+
container_status = ContainerStatus(r[ServiceSQLClient.CONTAINER_STATUS])
|
253
|
+
statuses.append(
|
254
|
+
ServiceStatusInfo(
|
255
|
+
service_status=ServiceStatus(r[ServiceSQLClient.SERVICE_STATUS]),
|
256
|
+
instance_id=r[ServiceSQLClient.INSTANCE_ID],
|
257
|
+
instance_status=instance_status,
|
258
|
+
container_status=container_status,
|
259
|
+
message=r[ServiceSQLClient.MESSAGE] if include_message else None,
|
260
|
+
)
|
208
261
|
)
|
209
|
-
|
210
|
-
.validate()
|
211
|
-
)
|
212
|
-
metadata = json.loads(rows[0][system_func])[0]
|
213
|
-
if metadata and metadata["status"]:
|
214
|
-
service_status = ServiceStatus(metadata["status"])
|
215
|
-
message = metadata["message"] if include_message else None
|
216
|
-
return service_status, message
|
217
|
-
return ServiceStatus.UNKNOWN, None
|
262
|
+
return statuses
|
218
263
|
|
219
264
|
def drop_service(
|
220
265
|
self,
|
@@ -12,9 +12,12 @@ class StageSQLClient(_base._BaseSQLClient):
|
|
12
12
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
13
13
|
stage_name: sql_identifier.SqlIdentifier,
|
14
14
|
statement_params: Optional[dict[str, Any]] = None,
|
15
|
-
) ->
|
15
|
+
) -> str:
|
16
|
+
fq_stage_name = self.fully_qualified_object_name(database_name, schema_name, stage_name)
|
16
17
|
query_result_checker.SqlResultValidator(
|
17
18
|
self._session,
|
18
|
-
f"CREATE SCOPED TEMPORARY STAGE {
|
19
|
+
f"CREATE SCOPED TEMPORARY STAGE {fq_stage_name}",
|
19
20
|
statement_params=statement_params,
|
20
21
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
22
|
+
|
23
|
+
return fq_stage_name
|
@@ -188,7 +188,9 @@ class ModelComposer:
|
|
188
188
|
if not options:
|
189
189
|
options = model_types.BaseModelSaveOption()
|
190
190
|
|
191
|
-
if not snowpark_utils.is_in_stored_procedure()
|
191
|
+
if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
|
192
|
+
model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
|
193
|
+
]:
|
192
194
|
snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
|
193
195
|
self.session,
|
194
196
|
reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
|
@@ -216,7 +216,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
216
216
|
explain_fn=cls._build_explain_fn(model, background_data, input_signature),
|
217
217
|
output_feature_names=transformed_background_data.columns,
|
218
218
|
)
|
219
|
-
except
|
219
|
+
except Exception:
|
220
220
|
if kwargs.get("enable_explainability", None):
|
221
221
|
# user explicitly enabled explainability, so we should raise the error
|
222
222
|
raise ValueError(
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
7
7
|
import pandas as pd
|
8
|
+
import shap
|
8
9
|
from typing_extensions import TypeGuard, Unpack
|
9
10
|
|
10
11
|
from snowflake.ml._internal import type_utils
|
@@ -25,6 +26,19 @@ if TYPE_CHECKING:
|
|
25
26
|
from snowflake.ml.modeling.framework.base import BaseEstimator
|
26
27
|
|
27
28
|
|
29
|
+
def _apply_transforms_up_to_last_step(
|
30
|
+
model: "BaseEstimator",
|
31
|
+
data: model_types.SupportedDataType,
|
32
|
+
) -> pd.DataFrame:
|
33
|
+
"""Apply all transformations in the snowml pipeline model up to the last step."""
|
34
|
+
if type_utils.LazyType("snowflake.ml.modeling.pipeline.Pipeline").isinstance(model):
|
35
|
+
for step_name, step in model.steps[:-1]: # type: ignore[attr-defined]
|
36
|
+
if not hasattr(step, "transform"):
|
37
|
+
raise ValueError(f"Step '{step_name}' does not have a 'transform' method.")
|
38
|
+
data = pd.DataFrame(step.transform(data))
|
39
|
+
return data
|
40
|
+
|
41
|
+
|
28
42
|
@final
|
29
43
|
class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
30
44
|
"""Handler for SnowML based model.
|
@@ -39,7 +53,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
39
53
|
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
40
54
|
|
41
55
|
DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
|
42
|
-
EXPLAIN_TARGET_METHODS = ["
|
56
|
+
EXPLAIN_TARGET_METHODS = ["predict_proba", "predict", "predict_log_proba"]
|
43
57
|
|
44
58
|
IS_AUTO_SIGNATURE = True
|
45
59
|
|
@@ -97,11 +111,6 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
97
111
|
return result
|
98
112
|
except exceptions.SnowflakeMLException:
|
99
113
|
pass # Do nothing and continue to the next method
|
100
|
-
|
101
|
-
if enable_explainability:
|
102
|
-
raise ValueError(
|
103
|
-
"Explain only supported for xgboost, lightgbm and sklearn (not pipeline) Snowpark ML models."
|
104
|
-
)
|
105
114
|
return None
|
106
115
|
|
107
116
|
@classmethod
|
@@ -189,23 +198,46 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
189
198
|
else:
|
190
199
|
enable_explainability = True
|
191
200
|
if enable_explainability:
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
target_method=explain_target_method,
|
200
|
-
output_return_type=model_task_and_output_type.output_type,
|
201
|
-
)
|
202
|
-
background_data = handlers_utils.get_explainability_supported_background(
|
203
|
-
sample_input_data, model_meta, explain_target_method
|
204
|
-
)
|
205
|
-
if background_data is not None:
|
206
|
-
handlers_utils.save_background_data(
|
207
|
-
model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
|
201
|
+
try:
|
202
|
+
model_task_and_output_type = model_task_utils.resolve_model_task_and_output_type(
|
203
|
+
python_base_obj, model_meta.task
|
204
|
+
)
|
205
|
+
model_meta.task = model_task_and_output_type.task
|
206
|
+
background_data = handlers_utils.get_explainability_supported_background(
|
207
|
+
sample_input_data, model_meta, explain_target_method
|
208
208
|
)
|
209
|
+
if type_utils.LazyType("snowflake.ml.modeling.pipeline.Pipeline").isinstance(model):
|
210
|
+
transformed_df = _apply_transforms_up_to_last_step(model, sample_input_data)
|
211
|
+
explain_fn = cls._build_explain_fn(model, background_data)
|
212
|
+
model_meta = handlers_utils.add_inferred_explain_method_signature(
|
213
|
+
model_meta=model_meta,
|
214
|
+
explain_method="explain",
|
215
|
+
target_method=explain_target_method, # type: ignore[arg-type]
|
216
|
+
background_data=background_data,
|
217
|
+
explain_fn=explain_fn,
|
218
|
+
output_feature_names=transformed_df.columns,
|
219
|
+
)
|
220
|
+
else:
|
221
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
222
|
+
model_meta=model_meta,
|
223
|
+
explain_method="explain",
|
224
|
+
target_method=explain_target_method,
|
225
|
+
output_return_type=model_task_and_output_type.output_type,
|
226
|
+
)
|
227
|
+
if background_data is not None:
|
228
|
+
handlers_utils.save_background_data(
|
229
|
+
model_blobs_dir_path,
|
230
|
+
cls.EXPLAIN_ARTIFACTS_DIR,
|
231
|
+
cls.BG_DATA_FILE_SUFFIX,
|
232
|
+
name,
|
233
|
+
background_data,
|
234
|
+
)
|
235
|
+
except Exception:
|
236
|
+
if kwargs.get("enable_explainability", None):
|
237
|
+
# user explicitly enabled explainability, so we should raise the error
|
238
|
+
raise ValueError(
|
239
|
+
"Explainability for this model is not supported. Please set `enable_explainability=False`"
|
240
|
+
)
|
209
241
|
|
210
242
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
211
243
|
os.makedirs(model_blob_path, exist_ok=True)
|
@@ -251,6 +283,53 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
251
283
|
assert isinstance(m, BaseEstimator)
|
252
284
|
return m
|
253
285
|
|
286
|
+
@classmethod
|
287
|
+
def _build_explain_fn(
|
288
|
+
cls, model: "BaseEstimator", background_data: model_types.SupportedDataType
|
289
|
+
) -> Callable[[model_types.SupportedDataType], pd.DataFrame]:
|
290
|
+
|
291
|
+
predictor = model
|
292
|
+
is_pipeline = type_utils.LazyType("snowflake.ml.modeling.pipeline.Pipeline").isinstance(model)
|
293
|
+
if is_pipeline:
|
294
|
+
background_data = _apply_transforms_up_to_last_step(model, background_data)
|
295
|
+
predictor = model.steps[-1][1] # type: ignore[attr-defined]
|
296
|
+
|
297
|
+
def explain_fn(data: model_types.SupportedDataType) -> pd.DataFrame:
|
298
|
+
data = _apply_transforms_up_to_last_step(model, data)
|
299
|
+
tree_methods = ["to_xgboost", "to_lightgbm"]
|
300
|
+
non_tree_methods = ["to_sklearn", None] # None just uses the predictor directly
|
301
|
+
for method_name in tree_methods:
|
302
|
+
try:
|
303
|
+
base_model = getattr(predictor, method_name)()
|
304
|
+
explainer = shap.TreeExplainer(base_model)
|
305
|
+
return handlers_utils.convert_explanations_to_2D_df(model, explainer.shap_values(data))
|
306
|
+
except exceptions.SnowflakeMLException:
|
307
|
+
pass # Do nothing and continue to the next method
|
308
|
+
for method_name in non_tree_methods: # type: ignore[assignment]
|
309
|
+
try:
|
310
|
+
base_model = getattr(predictor, method_name)() if method_name is not None else predictor
|
311
|
+
try:
|
312
|
+
explainer = shap.Explainer(base_model, masker=background_data)
|
313
|
+
return handlers_utils.convert_explanations_to_2D_df(base_model, explainer(data).values)
|
314
|
+
except TypeError:
|
315
|
+
for explain_target_method in cls.EXPLAIN_TARGET_METHODS:
|
316
|
+
if not hasattr(base_model, explain_target_method):
|
317
|
+
continue
|
318
|
+
explain_target_method_fn = getattr(base_model, explain_target_method)
|
319
|
+
if isinstance(data, np.ndarray):
|
320
|
+
explainer = shap.Explainer(
|
321
|
+
explain_target_method_fn,
|
322
|
+
background_data.values, # type: ignore[union-attr]
|
323
|
+
)
|
324
|
+
else:
|
325
|
+
explainer = shap.Explainer(explain_target_method_fn, background_data)
|
326
|
+
return handlers_utils.convert_explanations_to_2D_df(base_model, explainer(data).values)
|
327
|
+
except Exception:
|
328
|
+
pass # Do nothing and continue to the next method
|
329
|
+
raise ValueError("Explainability for this model is not supported.")
|
330
|
+
|
331
|
+
return explain_fn
|
332
|
+
|
254
333
|
@classmethod
|
255
334
|
def convert_as_custom_model(
|
256
335
|
cls,
|
@@ -286,57 +365,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
286
365
|
|
287
366
|
@custom_model.inference_api
|
288
367
|
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
289
|
-
|
290
|
-
|
291
|
-
tree_methods = ["to_xgboost", "to_lightgbm"]
|
292
|
-
non_tree_methods = ["to_sklearn"]
|
293
|
-
for method_name in tree_methods:
|
294
|
-
try:
|
295
|
-
base_model = getattr(raw_model, method_name)()
|
296
|
-
explainer = shap.TreeExplainer(base_model)
|
297
|
-
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
|
298
|
-
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
299
|
-
except exceptions.SnowflakeMLException:
|
300
|
-
pass # Do nothing and continue to the next method
|
301
|
-
for method_name in non_tree_methods:
|
302
|
-
try:
|
303
|
-
base_model = getattr(raw_model, method_name)()
|
304
|
-
try:
|
305
|
-
explainer = shap.Explainer(base_model, masker=background_data)
|
306
|
-
df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
|
307
|
-
except TypeError:
|
308
|
-
try:
|
309
|
-
dtype_map = {
|
310
|
-
spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs
|
311
|
-
}
|
312
|
-
|
313
|
-
if isinstance(X, pd.DataFrame):
|
314
|
-
X = X.astype(dtype_map, copy=False)
|
315
|
-
if hasattr(base_model, "predict_proba"):
|
316
|
-
if isinstance(X, np.ndarray):
|
317
|
-
explainer = shap.Explainer(
|
318
|
-
base_model.predict_proba,
|
319
|
-
background_data.values, # type: ignore[union-attr]
|
320
|
-
)
|
321
|
-
else:
|
322
|
-
explainer = shap.Explainer(base_model.predict_proba, background_data)
|
323
|
-
elif hasattr(base_model, "predict"):
|
324
|
-
if isinstance(X, np.ndarray):
|
325
|
-
explainer = shap.Explainer(
|
326
|
-
base_model.predict, background_data.values # type: ignore[union-attr]
|
327
|
-
)
|
328
|
-
else:
|
329
|
-
explainer = shap.Explainer(base_model.predict, background_data)
|
330
|
-
else:
|
331
|
-
raise ValueError("Missing any supported target method to explain.")
|
332
|
-
df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
|
333
|
-
except TypeError as e:
|
334
|
-
raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
|
335
|
-
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
336
|
-
|
337
|
-
except exceptions.SnowflakeMLException:
|
338
|
-
pass # Do nothing and continue to the next method
|
339
|
-
raise ValueError("The model must be an xgboost, lightgbm or sklearn (not pipeline) estimator.")
|
368
|
+
fn = cls._build_explain_fn(raw_model, background_data)
|
369
|
+
return model_signature_utils.rename_pandas_df(fn(X), signature.outputs)
|
340
370
|
|
341
371
|
if target_method == "explain":
|
342
372
|
return explain_fn
|