snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +14 -0
  5. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  6. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  7. snowflake/ml/data/data_connector.py +59 -6
  8. snowflake/ml/data/data_ingestor.py +18 -1
  9. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  10. snowflake/ml/data/torch_dataset.py +33 -0
  11. snowflake/ml/dataset/dataset_metadata.py +3 -1
  12. snowflake/ml/dataset/dataset_reader.py +9 -3
  13. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  14. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  15. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  16. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  17. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  18. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  19. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  20. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  21. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  22. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  23. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  24. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  25. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  26. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  27. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  29. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  30. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  31. snowflake/ml/feature_store/feature_store.py +59 -24
  32. snowflake/ml/feature_store/feature_view.py +148 -4
  33. snowflake/ml/model/_client/model/model_impl.py +11 -2
  34. snowflake/ml/model/_client/model/model_version_impl.py +171 -20
  35. snowflake/ml/model/_client/ops/model_ops.py +105 -27
  36. snowflake/ml/model/_client/ops/service_ops.py +121 -0
  37. snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
  38. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
  39. snowflake/ml/model/_client/sql/model_version.py +13 -4
  40. snowflake/ml/model/_client/sql/service.py +129 -0
  41. snowflake/ml/model/_model_composer/model_composer.py +3 -0
  42. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +10 -2
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  44. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  45. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +19 -12
  47. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  48. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
  49. snowflake/ml/model/_packager/model_handlers/lightgbm.py +27 -18
  50. snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
  51. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  52. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  53. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  54. snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
  55. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  57. snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
  58. snowflake/ml/model/_packager/model_handlers/xgboost.py +25 -16
  59. snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
  60. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
  61. snowflake/ml/model/_packager/model_packager.py +2 -1
  62. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  63. snowflake/ml/model/type_hints.py +1 -3
  64. snowflake/ml/modeling/framework/base.py +28 -19
  65. snowflake/ml/modeling/pipeline/pipeline.py +3 -0
  66. snowflake/ml/registry/_manager/model_manager.py +16 -2
  67. snowflake/ml/utils/sql_client.py +22 -0
  68. snowflake/ml/version.py +1 -1
  69. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +35 -2
  70. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +73 -62
  71. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  72. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
  73. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +0 -0
  74. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,31 @@
1
+ from typing import List, TypedDict
2
+
3
+ from typing_extensions import NotRequired, Required
4
+
5
+
6
+ class ModelDict(TypedDict):
7
+ name: Required[str]
8
+ version: Required[str]
9
+
10
+
11
+ class ImageBuildDict(TypedDict):
12
+ compute_pool: Required[str]
13
+ image_repo: Required[str]
14
+ image_name: NotRequired[str]
15
+ force_rebuild: Required[bool]
16
+ external_access_integrations: Required[List[str]]
17
+
18
+
19
+ class ServiceDict(TypedDict):
20
+ name: Required[str]
21
+ compute_pool: Required[str]
22
+ ingress_enabled: Required[bool]
23
+ min_instances: Required[int]
24
+ max_instances: Required[int]
25
+ gpu: NotRequired[str]
26
+
27
+
28
+ class ModelDeploymentSpecDict(TypedDict):
29
+ models: Required[List[ModelDict]]
30
+ image_build: Required[ImageBuildDict]
31
+ service: Required[ServiceDict]
@@ -371,6 +371,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
371
371
  returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
372
372
  partition_column: Optional[sql_identifier.SqlIdentifier],
373
373
  statement_params: Optional[Dict[str, Any]] = None,
374
+ is_partitioned: bool = True,
374
375
  ) -> dataframe.DataFrame:
375
376
  with_statements = []
376
377
  if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
@@ -409,12 +410,20 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
409
410
 
