snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +14 -0
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +59 -6
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_dataset.py +33 -0
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +59 -24
- snowflake/ml/feature_store/feature_view.py +148 -4
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +171 -20
- snowflake/ml/model/_client/ops/model_ops.py +105 -27
- snowflake/ml/model/_client/ops/service_ops.py +121 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +129 -0
- snowflake/ml/model/_model_composer/model_composer.py +3 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +10 -2
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/catboost.py +19 -12
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +27 -18
- snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
- snowflake/ml/model/_packager/model_handlers/xgboost.py +25 -16
- snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
- snowflake/ml/model/_packager/model_packager.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/type_hints.py +1 -3
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/pipeline/pipeline.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +16 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +35 -2
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +73 -62
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,31 @@
|
|
1
|
+
from typing import List, TypedDict
|
2
|
+
|
3
|
+
from typing_extensions import NotRequired, Required
|
4
|
+
|
5
|
+
|
6
|
+
class ModelDict(TypedDict):
|
7
|
+
name: Required[str]
|
8
|
+
version: Required[str]
|
9
|
+
|
10
|
+
|
11
|
+
class ImageBuildDict(TypedDict):
|
12
|
+
compute_pool: Required[str]
|
13
|
+
image_repo: Required[str]
|
14
|
+
image_name: NotRequired[str]
|
15
|
+
force_rebuild: Required[bool]
|
16
|
+
external_access_integrations: Required[List[str]]
|
17
|
+
|
18
|
+
|
19
|
+
class ServiceDict(TypedDict):
|
20
|
+
name: Required[str]
|
21
|
+
compute_pool: Required[str]
|
22
|
+
ingress_enabled: Required[bool]
|
23
|
+
min_instances: Required[int]
|
24
|
+
max_instances: Required[int]
|
25
|
+
gpu: NotRequired[str]
|
26
|
+
|
27
|
+
|
28
|
+
class ModelDeploymentSpecDict(TypedDict):
|
29
|
+
models: Required[List[ModelDict]]
|
30
|
+
image_build: Required[ImageBuildDict]
|
31
|
+
service: Required[ServiceDict]
|
@@ -371,6 +371,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
371
371
|
returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
372
372
|
partition_column: Optional[sql_identifier.SqlIdentifier],
|
373
373
|
statement_params: Optional[Dict[str, Any]] = None,
|
374
|
+
is_partitioned: bool = True,
|
374
375
|
) -> dataframe.DataFrame:
|
375
376
|
with_statements = []
|
376
377
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
@@ -409,12 +410,20 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
409
410
|
|
410
411
|
sql = textwrap.dedent(
|
411
412
|
f"""WITH {','.join(with_statements)}
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
OVER (PARTITION BY {partition_by}))"""
|
413
|
+
SELECT *,
|
414
|
+
FROM {INTERMEDIATE_TABLE_NAME},
|
415
|
+
TABLE({module_version_alias}!{method_name.identifier()}({args_sql}))"""
|
416
416
|
)
|
417
417
|
|
418
|
+
if is_partitioned or partition_column is not None:
|
419
|
+
sql = textwrap.dedent(
|
420
|
+
f"""WITH {','.join(with_statements)}
|
421
|
+
SELECT *,
|
422
|
+
FROM {INTERMEDIATE_TABLE_NAME},
|
423
|
+
TABLE({module_version_alias}!{method_name.identifier()}({args_sql})
|
424
|
+
OVER (PARTITION BY {partition_by}))"""
|
425
|
+
)
|
426
|
+
|
418
427
|
output_df = self._session.sql(sql)
|
419
428
|
|
420
429
|
# Prepare the output
|
@@ -0,0 +1,129 @@
|
|
1
|
+
import textwrap
|
2
|
+
from typing import Any, Dict, List, Optional, Tuple
|
3
|
+
|
4
|
+
from snowflake.ml._internal.utils import (
|
5
|
+
identifier,
|
6
|
+
query_result_checker,
|
7
|
+
sql_identifier,
|
8
|
+
)
|
9
|
+
from snowflake.ml.model._client.sql import _base
|
10
|
+
from snowflake.snowpark import dataframe, functions as F, types as spt
|
11
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
12
|
+
|
13
|
+
|
14
|
+
class ServiceSQLClient(_base._BaseSQLClient):
|
15
|
+
def build_model_container(
|
16
|
+
self,
|
17
|
+
*,
|
18
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
19
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
20
|
+
model_name: sql_identifier.SqlIdentifier,
|
21
|
+
version_name: sql_identifier.SqlIdentifier,
|
22
|
+
compute_pool_name: sql_identifier.SqlIdentifier,
|
23
|
+
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
24
|
+
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
25
|
+
image_repo_name: sql_identifier.SqlIdentifier,
|
26
|
+
gpu: Optional[str],
|
27
|
+
force_rebuild: bool,
|
28
|
+
external_access_integration: sql_identifier.SqlIdentifier,
|
29
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
30
|
+
) -> None:
|
31
|
+
actual_image_repo_database = image_repo_database_name or self._database_name
|
32
|
+
actual_image_repo_schema = image_repo_schema_name or self._schema_name
|
33
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
34
|
+
fq_image_repo_name = "/" + "/".join(
|
35
|
+
[
|
36
|
+
actual_image_repo_database.identifier(),
|
37
|
+
actual_image_repo_schema.identifier(),
|
38
|
+
image_repo_name.identifier(),
|
39
|
+
]
|
40
|
+
)
|
41
|
+
is_gpu = gpu is not None
|
42
|
+
query_result_checker.SqlResultValidator(
|
43
|
+
self._session,
|
44
|
+
(
|
45
|
+
f"CALL SYSTEM$BUILD_MODEL_CONTAINER('{fq_model_name}', '{version_name}', '{compute_pool_name}',"
|
46
|
+
f" '{fq_image_repo_name}', '{is_gpu}', '{force_rebuild}', '', '{external_access_integration}')"
|
47
|
+
),
|
48
|
+
statement_params=statement_params,
|
49
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
50
|
+
|
51
|
+
def deploy_model(
|
52
|
+
self,
|
53
|
+
*,
|
54
|
+
stage_path: str,
|
55
|
+
model_deployment_spec_file_rel_path: str,
|
56
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
57
|
+
) -> None:
|
58
|
+
query_result_checker.SqlResultValidator(
|
59
|
+
self._session,
|
60
|
+
f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')",
|
61
|
+
statement_params=statement_params,
|
62
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
63
|
+
|
64
|
+
def invoke_function_method(
|
65
|
+
self,
|
66
|
+
*,
|
67
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
68
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
69
|
+
service_name: sql_identifier.SqlIdentifier,
|
70
|
+
method_name: sql_identifier.SqlIdentifier,
|
71
|
+
input_df: dataframe.DataFrame,
|
72
|
+
input_args: List[sql_identifier.SqlIdentifier],
|
73
|
+
returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
74
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
75
|
+
) -> dataframe.DataFrame:
|
76
|
+
with_statements = []
|
77
|
+
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
78
|
+
INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
79
|
+
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
80
|
+
else:
|
81
|
+
actual_database_name = database_name or self._database_name
|
82
|
+
actual_schema_name = schema_name or self._schema_name
|
83
|
+
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
84
|
+
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
85
|
+
actual_database_name.identifier(),
|
86
|
+
actual_schema_name.identifier(),
|
87
|
+
tmp_table_name,
|
88
|
+
)
|
89
|
+
input_df.write.save_as_table(
|
90
|
+
table_name=INTERMEDIATE_TABLE_NAME,
|
91
|
+
mode="errorifexists",
|
92
|
+
table_type="temporary",
|
93
|
+
statement_params=statement_params,
|
94
|
+
)
|
95
|
+
|
96
|
+
INTERMEDIATE_OBJ_NAME = "TMP_RESULT"
|
97
|
+
|
98
|
+
with_sql = f"WITH {','.join(with_statements)}" if with_statements else ""
|
99
|
+
args_sql_list = []
|
100
|
+
for input_arg_value in input_args:
|
101
|
+
args_sql_list.append(input_arg_value)
|
102
|
+
args_sql = ", ".join(args_sql_list)
|
103
|
+
|
104
|
+
sql = textwrap.dedent(
|
105
|
+
f"""{with_sql}
|
106
|
+
SELECT *,
|
107
|
+
{service_name.identifier()}_{method_name.identifier()}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
|
108
|
+
FROM {INTERMEDIATE_TABLE_NAME}"""
|
109
|
+
)
|
110
|
+
|
111
|
+
output_df = self._session.sql(sql)
|
112
|
+
|
113
|
+
# Prepare the output
|
114
|
+
output_cols = []
|
115
|
+
output_names = []
|
116
|
+
|
117
|
+
for output_name, output_type, output_col_name in returns:
|
118
|
+
output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_name].astype(output_type))
|
119
|
+
output_names.append(output_col_name)
|
120
|
+
|
121
|
+
output_df = output_df.with_columns(
|
122
|
+
col_names=output_names,
|
123
|
+
values=output_cols,
|
124
|
+
).drop(INTERMEDIATE_OBJ_NAME)
|
125
|
+
|
126
|
+
if statement_params:
|
127
|
+
output_df._statement_params = statement_params # type: ignore[assignment]
|
128
|
+
|
129
|
+
return output_df
|
@@ -10,6 +10,7 @@ from absl import logging
|
|
10
10
|
from packaging import requirements
|
11
11
|
from typing_extensions import deprecated
|
12
12
|
|
13
|
+
from snowflake import snowpark
|
13
14
|
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
14
15
|
from snowflake.ml._internal.lineage import lineage_utils
|
15
16
|
from snowflake.ml.data import data_source
|
@@ -185,4 +186,6 @@ class ModelComposer:
|
|
185
186
|
data_sources = lineage_utils.get_data_sources(model)
|
186
187
|
if not data_sources and sample_input_data is not None:
|
187
188
|
data_sources = lineage_utils.get_data_sources(sample_input_data)
|
189
|
+
if not data_sources and isinstance(sample_input_data, snowpark.DataFrame):
|
190
|
+
data_sources = [data_source.DataFrameInfo(sample_input_data.queries["queries"][-1])]
|
188
191
|
return data_sources
|
@@ -6,6 +6,7 @@ from typing import List, Optional, cast
|
|
6
6
|
|
7
7
|
import yaml
|
8
8
|
|
9
|
+
from snowflake.ml._internal import env_utils
|
9
10
|
from snowflake.ml.data import data_source
|
10
11
|
from snowflake.ml.model import type_hints
|
11
12
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
@@ -47,7 +48,9 @@ class ModelManifest:
|
|
47
48
|
runtime_to_use = copy.deepcopy(model_meta.runtimes["cpu"])
|
48
49
|
runtime_to_use.name = self._DEFAULT_RUNTIME_NAME
|
49
50
|
runtime_to_use.imports.append(str(model_rel_path) + "/")
|
50
|
-
runtime_dict = runtime_to_use.save(
|
51
|
+
runtime_dict = runtime_to_use.save(
|
52
|
+
self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
53
|
+
)
|
51
54
|
|
52
55
|
self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
|
53
56
|
self.methods: List[model_method.ModelMethod] = []
|
@@ -137,10 +140,15 @@ class ModelManifest:
|
|
137
140
|
if isinstance(source, data_source.DatasetInfo):
|
138
141
|
result.append(
|
139
142
|
model_manifest_schema.LineageSourceDict(
|
140
|
-
# Currently, we only support lineage from Dataset.
|
141
143
|
type=model_manifest_schema.LineageSourceTypes.DATASET.value,
|
142
144
|
entity=source.fully_qualified_name,
|
143
145
|
version=source.version,
|
144
146
|
)
|
145
147
|
)
|
148
|
+
elif isinstance(source, data_source.DataFrameInfo):
|
149
|
+
result.append(
|
150
|
+
model_manifest_schema.LineageSourceDict(
|
151
|
+
type=model_manifest_schema.LineageSourceTypes.QUERY.value, entity=source.sql
|
152
|
+
)
|
153
|
+
)
|
146
154
|
return result
|
@@ -57,12 +57,14 @@ class ModelFunctionInfo(TypedDict):
|
|
57
57
|
target_method: actual target method name to be called.
|
58
58
|
target_method_function_type: target method function type (FUNCTION or TABLE_FUNCTION).
|
59
59
|
signature: The signature of the model method.
|
60
|
+
is_partitioned: Whether the function is partitioned.
|
60
61
|
"""
|
61
62
|
|
62
63
|
name: Required[str]
|
63
64
|
target_method: Required[str]
|
64
65
|
target_method_function_type: Required[str]
|
65
66
|
signature: Required[model_signature.ModelSignature]
|
67
|
+
is_partitioned: Required[bool]
|
66
68
|
|
67
69
|
|
68
70
|
class ModelFunctionInfoDict(TypedDict):
|
@@ -78,6 +80,7 @@ class SnowparkMLDataDict(TypedDict):
|
|
78
80
|
|
79
81
|
class LineageSourceTypes(enum.Enum):
|
80
82
|
DATASET = "DATASET"
|
83
|
+
QUERY = "QUERY"
|
81
84
|
|
82
85
|
|
83
86
|
class LineageSourceDict(TypedDict):
|
@@ -363,9 +363,14 @@ class ModelEnv:
|
|
363
363
|
self.cuda_version = env_dict.get("cuda_version", None)
|
364
364
|
self.snowpark_ml_version = env_dict["snowpark_ml_version"]
|
365
365
|
|
366
|
-
def save_as_dict(
|
366
|
+
def save_as_dict(
|
367
|
+
self, base_dir: pathlib.Path, default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
368
|
+
) -> model_meta_schema.ModelEnvDict:
|
367
369
|
env_utils.save_conda_env_file(
|
368
|
-
pathlib.Path(base_dir / self.conda_env_rel_path),
|
370
|
+
pathlib.Path(base_dir / self.conda_env_rel_path),
|
371
|
+
self._conda_dependencies,
|
372
|
+
self.python_version,
|
373
|
+
default_channel_override=default_channel_override,
|
369
374
|
)
|
370
375
|
env_utils.save_requirements_file(
|
371
376
|
pathlib.Path(base_dir / self.pip_requirements_rel_path), self._pip_requirements
|
@@ -1,7 +1,8 @@
|
|
1
|
+
import os
|
1
2
|
from abc import abstractmethod
|
2
|
-
from enum import Enum
|
3
3
|
from typing import Dict, Generic, Optional, Protocol, Type, final
|
4
4
|
|
5
|
+
import pandas as pd
|
5
6
|
from typing_extensions import TypeGuard, Unpack
|
6
7
|
|
7
8
|
from snowflake.ml.model import custom_model, type_hints as model_types
|
@@ -9,15 +10,6 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
9
10
|
from snowflake.ml.model._packager.model_meta import model_meta
|
10
11
|
|
11
12
|
|
12
|
-
class ModelObjective(Enum):
|
13
|
-
# This is not getting stored anywhere as metadata yet so it should be fine to slowly extend it for better coverage
|
14
|
-
UNKNOWN = "unknown"
|
15
|
-
BINARY_CLASSIFICATION = "binary_classification"
|
16
|
-
MULTI_CLASSIFICATION = "multi_classification"
|
17
|
-
REGRESSION = "regression"
|
18
|
-
RANKING = "ranking"
|
19
|
-
|
20
|
-
|
21
13
|
class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
22
14
|
HANDLER_TYPE: model_types.SupportedModelHandlerType
|
23
15
|
HANDLER_VERSION: str
|
@@ -106,6 +98,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
|
106
98
|
cls,
|
107
99
|
raw_model: model_types._ModelType,
|
108
100
|
model_meta: model_meta.ModelMetadata,
|
101
|
+
background_data: Optional[pd.DataFrame] = None,
|
109
102
|
**kwargs: Unpack[model_types.BaseModelLoadOption],
|
110
103
|
) -> custom_model.CustomModel:
|
111
104
|
"""Create a custom model class wrap for unified interface when being deployed. The predict method will be
|
@@ -114,6 +107,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
|
114
107
|
Args:
|
115
108
|
raw_model: original model object,
|
116
109
|
model_meta: The model metadata.
|
110
|
+
background_data: The background data used for the model explanations.
|
117
111
|
kwargs: Options when converting the model.
|
118
112
|
|
119
113
|
Raises:
|
@@ -131,7 +125,8 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
|
|
131
125
|
_MIN_SNOWPARK_ML_VERSION: The minimal version of Snowpark ML library to use the current handler.
|
132
126
|
_HANDLER_MIGRATOR_PLANS: Dict holding handler migrator plans.
|
133
127
|
|
134
|
-
|
128
|
+
MODEL_BLOB_FILE_OR_DIR: Relative path of the model blob file in the model subdir. Default to "model.pkl".
|
129
|
+
BG_DATA_FILE_SUFFIX: Suffix of the background data file. Default to "_background_data.pqt".
|
135
130
|
MODEL_ARTIFACTS_DIR: Relative path of the model artifacts dir in the model subdir. Default to "artifacts"
|
136
131
|
DEFAULT_TARGET_METHODS: Default target methods to be logged if not specified in this kind of model. Default to
|
137
132
|
["predict"]
|
@@ -139,8 +134,10 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
|
|
139
134
|
inputting sample data or model signature. Default to False.
|
140
135
|
"""
|
141
136
|
|
142
|
-
|
137
|
+
MODEL_BLOB_FILE_OR_DIR = "model.pkl"
|
138
|
+
BG_DATA_FILE_SUFFIX = "_background_data.pqt"
|
143
139
|
MODEL_ARTIFACTS_DIR = "artifacts"
|
140
|
+
EXPLAIN_ARTIFACTS_DIR = "explain_artifacts"
|
144
141
|
DEFAULT_TARGET_METHODS = ["predict"]
|
145
142
|
IS_AUTO_SIGNATURE = False
|
146
143
|
|
@@ -169,3 +166,23 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
|
|
169
166
|
model_meta=model_meta,
|
170
167
|
model_blobs_dir_path=model_blobs_dir_path,
|
171
168
|
)
|
169
|
+
|
170
|
+
@classmethod
|
171
|
+
@final
|
172
|
+
def load_background_data(cls, name: str, model_blobs_dir_path: str) -> Optional[pd.DataFrame]:
|
173
|
+
"""Load the model into memory.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
name: Name of the model.
|
177
|
+
model_blobs_dir_path: Directory path to the whole model.
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
Optional[pd.DataFrame], background data as pandas DataFrame, if exists.
|
181
|
+
"""
|
182
|
+
data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, name + cls.BG_DATA_FILE_SUFFIX)
|
183
|
+
if not os.path.exists(model_blobs_dir_path) or not os.path.isfile(data_blob_path):
|
184
|
+
return None
|
185
|
+
with open(data_blob_path, "rb") as f:
|
186
|
+
background_data = pd.read_parquet(f)
|
187
|
+
|
188
|
+
return background_data
|
@@ -30,24 +30,24 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
30
30
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
31
31
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
32
32
|
|
33
|
-
|
33
|
+
MODEL_BLOB_FILE_OR_DIR = "model.bin"
|
34
34
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
35
35
|
|
36
36
|
@classmethod
|
37
|
-
def get_model_objective(cls, model: "catboost.CatBoost") ->
|
37
|
+
def get_model_objective(cls, model: "catboost.CatBoost") -> model_meta_schema.ModelObjective:
|
38
38
|
import catboost
|
39
39
|
|
40
40
|
if isinstance(model, catboost.CatBoostClassifier):
|
41
41
|
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
42
42
|
if num_classes == 2:
|
43
|
-
return
|
44
|
-
return
|
43
|
+
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
44
|
+
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
45
45
|
if isinstance(model, catboost.CatBoostRanker):
|
46
|
-
return
|
46
|
+
return model_meta_schema.ModelObjective.RANKING
|
47
47
|
if isinstance(model, catboost.CatBoostRegressor):
|
48
|
-
return
|
48
|
+
return model_meta_schema.ModelObjective.REGRESSION
|
49
49
|
# TODO: Find out model type from the generic Catboost Model
|
50
|
-
return
|
50
|
+
return model_meta_schema.ModelObjective.UNKNOWN
|
51
51
|
|
52
52
|
@classmethod
|
53
53
|
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
|
@@ -105,9 +105,11 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
105
105
|
sample_input_data=sample_input_data,
|
106
106
|
get_prediction_fn=get_prediction,
|
107
107
|
)
|
108
|
-
|
108
|
+
model_objective = cls.get_model_objective(model)
|
109
|
+
model_meta.model_objective = model_objective
|
110
|
+
if kwargs.get("enable_explainability", True):
|
109
111
|
output_type = model_signature.DataType.DOUBLE
|
110
|
-
if
|
112
|
+
if model_objective == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
|
111
113
|
output_type = model_signature.DataType.STRING
|
112
114
|
model_meta = handlers_utils.add_explain_method_signature(
|
113
115
|
model_meta=model_meta,
|
@@ -115,10 +117,13 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
115
117
|
target_method="predict",
|
116
118
|
output_return_type=output_type,
|
117
119
|
)
|
120
|
+
model_meta.function_properties = {
|
121
|
+
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
122
|
+
}
|
118
123
|
|
119
124
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
120
125
|
os.makedirs(model_blob_path, exist_ok=True)
|
121
|
-
model_save_path = os.path.join(model_blob_path, cls.
|
126
|
+
model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
122
127
|
|
123
128
|
model.save_model(model_save_path)
|
124
129
|
|
@@ -126,7 +131,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
126
131
|
name=name,
|
127
132
|
model_type=cls.HANDLER_TYPE,
|
128
133
|
handler_version=cls.HANDLER_VERSION,
|
129
|
-
path=cls.
|
134
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
130
135
|
options=model_meta_schema.CatBoostModelBlobOptions({"catboost_estimator_type": model.__class__.__name__}),
|
131
136
|
)
|
132
137
|
model_meta.models[name] = base_meta
|
@@ -138,11 +143,12 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
138
143
|
],
|
139
144
|
check_local_version=True,
|
140
145
|
)
|
141
|
-
if kwargs.get("enable_explainability",
|
146
|
+
if kwargs.get("enable_explainability", True):
|
142
147
|
model_meta.env.include_if_absent(
|
143
148
|
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
144
149
|
check_local_version=True,
|
145
150
|
)
|
151
|
+
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
146
152
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
147
153
|
|
148
154
|
return None
|
@@ -188,6 +194,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
188
194
|
cls,
|
189
195
|
raw_model: "catboost.CatBoost",
|
190
196
|
model_meta: model_meta_api.ModelMetadata,
|
197
|
+
background_data: Optional[pd.DataFrame] = None,
|
191
198
|
**kwargs: Unpack[model_types.CatBoostModelLoadOptions],
|
192
199
|
) -> custom_model.CustomModel:
|
193
200
|
import catboost
|
@@ -51,6 +51,9 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
51
51
|
**kwargs: Unpack[model_types.CustomModelSaveOption],
|
52
52
|
) -> None:
|
53
53
|
assert isinstance(model, custom_model.CustomModel)
|
54
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
55
|
+
if enable_explainability:
|
56
|
+
raise NotImplementedError("Explainability is not supported for custom model.")
|
54
57
|
|
55
58
|
def get_prediction(
|
56
59
|
target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
|
@@ -108,13 +111,13 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
108
111
|
# Make sure that the module where the model is defined get pickled by value as well.
|
109
112
|
cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
|
110
113
|
pickled_obj = (model.__class__, model.context)
|
111
|
-
with open(os.path.join(model_blob_path, cls.
|
114
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
112
115
|
cloudpickle.dump(pickled_obj, f)
|
113
116
|
# model meta will be saved by the context manager
|
114
117
|
model_meta.models[name] = model_blob_meta.ModelBlobMeta(
|
115
118
|
name=name,
|
116
119
|
model_type=cls.HANDLER_TYPE,
|
117
|
-
path=cls.
|
120
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
118
121
|
handler_version=cls.HANDLER_VERSION,
|
119
122
|
function_properties=model_meta.function_properties,
|
120
123
|
artifacts={
|
@@ -183,6 +186,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
183
186
|
cls,
|
184
187
|
raw_model: custom_model.CustomModel,
|
185
188
|
model_meta: model_meta_api.ModelMetadata,
|
189
|
+
background_data: Optional[pd.DataFrame] = None,
|
186
190
|
**kwargs: Unpack[model_types.CustomModelLoadOption],
|
187
191
|
) -> custom_model.CustomModel:
|
188
192
|
return raw_model
|
@@ -89,7 +89,7 @@ class HuggingFacePipelineHandler(
|
|
89
89
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
90
90
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
91
91
|
|
92
|
-
|
92
|
+
MODEL_BLOB_FILE_OR_DIR = "model"
|
93
93
|
ADDITIONAL_CONFIG_FILE = "pipeline_config.pt"
|
94
94
|
DEFAULT_TARGET_METHODS = ["__call__"]
|
95
95
|
IS_AUTO_SIGNATURE = True
|
@@ -133,6 +133,9 @@ class HuggingFacePipelineHandler(
|
|
133
133
|
is_sub_model: Optional[bool] = False,
|
134
134
|
**kwargs: Unpack[model_types.HuggingFaceSaveOptions],
|
135
135
|
) -> None:
|
136
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
137
|
+
if enable_explainability:
|
138
|
+
raise NotImplementedError("Explainability is not supported for huggingface model.")
|
136
139
|
if type_utils.LazyType("transformers.Pipeline").isinstance(model):
|
137
140
|
task = model.task # type:ignore[attr-defined]
|
138
141
|
framework = model.framework # type:ignore[attr-defined]
|
@@ -193,7 +196,7 @@ class HuggingFacePipelineHandler(
|
|
193
196
|
|
194
197
|
if type_utils.LazyType("transformers.Pipeline").isinstance(model):
|
195
198
|
model.save_pretrained( # type:ignore[attr-defined]
|
196
|
-
os.path.join(model_blob_path, cls.
|
199
|
+
os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
197
200
|
)
|
198
201
|
pipeline_params = {
|
199
202
|
"_batch_size": model._batch_size, # type:ignore[attr-defined]
|
@@ -205,7 +208,7 @@ class HuggingFacePipelineHandler(
|
|
205
208
|
with open(
|
206
209
|
os.path.join(
|
207
210
|
model_blob_path,
|
208
|
-
cls.
|
211
|
+
cls.MODEL_BLOB_FILE_OR_DIR,
|
209
212
|
cls.ADDITIONAL_CONFIG_FILE,
|
210
213
|
),
|
211
214
|
"wb",
|
@@ -213,7 +216,7 @@ class HuggingFacePipelineHandler(
|
|
213
216
|
cloudpickle.dump(pipeline_params, f)
|
214
217
|
else:
|
215
218
|
with open(
|
216
|
-
os.path.join(model_blob_path, cls.
|
219
|
+
os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR),
|
217
220
|
"wb",
|
218
221
|
) as f:
|
219
222
|
cloudpickle.dump(model, f)
|
@@ -222,7 +225,7 @@ class HuggingFacePipelineHandler(
|
|
222
225
|
name=name,
|
223
226
|
model_type=cls.HANDLER_TYPE,
|
224
227
|
handler_version=cls.HANDLER_VERSION,
|
225
|
-
path=cls.
|
228
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
226
229
|
options=model_meta_schema.HuggingFacePipelineModelBlobOptions(
|
227
230
|
{
|
228
231
|
"task": task,
|
@@ -329,6 +332,7 @@ class HuggingFacePipelineHandler(
|
|
329
332
|
cls,
|
330
333
|
raw_model: Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"],
|
331
334
|
model_meta: model_meta_api.ModelMetadata,
|
335
|
+
background_data: Optional[pd.DataFrame] = None,
|
332
336
|
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
333
337
|
) -> custom_model.CustomModel:
|
334
338
|
import transformers
|