snowflake-ml-python 1.8.4__py3-none-any.whl → 1.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. snowflake/ml/_internal/telemetry.py +42 -16
  2. snowflake/ml/_internal/utils/connection_params.py +196 -0
  3. snowflake/ml/data/data_connector.py +1 -1
  4. snowflake/ml/jobs/__init__.py +2 -0
  5. snowflake/ml/jobs/_utils/constants.py +12 -2
  6. snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
  7. snowflake/ml/jobs/_utils/interop_utils.py +1 -1
  8. snowflake/ml/jobs/_utils/payload_utils.py +95 -39
  9. snowflake/ml/jobs/_utils/scripts/constants.py +22 -0
  10. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +67 -2
  11. snowflake/ml/jobs/_utils/spec_utils.py +30 -6
  12. snowflake/ml/jobs/_utils/stage_utils.py +119 -0
  13. snowflake/ml/jobs/_utils/types.py +5 -1
  14. snowflake/ml/jobs/decorators.py +10 -7
  15. snowflake/ml/jobs/job.py +176 -28
  16. snowflake/ml/jobs/manager.py +119 -26
  17. snowflake/ml/model/_client/model/model_impl.py +58 -0
  18. snowflake/ml/model/_client/model/model_version_impl.py +90 -0
  19. snowflake/ml/model/_client/ops/model_ops.py +6 -3
  20. snowflake/ml/model/_client/ops/service_ops.py +24 -7
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +11 -0
  22. snowflake/ml/model/_client/sql/model_version.py +1 -1
  23. snowflake/ml/model/_client/sql/service.py +73 -28
  24. snowflake/ml/model/_client/sql/stage.py +5 -2
  25. snowflake/ml/model/_model_composer/model_composer.py +3 -1
  26. snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
  28. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -2
  29. snowflake/ml/model/_signatures/core.py +24 -0
  30. snowflake/ml/monitoring/explain_visualize.py +160 -22
  31. snowflake/ml/monitoring/model_monitor.py +0 -4
  32. snowflake/ml/registry/registry.py +34 -14
  33. snowflake/ml/utils/connection_params.py +9 -3
  34. snowflake/ml/utils/html_utils.py +263 -0
  35. snowflake/ml/version.py +1 -1
  36. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/METADATA +40 -13
  37. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/RECORD +40 -37
  38. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/WHEEL +1 -1
  39. snowflake/ml/monitoring/model_monitor_version.py +0 -1
  40. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/licenses/LICENSE.txt +0 -0
  41. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,96 @@ class ModelVersion(lineage_node.LineageNode):
38
38
  def __init__(self) -> None:
39
39
  raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
40
40
 
41
+ def _repr_html_(self) -> str:
42
+ """Generate an HTML representation of the model version.
43
+
44
+ Returns:
45
+ str: HTML string containing formatted model version details.
46
+ """
47
+ from snowflake.ml.utils import html_utils
48
+
49
+ # Get task
50
+ try:
51
+ task = self.get_model_task().value
52
+ except Exception:
53
+ task = (
54
+ html_utils.create_error_message("Not available")
55
+ .replace('<em style="color: #888; font-style: italic;">', "")
56
+ .replace("</em>", "")
57
+ )
58
+
59
+ # Get functions info for display
60
+ try:
61
+ functions = self.show_functions()
62
+ if not functions:
63
+ functions_html = html_utils.create_error_message("No functions available")
64
+ else:
65
+ functions_list = []
66
+ for func in functions:
67
+ try:
68
+ sig_html = func["signature"]._repr_html_()
69
+ except Exception:
70
+ # Fallback to simple display if can't display signature
71
+ sig_html = f"<pre style='margin: 5px 0;'>{func['signature']}</pre>"
72
+
73
+ function_content = f"""
74
+ <div style="margin: 5px 0;">
75
+ <strong>Target Method:</strong> {func['target_method']}
76
+ </div>
77
+ <div style="margin: 5px 0;">
78
+ <strong>Function Type:</strong> {func.get('target_method_function_type', 'N/A')}
79
+ </div>
80
+ <div style="margin: 5px 0;">
81
+ <strong>Partitioned:</strong> {func.get('is_partitioned', False)}
82
+ </div>
83
+ <div style="margin: 10px 0;">
84
+ <strong>Signature:</strong>
85
+ {sig_html}
86
+ </div>
87
+ """
88
+
89
+ functions_list.append(
90
+ html_utils.create_collapsible_section(
91
+ title=func["name"], content=function_content, open_by_default=False
92
+ )
93
+ )
94
+ functions_html = "".join(functions_list)
95
+ except Exception:
96
+ functions_html = html_utils.create_error_message("Error retrieving functions")
97
+
98
+ # Get metrics for display
99
+ try:
100
+ metrics = self.show_metrics()
101
+ if not metrics:
102
+ metrics_html = html_utils.create_error_message("No metrics available")
103
+ else:
104
+ metrics_html = ""
105
+ for metric_name, value in metrics.items():
106
+ metrics_html += html_utils.create_metric_item(metric_name, value)
107
+ except Exception:
108
+ metrics_html = html_utils.create_error_message("Error retrieving metrics")
109
+
110
+ # Create main content sections
111
+ main_info = html_utils.create_grid_section(
112
+ [
113
+ ("Model Name", self.model_name),
114
+ ("Version", f'<strong style="color: #28a745;">{self.version_name}</strong>'),
115
+ ("Full Name", self.fully_qualified_model_name),
116
+ ("Description", self.description),
117
+ ("Task", task),
118
+ ]
119
+ )
120
+
121
+ functions_section = html_utils.create_section_header("Functions") + html_utils.create_content_section(
122
+ functions_html
123
+ )
124
+
125
+ metrics_section = html_utils.create_section_header("Metrics") + html_utils.create_content_section(metrics_html)
126
+
127
+ content = main_info + functions_section + metrics_section
128
+
129
+ return html_utils.create_base_container("Model Version Details", content)
130
+
41
131
  @classmethod