410
411
  sql = textwrap.dedent(
411
412
  f"""WITH {','.join(with_statements)}
412
- SELECT *,
413
- FROM {INTERMEDIATE_TABLE_NAME},
414
- TABLE({module_version_alias}!{method_name.identifier()}({args_sql})
415
- OVER (PARTITION BY {partition_by}))"""
413
+ SELECT *,
414
+ FROM {INTERMEDIATE_TABLE_NAME},
415
+ TABLE({module_version_alias}!{method_name.identifier()}({args_sql}))"""
416
416
  )
417
417
 
418
+ if is_partitioned or partition_column is not None:
419
+ sql = textwrap.dedent(
420
+ f"""WITH {','.join(with_statements)}
421
+ SELECT *,
422
+ FROM {INTERMEDIATE_TABLE_NAME},
423
+ TABLE({module_version_alias}!{method_name.identifier()}({args_sql})
424
+ OVER (PARTITION BY {partition_by}))"""
425
+ )
426
+
418
427
  output_df = self._session.sql(sql)
419
428
 
420
429
  # Prepare the output
@@ -0,0 +1,129 @@
1
+ import textwrap
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ from snowflake.ml._internal.utils import (
5
+ identifier,
6
+ query_result_checker,
7
+ sql_identifier,
8
+ )
9
+ from snowflake.ml.model._client.sql import _base
10
+ from snowflake.snowpark import dataframe, functions as F, types as spt
11
+ from snowflake.snowpark._internal import utils as snowpark_utils
12
+
13
+
14
+ class ServiceSQLClient(_base._BaseSQLClient):
15
+ def build_model_container(
16
+ self,
17
+ *,
18
+ database_name: Optional[sql_identifier.SqlIdentifier],
19
+ schema_name: Optional[sql_identifier.SqlIdentifier],
20
+ model_name: sql_identifier.SqlIdentifier,
21
+ version_name: sql_identifier.SqlIdentifier,
22
+ compute_pool_name: sql_identifier.SqlIdentifier,
23
+ image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
24
+ image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
25
+ image_repo_name: sql_identifier.SqlIdentifier,
26
+ gpu: Optional[str],
27
+ force_rebuild: bool,
28
+ external_access_integration: sql_identifier.SqlIdentifier,
29
+ statement_params: Optional[Dict[str, Any]] = None,
30
+ ) -> None:
31
+ actual_image_repo_database = image_repo_database_name or self._database_name
32
+ actual_image_repo_schema = image_repo_schema_name or self._schema_name
33
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
34
+ fq_image_repo_name = "/" + "/".join(
35
+ [
36
+ actual_image_repo_database.identifier(),
37
+ actual_image_repo_schema.identifier(),
38
+ image_repo_name.identifier(),
39
+ ]
40
+ )
41
+ is_gpu = gpu is not None
42
+ query_result_checker.SqlResultValidator(
43
+ self._session,
44
+ (
45
+ f"CALL SYSTEM$BUILD_MODEL_CONTAINER('{fq_model_name}', '{version_name}', '{compute_pool_name}',"
46
+ f" '{fq_image_repo_name}', '{is_gpu}', '{force_rebuild}', '', '{external_access_integration}')"
47
+ ),
48
+ statement_params=statement_params,
49
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
50
+
51
+ def deploy_model(
52
+ self,
53
+ *,
54
+ stage_path: str,
55
+ model_deployment_spec_file_rel_path: str,
56
+ statement_params: Optional[Dict[str, Any]] = None,
57
+ ) -> None:
58
+ query_result_checker.SqlResultValidator(
59
+ self._session,
60
+ f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')",
61
+ statement_params=statement_params,
62
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
63
+
64
+ def invoke_function_method(
65
+ self,
66
+ *,
67
+ database_name: Optional[sql_identifier.SqlIdentifier],
68
+ schema_name: Optional[sql_identifier.SqlIdentifier],
69
+ service_name: sql_identifier.SqlIdentifier,
70
+ method_name: sql_identifier.SqlIdentifier,
71
+ input_df: dataframe.DataFrame,
72
+ input_args: List[sql_identifier.SqlIdentifier],
73
+ returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
74
+ statement_params: Optional[Dict[str, Any]] = None,
75
+ ) -> dataframe.DataFrame:
76
+ with_statements = []
77
+ if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
78
+ INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
79
+ with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
80
+ else:
81
+ actual_database_name = database_name or self._database_name
82
+ actual_schema_name = schema_name or self._schema_name
83
+ tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
84
+ INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
85
+ actual_database_name.identifier(),
86
+ actual_schema_name.identifier(),
87
+ tmp_table_name,
88
+ )
89
+ input_df.write.save_as_table(
90
+ table_name=INTERMEDIATE_TABLE_NAME,
91
+ mode="errorifexists",
92
+ table_type="temporary",
93
+ statement_params=statement_params,
94
+ )
95
+
96
+ INTERMEDIATE_OBJ_NAME = "TMP_RESULT"
97
+
98
+ with_sql = f"WITH {','.join(with_statements)}" if with_statements else ""
99
+ args_sql_list = []
100
+ for input_arg_value in input_args:
101
+ args_sql_list.append(input_arg_value)
102
+ args_sql = ", ".join(args_sql_list)
103
+
104
+ sql = textwrap.dedent(
105
+ f"""{with_sql}
106
+ SELECT *,
107
+ {service_name.identifier()}_{method_name.identifier()}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
108
+ FROM {INTERMEDIATE_TABLE_NAME}"""
109
+ )
110
+
111
+ output_df = self._session.sql(sql)
112
+
113
+ # Prepare the output
114
+ output_cols = []
115
+ output_names = []
116
+
117
+ for output_name, output_type, output_col_name in returns:
118
+ output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_name].astype(output_type))
119
+ output_names.append(output_col_name)
120
+
121
+ output_df = output_df.with_columns(
122
+ col_names=output_names,
123
+ values=output_cols,
124
+ ).drop(INTERMEDIATE_OBJ_NAME)
125
+
126
+ if statement_params:
127
+ output_df._statement_params = statement_params # type: ignore[assignment]
128
+
129
+ return output_df
@@ -10,6 +10,7 @@ from absl import logging
10
10
  from packaging import requirements
11
11
  from typing_extensions import deprecated
12
12
 
13
+ from snowflake import snowpark
13
14
  from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
14
15
  from snowflake.ml._internal.lineage import lineage_utils
15
16
  from snowflake.ml.data import data_source
@@ -185,4 +186,6 @@ class ModelComposer:
185
186
  data_sources = lineage_utils.get_data_sources(model)
186
187
  if not data_sources and sample_input_data is not None:
187
188
  data_sources = lineage_utils.get_data_sources(sample_input_data)
189
+ if not data_sources and isinstance(sample_input_data, snowpark.DataFrame):
190
+ data_sources = [data_source.DataFrameInfo(sample_input_data.queries["queries"][-1])]
188
191
  return data_sources
@@ -6,6 +6,7 @@ from typing import List, Optional, cast
6
6
 
7
7
  import yaml
8
8
 
9
+ from snowflake.ml._internal import env_utils
9
10
  from snowflake.ml.data import data_source
10
11
  from snowflake.ml.model import type_hints
11
12
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -47,7 +48,9 @@ class ModelManifest:
47
48
  runtime_to_use = copy.deepcopy(model_meta.runtimes["cpu"])
48
49
  runtime_to_use.name = self._DEFAULT_RUNTIME_NAME
49
50
  runtime_to_use.imports.append(str(model_rel_path) + "/")
50
- runtime_dict = runtime_to_use.save(self.workspace_path)
51
+ runtime_dict = runtime_to_use.save(
52
+ self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
53
+ )
51
54
 
52
55
  self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
53
56
  self.methods: List[model_method.ModelMethod] = []
@@ -137,10 +140,15 @@ class ModelManifest:
137
140
  if isinstance(source, data_source.DatasetInfo):
138
141
  result.append(
139
142
  model_manifest_schema.LineageSourceDict(
140
- # Currently, we only support lineage from Dataset.
141
143
  type=model_manifest_schema.LineageSourceTypes.DATASET.value,
142
144
  entity=source.fully_qualified_name,
143
145
  version=source.version,
144
146
  )
145
147
  )
148
+ elif isinstance(source, data_source.DataFrameInfo):
149
+ result.append(
150
+ model_manifest_schema.LineageSourceDict(
151
+ type=model_manifest_schema.LineageSourceTypes.QUERY.value, entity=source.sql
152
+ )
153
+ )
146
154
  return result
@@ -57,12 +57,14 @@ class ModelFunctionInfo(TypedDict):
57
57
  target_method: actual target method name to be called.
58
58
  target_method_function_type: target method function type (FUNCTION or TABLE_FUNCTION).
59
59
  signature: The signature of the model method.
60
+ is_partitioned: Whether the function is partitioned.
60
61
  """
61
62
 
62
63
  name: Required[str]
63
64
  target_method: Required[str]
64
65
  target_method_function_type: Required[str]
65
66
  signature: Required[model_signature.ModelSignature]
67
+ is_partitioned: Required[bool]
66
68
 
67
69
 
68
70
  class ModelFunctionInfoDict(TypedDict):
@@ -78,6 +80,7 @@ class SnowparkMLDataDict(TypedDict):
78
80
 
79
81
  class LineageSourceTypes(enum.Enum):
80
82
  DATASET = "DATASET"
83
+ QUERY = "QUERY"
81
84
 
82
85
 
83
86
  class LineageSourceDict(TypedDict):
@@ -363,9 +363,14 @@ class ModelEnv:
363
363
  self.cuda_version = env_dict.get("cuda_version", None)
364
364
  self.snowpark_ml_version = env_dict["snowpark_ml_version"]
365
365
 
366
- def save_as_dict(self, base_dir: pathlib.Path) -> model_meta_schema.ModelEnvDict:
366
+ def save_as_dict(
367
+ self, base_dir: pathlib.Path, default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
368
+ ) -> model_meta_schema.ModelEnvDict:
367
369
  env_utils.save_conda_env_file(
368
- pathlib.Path(base_dir / self.conda_env_rel_path), self._conda_dependencies, self.python_version
370
+ pathlib.Path(base_dir / self.conda_env_rel_path),
371
+ self._conda_dependencies,
372
+ self.python_version,
373
+ default_channel_override=default_channel_override,
369
374
  )
370
375
  env_utils.save_requirements_file(
371
376
  pathlib.Path(base_dir / self.pip_requirements_rel_path), self._pip_requirements
@@ -1,7 +1,8 @@
1
+ import os
1
2
  from abc import abstractmethod
2
- from enum import Enum
3
3
  from typing import Dict, Generic, Optional, Protocol, Type, final
4
4
 
5
+ import pandas as pd
5
6
  from typing_extensions import TypeGuard, Unpack
6
7
 
7
8
  from snowflake.ml.model import custom_model, type_hints as model_types
@@ -9,15 +10,6 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
9
10
  from snowflake.ml.model._packager.model_meta import model_meta
10
11
 
11
12
 
12
- class ModelObjective(Enum):
13
- # This is not getting stored anywhere as metadata yet so it should be fine to slowly extend it for better coverage
14
- UNKNOWN = "unknown"
15
- BINARY_CLASSIFICATION = "binary_classification"
16
- MULTI_CLASSIFICATION = "multi_classification"
17
- REGRESSION = "regression"
18
- RANKING = "ranking"
19
-
20
-
21
13
  class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
22
14
  HANDLER_TYPE: model_types.SupportedModelHandlerType
23
15
  HANDLER_VERSION: str
@@ -106,6 +98,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
106
98
  cls,
107
99
  raw_model: model_types._ModelType,
108
100
  model_meta: model_meta.ModelMetadata,
101
+ background_data: Optional[pd.DataFrame] = None,
109
102
  **kwargs: Unpack[model_types.BaseModelLoadOption],
110
103
  ) -> custom_model.CustomModel:
111
104
  """Create a custom model class wrap for unified interface when being deployed. The predict method will be
@@ -114,6 +107,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
114
107
  Args:
115
108
  raw_model: original model object,
116
109
  model_meta: The model metadata.
110
+ background_data: The background data used for the model explanations.
117
111
  kwargs: Options when converting the model.
118
112
 
119
113
  Raises:
@@ -131,7 +125,8 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
131
125
  _MIN_SNOWPARK_ML_VERSION: The minimal version of Snowpark ML library to use the current handler.
132
126
  _HANDLER_MIGRATOR_PLANS: Dict holding handler migrator plans.
133
127
 
134
- MODELE_BLOB_FILE_OR_DIR: Relative path of the model blob file in the model subdir. Default to "model.pkl".
128
+ MODEL_BLOB_FILE_OR_DIR: Relative path of the model blob file in the model subdir. Default to "model.pkl".
129
+ BG_DATA_FILE_SUFFIX: Suffix of the background data file. Default to "_background_data.pqt".
135
130
  MODEL_ARTIFACTS_DIR: Relative path of the model artifacts dir in the model subdir. Default to "artifacts"
136
131
  DEFAULT_TARGET_METHODS: Default target methods to be logged if not specified in this kind of model. Default to
137
132
  ["predict"]
@@ -139,8 +134,10 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
139
134
  inputting sample data or model signature. Default to False.
140
135
  """
141
136
 
142
- MODELE_BLOB_FILE_OR_DIR = "model.pkl"
137
+ MODEL_BLOB_FILE_OR_DIR = "model.pkl"
138
+ BG_DATA_FILE_SUFFIX = "_background_data.pqt"
143
139
  MODEL_ARTIFACTS_DIR = "artifacts"
140
+ EXPLAIN_ARTIFACTS_DIR = "explain_artifacts"
144
141
  DEFAULT_TARGET_METHODS = ["predict"]
145
142
  IS_AUTO_SIGNATURE = False
146
143
 
@@ -169,3 +166,23 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
169
166
  model_meta=model_meta,
170
167
  model_blobs_dir_path=model_blobs_dir_path,
171
168
  )
