snowflake-ml-python 1.7.5__py3-none-any.whl → 1.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +58 -3
- snowflake/ml/_internal/file_utils.py +18 -4
- snowflake/ml/_internal/platform_capabilities.py +3 -0
- snowflake/ml/_internal/telemetry.py +4 -0
- snowflake/ml/fileset/fileset.py +0 -1
- snowflake/ml/jobs/_utils/constants.py +25 -1
- snowflake/ml/jobs/_utils/payload_utils.py +94 -20
- snowflake/ml/jobs/_utils/spec_utils.py +95 -31
- snowflake/ml/jobs/decorators.py +7 -0
- snowflake/ml/jobs/manager.py +20 -0
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/model_ops.py +113 -17
- snowflake/ml/model/_client/ops/service_ops.py +16 -5
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
- snowflake/ml/model/_client/sql/model_version.py +58 -0
- snowflake/ml/model/_client/sql/service.py +10 -2
- snowflake/ml/model/_model_composer/model_composer.py +50 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +5 -2
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
- snowflake/ml/model/_packager/model_env/model_env.py +4 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +28 -24
- snowflake/ml/model/_packager/model_handlers/keras.py +1 -5
- snowflake/ml/model/_packager/model_handlers/pytorch.py +50 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -8
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +46 -26
- snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
- snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +1 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -0
- snowflake/ml/model/_packager/model_packager.py +3 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -0
- snowflake/ml/model/_signatures/builtins_handler.py +20 -9
- snowflake/ml/model/_signatures/core.py +52 -31
- snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
- snowflake/ml/model/_signatures/numpy_handler.py +9 -17
- snowflake/ml/model/_signatures/pandas_handler.py +19 -30
- snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
- snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
- snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
- snowflake/ml/model/_signatures/utils.py +120 -8
- snowflake/ml/model/custom_model.py +13 -4
- snowflake/ml/model/model_signature.py +31 -13
- snowflake/ml/model/type_hints.py +13 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +5 -1
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +55 -7
- snowflake/ml/registry/registry.py +59 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/METADATA +308 -12
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/RECORD +62 -58
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info/licenses}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/manager.py
CHANGED
@@ -106,6 +106,8 @@ def submit_file(
|
|
106
106
|
external_access_integrations: Optional[List[str]] = None,
|
107
107
|
query_warehouse: Optional[str] = None,
|
108
108
|
spec_overrides: Optional[Dict[str, Any]] = None,
|
109
|
+
num_instances: Optional[int] = None,
|
110
|
+
enable_metrics: bool = False,
|
109
111
|
session: Optional[snowpark.Session] = None,
|
110
112
|
) -> jb.MLJob:
|
111
113
|
"""
|
@@ -121,6 +123,8 @@ def submit_file(
|
|
121
123
|
external_access_integrations: A list of external access integrations.
|
122
124
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
123
125
|
spec_overrides: Custom service specification overrides to apply.
|
126
|
+
num_instances: The number of instances to use for the job. If none specified, single node job is created.
|
127
|
+
enable_metrics: Whether to enable metrics publishing for the job.
|
124
128
|
session: The Snowpark session to use. If none specified, uses active session.
|
125
129
|
|
126
130
|
Returns:
|
@@ -136,6 +140,8 @@ def submit_file(
|
|
136
140
|
external_access_integrations=external_access_integrations,
|
137
141
|
query_warehouse=query_warehouse,
|
138
142
|
spec_overrides=spec_overrides,
|
143
|
+
num_instances=num_instances,
|
144
|
+
enable_metrics=enable_metrics,
|
139
145
|
session=session,
|
140
146
|
)
|
141
147
|
|
@@ -154,6 +160,8 @@ def submit_directory(
|
|
154
160
|
external_access_integrations: Optional[List[str]] = None,
|
155
161
|
query_warehouse: Optional[str] = None,
|
156
162
|
spec_overrides: Optional[Dict[str, Any]] = None,
|
163
|
+
num_instances: Optional[int] = None,
|
164
|
+
enable_metrics: bool = False,
|
157
165
|
session: Optional[snowpark.Session] = None,
|
158
166
|
) -> jb.MLJob:
|
159
167
|
"""
|
@@ -170,6 +178,8 @@ def submit_directory(
|
|
170
178
|
external_access_integrations: A list of external access integrations.
|
171
179
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
172
180
|
spec_overrides: Custom service specification overrides to apply.
|
181
|
+
num_instances: The number of instances to use for the job. If none specified, single node job is created.
|
182
|
+
enable_metrics: Whether to enable metrics publishing for the job.
|
173
183
|
session: The Snowpark session to use. If none specified, uses active session.
|
174
184
|
|
175
185
|
Returns:
|
@@ -186,6 +196,8 @@ def submit_directory(
|
|
186
196
|
external_access_integrations=external_access_integrations,
|
187
197
|
query_warehouse=query_warehouse,
|
188
198
|
spec_overrides=spec_overrides,
|
199
|
+
num_instances=num_instances,
|
200
|
+
enable_metrics=enable_metrics,
|
189
201
|
session=session,
|
190
202
|
)
|
191
203
|
|
@@ -212,6 +224,8 @@ def _submit_job(
|
|
212
224
|
external_access_integrations: Optional[List[str]] = None,
|
213
225
|
query_warehouse: Optional[str] = None,
|
214
226
|
spec_overrides: Optional[Dict[str, Any]] = None,
|
227
|
+
num_instances: Optional[int] = None,
|
228
|
+
enable_metrics: bool = False,
|
215
229
|
session: Optional[snowpark.Session] = None,
|
216
230
|
) -> jb.MLJob:
|
217
231
|
"""
|
@@ -228,6 +242,8 @@ def _submit_job(
|
|
228
242
|
external_access_integrations: A list of external access integrations.
|
229
243
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
230
244
|
spec_overrides: Custom service specification overrides to apply.
|
245
|
+
num_instances: The number of instances to use for the job. If none specified, single node job is created.
|
246
|
+
enable_metrics: Whether to enable metrics publishing for the job.
|
231
247
|
session: The Snowpark session to use. If none specified, uses active session.
|
232
248
|
|
233
249
|
Returns:
|
@@ -254,6 +270,8 @@ def _submit_job(
|
|
254
270
|
compute_pool=compute_pool,
|
255
271
|
payload=uploaded_payload,
|
256
272
|
args=args,
|
273
|
+
num_instances=num_instances,
|
274
|
+
enable_metrics=enable_metrics,
|
257
275
|
)
|
258
276
|
spec_overrides = spec_utils.generate_spec_overrides(
|
259
277
|
environment_vars=env_vars,
|
@@ -281,6 +299,8 @@ def _submit_job(
|
|
281
299
|
query_warehouse = query_warehouse or session.get_current_warehouse()
|
282
300
|
if query_warehouse:
|
283
301
|
query.append(f"QUERY_WAREHOUSE = {query_warehouse}")
|
302
|
+
if num_instances:
|
303
|
+
query.append(f"REPLICAS = {num_instances}")
|
284
304
|
|
285
305
|
# Submit job
|
286
306
|
query_text = "\n".join(line for line in query if line)
|
@@ -746,7 +746,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
746
746
|
max_instances: int = 1,
|
747
747
|
cpu_requests: Optional[str] = None,
|
748
748
|
memory_requests: Optional[str] = None,
|
749
|
-
gpu_requests: Optional[str] = None,
|
749
|
+
gpu_requests: Optional[Union[str, int]] = None,
|
750
750
|
num_workers: Optional[int] = None,
|
751
751
|
max_batch_rows: Optional[int] = None,
|
752
752
|
force_rebuild: bool = False,
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import enum
|
1
2
|
import json
|
2
3
|
import os
|
3
4
|
import pathlib
|
@@ -31,6 +32,12 @@ from snowflake.snowpark import dataframe, row, session
|
|
31
32
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
32
33
|
|
33
34
|
|
35
|
+
# An enum class to represent Create Or Alter Model SQL command.
|
36
|
+
class ModelAction(enum.Enum):
|
37
|
+
CREATE = "CREATE"
|
38
|
+
ALTER = "ALTER"
|
39
|
+
|
40
|
+
|
34
41
|
class ServiceInfo(TypedDict):
|
35
42
|
name: str
|
36
43
|
status: str
|
@@ -92,7 +99,7 @@ class ModelOperator:
|
|
92
99
|
and self._model_version_client == __value._model_version_client
|
93
100
|
)
|
94
101
|
|
95
|
-
def
|
102
|
+
def prepare_model_temp_stage_path(
|
96
103
|
self,
|
97
104
|
*,
|
98
105
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
@@ -110,17 +117,28 @@ class ModelOperator:
|
|
110
117
|
)
|
111
118
|
return f"@{self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)}/model"
|
112
119
|
|
113
|
-
def
|
120
|
+
def get_model_version_stage_path(
|
121
|
+
self,
|
122
|
+
*,
|
123
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
124
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
125
|
+
model_name: sql_identifier.SqlIdentifier,
|
126
|
+
version_name: sql_identifier.SqlIdentifier,
|
127
|
+
) -> str:
|
128
|
+
return (
|
129
|
+
f"snow://model/{self._stage_client.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
130
|
+
f"/versions/{version_name}/"
|
131
|
+
)
|
132
|
+
|
133
|
+
def get_model_action_from_model_name_and_version(
|
114
134
|
self,
|
115
|
-
composed_model: model_composer.ModelComposer,
|
116
135
|
*,
|
117
136
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
118
137
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
119
138
|
model_name: sql_identifier.SqlIdentifier,
|
120
139
|
version_name: sql_identifier.SqlIdentifier,
|
121
140
|
statement_params: Optional[Dict[str, Any]] = None,
|
122
|
-
) ->
|
123
|
-
stage_path = str(composed_model.stage_path)
|
141
|
+
) -> ModelAction:
|
124
142
|
if self.validate_existence(
|
125
143
|
database_name=database_name,
|
126
144
|
schema_name=schema_name,
|
@@ -140,6 +158,79 @@ class ModelOperator:
|
|
140
158
|
f" version {version_name} already existed."
|
141
159
|
)
|
142
160
|
else:
|
161
|
+
return ModelAction.ALTER
|
162
|
+
else:
|
163
|
+
return ModelAction.CREATE
|
164
|
+
|
165
|
+
def add_or_create_live_version(
|
166
|
+
self,
|
167
|
+
*,
|
168
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
169
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
170
|
+
model_name: sql_identifier.SqlIdentifier,
|
171
|
+
version_name: sql_identifier.SqlIdentifier,
|
172
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
173
|
+
) -> None:
|
174
|
+
model_action = self.get_model_action_from_model_name_and_version(
|
175
|
+
database_name=database_name,
|
176
|
+
schema_name=schema_name,
|
177
|
+
model_name=model_name,
|
178
|
+
version_name=version_name,
|
179
|
+
statement_params=statement_params,
|
180
|
+
)
|
181
|
+
if model_action == ModelAction.CREATE:
|
182
|
+
self._model_version_client.create_live_version(
|
183
|
+
database_name=database_name,
|
184
|
+
schema_name=schema_name,
|
185
|
+
model_name=model_name,
|
186
|
+
version_name=version_name,
|
187
|
+
statement_params=statement_params,
|
188
|
+
)
|
189
|
+
elif model_action == ModelAction.ALTER:
|
190
|
+
self._model_version_client.add_live_version(
|
191
|
+
database_name=database_name,
|
192
|
+
schema_name=schema_name,
|
193
|
+
model_name=model_name,
|
194
|
+
version_name=version_name,
|
195
|
+
statement_params=statement_params,
|
196
|
+
)
|
197
|
+
else:
|
198
|
+
raise AssertionError(f"The model_action is {model_action}. Expected CREATE or ALTER.")
|
199
|
+
|
200
|
+
def create_from_stage(
|
201
|
+
self,
|
202
|
+
composed_model: model_composer.ModelComposer,
|
203
|
+
*,
|
204
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
205
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
206
|
+
model_name: sql_identifier.SqlIdentifier,
|
207
|
+
version_name: sql_identifier.SqlIdentifier,
|
208
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
209
|
+
use_live_commit: Optional[bool] = False,
|
210
|
+
) -> None:
|
211
|
+
|
212
|
+
if use_live_commit:
|
213
|
+
# if the model version is live, we can only commit the version
|
214
|
+
self._model_version_client.commit_version(
|
215
|
+
database_name=database_name,
|
216
|
+
schema_name=schema_name,
|
217
|
+
model_name=model_name,
|
218
|
+
version_name=version_name,
|
219
|
+
statement_params=statement_params,
|
220
|
+
)
|
221
|
+
else:
|
222
|
+
stage_path = str(composed_model.stage_path)
|
223
|
+
# if the model version is not live,
|
224
|
+
# find whether the model exists and whether the version exists
|
225
|
+
# and then decide whether to create or alter the model
|
226
|
+
model_action = self.get_model_action_from_model_name_and_version(
|
227
|
+
database_name=database_name,
|
228
|
+
schema_name=schema_name,
|
229
|
+
model_name=model_name,
|
230
|
+
version_name=version_name,
|
231
|
+
statement_params=statement_params,
|
232
|
+
)
|
233
|
+
if model_action == ModelAction.ALTER:
|
143
234
|
self._model_version_client.add_version_from_stage(
|
144
235
|
database_name=database_name,
|
145
236
|
schema_name=schema_name,
|
@@ -148,15 +239,17 @@ class ModelOperator:
|
|
148
239
|
version_name=version_name,
|
149
240
|
statement_params=statement_params,
|
150
241
|
)
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
242
|
+
elif model_action == ModelAction.CREATE:
|
243
|
+
self._model_version_client.create_from_stage(
|
244
|
+
database_name=database_name,
|
245
|
+
schema_name=schema_name,
|
246
|
+
stage_path=stage_path,
|
247
|
+
model_name=model_name,
|
248
|
+
version_name=version_name,
|
249
|
+
statement_params=statement_params,
|
250
|
+
)
|
251
|
+
else:
|
252
|
+
raise AssertionError(f"The model_action is {model_action}. Expected CREATE or ALTER.")
|
160
253
|
|
161
254
|
def create_from_model_version(
|
162
255
|
self,
|
@@ -696,14 +789,17 @@ class ModelOperator:
|
|
696
789
|
version_name: sql_identifier.SqlIdentifier,
|
697
790
|
statement_params: Optional[Dict[str, Any]] = None,
|
698
791
|
) -> type_hints.Task:
|
699
|
-
|
792
|
+
model_version = self._model_client.show_versions(
|
700
793
|
database_name=database_name,
|
701
794
|
schema_name=schema_name,
|
702
795
|
model_name=model_name,
|
703
796
|
version_name=version_name,
|
797
|
+
validate_result=True,
|
704
798
|
statement_params=statement_params,
|
705
|
-
)
|
706
|
-
|
799
|
+
)[0]
|
800
|
+
|
801
|
+
model_attributes = json.loads(model_version.model_attributes)
|
802
|
+
task_val = model_attributes.get("task", type_hints.Task.UNKNOWN.value)
|
707
803
|
return type_hints.Task(task_val)
|
708
804
|
|
709
805
|
def get_functions(
|
@@ -100,7 +100,7 @@ class ServiceOperator:
|
|
100
100
|
max_instances: int,
|
101
101
|
cpu_requests: Optional[str],
|
102
102
|
memory_requests: Optional[str],
|
103
|
-
gpu_requests: Optional[str],
|
103
|
+
gpu_requests: Optional[Union[int, str]],
|
104
104
|
num_workers: Optional[int],
|
105
105
|
max_batch_rows: Optional[int],
|
106
106
|
force_rebuild: bool,
|
@@ -161,12 +161,16 @@ class ServiceOperator:
|
|
161
161
|
statement_params=statement_params,
|
162
162
|
)
|
163
163
|
|
164
|
-
# check if the inference service is already running
|
164
|
+
# check if the inference service is already running/suspended
|
165
165
|
model_inference_service_exists = self._check_if_service_exists(
|
166
166
|
database_name=service_database_name,
|
167
167
|
schema_name=service_schema_name,
|
168
168
|
service_name=service_name,
|
169
|
-
service_status_list_if_exists=[
|
169
|
+
service_status_list_if_exists=[
|
170
|
+
service_sql.ServiceStatus.READY,
|
171
|
+
service_sql.ServiceStatus.SUSPENDING,
|
172
|
+
service_sql.ServiceStatus.SUSPENDED,
|
173
|
+
],
|
170
174
|
statement_params=statement_params,
|
171
175
|
)
|
172
176
|
|
@@ -309,7 +313,10 @@ class ServiceOperator:
|
|
309
313
|
set_service_log_metadata_to_model_inference(
|
310
314
|
service_log_meta,
|
311
315
|
model_inference_service,
|
312
|
-
|
316
|
+
(
|
317
|
+
"Model Inference image build is not rebuilding the image, but using a previously built "
|
318
|
+
"image."
|
319
|
+
),
|
313
320
|
)
|
314
321
|
continue
|
315
322
|
|
@@ -366,7 +373,9 @@ class ServiceOperator:
|
|
366
373
|
time.sleep(5)
|
367
374
|
|
368
375
|
if model_inference_service_exists:
|
369
|
-
module_logger.info(
|
376
|
+
module_logger.info(
|
377
|
+
f"Inference service {model_inference_service.display_service_name} has already been deployed."
|
378
|
+
)
|
370
379
|
else:
|
371
380
|
self._finalize_logs(
|
372
381
|
service_log_meta.service_logger, service_log_meta.service, service_log_meta.log_offset, statement_params
|
@@ -416,6 +425,8 @@ class ServiceOperator:
|
|
416
425
|
service_status_list_if_exists = [
|
417
426
|
service_sql.ServiceStatus.PENDING,
|
418
427
|
service_sql.ServiceStatus.READY,
|
428
|
+
service_sql.ServiceStatus.SUSPENDING,
|
429
|
+
service_sql.ServiceStatus.SUSPENDED,
|
419
430
|
service_sql.ServiceStatus.DONE,
|
420
431
|
service_sql.ServiceStatus.FAILED,
|
421
432
|
]
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import pathlib
|
2
|
-
from typing import List, Optional
|
2
|
+
from typing import List, Optional, Union
|
3
3
|
|
4
4
|
import yaml
|
5
5
|
|
@@ -38,7 +38,7 @@ class ModelDeploymentSpec:
|
|
38
38
|
max_instances: int,
|
39
39
|
cpu: Optional[str],
|
40
40
|
memory: Optional[str],
|
41
|
-
gpu: Optional[str],
|
41
|
+
gpu: Optional[Union[str, int]],
|
42
42
|
num_workers: Optional[int],
|
43
43
|
max_batch_rows: Optional[int],
|
44
44
|
force_rebuild: bool,
|
@@ -86,7 +86,11 @@ class ModelDeploymentSpec:
|
|
86
86
|
service_dict["memory"] = memory
|
87
87
|
|
88
88
|
if gpu:
|
89
|
-
|
89
|
+
if isinstance(gpu, int):
|
90
|
+
gpu_str = str(gpu)
|
91
|
+
else:
|
92
|
+
gpu_str = gpu
|
93
|
+
service_dict["gpu"] = gpu_str
|
90
94
|
|
91
95
|
if num_workers:
|
92
96
|
service_dict["num_workers"] = num_workers
|
@@ -71,6 +71,64 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
71
71
|
statement_params=statement_params,
|
72
72
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
73
73
|
|
74
|
+
def create_live_version(
|
75
|
+
self,
|
76
|
+
*,
|
77
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
78
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
79
|
+
model_name: sql_identifier.SqlIdentifier,
|
80
|
+
version_name: sql_identifier.SqlIdentifier,
|
81
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
82
|
+
) -> None:
|
83
|
+
sql = (
|
84
|
+
f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
85
|
+
f" WITH LIVE VERSION {version_name.identifier()}"
|
86
|
+
)
|
87
|
+
query_result_checker.SqlResultValidator(
|
88
|
+
self._session,
|
89
|
+
sql,
|
90
|
+
statement_params=statement_params,
|
91
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
92
|
+
|
93
|
+
def add_live_version(
|
94
|
+
self,
|
95
|
+
*,
|
96
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
97
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
98
|
+
model_name: sql_identifier.SqlIdentifier,
|
99
|
+
version_name: sql_identifier.SqlIdentifier,
|
100
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
101
|
+
) -> None:
|
102
|
+
sql = (
|
103
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
104
|
+
f" ADD LIVE VERSION {version_name.identifier()}"
|
105
|
+
)
|
106
|
+
query_result_checker.SqlResultValidator(
|
107
|
+
self._session,
|
108
|
+
sql,
|
109
|
+
statement_params=statement_params,
|
110
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
111
|
+
|
112
|
+
def commit_version(
|
113
|
+
self,
|
114
|
+
*,
|
115
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
116
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
117
|
+
model_name: sql_identifier.SqlIdentifier,
|
118
|
+
version_name: sql_identifier.SqlIdentifier,
|
119
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
120
|
+
) -> None:
|
121
|
+
sql = (
|
122
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
123
|
+
f" COMMIT VERSION {version_name.identifier()}"
|
124
|
+
)
|
125
|
+
|
126
|
+
query_result_checker.SqlResultValidator(
|
127
|
+
self._session,
|
128
|
+
sql,
|
129
|
+
statement_params=statement_params,
|
130
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
131
|
+
|
74
132
|
# TODO(SNOW-987381): Merge with above when we have `create or alter module m [with] version v1 ...`
|
75
133
|
def add_version_from_stage(
|
76
134
|
self,
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import enum
|
2
2
|
import json
|
3
3
|
import textwrap
|
4
|
-
from typing import Any, Dict, List, Optional, Tuple
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
5
5
|
|
6
6
|
from snowflake import snowpark
|
7
7
|
from snowflake.ml._internal import platform_capabilities
|
@@ -11,6 +11,7 @@ from snowflake.ml._internal.utils import (
|
|
11
11
|
sql_identifier,
|
12
12
|
)
|
13
13
|
from snowflake.ml.model._client.sql import _base
|
14
|
+
from snowflake.ml.model._model_composer.model_method import constants
|
14
15
|
from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
15
16
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
16
17
|
|
@@ -19,6 +20,8 @@ class ServiceStatus(enum.Enum):
|
|
19
20
|
UNKNOWN = "UNKNOWN" # status is unknown because we have not received enough data from K8s yet.
|
20
21
|
PENDING = "PENDING" # resource set is being created, can't be used yet
|
21
22
|
READY = "READY" # resource set has been deployed.
|
23
|
+
SUSPENDING = "SUSPENDING" # the service is set to suspended but the resource set is still in deleting state
|
24
|
+
SUSPENDED = "SUSPENDED" # the service is suspended and the resource set is deleted
|
22
25
|
DELETING = "DELETING" # resource set is being deleted
|
23
26
|
FAILED = "FAILED" # resource set has failed and cannot be used anymore
|
24
27
|
DONE = "DONE" # resource set has finished running
|
@@ -41,7 +44,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
41
44
|
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
42
45
|
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
43
46
|
image_repo_name: sql_identifier.SqlIdentifier,
|
44
|
-
gpu: Optional[str],
|
47
|
+
gpu: Optional[Union[str, int]],
|
45
48
|
force_rebuild: bool,
|
46
49
|
external_access_integration: sql_identifier.SqlIdentifier,
|
47
50
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -121,6 +124,11 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
121
124
|
args_sql_list.append(input_arg_value)
|
122
125
|
args_sql = ", ".join(args_sql_list)
|
123
126
|
|
127
|
+
wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
|
128
|
+
if wide_input:
|
129
|
+
input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
|
130
|
+
args_sql = f"object_construct_keep_null({input_args_sql})"
|
131
|
+
|
124
132
|
if platform_capabilities.PlatformCapabilities.get_instance().is_nested_function_enabled():
|
125
133
|
fully_qualified_service_name = self.fully_qualified_object_name(
|
126
134
|
actual_database_name, actual_schema_name, service_name
|
@@ -1,8 +1,10 @@
|
|
1
1
|
import pathlib
|
2
2
|
import tempfile
|
3
3
|
import uuid
|
4
|
+
import warnings
|
4
5
|
from types import ModuleType
|
5
|
-
from typing import Any, Dict, List, Optional
|
6
|
+
from typing import Any, Dict, List, Optional, Union
|
7
|
+
from urllib import parse
|
6
8
|
|
7
9
|
from absl import logging
|
8
10
|
from packaging import requirements
|
@@ -44,7 +46,13 @@ class ModelComposer:
|
|
44
46
|
statement_params: Optional[Dict[str, Any]] = None,
|
45
47
|
) -> None:
|
46
48
|
self.session = session
|
47
|
-
self.stage_path
|
49
|
+
self.stage_path: Union[pathlib.PurePosixPath, parse.ParseResult] = None # type: ignore[assignment]
|
50
|
+
if stage_path.startswith("snow://"):
|
51
|
+
# The stage path is a snowflake internal stage path
|
52
|
+
self.stage_path = parse.urlparse(stage_path)
|
53
|
+
else:
|
54
|
+
# The stage path is a user stage path
|
55
|
+
self.stage_path = pathlib.PurePosixPath(stage_path)
|
48
56
|
|
49
57
|
self._workspace = tempfile.TemporaryDirectory()
|
50
58
|
self._packager_workspace = tempfile.TemporaryDirectory()
|
@@ -70,7 +78,20 @@ class ModelComposer:
|
|
70
78
|
|
71
79
|
@property
|
72
80
|
def model_stage_path(self) -> str:
|
73
|
-
|
81
|
+
if isinstance(self.stage_path, parse.ParseResult):
|
82
|
+
model_file_path = (pathlib.PosixPath(self.stage_path.path) / self.model_file_rel_path).as_posix()
|
83
|
+
new_url = parse.ParseResult(
|
84
|
+
scheme=self.stage_path.scheme,
|
85
|
+
netloc=self.stage_path.netloc,
|
86
|
+
path=str(model_file_path),
|
87
|
+
params=self.stage_path.params,
|
88
|
+
query=self.stage_path.query,
|
89
|
+
fragment=self.stage_path.fragment,
|
90
|
+
)
|
91
|
+
return str(parse.urlunparse(new_url))
|
92
|
+
else:
|
93
|
+
assert isinstance(self.stage_path, pathlib.PurePosixPath)
|
94
|
+
return (self.stage_path / self.model_file_rel_path).as_posix()
|
74
95
|
|
75
96
|
@property
|
76
97
|
def model_local_path(self) -> str:
|
@@ -86,6 +107,7 @@ class ModelComposer:
|
|
86
107
|
metadata: Optional[Dict[str, str]] = None,
|
87
108
|
conda_dependencies: Optional[List[str]] = None,
|
88
109
|
pip_requirements: Optional[List[str]] = None,
|
110
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
89
111
|
target_platforms: Optional[List[model_types.TargetPlatform]] = None,
|
90
112
|
python_version: Optional[str] = None,
|
91
113
|
user_files: Optional[Dict[str, List[str]]] = None,
|
@@ -94,8 +116,32 @@ class ModelComposer:
|
|
94
116
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
95
117
|
options: Optional[model_types.ModelSaveOption] = None,
|
96
118
|
) -> model_meta.ModelMetadata:
|
119
|
+
# set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
|
120
|
+
conda_dep_dict = env_utils.validate_conda_dependency_string_list(
|
121
|
+
conda_dependencies if conda_dependencies else []
|
122
|
+
)
|
123
|
+
is_warehouse_runnable = (
|
124
|
+
not conda_dep_dict
|
125
|
+
or all(
|
126
|
+
chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
127
|
+
for chan in conda_dep_dict
|
128
|
+
)
|
129
|
+
) and (not pip_requirements)
|
130
|
+
disable_explainability = (
|
131
|
+
target_platforms and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
|
132
|
+
) or (not is_warehouse_runnable)
|
133
|
+
|
134
|
+
if disable_explainability and options and options.get("enable_explainability", False):
|
135
|
+
warnings.warn(
|
136
|
+
("The model can be deployed to Snowpark Container Services only if `enable_explainability=False`."),
|
137
|
+
category=UserWarning,
|
138
|
+
stacklevel=2,
|
139
|
+
)
|
140
|
+
|
97
141
|
if not options:
|
98
142
|
options = model_types.BaseModelSaveOption()
|
143
|
+
if disable_explainability:
|
144
|
+
options["enable_explainability"] = False
|
99
145
|
|
100
146
|
if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
101
147
|
snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
|
@@ -120,6 +166,7 @@ class ModelComposer:
|
|
120
166
|
metadata=metadata,
|
121
167
|
conda_dependencies=conda_dependencies,
|
122
168
|
pip_requirements=pip_requirements,
|
169
|
+
artifact_repository_map=artifact_repository_map,
|
123
170
|
python_version=python_version,
|
124
171
|
ext_modules=ext_modules,
|
125
172
|
code_paths=code_paths,
|
@@ -36,7 +36,6 @@ class ModelManifest:
|
|
36
36
|
"""
|
37
37
|
|
38
38
|
MANIFEST_FILE_REL_PATH = "MANIFEST.yml"
|
39
|
-
_ENABLE_USER_FILES = False
|
40
39
|
_DEFAULT_RUNTIME_NAME = "python_runtime"
|
41
40
|
|
42
41
|
def __init__(self, workspace_path: pathlib.Path) -> None:
|
@@ -78,6 +77,7 @@ class ModelManifest:
|
|
78
77
|
logger.info("Relaxing version constraints for dependencies in the model.")
|
79
78
|
logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}")
|
80
79
|
logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}")
|
80
|
+
logger.info(f"artifact_repository_map: {runtime_to_use.runtime_env.artifact_repository_map}")
|
81
81
|
runtime_dict = runtime_to_use.save(
|
82
82
|
self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
83
83
|
)
|
@@ -124,6 +124,9 @@ class ModelManifest:
|
|
124
124
|
if len(model_meta.env.pip_requirements) > 0:
|
125
125
|
dependencies["pip"] = runtime_dict["dependencies"]["pip"]
|
126
126
|
|
127
|
+
if model_meta.env.artifact_repository_map:
|
128
|
+
dependencies["artifact_repository_map"] = runtime_dict["dependencies"]["artifact_repository_map"]
|
129
|
+
|
127
130
|
manifest_dict = model_manifest_schema.ModelManifestDict(
|
128
131
|
manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
|
129
132
|
runtimes={
|
@@ -145,7 +148,7 @@ class ModelManifest:
|
|
145
148
|
],
|
146
149
|
)
|
147
150
|
|
148
|
-
if self.
|
151
|
+
if self.user_files:
|
149
152
|
manifest_dict["user_files"] = [user_file.save(self.workspace_path) for user_file in self.user_files]
|
150
153
|
|
151
154
|
lineage_sources = self._extract_lineage_info(data_sources)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# This files contains schema definition of what will be written into MANIFEST.yml
|
2
2
|
import enum
|
3
|
-
from typing import Any, Dict, List, Literal, TypedDict, Union
|
3
|
+
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
|
4
4
|
|
5
5
|
from typing_extensions import NotRequired, Required
|
6
6
|
|
@@ -20,6 +20,7 @@ class ModelMethodFunctionTypes(enum.Enum):
|
|
20
20
|
class ModelRuntimeDependenciesDict(TypedDict):
|
21
21
|
conda: NotRequired[str]
|
22
22
|
pip: NotRequired[str]
|
23
|
+
artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
|
23
24
|
|
24
25
|
|
25
26
|
class ModelRuntimeDict(TypedDict):
|
@@ -98,7 +98,6 @@ class ModelMethod:
|
|
98
98
|
def _get_method_arg_from_feature(
|
99
99
|
feature: model_signature.BaseFeatureSpec, case_sensitive: bool = False
|
100
100
|
) -> model_manifest_schema.ModelMethodSignatureFieldWithName:
|
101
|
-
assert isinstance(feature, model_signature.FeatureSpec), "FeatureGroupSpec is not supported."
|
102
101
|
try:
|
103
102
|
feature_name = sql_identifier.SqlIdentifier(feature.name, case_sensitive=case_sensitive)
|
104
103
|
except ValueError as e:
|
@@ -3,7 +3,7 @@ import itertools
|
|
3
3
|
import os
|
4
4
|
import pathlib
|
5
5
|
import warnings
|
6
|
-
from typing import DefaultDict, List, Optional
|
6
|
+
from typing import DefaultDict, Dict, List, Optional
|
7
7
|
|
8
8
|
from packaging import requirements, version
|
9
9
|
|
@@ -36,6 +36,7 @@ class ModelEnv:
|
|
36
36
|
pip_requirements_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_PIP_REQUIREMENTS_FILENAME)
|
37
37
|
self.conda_env_rel_path = pathlib.PurePosixPath(pathlib.Path(conda_env_rel_path).as_posix())
|
38
38
|
self.pip_requirements_rel_path = pathlib.PurePosixPath(pathlib.Path(pip_requirements_rel_path).as_posix())
|
39
|
+
self.artifact_repository_map: Optional[Dict[str, str]] = None
|
39
40
|
self._conda_dependencies: DefaultDict[str, List[requirements.Requirement]] = collections.defaultdict(list)
|
40
41
|
self._pip_requirements: List[requirements.Requirement] = []
|
41
42
|
self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
|
@@ -345,6 +346,7 @@ class ModelEnv:
|
|
345
346
|
def load_from_dict(self, base_dir: pathlib.Path, env_dict: model_meta_schema.ModelEnvDict) -> None:
|
346
347
|
self.conda_env_rel_path = pathlib.PurePosixPath(env_dict["conda"])
|
347
348
|
self.pip_requirements_rel_path = pathlib.PurePosixPath(env_dict["pip"])
|
349
|
+
self.artifact_repository_map = env_dict.get("artifact_repository_map", None)
|
348
350
|
|
349
351
|
self.load_from_conda_file(base_dir / self.conda_env_rel_path)
|
350
352
|
self.load_from_pip_file(base_dir / self.pip_requirements_rel_path)
|
@@ -373,6 +375,7 @@ class ModelEnv:
|
|
373
375
|
return {
|
374
376
|
"conda": self.conda_env_rel_path.as_posix(),
|
375
377
|
"pip": self.pip_requirements_rel_path.as_posix(),
|
378
|
+
"artifact_repository_map": self.artifact_repository_map if self.artifact_repository_map is not None else {},
|
376
379
|
"python_version": self.python_version,
|
377
380
|
"cuda_version": self.cuda_version,
|
378
381
|
"snowpark_ml_version": self.snowpark_ml_version,
|