snowflake-ml-python 1.7.4__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +58 -3
- snowflake/ml/_internal/env_utils.py +64 -21
- snowflake/ml/_internal/file_utils.py +18 -4
- snowflake/ml/_internal/platform_capabilities.py +3 -0
- snowflake/ml/_internal/relax_version_strategy.py +16 -0
- snowflake/ml/_internal/telemetry.py +25 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
- snowflake/ml/feature_store/feature_store.py +18 -0
- snowflake/ml/feature_store/feature_view.py +46 -1
- snowflake/ml/fileset/fileset.py +0 -1
- snowflake/ml/jobs/_utils/constants.py +31 -1
- snowflake/ml/jobs/_utils/payload_utils.py +232 -72
- snowflake/ml/jobs/_utils/spec_utils.py +78 -38
- snowflake/ml/jobs/decorators.py +8 -25
- snowflake/ml/jobs/job.py +4 -4
- snowflake/ml/jobs/manager.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/model_ops.py +107 -14
- snowflake/ml/model/_client/ops/service_ops.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
- snowflake/ml/model/_client/sql/model_version.py +58 -0
- snowflake/ml/model/_client/sql/service.py +8 -2
- snowflake/ml/model/_model_composer/model_composer.py +50 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
- snowflake/ml/model/_packager/model_env/model_env.py +49 -29
- snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +44 -24
- snowflake/ml/model/_packager/model_handlers/keras.py +226 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +51 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +25 -3
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +73 -21
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -72
- snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
- snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +6 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +16 -0
- snowflake/ml/model/_packager/model_packager.py +3 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +8 -1
- snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
- snowflake/ml/model/_signatures/builtins_handler.py +20 -9
- snowflake/ml/model/_signatures/core.py +54 -33
- snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
- snowflake/ml/model/_signatures/numpy_handler.py +12 -20
- snowflake/ml/model/_signatures/pandas_handler.py +28 -37
- snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
- snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
- snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
- snowflake/ml/model/_signatures/utils.py +120 -8
- snowflake/ml/model/custom_model.py +13 -4
- snowflake/ml/model/model_signature.py +39 -13
- snowflake/ml/model/type_hints.py +28 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/pipeline/pipeline.py +18 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
- snowflake/ml/registry/_manager/model_manager.py +55 -7
- snowflake/ml/registry/registry.py +52 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +336 -27
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +73 -66
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/decorators.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
import copy
|
2
2
|
import functools
|
3
|
-
import inspect
|
4
3
|
from typing import Callable, Dict, List, Optional, TypeVar
|
5
4
|
|
6
5
|
from typing_extensions import ParamSpec
|
@@ -8,7 +7,7 @@ from typing_extensions import ParamSpec
|
|
8
7
|
from snowflake import snowpark
|
9
8
|
from snowflake.ml._internal import telemetry
|
10
9
|
from snowflake.ml.jobs import job as jb, manager as jm
|
11
|
-
from snowflake.ml.jobs._utils import
|
10
|
+
from snowflake.ml.jobs._utils import constants
|
12
11
|
|
13
12
|
_PROJECT = "MLJob"
|
14
13
|
|
@@ -26,6 +25,7 @@ def remote(
|
|
26
25
|
query_warehouse: Optional[str] = None,
|
27
26
|
env_vars: Optional[Dict[str, str]] = None,
|
28
27
|
session: Optional[snowpark.Session] = None,
|
28
|
+
num_instances: Optional[int] = None,
|
29
29
|
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob]]:
|
30
30
|
"""
|
31
31
|
Submit a job to the compute pool.
|
@@ -38,6 +38,7 @@ def remote(
|
|
38
38
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
39
39
|
env_vars: Environment variables to set in container
|
40
40
|
session: The Snowpark session to use. If none specified, uses active session.
|
41
|
+
num_instances: The number of nodes in the job. If none specified, create a single node job.
|
41
42
|
|
42
43
|
Returns:
|
43
44
|
Decorator that dispatches invocations of the decorated function as remote jobs.
|
@@ -50,31 +51,12 @@ def remote(
|
|
50
51
|
wrapped_func = copy.copy(func)
|
51
52
|
wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
|
52
53
|
|
53
|
-
# Validate function arguments based on signature
|
54
|
-
signature = inspect.signature(func)
|
55
|
-
pos_arg_names = []
|
56
|
-
for name, param in signature.parameters.items():
|
57
|
-
param_type = payload_utils.get_parameter_type(param)
|
58
|
-
if param_type is not None:
|
59
|
-
payload_utils.validate_parameter_type(param_type, name)
|
60
|
-
if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
|
61
|
-
pos_arg_names.append(name)
|
62
|
-
|
63
54
|
@functools.wraps(func)
|
64
55
|
def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob:
|
65
|
-
|
66
|
-
|
67
|
-
arg_name = pos_arg_names[i] if i < len(pos_arg_names) else f"args[{i}]"
|
68
|
-
payload_utils.validate_parameter_type(type(arg), arg_name)
|
69
|
-
|
70
|
-
# Validate keyword args
|
71
|
-
for k, v in kwargs.items():
|
72
|
-
payload_utils.validate_parameter_type(type(v), k)
|
73
|
-
|
74
|
-
arg_list = [str(v) for v in args] + [x for k, v in kwargs.items() for x in (f"--{k}", str(v))]
|
56
|
+
payload = functools.partial(func, *args, **kwargs)
|
57
|
+
setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
|
75
58
|
job = jm._submit_job(
|
76
|
-
source=
|
77
|
-
args=arg_list,
|
59
|
+
source=payload,
|
78
60
|
stage_name=stage_name,
|
79
61
|
compute_pool=compute_pool,
|
80
62
|
pip_requirements=pip_requirements,
|
@@ -82,8 +64,9 @@ def remote(
|
|
82
64
|
query_warehouse=query_warehouse,
|
83
65
|
env_vars=env_vars,
|
84
66
|
session=session,
|
67
|
+
num_instances=num_instances,
|
85
68
|
)
|
86
|
-
assert isinstance(job, jb.MLJob)
|
69
|
+
assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
|
87
70
|
return job
|
88
71
|
|
89
72
|
return wrapper
|
snowflake/ml/jobs/job.py
CHANGED
@@ -4,7 +4,7 @@ from typing import Any, List, Optional, cast
|
|
4
4
|
from snowflake import snowpark
|
5
5
|
from snowflake.ml._internal import telemetry
|
6
6
|
from snowflake.ml.jobs._utils import constants, types
|
7
|
-
from snowflake.snowpark
|
7
|
+
from snowflake.snowpark import context as sp_context
|
8
8
|
|
9
9
|
_PROJECT = "MLJob"
|
10
10
|
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
|
@@ -13,7 +13,7 @@ TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
|
|
13
13
|
class MLJob:
|
14
14
|
def __init__(self, id: str, session: Optional[snowpark.Session] = None) -> None:
|
15
15
|
self._id = id
|
16
|
-
self._session = session or get_active_session()
|
16
|
+
self._session = session or sp_context.get_active_session()
|
17
17
|
self._status: types.JOB_STATUS = "PENDING"
|
18
18
|
|
19
19
|
@property
|
@@ -79,7 +79,7 @@ class MLJob:
|
|
79
79
|
return self.status
|
80
80
|
|
81
81
|
|
82
|
-
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
82
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
83
83
|
def _get_status(session: snowpark.Session, job_id: str) -> types.JOB_STATUS:
|
84
84
|
"""Retrieve job execution status."""
|
85
85
|
# TODO: snowflake-snowpark-python<1.24.0 shows spurious error messages on
|
@@ -90,7 +90,7 @@ def _get_status(session: snowpark.Session, job_id: str) -> types.JOB_STATUS:
|
|
90
90
|
return cast(types.JOB_STATUS, row["status"])
|
91
91
|
|
92
92
|
|
93
|
-
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
93
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit"])
|
94
94
|
def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
|
95
95
|
"""
|
96
96
|
Retrieve the job's execution logs.
|
snowflake/ml/jobs/manager.py
CHANGED
@@ -213,6 +213,7 @@ def _submit_job(
|
|
213
213
|
query_warehouse: Optional[str] = None,
|
214
214
|
spec_overrides: Optional[Dict[str, Any]] = None,
|
215
215
|
session: Optional[snowpark.Session] = None,
|
216
|
+
num_instances: Optional[int] = None,
|
216
217
|
) -> jb.MLJob:
|
217
218
|
"""
|
218
219
|
Submit a job to the compute pool.
|
@@ -229,6 +230,7 @@ def _submit_job(
|
|
229
230
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
230
231
|
spec_overrides: Custom service specification overrides to apply.
|
231
232
|
session: The Snowpark session to use. If none specified, uses active session.
|
233
|
+
num_instances: The number of instances to use for the job. If none specified, single node job is created.
|
232
234
|
|
233
235
|
Returns:
|
234
236
|
An object representing the submitted job.
|
@@ -254,6 +256,7 @@ def _submit_job(
|
|
254
256
|
compute_pool=compute_pool,
|
255
257
|
payload=uploaded_payload,
|
256
258
|
args=args,
|
259
|
+
num_instances=num_instances,
|
257
260
|
)
|
258
261
|
spec_overrides = spec_utils.generate_spec_overrides(
|
259
262
|
environment_vars=env_vars,
|
@@ -281,6 +284,8 @@ def _submit_job(
|
|
281
284
|
query_warehouse = query_warehouse or session.get_current_warehouse()
|
282
285
|
if query_warehouse:
|
283
286
|
query.append(f"QUERY_WAREHOUSE = {query_warehouse}")
|
287
|
+
if num_instances:
|
288
|
+
query.append(f"REPLICAS = {num_instances}")
|
284
289
|
|
285
290
|
# Submit job
|
286
291
|
query_text = "\n".join(line for line in query if line)
|
@@ -746,7 +746,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
746
746
|
max_instances: int = 1,
|
747
747
|
cpu_requests: Optional[str] = None,
|
748
748
|
memory_requests: Optional[str] = None,
|
749
|
-
gpu_requests: Optional[str] = None,
|
749
|
+
gpu_requests: Optional[Union[str, int]] = None,
|
750
750
|
num_workers: Optional[int] = None,
|
751
751
|
max_batch_rows: Optional[int] = None,
|
752
752
|
force_rebuild: bool = False,
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import enum
|
1
2
|
import json
|
2
3
|
import os
|
3
4
|
import pathlib
|
@@ -31,6 +32,12 @@ from snowflake.snowpark import dataframe, row, session
|
|
31
32
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
32
33
|
|
33
34
|
|
35
|
+
# An enum class to represent Create Or Alter Model SQL command.
|
36
|
+
class ModelAction(enum.Enum):
|
37
|
+
CREATE = "CREATE"
|
38
|
+
ALTER = "ALTER"
|
39
|
+
|
40
|
+
|
34
41
|
class ServiceInfo(TypedDict):
|
35
42
|
name: str
|
36
43
|
status: str
|
@@ -92,7 +99,7 @@ class ModelOperator:
|
|
92
99
|
and self._model_version_client == __value._model_version_client
|
93
100
|
)
|
94
101
|
|
95
|
-
def
|
102
|
+
def prepare_model_temp_stage_path(
|
96
103
|
self,
|
97
104
|
*,
|
98
105
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
@@ -110,17 +117,28 @@ class ModelOperator:
|
|
110
117
|
)
|
111
118
|
return f"@{self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)}/model"
|
112
119
|
|
113
|
-
def
|
120
|
+
def get_model_version_stage_path(
|
121
|
+
self,
|
122
|
+
*,
|
123
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
124
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
125
|
+
model_name: sql_identifier.SqlIdentifier,
|
126
|
+
version_name: sql_identifier.SqlIdentifier,
|
127
|
+
) -> str:
|
128
|
+
return (
|
129
|
+
f"snow://model/{self._stage_client.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
130
|
+
f"/versions/{version_name}/"
|
131
|
+
)
|
132
|
+
|
133
|
+
def get_model_action_from_model_name_and_version(
|
114
134
|
self,
|
115
|
-
composed_model: model_composer.ModelComposer,
|
116
135
|
*,
|
117
136
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
118
137
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
119
138
|
model_name: sql_identifier.SqlIdentifier,
|
120
139
|
version_name: sql_identifier.SqlIdentifier,
|
121
140
|
statement_params: Optional[Dict[str, Any]] = None,
|
122
|
-
) ->
|
123
|
-
stage_path = str(composed_model.stage_path)
|
141
|
+
) -> ModelAction:
|
124
142
|
if self.validate_existence(
|
125
143
|
database_name=database_name,
|
126
144
|
schema_name=schema_name,
|
@@ -140,6 +158,79 @@ class ModelOperator:
|
|
140
158
|
f" version {version_name} already existed."
|
141
159
|
)
|
142
160
|
else:
|
161
|
+
return ModelAction.ALTER
|
162
|
+
else:
|
163
|
+
return ModelAction.CREATE
|
164
|
+
|
165
|
+
def add_or_create_live_version(
|
166
|
+
self,
|
167
|
+
*,
|
168
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
169
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
170
|
+
model_name: sql_identifier.SqlIdentifier,
|
171
|
+
version_name: sql_identifier.SqlIdentifier,
|
172
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
173
|
+
) -> None:
|
174
|
+
model_action = self.get_model_action_from_model_name_and_version(
|
175
|
+
database_name=database_name,
|
176
|
+
schema_name=schema_name,
|
177
|
+
model_name=model_name,
|
178
|
+
version_name=version_name,
|
179
|
+
statement_params=statement_params,
|
180
|
+
)
|
181
|
+
if model_action == ModelAction.CREATE:
|
182
|
+
self._model_version_client.create_live_version(
|
183
|
+
database_name=database_name,
|
184
|
+
schema_name=schema_name,
|
185
|
+
model_name=model_name,
|
186
|
+
version_name=version_name,
|
187
|
+
statement_params=statement_params,
|
188
|
+
)
|
189
|
+
elif model_action == ModelAction.ALTER:
|
190
|
+
self._model_version_client.add_live_version(
|
191
|
+
database_name=database_name,
|
192
|
+
schema_name=schema_name,
|
193
|
+
model_name=model_name,
|
194
|
+
version_name=version_name,
|
195
|
+
statement_params=statement_params,
|
196
|
+
)
|
197
|
+
else:
|
198
|
+
raise AssertionError(f"The model_action is {model_action}. Expected CREATE or ALTER.")
|
199
|
+
|
200
|
+
def create_from_stage(
|
201
|
+
self,
|
202
|
+
composed_model: model_composer.ModelComposer,
|
203
|
+
*,
|
204
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
205
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
206
|
+
model_name: sql_identifier.SqlIdentifier,
|
207
|
+
version_name: sql_identifier.SqlIdentifier,
|
208
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
209
|
+
use_live_commit: Optional[bool] = False,
|
210
|
+
) -> None:
|
211
|
+
|
212
|
+
if use_live_commit:
|
213
|
+
# if the model version is live, we can only commit the version
|
214
|
+
self._model_version_client.commit_version(
|
215
|
+
database_name=database_name,
|
216
|
+
schema_name=schema_name,
|
217
|
+
model_name=model_name,
|
218
|
+
version_name=version_name,
|
219
|
+
statement_params=statement_params,
|
220
|
+
)
|
221
|
+
else:
|
222
|
+
stage_path = str(composed_model.stage_path)
|
223
|
+
# if the model version is not live,
|
224
|
+
# find whether the model exists and whether the version exists
|
225
|
+
# and then decide whether to create or alter the model
|
226
|
+
model_action = self.get_model_action_from_model_name_and_version(
|
227
|
+
database_name=database_name,
|
228
|
+
schema_name=schema_name,
|
229
|
+
model_name=model_name,
|
230
|
+
version_name=version_name,
|
231
|
+
statement_params=statement_params,
|
232
|
+
)
|
233
|
+
if model_action == ModelAction.ALTER:
|
143
234
|
self._model_version_client.add_version_from_stage(
|
144
235
|
database_name=database_name,
|
145
236
|
schema_name=schema_name,
|
@@ -148,15 +239,17 @@ class ModelOperator:
|
|
148
239
|
version_name=version_name,
|
149
240
|
statement_params=statement_params,
|
150
241
|
)
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
242
|
+
elif model_action == ModelAction.CREATE:
|
243
|
+
self._model_version_client.create_from_stage(
|
244
|
+
database_name=database_name,
|
245
|
+
schema_name=schema_name,
|
246
|
+
stage_path=stage_path,
|
247
|
+
model_name=model_name,
|
248
|
+
version_name=version_name,
|
249
|
+
statement_params=statement_params,
|
250
|
+
)
|
251
|
+
else:
|
252
|
+
raise AssertionError(f"The model_action is {model_action}. Expected CREATE or ALTER.")
|
160
253
|
|
161
254
|
def create_from_model_version(
|
162
255
|
self,
|
@@ -100,7 +100,7 @@ class ServiceOperator:
|
|
100
100
|
max_instances: int,
|
101
101
|
cpu_requests: Optional[str],
|
102
102
|
memory_requests: Optional[str],
|
103
|
-
gpu_requests: Optional[str],
|
103
|
+
gpu_requests: Optional[Union[int, str]],
|
104
104
|
num_workers: Optional[int],
|
105
105
|
max_batch_rows: Optional[int],
|
106
106
|
force_rebuild: bool,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import pathlib
|
2
|
-
from typing import List, Optional
|
2
|
+
from typing import List, Optional, Union
|
3
3
|
|
4
4
|
import yaml
|
5
5
|
|
@@ -38,7 +38,7 @@ class ModelDeploymentSpec:
|
|
38
38
|
max_instances: int,
|
39
39
|
cpu: Optional[str],
|
40
40
|
memory: Optional[str],
|
41
|
-
gpu: Optional[str],
|
41
|
+
gpu: Optional[Union[str, int]],
|
42
42
|
num_workers: Optional[int],
|
43
43
|
max_batch_rows: Optional[int],
|
44
44
|
force_rebuild: bool,
|
@@ -86,7 +86,11 @@ class ModelDeploymentSpec:
|
|
86
86
|
service_dict["memory"] = memory
|
87
87
|
|
88
88
|
if gpu:
|
89
|
-
|
89
|
+
if isinstance(gpu, int):
|
90
|
+
gpu_str = str(gpu)
|
91
|
+
else:
|
92
|
+
gpu_str = gpu
|
93
|
+
service_dict["gpu"] = gpu_str
|
90
94
|
|
91
95
|
if num_workers:
|
92
96
|
service_dict["num_workers"] = num_workers
|
@@ -71,6 +71,64 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
71
71
|
statement_params=statement_params,
|
72
72
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
73
73
|
|
74
|
+
def create_live_version(
|
75
|
+
self,
|
76
|
+
*,
|
77
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
78
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
79
|
+
model_name: sql_identifier.SqlIdentifier,
|
80
|
+
version_name: sql_identifier.SqlIdentifier,
|
81
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
82
|
+
) -> None:
|
83
|
+
sql = (
|
84
|
+
f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
85
|
+
f" WITH LIVE VERSION {version_name.identifier()}"
|
86
|
+
)
|
87
|
+
query_result_checker.SqlResultValidator(
|
88
|
+
self._session,
|
89
|
+
sql,
|
90
|
+
statement_params=statement_params,
|
91
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
92
|
+
|
93
|
+
def add_live_version(
|
94
|
+
self,
|
95
|
+
*,
|
96
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
97
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
98
|
+
model_name: sql_identifier.SqlIdentifier,
|
99
|
+
version_name: sql_identifier.SqlIdentifier,
|
100
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
101
|
+
) -> None:
|
102
|
+
sql = (
|
103
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
104
|
+
f" ADD LIVE VERSION {version_name.identifier()}"
|
105
|
+
)
|
106
|
+
query_result_checker.SqlResultValidator(
|
107
|
+
self._session,
|
108
|
+
sql,
|
109
|
+
statement_params=statement_params,
|
110
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
111
|
+
|
112
|
+
def commit_version(
|
113
|
+
self,
|
114
|
+
*,
|
115
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
116
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
117
|
+
model_name: sql_identifier.SqlIdentifier,
|
118
|
+
version_name: sql_identifier.SqlIdentifier,
|
119
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
120
|
+
) -> None:
|
121
|
+
sql = (
|
122
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
123
|
+
f" COMMIT VERSION {version_name.identifier()}"
|
124
|
+
)
|
125
|
+
|
126
|
+
query_result_checker.SqlResultValidator(
|
127
|
+
self._session,
|
128
|
+
sql,
|
129
|
+
statement_params=statement_params,
|
130
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
131
|
+
|
74
132
|
# TODO(SNOW-987381): Merge with above when we have `create or alter module m [with] version v1 ...`
|
75
133
|
def add_version_from_stage(
|
76
134
|
self,
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import enum
|
2
2
|
import json
|
3
3
|
import textwrap
|
4
|
-
from typing import Any, Dict, List, Optional, Tuple
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
5
5
|
|
6
6
|
from snowflake import snowpark
|
7
7
|
from snowflake.ml._internal import platform_capabilities
|
@@ -11,6 +11,7 @@ from snowflake.ml._internal.utils import (
|
|
11
11
|
sql_identifier,
|
12
12
|
)
|
13
13
|
from snowflake.ml.model._client.sql import _base
|
14
|
+
from snowflake.ml.model._model_composer.model_method import constants
|
14
15
|
from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
15
16
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
16
17
|
|
@@ -41,7 +42,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
41
42
|
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
42
43
|
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
43
44
|
image_repo_name: sql_identifier.SqlIdentifier,
|
44
|
-
gpu: Optional[str],
|
45
|
+
gpu: Optional[Union[str, int]],
|
45
46
|
force_rebuild: bool,
|
46
47
|
external_access_integration: sql_identifier.SqlIdentifier,
|
47
48
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -121,6 +122,11 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
121
122
|
args_sql_list.append(input_arg_value)
|
122
123
|
args_sql = ", ".join(args_sql_list)
|
123
124
|
|
125
|
+
wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
126
|
+
if wide_input:
|
127
|
+
input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
|
128
|
+
args_sql = f"object_construct_keep_null({input_args_sql})"
|
129
|
+
|
124
130
|
if platform_capabilities.PlatformCapabilities.get_instance().is_nested_function_enabled():
|
125
131
|
fully_qualified_service_name = self.fully_qualified_object_name(
|
126
132
|
actual_database_name, actual_schema_name, service_name
|
@@ -1,8 +1,10 @@
|
|
1
1
|
import pathlib
|
2
2
|
import tempfile
|
3
3
|
import uuid
|
4
|
+
import warnings
|
4
5
|
from types import ModuleType
|
5
|
-
from typing import Any, Dict, List, Optional
|
6
|
+
from typing import Any, Dict, List, Optional, Union
|
7
|
+
from urllib import parse
|
6
8
|
|
7
9
|
from absl import logging
|
8
10
|
from packaging import requirements
|
@@ -44,7 +46,13 @@ class ModelComposer:
|
|
44
46
|
statement_params: Optional[Dict[str, Any]] = None,
|
45
47
|
) -> None:
|
46
48
|
self.session = session
|
47
|
-
self.stage_path
|
49
|
+
self.stage_path: Union[pathlib.PurePosixPath, parse.ParseResult] = None # type: ignore[assignment]
|
50
|
+
if stage_path.startswith("snow://"):
|
51
|
+
# The stage path is a snowflake internal stage path
|
52
|
+
self.stage_path = parse.urlparse(stage_path)
|
53
|
+
else:
|
54
|
+
# The stage path is a user stage path
|
55
|
+
self.stage_path = pathlib.PurePosixPath(stage_path)
|
48
56
|
|
49
57
|
self._workspace = tempfile.TemporaryDirectory()
|
50
58
|
self._packager_workspace = tempfile.TemporaryDirectory()
|
@@ -70,7 +78,20 @@ class ModelComposer:
|
|
70
78
|
|
71
79
|
@property
|
72
80
|
def model_stage_path(self) -> str:
|
73
|
-
|
81
|
+
if isinstance(self.stage_path, parse.ParseResult):
|
82
|
+
model_file_path = (pathlib.PosixPath(self.stage_path.path) / self.model_file_rel_path).as_posix()
|
83
|
+
new_url = parse.ParseResult(
|
84
|
+
scheme=self.stage_path.scheme,
|
85
|
+
netloc=self.stage_path.netloc,
|
86
|
+
path=str(model_file_path),
|
87
|
+
params=self.stage_path.params,
|
88
|
+
query=self.stage_path.query,
|
89
|
+
fragment=self.stage_path.fragment,
|
90
|
+
)
|
91
|
+
return str(parse.urlunparse(new_url))
|
92
|
+
else:
|
93
|
+
assert isinstance(self.stage_path, pathlib.PurePosixPath)
|
94
|
+
return (self.stage_path / self.model_file_rel_path).as_posix()
|
74
95
|
|
75
96
|
@property
|
76
97
|
def model_local_path(self) -> str:
|
@@ -86,6 +107,7 @@ class ModelComposer:
|
|
86
107
|
metadata: Optional[Dict[str, str]] = None,
|
87
108
|
conda_dependencies: Optional[List[str]] = None,
|
88
109
|
pip_requirements: Optional[List[str]] = None,
|
110
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
89
111
|
target_platforms: Optional[List[model_types.TargetPlatform]] = None,
|
90
112
|
python_version: Optional[str] = None,
|
91
113
|
user_files: Optional[Dict[str, List[str]]] = None,
|
@@ -94,8 +116,32 @@ class ModelComposer:
|
|
94
116
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
95
117
|
options: Optional[model_types.ModelSaveOption] = None,
|
96
118
|
) -> model_meta.ModelMetadata:
|
119
|
+
# set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
|
120
|
+
conda_dep_dict = env_utils.validate_conda_dependency_string_list(
|
121
|
+
conda_dependencies if conda_dependencies else []
|
122
|
+
)
|
123
|
+
is_warehouse_runnable = (
|
124
|
+
not conda_dep_dict
|
125
|
+
or all(
|
126
|
+
chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
127
|
+
for chan in conda_dep_dict
|
128
|
+
)
|
129
|
+
) and (not pip_requirements)
|
130
|
+
disable_explainability = (
|
131
|
+
target_platforms and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
|
132
|
+
) or (not is_warehouse_runnable)
|
133
|
+
|
134
|
+
if disable_explainability and options and options.get("enable_explainability", False):
|
135
|
+
warnings.warn(
|
136
|
+
("The model can be deployed to Snowpark Container Services only if `enable_explainability=False`."),
|
137
|
+
category=UserWarning,
|
138
|
+
stacklevel=2,
|
139
|
+
)
|
140
|
+
|
97
141
|
if not options:
|
98
142
|
options = model_types.BaseModelSaveOption()
|
143
|
+
if disable_explainability:
|
144
|
+
options["enable_explainability"] = False
|
99
145
|
|
100
146
|
if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
101
147
|
snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
|
@@ -120,6 +166,7 @@ class ModelComposer:
|
|
120
166
|
metadata=metadata,
|
121
167
|
conda_dependencies=conda_dependencies,
|
122
168
|
pip_requirements=pip_requirements,
|
169
|
+
artifact_repository_map=artifact_repository_map,
|
123
170
|
python_version=python_version,
|
124
171
|
ext_modules=ext_modules,
|
125
172
|
code_paths=code_paths,
|
@@ -78,6 +78,7 @@ class ModelManifest:
|
|
78
78
|
logger.info("Relaxing version constraints for dependencies in the model.")
|
79
79
|
logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}")
|
80
80
|
logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}")
|
81
|
+
logger.info(f"artifact_repository_map: {runtime_to_use.runtime_env.artifact_repository_map}")
|
81
82
|
runtime_dict = runtime_to_use.save(
|
82
83
|
self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
83
84
|
)
|
@@ -124,6 +125,9 @@ class ModelManifest:
|
|
124
125
|
if len(model_meta.env.pip_requirements) > 0:
|
125
126
|
dependencies["pip"] = runtime_dict["dependencies"]["pip"]
|
126
127
|
|
128
|
+
if model_meta.env.artifact_repository_map:
|
129
|
+
dependencies["artifact_repository_map"] = runtime_dict["dependencies"]["artifact_repository_map"]
|
130
|
+
|
127
131
|
manifest_dict = model_manifest_schema.ModelManifestDict(
|
128
132
|
manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
|
129
133
|
runtimes={
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# This files contains schema definition of what will be written into MANIFEST.yml
|
2
2
|
import enum
|
3
|
-
from typing import Any, Dict, List, Literal, TypedDict, Union
|
3
|
+
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
|
4
4
|
|
5
5
|
from typing_extensions import NotRequired, Required
|
6
6
|
|
@@ -20,6 +20,7 @@ class ModelMethodFunctionTypes(enum.Enum):
|
|
20
20
|
class ModelRuntimeDependenciesDict(TypedDict):
|
21
21
|
conda: NotRequired[str]
|
22
22
|
pip: NotRequired[str]
|
23
|
+
artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
|
23
24
|
|
24
25
|
|
25
26
|
class ModelRuntimeDict(TypedDict):
|
@@ -98,7 +98,6 @@ class ModelMethod:
|
|
98
98
|
def _get_method_arg_from_feature(
|
99
99
|
feature: model_signature.BaseFeatureSpec, case_sensitive: bool = False
|
100
100
|
) -> model_manifest_schema.ModelMethodSignatureFieldWithName:
|
101
|
-
assert isinstance(feature, model_signature.FeatureSpec), "FeatureGroupSpec is not supported."
|
102
101
|
try:
|
103
102
|
feature_name = sql_identifier.SqlIdentifier(feature.name, case_sensitive=case_sensitive)
|
104
103
|
except ValueError as e:
|