169
+
170
+ @classmethod
171
+ @final
172
+ def load_background_data(cls, name: str, model_blobs_dir_path: str) -> Optional[pd.DataFrame]:
173
+ """Load the model into memory.
174
+
175
+ Args:
176
+ name: Name of the model.
177
+ model_blobs_dir_path: Directory path to the whole model.
178
+
179
+ Returns:
180
+ Optional[pd.DataFrame], background data as pandas DataFrame, if exists.
181
+ """
182
+ data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, name + cls.BG_DATA_FILE_SUFFIX)
183
+ if not os.path.exists(model_blobs_dir_path) or not os.path.isfile(data_blob_path):
184
+ return None
185
+ with open(data_blob_path, "rb") as f:
186
+ background_data = pd.read_parquet(f)
187
+
188
+ return background_data
@@ -30,24 +30,24 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
30
30
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
31
31
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
32
32
 
33
- MODELE_BLOB_FILE_OR_DIR = "model.bin"
33
+ MODEL_BLOB_FILE_OR_DIR = "model.bin"
34
34
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
35
35
 
36
36
  @classmethod
37
- def get_model_objective(cls, model: "catboost.CatBoost") -> _base.ModelObjective:
37
+ def get_model_objective(cls, model: "catboost.CatBoost") -> model_meta_schema.ModelObjective:
38
38
  import catboost
