snowflake-ml-python 1.7.5__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +58 -3
- snowflake/ml/_internal/file_utils.py +18 -4
- snowflake/ml/_internal/platform_capabilities.py +3 -0
- snowflake/ml/_internal/telemetry.py +4 -0
- snowflake/ml/fileset/fileset.py +0 -1
- snowflake/ml/jobs/_utils/constants.py +24 -0
- snowflake/ml/jobs/_utils/payload_utils.py +94 -20
- snowflake/ml/jobs/_utils/spec_utils.py +73 -31
- snowflake/ml/jobs/decorators.py +3 -0
- snowflake/ml/jobs/manager.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/model_ops.py +107 -14
- snowflake/ml/model/_client/ops/service_ops.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
- snowflake/ml/model/_client/sql/model_version.py +58 -0
- snowflake/ml/model/_client/sql/service.py +8 -2
- snowflake/ml/model/_model_composer/model_composer.py +50 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
- snowflake/ml/model/_packager/model_env/model_env.py +4 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +28 -24
- snowflake/ml/model/_packager/model_handlers/keras.py +1 -5
- snowflake/ml/model/_packager/model_handlers/pytorch.py +50 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +46 -26
- snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
- snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -0
- snowflake/ml/model/_packager/model_packager.py +3 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -0
- snowflake/ml/model/_signatures/builtins_handler.py +20 -9
- snowflake/ml/model/_signatures/core.py +52 -31
- snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
- snowflake/ml/model/_signatures/numpy_handler.py +9 -17
- snowflake/ml/model/_signatures/pandas_handler.py +19 -30
- snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
- snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
- snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
- snowflake/ml/model/_signatures/utils.py +120 -8
- snowflake/ml/model/custom_model.py +13 -4
- snowflake/ml/model/model_signature.py +31 -13
- snowflake/ml/model/type_hints.py +13 -2
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +55 -7
- snowflake/ml/registry/registry.py +18 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +287 -11
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +61 -57
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
|
|
1
|
+
import enum
|
1
2
|
import json
|
2
3
|
import os
|
3
4
|
import pathlib
|
@@ -31,6 +32,12 @@ from snowflake.snowpark import dataframe, row, session
|
|
31
32
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
32
33
|
|
33
34
|
|
35
|
+
# An enum class to represent Create Or Alter Model SQL command.
|
36
|
+
class ModelAction(enum.Enum):
|
37
|
+
CREATE = "CREATE"
|
38
|
+
ALTER = "ALTER"
|
39
|
+
|
40
|
+
|
34
41
|
class ServiceInfo(TypedDict):
|
35
42
|
name: str
|
36
43
|
status: str
|
@@ -92,7 +99,7 @@ class ModelOperator:
|
|
92
99
|
and self._model_version_client == __value._model_version_client
|
93
100
|
)
|
94
101
|
|
95
|
-
def
|
102
|
+
def prepare_model_temp_stage_path(
|
96
103
|
self,
|
97
104
|
*,
|
98
105
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
@@ -110,17 +117,28 @@ class ModelOperator:
|
|
110
117
|
)
|
111
118
|
return f"@{self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)}/model"
|
112
119
|
|
113
|
-
def
|
120
|
+
def get_model_version_stage_path(
|
121
|
+
self,
|
122
|
+
*,
|
123
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
124
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
125
|
+
model_name: sql_identifier.SqlIdentifier,
|
126
|
+
version_name: sql_identifier.SqlIdentifier,
|
127
|
+
) -> str:
|
128
|
+
return (
|
129
|
+
f"snow://model/{self._stage_client.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
130
|
+
f"/versions/{version_name}/"
|
131
|
+
)
|
132
|
+
|
133
|
+
def get_model_action_from_model_name_and_version(
|
114
134
|
self,
|
115
|
-
composed_model: model_composer.ModelComposer,
|
116
135
|
*,
|
117
136
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
118
137
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
119
138
|
model_name: sql_identifier.SqlIdentifier,
|
120
139
|
version_name: sql_identifier.SqlIdentifier,
|
121
140
|
statement_params: Optional[Dict[str, Any]] = None,
|
122
|
-
) ->
|
123
|
-
stage_path = str(composed_model.stage_path)
|
141
|
+
) -> ModelAction:
|
124
142
|
if self.validate_existence(
|
125
143
|
database_name=database_name,
|
126
144
|
schema_name=schema_name,
|
@@ -140,6 +158,79 @@ class ModelOperator:
|
|
140
158
|
f" version {version_name} already existed."
|
141
159
|
)
|
142
160
|
else:
|
161
|
+
return ModelAction.ALTER
|
162
|
+
else:
|
163
|
+
return ModelAction.CREATE
|
164
|
+
|
165
|
+
def add_or_create_live_version(
|
166
|
+
self,
|
167
|
+
*,
|
168
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
169
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
170
|
+
model_name: sql_identifier.SqlIdentifier,
|
171
|
+
version_name: sql_identifier.SqlIdentifier,
|
172
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
173
|
+
) -> None:
|
174
|
+
model_action = self.get_model_action_from_model_name_and_version(
|
175
|
+
database_name=database_name,
|
176
|
+
schema_name=schema_name,
|
177
|
+
model_name=model_name,
|
178
|
+
version_name=version_name,
|
179
|
+
statement_params=statement_params,
|
180
|
+
)
|
181
|
+
if model_action == ModelAction.CREATE:
|
182
|
+
self._model_version_client.create_live_version(
|
183
|
+
database_name=database_name,
|
184
|
+
schema_name=schema_name,
|
185
|
+
model_name=model_name,
|
186
|
+
version_name=version_name,
|
187
|
+
statement_params=statement_params,
|
188
|
+
)
|
189
|
+
elif model_action == ModelAction.ALTER:
|
190
|
+
self._model_version_client.add_live_version(
|
191
|
+
database_name=database_name,
|
192
|
+
schema_name=schema_name,
|
193
|
+
model_name=model_name,
|
194
|
+
version_name=version_name,
|
195
|
+
statement_params=statement_params,
|
196
|
+
)
|
197
|
+
else:
|
198
|
+
raise AssertionError(f"The model_action is {model_action}. Expected CREATE or ALTER.")
|
199
|
+
|
200
|
+
def create_from_stage(
|
201
|
+
self,
|
202
|
+
composed_model: model_composer.ModelComposer,
|
203
|
+
*,
|
204
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
205
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
206
|
+
model_name: sql_identifier.SqlIdentifier,
|
207
|
+
version_name: sql_identifier.SqlIdentifier,
|
208
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
209
|
+
use_live_commit: Optional[bool] = False,
|
210
|
+
) -> None:
|
211
|
+
|
212
|
+
if use_live_commit:
|
213
|
+
# if the model version is live, we can only commit the version
|
214
|
+
self._model_version_client.commit_version(
|
215
|
+
database_name=database_name,
|
216
|
+
schema_name=schema_name,
|
217
|
+
model_name=model_name,
|
218
|
+
version_name=version_name,
|
219
|
+
statement_params=statement_params,
|
220
|
+
)
|
221
|
+
else:
|
222
|
+
stage_path = str(composed_model.stage_path)
|
223
|
+
# if the model version is not live,
|
224
|
+
# find whether the model exists and whether the version exists
|
225
|
+
# and then decide whether to create or alter the model
|
226
|
+
model_action = self.get_model_action_from_model_name_and_version(
|
227
|
+
database_name=database_name,
|
228
|
+
schema_name=schema_name,
|
229
|
+
model_name=model_name,
|
230
|
+
version_name=version_name,
|
231
|
+
statement_params=statement_params,
|
232
|
+
)
|
233
|
+
if model_action == ModelAction.ALTER:
|
143
234
|
self._model_version_client.add_version_from_stage(
|
144
235
|
database_name=database_name,
|
145
236
|
schema_name=schema_name,
|
@@ -148,15 +239,17 @@ class ModelOperator:
|
|
148
239
|
version_name=version_name,
|
149
240
|
statement_params=statement_params,
|
150
241
|
)
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
242
|
+
elif model_action == ModelAction.CREATE:
|
243
|
+
self._model_version_client.create_from_stage(
|
244
|
+
database_name=database_name,
|
245
|
+
schema_name=schema_name,
|
246
|
+
stage_path=stage_path,
|
247
|
+
model_name=model_name,
|
248
|
+
version_name=version_name,
|
249
|
+
statement_params=statement_params,
|
250
|
+
)
|
251
|
+
else:
|
252
|
+
raise AssertionError(f"The model_action is {model_action}. Expected CREATE or ALTER.")
|
160
253
|
|
161
254
|
def create_from_model_version(
|
162
255
|
self,
|
@@ -100,7 +100,7 @@ class ServiceOperator:
|
|
100
100
|
max_instances: int,
|
101
101
|
cpu_requests: Optional[str],
|
102
102
|
memory_requests: Optional[str],
|
103
|
-
gpu_requests: Optional[str],
|
103
|
+
gpu_requests: Optional[Union[int, str]],
|
104
104
|
num_workers: Optional[int],
|
105
105
|
max_batch_rows: Optional[int],
|
106
106
|
force_rebuild: bool,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import pathlib
|
2
|
-
from typing import List, Optional
|
2
|
+
from typing import List, Optional, Union
|
3
3
|
|
4
4
|
import yaml
|
5
5
|
|
@@ -38,7 +38,7 @@ class ModelDeploymentSpec:
|
|
38
38
|
max_instances: int,
|
39
39
|
cpu: Optional[str],
|
40
40
|
memory: Optional[str],
|
41
|
-
gpu: Optional[str],
|
41
|
+
gpu: Optional[Union[str, int]],
|
42
42
|
num_workers: Optional[int],
|
43
43
|
max_batch_rows: Optional[int],
|
44
44
|
force_rebuild: bool,
|
@@ -86,7 +86,11 @@ class ModelDeploymentSpec:
|
|
86
86
|
service_dict["memory"] = memory
|
87
87
|
|
88
88
|
if gpu:
|
89
|
-
|
89
|
+
if isinstance(gpu, int):
|
90
|
+
gpu_str = str(gpu)
|
91
|
+
else:
|
92
|
+
gpu_str = gpu
|
93
|
+
service_dict["gpu"] = gpu_str
|
90
94
|
|
91
95
|
if num_workers:
|
92
96
|
service_dict["num_workers"] = num_workers
|
@@ -71,6 +71,64 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
71
71
|
statement_params=statement_params,
|
72
72
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
73
73
|
|
74
|
+
def create_live_version(
|
75
|
+
self,
|
76
|
+
*,
|
77
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
78
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
79
|
+
model_name: sql_identifier.SqlIdentifier,
|
80
|
+
version_name: sql_identifier.SqlIdentifier,
|
81
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
82
|
+
) -> None:
|
83
|
+
sql = (
|
84
|
+
f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
85
|
+
f" WITH LIVE VERSION {version_name.identifier()}"
|
86
|
+
)
|
87
|
+
query_result_checker.SqlResultValidator(
|
88
|
+
self._session,
|
89
|
+
sql,
|
90
|
+
statement_params=statement_params,
|
91
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
92
|
+
|
93
|
+
def add_live_version(
|
94
|
+
self,
|
95
|
+
*,
|
96
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
97
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
98
|
+
model_name: sql_identifier.SqlIdentifier,
|
99
|
+
version_name: sql_identifier.SqlIdentifier,
|
100
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
101
|
+
) -> None:
|
102
|
+
sql = (
|
103
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
104
|
+
f" ADD LIVE VERSION {version_name.identifier()}"
|
105
|
+
)
|
106
|
+
query_result_checker.SqlResultValidator(
|
107
|
+
self._session,
|
108
|
+
sql,
|
109
|
+
statement_params=statement_params,
|
110
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
111
|
+
|
112
|
+
def commit_version(
|
113
|
+
self,
|
114
|
+
*,
|
115
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
116
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
117
|
+
model_name: sql_identifier.SqlIdentifier,
|
118
|
+
version_name: sql_identifier.SqlIdentifier,
|
119
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
120
|
+
) -> None:
|
121
|
+
sql = (
|
122
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
123
|
+
f" COMMIT VERSION {version_name.identifier()}"
|
124
|
+
)
|
125
|
+
|
126
|
+
query_result_checker.SqlResultValidator(
|
127
|
+
self._session,
|
128
|
+
sql,
|
129
|
+
statement_params=statement_params,
|
130
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
131
|
+
|
74
132
|
# TODO(SNOW-987381): Merge with above when we have `create or alter module m [with] version v1 ...`
|
75
133
|
def add_version_from_stage(
|
76
134
|
self,
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import enum
|
2
2
|
import json
|
3
3
|
import textwrap
|
4
|
-
from typing import Any, Dict, List, Optional, Tuple
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
5
5
|
|
6
6
|
from snowflake import snowpark
|
7
7
|
from snowflake.ml._internal import platform_capabilities
|
@@ -11,6 +11,7 @@ from snowflake.ml._internal.utils import (
|
|
11
11
|
sql_identifier,
|
12
12
|
)
|
13
13
|
from snowflake.ml.model._client.sql import _base
|
14
|
+
from snowflake.ml.model._model_composer.model_method import constants
|
14
15
|
from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
15
16
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
16
17
|
|
@@ -41,7 +42,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
41
42
|
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
42
43
|
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
43
44
|
image_repo_name: sql_identifier.SqlIdentifier,
|
44
|
-
gpu: Optional[str],
|
45
|
+
gpu: Optional[Union[str, int]],
|
45
46
|
force_rebuild: bool,
|
46
47
|
external_access_integration: sql_identifier.SqlIdentifier,
|
47
48
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -121,6 +122,11 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
121
122
|
args_sql_list.append(input_arg_value)
|
122
123
|
args_sql = ", ".join(args_sql_list)
|
123
124
|
|
125
|
+
wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
126
|
+
if wide_input:
|
127
|
+
input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
|
128
|
+
args_sql = f"object_construct_keep_null({input_args_sql})"
|
129
|
+
|
124
130
|
if platform_capabilities.PlatformCapabilities.get_instance().is_nested_function_enabled():
|
125
131
|
fully_qualified_service_name = self.fully_qualified_object_name(
|
126
132
|
actual_database_name, actual_schema_name, service_name
|
@@ -1,8 +1,10 @@
|
|
1
1
|
import pathlib
|
2
2
|
import tempfile
|
3
3
|
import uuid
|
4
|
+
import warnings
|
4
5
|
from types import ModuleType
|
5
|
-
from typing import Any, Dict, List, Optional
|
6
|
+
from typing import Any, Dict, List, Optional, Union
|
7
|
+
from urllib import parse
|
6
8
|
|
7
9
|
from absl import logging
|
8
10
|
from packaging import requirements
|
@@ -44,7 +46,13 @@ class ModelComposer:
|
|
44
46
|
statement_params: Optional[Dict[str, Any]] = None,
|
45
47
|
) -> None:
|
46
48
|
self.session = session
|
47
|
-
self.stage_path
|
49
|
+
self.stage_path: Union[pathlib.PurePosixPath, parse.ParseResult] = None # type: ignore[assignment]
|
50
|
+
if stage_path.startswith("snow://"):
|
51
|
+
# The stage path is a snowflake internal stage path
|
52
|
+
self.stage_path = parse.urlparse(stage_path)
|
53
|
+
else:
|
54
|
+
# The stage path is a user stage path
|
55
|
+
self.stage_path = pathlib.PurePosixPath(stage_path)
|
48
56
|
|
49
57
|
self._workspace = tempfile.TemporaryDirectory()
|
50
58
|
self._packager_workspace = tempfile.TemporaryDirectory()
|
@@ -70,7 +78,20 @@ class ModelComposer:
|
|
70
78
|
|
71
79
|
@property
|
72
80
|
def model_stage_path(self) -> str:
|
73
|
-
|
81
|
+
if isinstance(self.stage_path, parse.ParseResult):
|
82
|
+
model_file_path = (pathlib.PosixPath(self.stage_path.path) / self.model_file_rel_path).as_posix()
|
83
|
+
new_url = parse.ParseResult(
|
84
|
+
scheme=self.stage_path.scheme,
|
85
|
+
netloc=self.stage_path.netloc,
|
86
|
+
path=str(model_file_path),
|
87
|
+
params=self.stage_path.params,
|
88
|
+
query=self.stage_path.query,
|
89
|
+
fragment=self.stage_path.fragment,
|
90
|
+
)
|
91
|
+
return str(parse.urlunparse(new_url))
|
92
|
+
else:
|
93
|
+
assert isinstance(self.stage_path, pathlib.PurePosixPath)
|
94
|
+
return (self.stage_path / self.model_file_rel_path).as_posix()
|
74
95
|
|
75
96
|
@property
|
76
97
|
def model_local_path(self) -> str:
|
@@ -86,6 +107,7 @@ class ModelComposer:
|
|
86
107
|
metadata: Optional[Dict[str, str]] = None,
|
87
108
|
conda_dependencies: Optional[List[str]] = None,
|
88
109
|
pip_requirements: Optional[List[str]] = None,
|
110
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
89
111
|
target_platforms: Optional[List[model_types.TargetPlatform]] = None,
|
90
112
|
python_version: Optional[str] = None,
|
91
113
|
user_files: Optional[Dict[str, List[str]]] = None,
|
@@ -94,8 +116,32 @@ class ModelComposer:
|
|
94
116
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
95
117
|
options: Optional[model_types.ModelSaveOption] = None,
|
96
118
|
) -> model_meta.ModelMetadata:
|
119
|
+
# set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
|
120
|
+
conda_dep_dict = env_utils.validate_conda_dependency_string_list(
|
121
|
+
conda_dependencies if conda_dependencies else []
|
122
|
+
)
|
123
|
+
is_warehouse_runnable = (
|
124
|
+
not conda_dep_dict
|
125
|
+
or all(
|
126
|
+
chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
127
|
+
for chan in conda_dep_dict
|
128
|
+
)
|
129
|
+
) and (not pip_requirements)
|
130
|
+
disable_explainability = (
|
131
|
+
target_platforms and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
|
132
|
+
) or (not is_warehouse_runnable)
|
133
|
+
|
134
|
+
if disable_explainability and options and options.get("enable_explainability", False):
|
135
|
+
warnings.warn(
|
136
|
+
("The model can be deployed to Snowpark Container Services only if `enable_explainability=False`."),
|
137
|
+
category=UserWarning,
|
138
|
+
stacklevel=2,
|
139
|
+
)
|
140
|
+
|
97
141
|
if not options:
|
98
142
|
options = model_types.BaseModelSaveOption()
|
143
|
+
if disable_explainability:
|
144
|
+
options["enable_explainability"] = False
|
99
145
|
|
100
146
|
if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
101
147
|
snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
|
@@ -120,6 +166,7 @@ class ModelComposer:
|
|
120
166
|
metadata=metadata,
|
121
167
|
conda_dependencies=conda_dependencies,
|
122
168
|
pip_requirements=pip_requirements,
|
169
|
+
artifact_repository_map=artifact_repository_map,
|
123
170
|
python_version=python_version,
|
124
171
|
ext_modules=ext_modules,
|
125
172
|
code_paths=code_paths,
|
@@ -78,6 +78,7 @@ class ModelManifest:
|
|
78
78
|
logger.info("Relaxing version constraints for dependencies in the model.")
|
79
79
|
logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}")
|
80
80
|
logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}")
|
81
|
+
logger.info(f"artifact_repository_map: {runtime_to_use.runtime_env.artifact_repository_map}")
|
81
82
|
runtime_dict = runtime_to_use.save(
|
82
83
|
self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
83
84
|
)
|
@@ -124,6 +125,9 @@ class ModelManifest:
|
|
124
125
|
if len(model_meta.env.pip_requirements) > 0:
|
125
126
|
dependencies["pip"] = runtime_dict["dependencies"]["pip"]
|
126
127
|
|
128
|
+
if model_meta.env.artifact_repository_map:
|
129
|
+
dependencies["artifact_repository_map"] = runtime_dict["dependencies"]["artifact_repository_map"]
|
130
|
+
|
127
131
|
manifest_dict = model_manifest_schema.ModelManifestDict(
|
128
132
|
manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
|
129
133
|
runtimes={
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# This files contains schema definition of what will be written into MANIFEST.yml
|
2
2
|
import enum
|
3
|
-
from typing import Any, Dict, List, Literal, TypedDict, Union
|
3
|
+
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
|
4
4
|
|
5
5
|
from typing_extensions import NotRequired, Required
|
6
6
|
|
@@ -20,6 +20,7 @@ class ModelMethodFunctionTypes(enum.Enum):
|
|
20
20
|
class ModelRuntimeDependenciesDict(TypedDict):
|
21
21
|
conda: NotRequired[str]
|
22
22
|
pip: NotRequired[str]
|
23
|
+
artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
|
23
24
|
|
24
25
|
|
25
26
|
class ModelRuntimeDict(TypedDict):
|
@@ -98,7 +98,6 @@ class ModelMethod:
|
|
98
98
|
def _get_method_arg_from_feature(
|
99
99
|
feature: model_signature.BaseFeatureSpec, case_sensitive: bool = False
|
100
100
|
) -> model_manifest_schema.ModelMethodSignatureFieldWithName:
|
101
|
-
assert isinstance(feature, model_signature.FeatureSpec), "FeatureGroupSpec is not supported."
|
102
101
|
try:
|
103
102
|
feature_name = sql_identifier.SqlIdentifier(feature.name, case_sensitive=case_sensitive)
|
104
103
|
except ValueError as e:
|
@@ -3,7 +3,7 @@ import itertools
|
|
3
3
|
import os
|
4
4
|
import pathlib
|
5
5
|
import warnings
|
6
|
-
from typing import DefaultDict, List, Optional
|
6
|
+
from typing import DefaultDict, Dict, List, Optional
|
7
7
|
|
8
8
|
from packaging import requirements, version
|
9
9
|
|
@@ -36,6 +36,7 @@ class ModelEnv:
|
|
36
36
|
pip_requirements_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_PIP_REQUIREMENTS_FILENAME)
|
37
37
|
self.conda_env_rel_path = pathlib.PurePosixPath(pathlib.Path(conda_env_rel_path).as_posix())
|
38
38
|
self.pip_requirements_rel_path = pathlib.PurePosixPath(pathlib.Path(pip_requirements_rel_path).as_posix())
|
39
|
+
self.artifact_repository_map: Optional[Dict[str, str]] = None
|
39
40
|
self._conda_dependencies: DefaultDict[str, List[requirements.Requirement]] = collections.defaultdict(list)
|
40
41
|
self._pip_requirements: List[requirements.Requirement] = []
|
41
42
|
self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
|
@@ -345,6 +346,7 @@ class ModelEnv:
|
|
345
346
|
def load_from_dict(self, base_dir: pathlib.Path, env_dict: model_meta_schema.ModelEnvDict) -> None:
|
346
347
|
self.conda_env_rel_path = pathlib.PurePosixPath(env_dict["conda"])
|
347
348
|
self.pip_requirements_rel_path = pathlib.PurePosixPath(env_dict["pip"])
|
349
|
+
self.artifact_repository_map = env_dict.get("artifact_repository_map", None)
|
348
350
|
|
349
351
|
self.load_from_conda_file(base_dir / self.conda_env_rel_path)
|
350
352
|
self.load_from_pip_file(base_dir / self.pip_requirements_rel_path)
|
@@ -373,6 +375,7 @@ class ModelEnv:
|
|
373
375
|
return {
|
374
376
|
"conda": self.conda_env_rel_path.as_posix(),
|
375
377
|
"pip": self.pip_requirements_rel_path.as_posix(),
|
378
|
+
"artifact_repository_map": self.artifact_repository_map if self.artifact_repository_map is not None else {},
|
376
379
|
"python_version": self.python_version,
|
377
380
|
"cuda_version": self.cuda_version,
|
378
381
|
"snowpark_ml_version": self.snowpark_ml_version,
|
@@ -30,10 +30,7 @@ from snowflake.ml.model._packager.model_meta import (
|
|
30
30
|
model_meta as model_meta_api,
|
31
31
|
model_meta_schema,
|
32
32
|
)
|
33
|
-
from snowflake.ml.model._signatures import
|
34
|
-
builtins_handler,
|
35
|
-
utils as model_signature_utils,
|
36
|
-
)
|
33
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
34
|
from snowflake.ml.model.models import huggingface_pipeline
|
38
35
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
39
36
|
|
@@ -66,16 +63,16 @@ def get_requirements_from_task(task: str, spcs_only: bool = False) -> List[model
|
|
66
63
|
return []
|
67
64
|
|
68
65
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
66
|
+
def sanitize_output(data: Any) -> Any:
|
67
|
+
if isinstance(data, np.number):
|
68
|
+
return data.item()
|
69
|
+
if isinstance(data, np.ndarray):
|
70
|
+
return sanitize_output(data.tolist())
|
71
|
+
if isinstance(data, list):
|
72
|
+
return [sanitize_output(x) for x in data]
|
73
|
+
if isinstance(data, dict):
|
74
|
+
return {k: sanitize_output(v) for k, v in data.items()}
|
75
|
+
return data
|
79
76
|
|
80
77
|
|
81
78
|
@final
|
@@ -410,13 +407,17 @@ class HuggingFacePipelineHandler(
|
|
410
407
|
)
|
411
408
|
for conv_data in X.to_dict("records")
|
412
409
|
]
|
413
|
-
elif len(signature.inputs) == 1:
|
414
|
-
input_data = X.to_dict("list")[signature.inputs[0].name]
|
415
410
|
else:
|
416
411
|
if isinstance(raw_model, transformers.TableQuestionAnsweringPipeline):
|
417
412
|
X["table"] = X["table"].apply(json.loads)
|
418
413
|
|
419
|
-
|
414
|
+
# Most pipelines if it is expecting more than one arguments,
|
415
|
+
# it is expecting a list of dict, where each dict has keys corresponding to the argument.
|
416
|
+
if len(signature.inputs) > 1:
|
417
|
+
input_data = X.to_dict("records")
|
418
|
+
# If it is only expecting one argument, Then it is expecting a list of something.
|
419
|
+
else:
|
420
|
+
input_data = X[signature.inputs[0].name].to_list()
|
420
421
|
temp_res = getattr(raw_model, target_method)(input_data)
|
421
422
|
|
422
423
|
# Some huggingface pipeline will omit the outer list when there is only 1 input.
|
@@ -439,7 +440,6 @@ class HuggingFacePipelineHandler(
|
|
439
440
|
),
|
440
441
|
)
|
441
442
|
and X.shape[0] == 1
|
442
|
-
and isinstance(temp_res[0], dict)
|
443
443
|
)
|
444
444
|
):
|
445
445
|
temp_res = [temp_res]
|
@@ -453,14 +453,18 @@ class HuggingFacePipelineHandler(
|
|
453
453
|
temp_res = [[conv.generated_responses] for conv in temp_res]
|
454
454
|
|
455
455
|
# To concat those who outputs a list with one input.
|
456
|
-
if
|
457
|
-
|
458
|
-
|
456
|
+
if isinstance(temp_res[0], list):
|
457
|
+
if isinstance(temp_res[0][0], dict):
|
458
|
+
res = pd.DataFrame({0: temp_res})
|
459
|
+
else:
|
460
|
+
res = pd.DataFrame(temp_res)
|
461
|
+
else:
|
459
462
|
res = pd.DataFrame(temp_res)
|
460
|
-
|
461
|
-
|
463
|
+
|
464
|
+
if hasattr(res, "map"):
|
465
|
+
res = res.map(sanitize_output)
|
462
466
|
else:
|
463
|
-
|
467
|
+
res = res.applymap(sanitize_output)
|
464
468
|
|
465
469
|
return model_signature_utils.rename_pandas_df(data=res, features=signature.outputs)
|
466
470
|
|
@@ -191,11 +191,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
|
|
191
191
|
signature: model_signature.ModelSignature,
|
192
192
|
target_method: str,
|
193
193
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
194
|
-
dtype_map = {
|
195
|
-
spec.name: spec.as_dtype(force_numpy_dtype=True)
|
196
|
-
for spec in signature.inputs
|
197
|
-
if isinstance(spec, model_signature.FeatureSpec)
|
198
|
-
}
|
194
|
+
dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
|
199
195
|
|
200
196
|
@custom_model.inference_api
|
201
197
|
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|