snowflake-ml-python 1.8.5__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +6 -9
- snowflake/ml/_internal/utils/connection_params.py +196 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +61 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +3 -2
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +89 -40
- snowflake/ml/jobs/_utils/query_helper.py +9 -0
- snowflake/ml/jobs/_utils/scripts/constants.py +19 -3
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -26
- snowflake/ml/jobs/_utils/spec_utils.py +29 -5
- snowflake/ml/jobs/_utils/stage_utils.py +119 -0
- snowflake/ml/jobs/_utils/types.py +5 -1
- snowflake/ml/jobs/decorators.py +20 -28
- snowflake/ml/jobs/job.py +197 -61
- snowflake/ml/jobs/manager.py +253 -121
- snowflake/ml/model/_client/model/model_impl.py +58 -0
- snowflake/ml/model/_client/model/model_version_impl.py +90 -0
- snowflake/ml/model/_client/ops/model_ops.py +18 -6
- snowflake/ml/model/_client/ops/service_ops.py +23 -6
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
- snowflake/ml/model/_client/sql/service.py +68 -20
- snowflake/ml/model/_client/sql/stage.py +5 -2
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_signatures/core.py +24 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/monitoring/explain_visualize.py +2 -2
- snowflake/ml/monitoring/model_monitor.py +0 -4
- snowflake/ml/registry/_manager/model_manager.py +30 -15
- snowflake/ml/registry/registry.py +144 -47
- snowflake/ml/utils/connection_params.py +1 -1
- snowflake/ml/utils/html_utils.py +263 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +64 -19
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +48 -41
- snowflake/ml/monitoring/model_monitor_version.py +0 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,96 @@ class ModelVersion(lineage_node.LineageNode):
|
|
38
38
|
def __init__(self) -> None:
|
39
39
|
raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
|
40
40
|
|
41
|
+
def _repr_html_(self) -> str:
|
42
|
+
"""Generate an HTML representation of the model version.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
str: HTML string containing formatted model version details.
|
46
|
+
"""
|
47
|
+
from snowflake.ml.utils import html_utils
|
48
|
+
|
49
|
+
# Get task
|
50
|
+
try:
|
51
|
+
task = self.get_model_task().value
|
52
|
+
except Exception:
|
53
|
+
task = (
|
54
|
+
html_utils.create_error_message("Not available")
|
55
|
+
.replace('<em style="color: #888; font-style: italic;">', "")
|
56
|
+
.replace("</em>", "")
|
57
|
+
)
|
58
|
+
|
59
|
+
# Get functions info for display
|
60
|
+
try:
|
61
|
+
functions = self.show_functions()
|
62
|
+
if not functions:
|
63
|
+
functions_html = html_utils.create_error_message("No functions available")
|
64
|
+
else:
|
65
|
+
functions_list = []
|
66
|
+
for func in functions:
|
67
|
+
try:
|
68
|
+
sig_html = func["signature"]._repr_html_()
|
69
|
+
except Exception:
|
70
|
+
# Fallback to simple display if can't display signature
|
71
|
+
sig_html = f"<pre style='margin: 5px 0;'>{func['signature']}</pre>"
|
72
|
+
|
73
|
+
function_content = f"""
|
74
|
+
<div style="margin: 5px 0;">
|
75
|
+
<strong>Target Method:</strong> {func['target_method']}
|
76
|
+
</div>
|
77
|
+
<div style="margin: 5px 0;">
|
78
|
+
<strong>Function Type:</strong> {func.get('target_method_function_type', 'N/A')}
|
79
|
+
</div>
|
80
|
+
<div style="margin: 5px 0;">
|
81
|
+
<strong>Partitioned:</strong> {func.get('is_partitioned', False)}
|
82
|
+
</div>
|
83
|
+
<div style="margin: 10px 0;">
|
84
|
+
<strong>Signature:</strong>
|
85
|
+
{sig_html}
|
86
|
+
</div>
|
87
|
+
"""
|
88
|
+
|
89
|
+
functions_list.append(
|
90
|
+
html_utils.create_collapsible_section(
|
91
|
+
title=func["name"], content=function_content, open_by_default=False
|
92
|
+
)
|
93
|
+
)
|
94
|
+
functions_html = "".join(functions_list)
|
95
|
+
except Exception:
|
96
|
+
functions_html = html_utils.create_error_message("Error retrieving functions")
|
97
|
+
|
98
|
+
# Get metrics for display
|
99
|
+
try:
|
100
|
+
metrics = self.show_metrics()
|
101
|
+
if not metrics:
|
102
|
+
metrics_html = html_utils.create_error_message("No metrics available")
|
103
|
+
else:
|
104
|
+
metrics_html = ""
|
105
|
+
for metric_name, value in metrics.items():
|
106
|
+
metrics_html += html_utils.create_metric_item(metric_name, value)
|
107
|
+
except Exception:
|
108
|
+
metrics_html = html_utils.create_error_message("Error retrieving metrics")
|
109
|
+
|
110
|
+
# Create main content sections
|
111
|
+
main_info = html_utils.create_grid_section(
|
112
|
+
[
|
113
|
+
("Model Name", self.model_name),
|
114
|
+
("Version", f'<strong style="color: #28a745;">{self.version_name}</strong>'),
|
115
|
+
("Full Name", self.fully_qualified_model_name),
|
116
|
+
("Description", self.description),
|
117
|
+
("Task", task),
|
118
|
+
]
|
119
|
+
)
|
120
|
+
|
121
|
+
functions_section = html_utils.create_section_header("Functions") + html_utils.create_content_section(
|
122
|
+
functions_html
|
123
|
+
)
|
124
|
+
|
125
|
+
metrics_section = html_utils.create_section_header("Metrics") + html_utils.create_content_section(metrics_html)
|
126
|
+
|
127
|
+
content = main_info + functions_section + metrics_section
|
128
|
+
|
129
|
+
return html_utils.create_base_container("Model Version Details", content)
|
130
|
+
|
41
131
|
@classmethod
|
42
132
|
def _ref(
|
43
133
|
cls,
|
@@ -643,14 +643,17 @@ class ModelOperator:
|
|
643
643
|
# TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
|
644
644
|
fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
|
645
645
|
|
646
|
-
result = []
|
647
|
-
|
646
|
+
result: list[ServiceInfo] = []
|
648
647
|
for fully_qualified_service_name in fully_qualified_service_names:
|
649
648
|
ingress_url: Optional[str] = None
|
650
649
|
db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
|
651
|
-
|
650
|
+
statuses = self._service_client.get_service_container_statuses(
|
652
651
|
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
653
652
|
)
|
653
|
+
if len(statuses) == 0:
|
654
|
+
return result
|
655
|
+
|
656
|
+
service_status = statuses[0].service_status
|
654
657
|
for res_row in self._service_client.show_endpoints(
|
655
658
|
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
656
659
|
):
|
@@ -952,7 +955,7 @@ class ModelOperator:
|
|
952
955
|
output_with_input_features = False
|
953
956
|
df = model_signature._convert_and_validate_local_data(X, signature.inputs, strict=strict_input_validation)
|
954
957
|
s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
|
955
|
-
self._session, df, keep_order=keep_order, features=signature.inputs
|
958
|
+
self._session, df, keep_order=keep_order, features=signature.inputs, statement_params=statement_params
|
956
959
|
)
|
957
960
|
else:
|
958
961
|
keep_order = False
|
@@ -966,9 +969,16 @@ class ModelOperator:
|
|
966
969
|
|
967
970
|
# Compose input and output names
|
968
971
|
input_args = []
|
972
|
+
quoted_identifiers_ignore_case = (
|
973
|
+
snowpark_handler.SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
|
974
|
+
self._session, statement_params
|
975
|
+
)
|
976
|
+
)
|
977
|
+
|
969
978
|
for input_feature in signature.inputs:
|
970
979
|
col_name = identifier_rule.get_sql_identifier_from_feature(input_feature.name)
|
971
|
-
|
980
|
+
if quoted_identifiers_ignore_case:
|
981
|
+
col_name = sql_identifier.SqlIdentifier(input_feature.name.upper(), case_sensitive=True)
|
972
982
|
input_args.append(col_name)
|
973
983
|
|
974
984
|
returns = []
|
@@ -1048,7 +1058,9 @@ class ModelOperator:
|
|
1048
1058
|
|
1049
1059
|
# Get final result
|
1050
1060
|
if not isinstance(X, dataframe.DataFrame):
|
1051
|
-
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
|
1061
|
+
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
|
1062
|
+
df_res, features=signature.outputs, statement_params=statement_params
|
1063
|
+
)
|
1052
1064
|
else:
|
1053
1065
|
return df_res
|
1054
1066
|
|
@@ -325,13 +325,14 @@ class ServiceOperator:
|
|
325
325
|
)
|
326
326
|
continue
|
327
327
|
|
328
|
-
|
328
|
+
statuses = self._service_client.get_service_container_statuses(
|
329
329
|
database_name=service_log_meta.service.database_name,
|
330
330
|
schema_name=service_log_meta.service.schema_name,
|
331
331
|
service_name=service_log_meta.service.service_name,
|
332
332
|
include_message=True,
|
333
333
|
statement_params=statement_params,
|
334
334
|
)
|
335
|
+
service_status = statuses[0].service_status
|
335
336
|
if (service_status != service_sql.ServiceStatus.RUNNING) or (
|
336
337
|
service_status != service_log_meta.service_status
|
337
338
|
):
|
@@ -341,7 +342,19 @@ class ServiceOperator:
|
|
341
342
|
f"{service_log_meta.service.display_service_name} is "
|
342
343
|
f"{service_log_meta.service_status.value}."
|
343
344
|
)
|
344
|
-
|
345
|
+
for status in statuses:
|
346
|
+
if status.instance_id is not None:
|
347
|
+
instance_status, container_status = None, None
|
348
|
+
if status.instance_status is not None:
|
349
|
+
instance_status = status.instance_status.value
|
350
|
+
if status.container_status is not None:
|
351
|
+
container_status = status.container_status.value
|
352
|
+
module_logger.info(
|
353
|
+
f"Instance[{status.instance_id}]: "
|
354
|
+
f"instance status: {instance_status}, "
|
355
|
+
f"container status: {container_status}, "
|
356
|
+
f"message: {status.message}"
|
357
|
+
)
|
345
358
|
|
346
359
|
new_logs, new_offset = fetch_logs(
|
347
360
|
service_log_meta.service,
|
@@ -353,13 +366,14 @@ class ServiceOperator:
|
|
353
366
|
|
354
367
|
# check if model build service is done
|
355
368
|
if not service_log_meta.is_model_build_service_done:
|
356
|
-
|
369
|
+
statuses = self._service_client.get_service_container_statuses(
|
357
370
|
database_name=model_build_service.database_name,
|
358
371
|
schema_name=model_build_service.schema_name,
|
359
372
|
service_name=model_build_service.service_name,
|
360
373
|
include_message=False,
|
361
374
|
statement_params=statement_params,
|
362
375
|
)
|
376
|
+
service_status = statuses[0].service_status
|
363
377
|
|
364
378
|
if service_status == service_sql.ServiceStatus.DONE:
|
365
379
|
set_service_log_metadata_to_model_inference(
|
@@ -436,13 +450,14 @@ class ServiceOperator:
|
|
436
450
|
service_sql.ServiceStatus.FAILED,
|
437
451
|
]
|
438
452
|
try:
|
439
|
-
|
453
|
+
statuses = self._service_client.get_service_container_statuses(
|
440
454
|
database_name=database_name,
|
441
455
|
schema_name=schema_name,
|
442
456
|
service_name=service_name,
|
443
457
|
include_message=False,
|
444
458
|
statement_params=statement_params,
|
445
459
|
)
|
460
|
+
service_status = statuses[0].service_status
|
446
461
|
return any(service_status == status for status in service_status_list_if_exists)
|
447
462
|
except exceptions.SnowparkSQLException:
|
448
463
|
return False
|
@@ -503,7 +518,7 @@ class ServiceOperator:
|
|
503
518
|
output_with_input_features = False
|
504
519
|
df = model_signature._convert_and_validate_local_data(X, signature.inputs)
|
505
520
|
s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
|
506
|
-
self._session, df, keep_order=keep_order, features=signature.inputs
|
521
|
+
self._session, df, keep_order=keep_order, features=signature.inputs, statement_params=statement_params
|
507
522
|
)
|
508
523
|
else:
|
509
524
|
keep_order = False
|
@@ -615,7 +630,9 @@ class ServiceOperator:
|
|
615
630
|
|
616
631
|
# get final result
|
617
632
|
if not isinstance(X, dataframe.DataFrame):
|
618
|
-
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
|
633
|
+
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
|
634
|
+
df_res, features=signature.outputs, statement_params=statement_params
|
635
|
+
)
|
619
636
|
else:
|
620
637
|
return df_res
|
621
638
|
|
@@ -1,4 +1,6 @@
|
|
1
|
+
import dataclasses
|
1
2
|
import enum
|
3
|
+
import logging
|
2
4
|
import textwrap
|
3
5
|
from typing import Any, Optional, Union
|
4
6
|
|
@@ -13,26 +15,59 @@ from snowflake.ml.model._model_composer.model_method import constants
|
|
13
15
|
from snowflake.snowpark import dataframe, functions as F, row, types as spt
|
14
16
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
15
17
|
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
16
20
|
|
17
|
-
# The enum comes from https://docs.snowflake.com/en/sql-reference/sql/show-service-containers-in-service#output
|
18
|
-
# except UNKNOWN
|
19
21
|
class ServiceStatus(enum.Enum):
|
20
|
-
|
21
|
-
PENDING = "PENDING" # resource set is being created, can't be used yet
|
22
|
-
SUSPENDING = "SUSPENDING" # the service is set to suspended but the resource set is still in deleting state
|
23
|
-
SUSPENDED = "SUSPENDED" # the service is suspended and the resource set is deleted
|
24
|
-
DELETING = "DELETING" # resource set is being deleted
|
25
|
-
FAILED = "FAILED" # resource set has failed and cannot be used anymore
|
26
|
-
DONE = "DONE" # resource set has finished running
|
27
|
-
INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error.
|
22
|
+
PENDING = "PENDING"
|
28
23
|
RUNNING = "RUNNING"
|
24
|
+
FAILED = "FAILED"
|
25
|
+
DONE = "DONE"
|
26
|
+
SUSPENDING = "SUSPENDING"
|
27
|
+
SUSPENDED = "SUSPENDED"
|
28
|
+
DELETING = "DELETING"
|
29
29
|
DELETED = "DELETED"
|
30
|
+
INTERNAL_ERROR = "INTERNAL_ERROR"
|
31
|
+
|
32
|
+
|
33
|
+
class InstanceStatus(enum.Enum):
|
34
|
+
PENDING = "PENDING"
|
35
|
+
READY = "READY"
|
36
|
+
FAILED = "FAILED"
|
37
|
+
TERMINATING = "TERMINATING"
|
38
|
+
SUCCEEDED = "SUCCEEDED"
|
39
|
+
|
40
|
+
|
41
|
+
class ContainerStatus(enum.Enum):
|
42
|
+
PENDING = "PENDING"
|
43
|
+
READY = "READY"
|
44
|
+
DONE = "DONE"
|
45
|
+
FAILED = "FAILED"
|
46
|
+
UNKNOWN = "UNKNOWN"
|
47
|
+
|
48
|
+
|
49
|
+
@dataclasses.dataclass
|
50
|
+
class ServiceStatusInfo:
|
51
|
+
"""
|
52
|
+
Class containing information about service container status.
|
53
|
+
Reference: https://docs.snowflake.com/en/sql-reference/sql/show-service-containers-in-service
|
54
|
+
"""
|
55
|
+
|
56
|
+
service_status: ServiceStatus
|
57
|
+
instance_id: Optional[int] = None
|
58
|
+
instance_status: Optional[InstanceStatus] = None
|
59
|
+
container_status: Optional[ContainerStatus] = None
|
60
|
+
message: Optional[str] = None
|
30
61
|
|
31
62
|
|
32
63
|
class ServiceSQLClient(_base._BaseSQLClient):
|
33
64
|
MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
|
34
65
|
MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
|
35
66
|
SERVICE_STATUS = "service_status"
|
67
|
+
INSTANCE_ID = "instance_id"
|
68
|
+
INSTANCE_STATUS = "instance_status"
|
69
|
+
CONTAINER_STATUS = "status"
|
70
|
+
MESSAGE = "message"
|
36
71
|
|
37
72
|
def build_model_container(
|
38
73
|
self,
|
@@ -81,6 +116,10 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
81
116
|
) -> tuple[str, snowpark.AsyncJob]:
|
82
117
|
assert model_deployment_spec_yaml_str or model_deployment_spec_file_rel_path
|
83
118
|
if model_deployment_spec_yaml_str:
|
119
|
+
model_deployment_spec_yaml_str = snowpark_utils.escape_single_quotes(
|
120
|
+
model_deployment_spec_yaml_str
|
121
|
+
) # type: ignore[no-untyped-call]
|
122
|
+
logger.info(f"Deploying model with spec={model_deployment_spec_yaml_str}")
|
84
123
|
sql_str = f"CALL SYSTEM$DEPLOY_MODEL('{model_deployment_spec_yaml_str}')"
|
85
124
|
else:
|
86
125
|
sql_str = f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
|
@@ -192,7 +231,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
192
231
|
)
|
193
232
|
return str(rows[0][system_func])
|
194
233
|
|
195
|
-
def
|
234
|
+
def get_service_container_statuses(
|
196
235
|
self,
|
197
236
|
*,
|
198
237
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
@@ -200,18 +239,27 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
200
239
|
service_name: sql_identifier.SqlIdentifier,
|
201
240
|
include_message: bool = False,
|
202
241
|
statement_params: Optional[dict[str, Any]] = None,
|
203
|
-
) ->
|
242
|
+
) -> list[ServiceStatusInfo]:
|
204
243
|
fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
|
205
244
|
query = f"SHOW SERVICE CONTAINERS IN SERVICE {fully_qualified_object_name}"
|
206
245
|
rows = self._session.sql(query).collect(statement_params=statement_params)
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
246
|
+
statuses = []
|
247
|
+
for r in rows:
|
248
|
+
instance_status, container_status = None, None
|
249
|
+
if r[ServiceSQLClient.INSTANCE_STATUS] is not None:
|
250
|
+
instance_status = InstanceStatus(r[ServiceSQLClient.INSTANCE_STATUS])
|
251
|
+
if r[ServiceSQLClient.CONTAINER_STATUS] is not None:
|
252
|
+
container_status = ContainerStatus(r[ServiceSQLClient.CONTAINER_STATUS])
|
253
|
+
statuses.append(
|
254
|
+
ServiceStatusInfo(
|
255
|
+
service_status=ServiceStatus(r[ServiceSQLClient.SERVICE_STATUS]),
|
256
|
+
instance_id=r[ServiceSQLClient.INSTANCE_ID],
|
257
|
+
instance_status=instance_status,
|
258
|
+
container_status=container_status,
|
259
|
+
message=r[ServiceSQLClient.MESSAGE] if include_message else None,
|
260
|
+
)
|
261
|
+
)
|
262
|
+
return statuses
|
215
263
|
|
216
264
|
def drop_service(
|
217
265
|
self,
|
@@ -12,9 +12,12 @@ class StageSQLClient(_base._BaseSQLClient):
|
|
12
12
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
13
13
|
stage_name: sql_identifier.SqlIdentifier,
|
14
14
|
statement_params: Optional[dict[str, Any]] = None,
|
15
|
-
) ->
|
15
|
+
) -> str:
|
16
|
+
fq_stage_name = self.fully_qualified_object_name(database_name, schema_name, stage_name)
|
16
17
|
query_result_checker.SqlResultValidator(
|
17
18
|
self._session,
|
18
|
-
f"CREATE SCOPED TEMPORARY STAGE {
|
19
|
+
f"CREATE SCOPED TEMPORARY STAGE {fq_stage_name}",
|
19
20
|
statement_params=statement_params,
|
20
21
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
22
|
+
|
23
|
+
return fq_stage_name
|
@@ -7,6 +7,7 @@ from typing import Optional, cast
|
|
7
7
|
import yaml
|
8
8
|
|
9
9
|
from snowflake.ml._internal import env_utils
|
10
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
10
11
|
from snowflake.ml.data import data_source
|
11
12
|
from snowflake.ml.model import type_hints
|
12
13
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
@@ -53,17 +54,44 @@ class ModelManifest:
|
|
53
54
|
if options is None:
|
54
55
|
options = {}
|
55
56
|
|
57
|
+
has_pip_requirements = len(model_meta.env.pip_requirements) > 0
|
58
|
+
only_spcs = (
|
59
|
+
target_platforms
|
60
|
+
and len(target_platforms) == 1
|
61
|
+
and target_platforms[0] == type_hints.TargetPlatform.SNOWPARK_CONTAINER_SERVICES
|
62
|
+
)
|
63
|
+
|
56
64
|
if "relax_version" not in options:
|
57
|
-
|
58
|
-
(
|
59
|
-
"`relax_version`
|
60
|
-
"
|
61
|
-
"
|
62
|
-
)
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
65
|
+
if has_pip_requirements or only_spcs:
|
66
|
+
logger.info(
|
67
|
+
"Setting `relax_version=False` as this model will run in Snowpark Container Services "
|
68
|
+
"or in Warehouse with a specified artifact_repository_map where exact version "
|
69
|
+
" specifications will be honored."
|
70
|
+
)
|
71
|
+
relax_version = False
|
72
|
+
else:
|
73
|
+
warnings.warn(
|
74
|
+
(
|
75
|
+
"`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
|
76
|
+
" relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
|
77
|
+
" reproducibility, etc., set `options={'relax_version': False}` when logging the model."
|
78
|
+
),
|
79
|
+
category=UserWarning,
|
80
|
+
stacklevel=2,
|
81
|
+
)
|
82
|
+
relax_version = True
|
83
|
+
options["relax_version"] = relax_version
|
84
|
+
else:
|
85
|
+
relax_version = options.get("relax_version", True)
|
86
|
+
if relax_version and (has_pip_requirements or only_spcs):
|
87
|
+
raise exceptions.SnowflakeMLException(
|
88
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
89
|
+
original_exception=ValueError(
|
90
|
+
"Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
|
91
|
+
"Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
|
92
|
+
"targeting only Snowpark Container Services."
|
93
|
+
),
|
94
|
+
)
|
67
95
|
|
68
96
|
runtime_to_use = model_runtime.ModelRuntime(
|
69
97
|
name=self._DEFAULT_RUNTIME_NAME,
|
@@ -9,6 +9,7 @@ from packaging import requirements, version
|
|
9
9
|
|
10
10
|
from snowflake.ml import version as snowml_version
|
11
11
|
from snowflake.ml._internal import env as snowml_env, env_utils
|
12
|
+
from snowflake.ml.model import type_hints as model_types
|
12
13
|
from snowflake.ml.model._packager.model_meta import model_meta_schema
|
13
14
|
|
14
15
|
# requirement: Full version requirement where name is conda package name.
|
@@ -30,6 +31,7 @@ class ModelEnv:
|
|
30
31
|
conda_env_rel_path: Optional[str] = None,
|
31
32
|
pip_requirements_rel_path: Optional[str] = None,
|
32
33
|
prefer_pip: bool = False,
|
34
|
+
target_platforms: Optional[list[model_types.TargetPlatform]] = None,
|
33
35
|
) -> None:
|
34
36
|
if conda_env_rel_path is None:
|
35
37
|
conda_env_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_CONDA_ENV_FILENAME)
|
@@ -45,6 +47,8 @@ class ModelEnv:
|
|
45
47
|
self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
|
46
48
|
self._cuda_version: Optional[version.Version] = None
|
47
49
|
self._snowpark_ml_version: version.Version = version.parse(snowml_version.VERSION)
|
50
|
+
self._target_platforms = target_platforms
|
51
|
+
self._warnings_shown: set[str] = set()
|
48
52
|
|
49
53
|
@property
|
50
54
|
def conda_dependencies(self) -> list[str]:
|
@@ -116,6 +120,17 @@ class ModelEnv:
|
|
116
120
|
if snowpark_ml_version:
|
117
121
|
self._snowpark_ml_version = version.parse(snowpark_ml_version)
|
118
122
|
|
123
|
+
@property
|
124
|
+
def targets_warehouse(self) -> bool:
|
125
|
+
"""Returns True if warehouse is a target platform."""
|
126
|
+
return self._target_platforms is None or model_types.TargetPlatform.WAREHOUSE in self._target_platforms
|
127
|
+
|
128
|
+
def _warn_once(self, message: str, stacklevel: int = 2) -> None:
|
129
|
+
"""Show warning only once per ModelEnv instance."""
|
130
|
+
if message not in self._warnings_shown:
|
131
|
+
warnings.warn(message, category=UserWarning, stacklevel=stacklevel)
|
132
|
+
self._warnings_shown.add(message)
|
133
|
+
|
119
134
|
def include_if_absent(
|
120
135
|
self,
|
121
136
|
pkgs: list[ModelDependency],
|
@@ -130,14 +145,14 @@ class ModelEnv:
|
|
130
145
|
"""
|
131
146
|
if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
|
132
147
|
pip_pkg_reqs: list[str] = []
|
133
|
-
|
134
|
-
(
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
148
|
+
if self.targets_warehouse:
|
149
|
+
self._warn_once(
|
150
|
+
(
|
151
|
+
"Dependencies specified from pip requirements."
|
152
|
+
" This may prevent model deploying to Snowflake Warehouse."
|
153
|
+
),
|
154
|
+
stacklevel=2,
|
155
|
+
)
|
141
156
|
for conda_req_str, pip_name in pkgs:
|
142
157
|
_, conda_req = env_utils._validate_conda_dependency_string(conda_req_str)
|
143
158
|
pip_req = requirements.Requirement(f"{pip_name}{conda_req.specifier}")
|
@@ -162,16 +177,15 @@ class ModelEnv:
|
|
162
177
|
req_to_add.name = conda_req.name
|
163
178
|
else:
|
164
179
|
req_to_add = conda_req
|
165
|
-
show_warning_message = conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME
|
180
|
+
show_warning_message = conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME and self.targets_warehouse
|
166
181
|
|
167
182
|
if any(added_pip_req.name == pip_name for added_pip_req in self._pip_requirements):
|
168
183
|
if show_warning_message:
|
169
|
-
|
184
|
+
self._warn_once(
|
170
185
|
(
|
171
186
|
f"Basic dependency {req_to_add.name} specified from pip requirements."
|
172
187
|
" This may prevent model deploying to Snowflake Warehouse."
|
173
188
|
),
|
174
|
-
category=UserWarning,
|
175
189
|
stacklevel=2,
|
176
190
|
)
|
177
191
|
continue
|
@@ -182,12 +196,11 @@ class ModelEnv:
|
|
182
196
|
pass
|
183
197
|
except env_utils.DuplicateDependencyInMultipleChannelsError:
|
184
198
|
if show_warning_message:
|
185
|
-
|
199
|
+
self._warn_once(
|
186
200
|
(
|
187
201
|
f"Basic dependency {req_to_add.name} specified from non-Snowflake channel."
|
188
202
|
+ " This may prevent model deploying to Snowflake Warehouse."
|
189
203
|
),
|
190
|
-
category=UserWarning,
|
191
204
|
stacklevel=2,
|
192
205
|
)
|
193
206
|
|
@@ -272,22 +285,20 @@ class ModelEnv:
|
|
272
285
|
)
|
273
286
|
|
274
287
|
for channel, channel_dependencies in conda_dependencies_dict.items():
|
275
|
-
if channel != env_utils.DEFAULT_CHANNEL_NAME:
|
276
|
-
|
288
|
+
if channel != env_utils.DEFAULT_CHANNEL_NAME and self.targets_warehouse:
|
289
|
+
self._warn_once(
|
277
290
|
(
|
278
291
|
"Found dependencies specified in the conda file from non-Snowflake channel."
|
279
292
|
" This may prevent model deploying to Snowflake Warehouse."
|
280
293
|
),
|
281
|
-
category=UserWarning,
|
282
294
|
stacklevel=2,
|
283
295
|
)
|
284
|
-
if len(channel_dependencies) == 0 and channel not in self._conda_dependencies:
|
285
|
-
|
296
|
+
if len(channel_dependencies) == 0 and channel not in self._conda_dependencies and self.targets_warehouse:
|
297
|
+
self._warn_once(
|
286
298
|
(
|
287
299
|
f"Found additional conda channel {channel} specified in the conda file."
|
288
300
|
" This may prevent model deploying to Snowflake Warehouse."
|
289
301
|
),
|
290
|
-
category=UserWarning,
|
291
302
|
stacklevel=2,
|
292
303
|
)
|
293
304
|
self._conda_dependencies[channel] = []
|
@@ -298,22 +309,20 @@ class ModelEnv:
|
|
298
309
|
except env_utils.DuplicateDependencyError:
|
299
310
|
pass
|
300
311
|
except env_utils.DuplicateDependencyInMultipleChannelsError:
|
301
|
-
|
312
|
+
self._warn_once(
|
302
313
|
(
|
303
314
|
f"Dependency {channel_dependency.name} appeared in multiple channels as conda dependency."
|
304
315
|
" This may be unintentional."
|
305
316
|
),
|
306
|
-
category=UserWarning,
|
307
317
|
stacklevel=2,
|
308
318
|
)
|
309
319
|
|
310
|
-
if pip_requirements_list:
|
311
|
-
|
320
|
+
if pip_requirements_list and self.targets_warehouse:
|
321
|
+
self._warn_once(
|
312
322
|
(
|
313
323
|
"Found dependencies specified as pip requirements."
|
314
324
|
" This may prevent model deploying to Snowflake Warehouse."
|
315
325
|
),
|
316
|
-
category=UserWarning,
|
317
326
|
stacklevel=2,
|
318
327
|
)
|
319
328
|
for pip_dependency in pip_requirements_list:
|
@@ -333,13 +342,12 @@ class ModelEnv:
|
|
333
342
|
def load_from_pip_file(self, pip_requirements_path: pathlib.Path) -> None:
|
334
343
|
pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
|
335
344
|
|
336
|
-
if pip_requirements_list:
|
337
|
-
|
345
|
+
if pip_requirements_list and self.targets_warehouse:
|
346
|
+
self._warn_once(
|
338
347
|
(
|
339
348
|
"Found dependencies specified as pip requirements."
|
340
349
|
" This may prevent model deploying to Snowflake Warehouse."
|
341
350
|
),
|
342
|
-
category=UserWarning,
|
343
351
|
stacklevel=2,
|
344
352
|
)
|
345
353
|
for pip_dependency in pip_requirements_list:
|
@@ -167,7 +167,11 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
167
167
|
model_blob_metadata = model_blobs_metadata[name]
|
168
168
|
model_blob_filename = model_blob_metadata.path
|
169
169
|
with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
|
170
|
-
m = torch.load(
|
170
|
+
m = torch.load(
|
171
|
+
f,
|
172
|
+
map_location="cuda" if kwargs.get("use_gpu", False) else "cpu",
|
173
|
+
weights_only=False,
|
174
|
+
)
|
171
175
|
assert isinstance(m, torch.nn.Module)
|
172
176
|
|
173
177
|
return m
|