39
39
 
40
40
  if isinstance(model, catboost.CatBoostClassifier):
41
41
  num_classes = handlers_utils.get_num_classes_if_exists(model)
42
42
  if num_classes == 2:
43
- return _base.ModelObjective.BINARY_CLASSIFICATION
44
- return _base.ModelObjective.MULTI_CLASSIFICATION
43
+ return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
44
+ return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
45
45
  if isinstance(model, catboost.CatBoostRanker):
46
- return _base.ModelObjective.RANKING
46
+ return model_meta_schema.ModelObjective.RANKING
47
47
  if isinstance(model, catboost.CatBoostRegressor):
48
- return _base.ModelObjective.REGRESSION
48
+ return model_meta_schema.ModelObjective.REGRESSION
49
49
  # TODO: Find out model type from the generic Catboost Model
50
- return _base.ModelObjective.UNKNOWN
50
+ return model_meta_schema.ModelObjective.UNKNOWN
51
51
 
52
52
  @classmethod
53
53
  def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
@@ -105,9 +105,11 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
105
105
  sample_input_data=sample_input_data,
106
106
  get_prediction_fn=get_prediction,
107
107
  )
108
- if kwargs.get("enable_explainability", False):
108
+ model_objective = cls.get_model_objective(model)
109
+ model_meta.model_objective = model_objective
110
+ if kwargs.get("enable_explainability", True):
109
111
  output_type = model_signature.DataType.DOUBLE