42
132
  def _ref(
43
133
  cls,
@@ -643,14 +643,17 @@ class ModelOperator:
643
643
  # TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
644
644
  fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
645
645
 
646
- result = []
647
-
646
+ result: list[ServiceInfo] = []
648
647
  for fully_qualified_service_name in fully_qualified_service_names:
649
648
  ingress_url: Optional[str] = None
650
649
  db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
651
- service_status, _ = self._service_client.get_service_status(
650
+ statuses = self._service_client.get_service_container_statuses(
652
651
  database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
653
652
  )
653
+ if len(statuses) == 0:
654
+ return result
655
+
656
+ service_status = statuses[0].service_status
654
657
  for res_row in self._service_client.show_endpoints(
655
658
  database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
656
659
  ):
@@ -125,6 +125,7 @@ class ServiceOperator:
125
125
  stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
126
126
  else:
127
127
  stage_path = None
128
+ self._model_deployment_spec.clear()
128
129
  self._model_deployment_spec.add_model_spec(
129
130
  database_name=database_name,
130
131
  schema_name=schema_name,
@@ -168,7 +169,7 @@ class ServiceOperator:
168
169
  schema_name=service_schema_name,
169
170
  service_name=service_name,
170
171
  service_status_list_if_exists=[
171
- service_sql.ServiceStatus.READY,
172
+ service_sql.ServiceStatus.RUNNING,
172
173
  service_sql.ServiceStatus.SUSPENDING,
173
174
  service_sql.ServiceStatus.SUSPENDED,
174
175
  ],
@@ -324,14 +325,15 @@ class ServiceOperator:
324
325
  )
325
326
  continue
326
327
 
327
- service_status, message = self._service_client.get_service_status(
328
+ statuses = self._service_client.get_service_container_statuses(
328
329
  database_name=service_log_meta.service.database_name,
329
330
  schema_name=service_log_meta.service.schema_name,
330
331
  service_name=service_log_meta.service.service_name,
331
332
  include_message=True,
332
333
  statement_params=statement_params,
333
334
  )
334
- if (service_status != service_sql.ServiceStatus.READY) or (
335
+ service_status = statuses[0].service_status
336
+ if (service_status != service_sql.ServiceStatus.RUNNING) or (
335
337
  service_status != service_log_meta.service_status
336
338
  ):
337
339
  service_log_meta.service_status = service_status
@@ -340,7 +342,19 @@ class ServiceOperator:
340
342
  f"{service_log_meta.service.display_service_name} is "
341
343
  f"{service_log_meta.service_status.value}."
342
344
  )
343
- module_logger.info(f"Service message: {message}")
345
+ for status in statuses:
346
+ if status.instance_id is not None:
347
+ instance_status, container_status = None, None
348
+ if status.instance_status is not None:
349
+ instance_status = status.instance_status.value
350
+ if status.container_status is not None:
351
+ container_status = status.container_status.value
352
+ module_logger.info(
353
+ f"Instance[{status.instance_id}]: "
354
+ f"instance status: {instance_status}, "
355
+ f"container status: {container_status}, "
356
+ f"message: {status.message}"
357
+ )
344
358
 
345
359
  new_logs, new_offset = fetch_logs(
346
360
  service_log_meta.service,
@@ -352,13 +366,14 @@ class ServiceOperator:
352
366
 
353
367
  # check if model build service is done
354
368
  if not service_log_meta.is_model_build_service_done:
355
- service_status, _ = self._service_client.get_service_status(
369
+ statuses = self._service_client.get_service_container_statuses(
356
370
  database_name=model_build_service.database_name,
357
371
  schema_name=model_build_service.schema_name,
358
372
  service_name=model_build_service.service_name,
359
373
  include_message=False,
360
374
  statement_params=statement_params,
361
375
  )
376
+ service_status = statuses[0].service_status
362
377
 
363
378
  if service_status == service_sql.ServiceStatus.DONE:
364
379
  set_service_log_metadata_to_model_inference(
@@ -428,20 +443,21 @@ class ServiceOperator:
428
443
  if service_status_list_if_exists is None:
429
444
  service_status_list_if_exists = [
430
445
  service_sql.ServiceStatus.PENDING,
431
- service_sql.ServiceStatus.READY,
446
+ service_sql.ServiceStatus.RUNNING,
432
447
  service_sql.ServiceStatus.SUSPENDING,
433
448
  service_sql.ServiceStatus.SUSPENDED,
434
449
  service_sql.ServiceStatus.DONE,
435
450
  service_sql.ServiceStatus.FAILED,
436
451
  ]
437
452
  try:
438
- service_status, _ = self._service_client.get_service_status(
453
+ statuses = self._service_client.get_service_container_statuses(
439
454
  database_name=database_name,
440
455
  schema_name=schema_name,
441
456
  service_name=service_name,
442
457
  include_message=False,
443
458
  statement_params=statement_params,
444
459
  )
460
+ service_status = statuses[0].service_status
445
461
  return any(service_status == status for status in service_status_list_if_exists)
446
462
  except exceptions.SnowparkSQLException:
447
463
  return False
@@ -538,6 +554,7 @@ class ServiceOperator:
538
554
  )
539
555
 
540
556
  try:
557
+ self._model_deployment_spec.clear()
541
558
  # save the spec
542
559
  self._model_deployment_spec.add_model_spec(
543
560
  database_name=database_name,
@@ -29,6 +29,17 @@ class ModelDeploymentSpec:
29
29
  self.database: Optional[sql_identifier.SqlIdentifier] = None
30
30
  self.schema: Optional[sql_identifier.SqlIdentifier] = None
31
31
 
32
+ def clear(self) -> None:
33
+ """Reset the deployment spec to its initial state."""
34
+ self._models = []
35
+ self._image_build = None
36
+ self._service = None
37
+ self._job = None
38
+ self._model_loggings = None
39
+ self._inference_spec = {}
40
+ self.database = None
41
+ self.schema = None
42
+
32
43
  def add_model_spec(
33
44
  self,
34
45
  database_name: sql_identifier.SqlIdentifier,
@@ -293,7 +293,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
293
293
  if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
294
294
  options = {"parallel": 10}
295
295
  cursor = self._session._conn._cursor
296
- cursor._download(stage_location_url, str(target_path), options) # type: ignore[union-attr]
296
+ cursor._download(stage_location_url, str(target_path), options)
297
297
  cursor.fetchall()
298
298
  else:
299
299
  query_result_checker.SqlResultValidator(
@@ -1,5 +1,6 @@
1
+ import dataclasses
1
2
  import enum
2
- import json
3
+ import logging
3
4
  import textwrap
4
5
  from typing import Any, Optional, Union
5
6
 
@@ -14,23 +15,59 @@ from snowflake.ml.model._model_composer.model_method import constants
14
15
  from snowflake.snowpark import dataframe, functions as F, row, types as spt
15
16
  from snowflake.snowpark._internal import utils as snowpark_utils
16
17
 
18
+ logger = logging.getLogger(__name__)
19
+
17
20
 
18
21
  class ServiceStatus(enum.Enum):
19
- UNKNOWN = "UNKNOWN" # status is unknown because we have not received enough data from K8s yet.
20
- PENDING = "PENDING" # resource set is being created, can't be used yet
21
- READY = "READY" # resource set has been deployed.
22
- SUSPENDING = "SUSPENDING" # the service is set to suspended but the resource set is still in deleting state
23
- SUSPENDED = "SUSPENDED" # the service is suspended and the resource set is deleted
24
- DELETING = "DELETING" # resource set is being deleted
25
- FAILED = "FAILED" # resource set has failed and cannot be used anymore
26
- DONE = "DONE" # resource set has finished running
27
- NOT_FOUND = "NOT_FOUND" # not found or deleted
28
- INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error.
22
+ PENDING = "PENDING"
23
+ RUNNING = "RUNNING"
24
+ FAILED = "FAILED"
25
+ DONE = "DONE"
26
+ SUSPENDING = "SUSPENDING"
27
+ SUSPENDED = "SUSPENDED"
28
+ DELETING = "DELETING"
29
+ DELETED = "DELETED"
30
+ INTERNAL_ERROR = "INTERNAL_ERROR"
31
+
32
+
33
+ class InstanceStatus(enum.Enum):
34
+ PENDING = "PENDING"
35
+ READY = "READY"
36
+ FAILED = "FAILED"
37
+ TERMINATING = "TERMINATING"
38
+ SUCCEEDED = "SUCCEEDED"
39
+
40
+
41
+ class ContainerStatus(enum.Enum):
42
+ PENDING = "PENDING"
43
+ READY = "READY"
44
+ DONE = "DONE"
45
+ FAILED = "FAILED"
46
+ UNKNOWN = "UNKNOWN"
47
+
48
+
49
+ @dataclasses.dataclass
50
+ class ServiceStatusInfo:
51
+ """
52
+ Class containing information about service container status.
53
+ Reference: https://docs.snowflake.com/en/sql-reference/sql/show-service-containers-in-service
54
+ """
55
+
56
+ service_status: ServiceStatus
57
+ instance_id: Optional[int] = None
58
+ instance_status: Optional[InstanceStatus] = None
59
+ container_status: Optional[ContainerStatus] = None
60
+ message: Optional[str] = None
29
61
 
30
62
 
31
63
  class ServiceSQLClient(_base._BaseSQLClient):
32
64
  MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
33
65
  MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
66
+ SERVICE_STATUS = "service_status"
67
+ INSTANCE_ID = "instance_id"
68
+ INSTANCE_STATUS = "instance_status"
69
+ CONTAINER_STATUS = "status"
70
+ MESSAGE = "message"
34
71
 
35
72
  def build_model_container(
36
73
  self,
@@ -79,6 +116,10 @@ class ServiceSQLClient(_base._BaseSQLClient):
79
116
  ) -> tuple[str, snowpark.AsyncJob]:
80
117
  assert model_deployment_spec_yaml_str or model_deployment_spec_file_rel_path
81
118
  if model_deployment_spec_yaml_str:
119
+ model_deployment_spec_yaml_str = snowpark_utils.escape_single_quotes(
120
+ model_deployment_spec_yaml_str
121
+ ) # type: ignore[no-untyped-call]
122
+ logger.info(f"Deploying model with spec={model_deployment_spec_yaml_str}")
82
123
  sql_str = f"CALL SYSTEM$DEPLOY_MODEL('{model_deployment_spec_yaml_str}')"
83
124
  else:
84
125
  sql_str = f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
@@ -190,7 +231,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
190
231
  )
191
232
  return str(rows[0][system_func])
192
233
 
193
- def get_service_status(
234
+ def get_service_container_statuses(
194
235
  self,
195
236
  *,
196
237
  database_name: Optional[sql_identifier.SqlIdentifier],
@@ -198,23 +239,27 @@ class ServiceSQLClient(_base._BaseSQLClient):
198
239
  service_name: sql_identifier.SqlIdentifier,
199
240
  include_message: bool = False,
200
241
  statement_params: Optional[dict[str, Any]] = None,
201
- ) -> tuple[ServiceStatus, Optional[str]]:
202
- system_func = "SYSTEM$GET_SERVICE_STATUS"
203
- rows = (
204
- query_result_checker.SqlResultValidator(
205
- self._session,
206
- f"CALL {system_func}('{self.fully_qualified_object_name(database_name, schema_name, service_name)}')",
207
- statement_params=statement_params,
242
+ ) -> list[ServiceStatusInfo]:
243
+ fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
244
+ query = f"SHOW SERVICE CONTAINERS IN SERVICE {fully_qualified_object_name}"
245
+ rows = self._session.sql(query).collect(statement_params=statement_params)
246
+ statuses = []
247
+ for r in rows:
248
+ instance_status, container_status = None, None
249
+ if r[ServiceSQLClient.INSTANCE_STATUS] is not None:
250
+ instance_status = InstanceStatus(r[ServiceSQLClient.INSTANCE_STATUS])
251
+ if r[ServiceSQLClient.CONTAINER_STATUS] is not None:
252
+ container_status = ContainerStatus(r[ServiceSQLClient.CONTAINER_STATUS])
253
+ statuses.append(
254
+ ServiceStatusInfo(
255
+ service_status=ServiceStatus(r[ServiceSQLClient.SERVICE_STATUS]),
256
+ instance_id=r[ServiceSQLClient.INSTANCE_ID],
257
+ instance_status=instance_status,
258
+ container_status=container_status,
259
+ message=r[ServiceSQLClient.MESSAGE] if include_message else None,
260
+ )
208
261
  )
209
- .has_dimensions(expected_rows=1, expected_cols=1)
210
- .validate()
211
- )
212
- metadata = json.loads(rows[0][system_func])[0]
213
- if metadata and metadata["status"]:
214
- service_status = ServiceStatus(metadata["status"])
215
- message = metadata["message"] if include_message else None
216
- return service_status, message
217
- return ServiceStatus.UNKNOWN, None
262
+ return statuses
218
263
 
219
264
  def drop_service(
220
265
  self,
@@ -12,9 +12,12 @@ class StageSQLClient(_base._BaseSQLClient):
12
12
  schema_name: Optional[sql_identifier.SqlIdentifier],
13
13
  stage_name: sql_identifier.SqlIdentifier,
14
14
  statement_params: Optional[dict[str, Any]] = None,
15
- ) -> None:
15
+ ) -> str:
16
+ fq_stage_name = self.fully_qualified_object_name(database_name, schema_name, stage_name)
16
17
  query_result_checker.SqlResultValidator(
17
18
  self._session,
18
- f"CREATE SCOPED TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
19
+ f"CREATE SCOPED TEMPORARY STAGE {fq_stage_name}",
19
20
  statement_params=statement_params,
20
21
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
22
+
23
+ return fq_stage_name
@@ -188,7 +188,9 @@ class ModelComposer:
188
188
  if not options:
189
189
  options = model_types.BaseModelSaveOption()
190
190
 
191
- if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
191
+ if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
192
+ model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
193
+ ]:
192
194
  snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
193
195
  self.session,
194
196
  reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
@@ -216,7 +216,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
216
216
  explain_fn=cls._build_explain_fn(model, background_data, input_signature),
217
217
  output_feature_names=transformed_background_data.columns,
218
218
  )
219
- except ValueError:
219
+ except Exception:
220
220
  if kwargs.get("enable_explainability", None):
221
221
  # user explicitly enabled explainability, so we should raise the error
222
222
  raise ValueError(
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
5
5
  import cloudpickle
6
6
  import numpy as np
7
7
  import pandas as pd
8
+ import shap
8
9
  from typing_extensions import TypeGuard, Unpack
9
10
 
10
11
  from snowflake.ml._internal import type_utils
@@ -25,6 +26,19 @@ if TYPE_CHECKING:
25
26
  from snowflake.ml.modeling.framework.base import BaseEstimator
26
27
 
27
28
 
29
+ def _apply_transforms_up_to_last_step(
30
+ model: "BaseEstimator",
31
+ data: model_types.SupportedDataType,
32
+ ) -> pd.DataFrame:
33
+ """Apply all transformations in the snowml pipeline model up to the last step."""
34
+ if type_utils.LazyType("snowflake.ml.modeling.pipeline.Pipeline").isinstance(model):
35
+ for step_name, step in model.steps[:-1]: # type: ignore[attr-defined]
36
+ if not hasattr(step, "transform"):
37
+ raise ValueError(f"Step '{step_name}' does not have a 'transform' method.")
38
+ data = pd.DataFrame(step.transform(data))
39
+ return data
40
+
41
+
28
42
  @final
29
43
  class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
30
44
  """Handler for SnowML based model.
@@ -39,7 +53,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
39
53
  _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
40
54
 
41
55
  DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
42
- EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
56
+ EXPLAIN_TARGET_METHODS = ["predict_proba", "predict", "predict_log_proba"]
43
57
 
44
58
  IS_AUTO_SIGNATURE = True
45
59
 
@@ -97,11 +111,6 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
97
111
  return result
98
112
  except exceptions.SnowflakeMLException:
99
113
  pass # Do nothing and continue to the next method
100
-
101
- if enable_explainability:
102
- raise ValueError(
103
- "Explain only supported for xgboost, lightgbm and sklearn (not pipeline) Snowpark ML models."
104
- )
105
114
  return None
106
115
 
107
116
  @classmethod
@@ -189,23 +198,46 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
189
198
  else:
190
199
  enable_explainability = True
191
200
  if enable_explainability:
192
- model_task_and_output_type = model_task_utils.resolve_model_task_and_output_type(
193
- python_base_obj, model_meta.task
194
- )
195
- model_meta.task = model_task_and_output_type.task
196
- model_meta = handlers_utils.add_explain_method_signature(
197
- model_meta=model_meta,
198
- explain_method="explain",
199
- target_method=explain_target_method,
200
- output_return_type=model_task_and_output_type.output_type,
201
- )
202
- background_data = handlers_utils.get_explainability_supported_background(
203
- sample_input_data, model_meta, explain_target_method
204
- )
205
- if background_data is not None:
206
- handlers_utils.save_background_data(
207
- model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
201
+ try:
202
+ model_task_and_output_type = model_task_utils.resolve_model_task_and_output_type(
203
+ python_base_obj, model_meta.task
204
+ )
205
+ model_meta.task = model_task_and_output_type.task
206
+ background_data = handlers_utils.get_explainability_supported_background(
207
+ sample_input_data, model_meta, explain_target_method
208
208
  )
209
+ if type_utils.LazyType("snowflake.ml.modeling.pipeline.Pipeline").isinstance(model):
210
+ transformed_df = _apply_transforms_up_to_last_step(model, sample_input_data)
211
+ explain_fn = cls._build_explain_fn(model, background_data)
212
+ model_meta = handlers_utils.add_inferred_explain_method_signature(
213
+ model_meta=model_meta,
214
+ explain_method="explain",
215
+ target_method=explain_target_method, # type: ignore[arg-type]
216
+ background_data=background_data,
217
+ explain_fn=explain_fn,
218
+ output_feature_names=transformed_df.columns,
219
+ )
220
+ else:
221
+ model_meta = handlers_utils.add_explain_method_signature(
222
+ model_meta=model_meta,
223
+ explain_method="explain",
224
+ target_method=explain_target_method,
225
+ output_return_type=model_task_and_output_type.output_type,
226
+ )
227
+ if background_data is not None:
228
+ handlers_utils.save_background_data(
229
+ model_blobs_dir_path,
230
+ cls.EXPLAIN_ARTIFACTS_DIR,
231
+ cls.BG_DATA_FILE_SUFFIX,
232
+ name,
233
+ background_data,
234
+ )
235
+ except Exception:
236
+ if kwargs.get("enable_explainability", None):
237
+ # user explicitly enabled explainability, so we should raise the error
238
+ raise ValueError(
239
+ "Explainability for this model is not supported. Please set `enable_explainability=False`"
240
+ )
209
241
 
210
242
  model_blob_path = os.path.join(model_blobs_dir_path, name)
211
243
  os.makedirs(model_blob_path, exist_ok=True)
@@ -251,6 +283,53 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
251
283
  assert isinstance(m, BaseEstimator)
252
284
  return m
253
285
 
286
+ @classmethod
287
+ def _build_explain_fn(
288
+ cls, model: "BaseEstimator", background_data: model_types.SupportedDataType
289
+ ) -> Callable[[model_types.SupportedDataType], pd.DataFrame]:
290
+
291
+ predictor = model
292
+ is_pipeline = type_utils.LazyType("snowflake.ml.modeling.pipeline.Pipeline").isinstance(model)
293
+ if is_pipeline:
294
+ background_data = _apply_transforms_up_to_last_step(model, background_data)
295
+ predictor = model.steps[-1][1] # type: ignore[attr-defined]
296
+
297
+ def explain_fn(data: model_types.SupportedDataType) -> pd.DataFrame:
298
+ data = _apply_transforms_up_to_last_step(model, data)
299
+ tree_methods = ["to_xgboost", "to_lightgbm"]
300
+ non_tree_methods = ["to_sklearn", None] # None just uses the predictor directly
301
+ for method_name in tree_methods:
302
+ try:
303
+ base_model = getattr(predictor, method_name)()
304
+ explainer = shap.TreeExplainer(base_model)
305
+ return handlers_utils.convert_explanations_to_2D_df(model, explainer.shap_values(data))
306
+ except exceptions.SnowflakeMLException:
307
+ pass # Do nothing and continue to the next method
308
+ for method_name in non_tree_methods: # type: ignore[assignment]
309
+ try:
310
+ base_model = getattr(predictor, method_name)() if method_name is not None else predictor
311
+ try:
312
+ explainer = shap.Explainer(base_model, masker=background_data)
313
+ return handlers_utils.convert_explanations_to_2D_df(base_model, explainer(data).values)
314
+ except TypeError:
315
+ for explain_target_method in cls.EXPLAIN_TARGET_METHODS:
316
+ if not hasattr(base_model, explain_target_method):
317
+ continue
318
+ explain_target_method_fn = getattr(base_model, explain_target_method)
319
+ if isinstance(data, np.ndarray):
320
+ explainer = shap.Explainer(
321
+ explain_target_method_fn,
322
+ background_data.values, # type: ignore[union-attr]
323
+ )
324
+ else:
325
+ explainer = shap.Explainer(explain_target_method_fn, background_data)
326
+ return handlers_utils.convert_explanations_to_2D_df(base_model, explainer(data).values)
327
+ except Exception:
328
+ pass # Do nothing and continue to the next method
329
+ raise ValueError("Explainability for this model is not supported.")
330
+
331
+ return explain_fn
332
+
254
333
  @classmethod
255
334
  def convert_as_custom_model(
256
335
  cls,
@@ -286,57 +365,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
286
365
 
287
366
  @custom_model.inference_api
288
367
  def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
289
- import shap
290
-
291
- tree_methods = ["to_xgboost", "to_lightgbm"]
292
- non_tree_methods = ["to_sklearn"]
293
- for method_name in tree_methods:
294
- try:
295
- base_model = getattr(raw_model, method_name)()
296
- explainer = shap.TreeExplainer(base_model)
297
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
298
- return model_signature_utils.rename_pandas_df(df, signature.outputs)
299
- except exceptions.SnowflakeMLException:
300
- pass # Do nothing and continue to the next method
301
- for method_name in non_tree_methods:
302
- try:
303
- base_model = getattr(raw_model, method_name)()
304
- try:
305
- explainer = shap.Explainer(base_model, masker=background_data)
306
- df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
307
- except TypeError:
308
- try:
309
- dtype_map = {
310
- spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs
311
- }
312
-
313
- if isinstance(X, pd.DataFrame):
314
- X = X.astype(dtype_map, copy=False)
315
- if hasattr(base_model, "predict_proba"):
316
- if isinstance(X, np.ndarray):
317
- explainer = shap.Explainer(
318
- base_model.predict_proba,
319
- background_data.values, # type: ignore[union-attr]
320
- )
321
- else:
322
- explainer = shap.Explainer(base_model.predict_proba, background_data)
323
- elif hasattr(base_model, "predict"):
324
- if isinstance(X, np.ndarray):
325
- explainer = shap.Explainer(
326
- base_model.predict, background_data.values # type: ignore[union-attr]
327
- )
328
- else:
329
- explainer = shap.Explainer(base_model.predict, background_data)
330
- else:
331
- raise ValueError("Missing any supported target method to explain.")
332
- df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
333
- except TypeError as e:
334
- raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
335
- return model_signature_utils.rename_pandas_df(df, signature.outputs)
336
-
337
- except exceptions.SnowflakeMLException:
338
- pass # Do nothing and continue to the next method
339
- raise ValueError("The model must be an xgboost, lightgbm or sklearn (not pipeline) estimator.")
368
+ fn = cls._build_explain_fn(raw_model, background_data)
369
+ return model_signature_utils.rename_pandas_df(fn(X), signature.outputs)
340
370
 
341
371
  if target_method == "explain":
342
372
  return explain_fn