110
- if cls.get_model_objective(model) == _base.ModelObjective.MULTI_CLASSIFICATION:
112
+ if model_objective == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
111
113
  output_type = model_signature.DataType.STRING
112
114
  model_meta = handlers_utils.add_explain_method_signature(
113
115
  model_meta=model_meta,
@@ -115,10 +117,13 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
115
117
  target_method="predict",
116
118
  output_return_type=output_type,
117
119
  )
120
+ model_meta.function_properties = {
121
+ "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
122
+ }
118
123
 
119
124
  model_blob_path = os.path.join(model_blobs_dir_path, name)
120
125
  os.makedirs(model_blob_path, exist_ok=True)
121
- model_save_path = os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
126
+ model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
122
127
 
123
128
  model.save_model(model_save_path)
124
129
 
@@ -126,7 +131,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
126
131
  name=name,
127
132
  model_type=cls.HANDLER_TYPE,
128
133
  handler_version=cls.HANDLER_VERSION,
129
- path=cls.MODELE_BLOB_FILE_OR_DIR,
134
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
130
135
  options=model_meta_schema.CatBoostModelBlobOptions({"catboost_estimator_type": model.__class__.__name__}),
131
136
  )
132
137
  model_meta.models[name] = base_meta
@@ -138,11 +143,12 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
138
143
  ],
139
144
  check_local_version=True,
140
145
  )
141
- if kwargs.get("enable_explainability", False):
146
+ if kwargs.get("enable_explainability", True):
142
147
  model_meta.env.include_if_absent(
143
148
  [model_env.ModelDependency(requirement="shap", pip_name="shap")],
144
149
  check_local_version=True,
145
150
  )
151
+ model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
146
152
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
147
153
 
148
154
  return None
@@ -188,6 +194,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
188
194
  cls,
189
195
  raw_model: "catboost.CatBoost",
190
196
  model_meta: model_meta_api.ModelMetadata,
197
+ background_data: Optional[pd.DataFrame] = None,
191
198
  **kwargs: Unpack[model_types.CatBoostModelLoadOptions],
192
199
  ) -> custom_model.CustomModel:
193
200
  import catboost
@@ -51,6 +51,9 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
51
51
  **kwargs: Unpack[model_types.CustomModelSaveOption],
52
52
  ) -> None:
53
53
  assert isinstance(model, custom_model.CustomModel)
54
+ enable_explainability = kwargs.get("enable_explainability", False)
55
+ if enable_explainability:
56
+ raise NotImplementedError("Explainability is not supported for custom model.")
54
57
 
55
58
  def get_prediction(
56
59
  target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
@@ -108,13 +111,13 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
108
111
  # Make sure that the module where the model is defined get pickled by value as well.
109
112
  cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
110
113
  pickled_obj = (model.__class__, model.context)
111
- with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
114
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
112
115
  cloudpickle.dump(pickled_obj, f)
113
116
  # model meta will be saved by the context manager
114
117
  model_meta.models[name] = model_blob_meta.ModelBlobMeta(
115
118
  name=name,
116
119
  model_type=cls.HANDLER_TYPE,
117
- path=cls.MODELE_BLOB_FILE_OR_DIR,
120
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
118
121
  handler_version=cls.HANDLER_VERSION,
119
122
  function_properties=model_meta.function_properties,
120
123
  artifacts={
@@ -183,6 +186,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
183
186
  cls,
184
187
  raw_model: custom_model.CustomModel,
185
188
  model_meta: model_meta_api.ModelMetadata,
189
+ background_data: Optional[pd.DataFrame] = None,
186
190
  **kwargs: Unpack[model_types.CustomModelLoadOption],
187
191
  ) -> custom_model.CustomModel:
188
192
  return raw_model
@@ -89,7 +89,7 @@ class HuggingFacePipelineHandler(
89
89
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
90
90
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
91
91
 
92
- MODELE_BLOB_FILE_OR_DIR = "model"
92
+ MODEL_BLOB_FILE_OR_DIR = "model"
93
93
  ADDITIONAL_CONFIG_FILE = "pipeline_config.pt"
94
94
  DEFAULT_TARGET_METHODS = ["__call__"]
95
95
  IS_AUTO_SIGNATURE = True
@@ -133,6 +133,9 @@ class HuggingFacePipelineHandler(
133
133
  is_sub_model: Optional[bool] = False,
134
134
  **kwargs: Unpack[model_types.HuggingFaceSaveOptions],
135
135
  ) -> None:
136
+ enable_explainability = kwargs.get("enable_explainability", False)
137
+ if enable_explainability:
138
+ raise NotImplementedError("Explainability is not supported for huggingface model.")
136
139
  if type_utils.LazyType("transformers.Pipeline").isinstance(model):
137
140
  task = model.task # type:ignore[attr-defined]
138
141
  framework = model.framework # type:ignore[attr-defined]
@@ -193,7 +196,7 @@ class HuggingFacePipelineHandler(
193
196
 
194
197
  if type_utils.LazyType("transformers.Pipeline").isinstance(model):
195
198
  model.save_pretrained( # type:ignore[attr-defined]
196
- os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
199
+ os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
197
200
  )
198
201
  pipeline_params = {
199
202
  "_batch_size": model._batch_size, # type:ignore[attr-defined]
@@ -205,7 +208,7 @@ class HuggingFacePipelineHandler(
205
208
  with open(
206
209
  os.path.join(
207
210
  model_blob_path,
208
- cls.MODELE_BLOB_FILE_OR_DIR,
211
+ cls.MODEL_BLOB_FILE_OR_DIR,
209
212
  cls.ADDITIONAL_CONFIG_FILE,
210
213
  ),
211
214
  "wb",
@@ -213,7 +216,7 @@ class HuggingFacePipelineHandler(
213
216
  cloudpickle.dump(pipeline_params, f)
214
217
  else:
215
218
  with open(
216
- os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR),
219
+ os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR),
217
220
  "wb",
218
221
  ) as f:
219
222
  cloudpickle.dump(model, f)
@@ -222,7 +225,7 @@ class HuggingFacePipelineHandler(
222
225
  name=name,
223
226
  model_type=cls.HANDLER_TYPE,
224
227
  handler_version=cls.HANDLER_VERSION,
225
- path=cls.MODELE_BLOB_FILE_OR_DIR,
228
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
226
229
  options=model_meta_schema.HuggingFacePipelineModelBlobOptions(
227
230
  {
228
231
  "task": task,
@@ -329,6 +332,7 @@ class HuggingFacePipelineHandler(
329
332
  cls,
330
333
  raw_model: Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"],
331
334
  model_meta: model_meta_api.ModelMetadata,
335
+ background_data: Optional[pd.DataFrame] = None,
332
336
  **kwargs: Unpack[model_types.HuggingFaceLoadOptions],
333
337
  ) -> custom_model.CustomModel:
334
338
  import transformers