snowflake-ml-python 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/url.py +42 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +2 -0
- snowflake/ml/jobs/_utils/payload_utils.py +38 -18
- snowflake/ml/jobs/_utils/query_helper.py +8 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +58 -4
- snowflake/ml/jobs/_utils/spec_utils.py +0 -31
- snowflake/ml/jobs/_utils/stage_utils.py +2 -2
- snowflake/ml/jobs/_utils/types.py +22 -2
- snowflake/ml/jobs/job_definition.py +232 -0
- snowflake/ml/jobs/manager.py +16 -177
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
- snowflake/ml/model/_client/model/model_version_impl.py +109 -32
- snowflake/ml/model/_client/ops/deployment_step.py +36 -0
- snowflake/ml/model/_client/ops/model_ops.py +45 -2
- snowflake/ml/model/_client/ops/param_utils.py +124 -0
- snowflake/ml/model/_client/ops/service_ops.py +81 -61
- snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +24 -9
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +4 -0
- snowflake/ml/model/_client/sql/model_version.py +30 -6
- snowflake/ml/model/_client/sql/service.py +30 -29
- snowflake/ml/model/_model_composer/model_composer.py +1 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/model_method.py +62 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
- snowflake/ml/model/_packager/model_handlers/huggingface.py +54 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +52 -16
- snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
- snowflake/ml/model/_packager/model_packager.py +1 -1
- snowflake/ml/model/_signatures/core.py +85 -0
- snowflake/ml/model/_signatures/utils.py +55 -0
- snowflake/ml/model/code_path.py +104 -0
- snowflake/ml/model/custom_model.py +55 -13
- snowflake/ml/model/model_signature.py +13 -1
- snowflake/ml/model/openai_signatures.py +97 -0
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/registry/_manager/model_manager.py +230 -15
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
- snowflake/ml/registry/registry.py +4 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/METADATA +95 -1
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/RECORD +52 -46
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import enum
|
|
2
2
|
import json
|
|
3
|
+
import logging
|
|
3
4
|
import os
|
|
4
5
|
import pathlib
|
|
5
6
|
import tempfile
|
|
@@ -11,9 +12,9 @@ from typing_extensions import NotRequired
|
|
|
11
12
|
|
|
12
13
|
from snowflake.ml._internal import platform_capabilities
|
|
13
14
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
|
14
|
-
from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
|
|
15
|
+
from snowflake.ml._internal.utils import formatting, identifier, sql_identifier, url
|
|
15
16
|
from snowflake.ml.model import model_signature, type_hints
|
|
16
|
-
from snowflake.ml.model._client.ops import metadata_ops
|
|
17
|
+
from snowflake.ml.model._client.ops import deployment_step, metadata_ops, param_utils
|
|
17
18
|
from snowflake.ml.model._client.sql import (
|
|
18
19
|
model as model_sql,
|
|
19
20
|
model_version as model_version_sql,
|
|
@@ -33,6 +34,8 @@ from snowflake.ml.model._signatures import snowpark_handler
|
|
|
33
34
|
from snowflake.snowpark import dataframe, row, session
|
|
34
35
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
35
36
|
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
36
39
|
|
|
37
40
|
# An enum class to represent Create Or Alter Model SQL command.
|
|
38
41
|
class ModelAction(enum.Enum):
|
|
@@ -986,6 +989,7 @@ class ModelOperator:
|
|
|
986
989
|
statement_params: Optional[dict[str, str]] = None,
|
|
987
990
|
is_partitioned: Optional[bool] = None,
|
|
988
991
|
explain_case_sensitive: bool = False,
|
|
992
|
+
params: Optional[dict[str, Any]] = None,
|
|
989
993
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
|
990
994
|
...
|
|
991
995
|
|
|
@@ -1002,6 +1006,7 @@ class ModelOperator:
|
|
|
1002
1006
|
strict_input_validation: bool = False,
|
|
1003
1007
|
statement_params: Optional[dict[str, str]] = None,
|
|
1004
1008
|
explain_case_sensitive: bool = False,
|
|
1009
|
+
params: Optional[dict[str, Any]] = None,
|
|
1005
1010
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
|
1006
1011
|
...
|
|
1007
1012
|
|
|
@@ -1022,6 +1027,7 @@ class ModelOperator:
|
|
|
1022
1027
|
statement_params: Optional[dict[str, str]] = None,
|
|
1023
1028
|
is_partitioned: Optional[bool] = None,
|
|
1024
1029
|
explain_case_sensitive: bool = False,
|
|
1030
|
+
params: Optional[dict[str, Any]] = None,
|
|
1025
1031
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
|
1026
1032
|
identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
|
|
1027
1033
|
|
|
@@ -1057,6 +1063,8 @@ class ModelOperator:
|
|
|
1057
1063
|
col_name = sql_identifier.SqlIdentifier(input_feature.name.upper(), case_sensitive=True)
|
|
1058
1064
|
input_args.append(col_name)
|
|
1059
1065
|
|
|
1066
|
+
method_parameters = param_utils.validate_and_resolve_params(params, signature.params)
|
|
1067
|
+
|
|
1060
1068
|
returns = []
|
|
1061
1069
|
for output_feature in signature.outputs:
|
|
1062
1070
|
output_name = identifier_rule.get_sql_identifier_from_feature(output_feature.name)
|
|
@@ -1075,6 +1083,7 @@ class ModelOperator:
|
|
|
1075
1083
|
schema_name=schema_name,
|
|
1076
1084
|
service_name=service_name,
|
|
1077
1085
|
statement_params=statement_params,
|
|
1086
|
+
params=method_parameters,
|
|
1078
1087
|
)
|
|
1079
1088
|
else:
|
|
1080
1089
|
assert model_name is not None
|
|
@@ -1090,6 +1099,7 @@ class ModelOperator:
|
|
|
1090
1099
|
model_name=model_name,
|
|
1091
1100
|
version_name=version_name,
|
|
1092
1101
|
statement_params=statement_params,
|
|
1102
|
+
params=method_parameters,
|
|
1093
1103
|
)
|
|
1094
1104
|
elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
|
1095
1105
|
df_res = self._model_version_client.invoke_table_function_method(
|
|
@@ -1105,6 +1115,7 @@ class ModelOperator:
|
|
|
1105
1115
|
statement_params=statement_params,
|
|
1106
1116
|
is_partitioned=is_partitioned or False,
|
|
1107
1117
|
explain_case_sensitive=explain_case_sensitive,
|
|
1118
|
+
params=method_parameters,
|
|
1108
1119
|
)
|
|
1109
1120
|
|
|
1110
1121
|
if keep_order:
|
|
@@ -1238,3 +1249,35 @@ class ModelOperator:
|
|
|
1238
1249
|
target_path=local_file_dir,
|
|
1239
1250
|
statement_params=statement_params,
|
|
1240
1251
|
)
|
|
1252
|
+
|
|
1253
|
+
def run_import_model_query(
|
|
1254
|
+
self,
|
|
1255
|
+
*,
|
|
1256
|
+
database_name: str,
|
|
1257
|
+
schema_name: str,
|
|
1258
|
+
yaml_content: str,
|
|
1259
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
1260
|
+
) -> None:
|
|
1261
|
+
yaml_content_escaped = snowpark_utils.escape_single_quotes(yaml_content) # type: ignore[no-untyped-call]
|
|
1262
|
+
|
|
1263
|
+
async_job = self._session.sql(
|
|
1264
|
+
f"SELECT SYSTEM$IMPORT_MODEL('{yaml_content_escaped}')",
|
|
1265
|
+
).collect(block=False, statement_params=statement_params)
|
|
1266
|
+
query_id = async_job.query_id # type: ignore[attr-defined]
|
|
1267
|
+
|
|
1268
|
+
logger.info(f"Remotely importing model, with the query id: {query_id}")
|
|
1269
|
+
model_logger_service_name = sql_identifier.SqlIdentifier(
|
|
1270
|
+
deployment_step.get_service_id_from_deployment_step(
|
|
1271
|
+
query_id,
|
|
1272
|
+
deployment_step.DeploymentStep.MODEL_LOGGING,
|
|
1273
|
+
)
|
|
1274
|
+
)
|
|
1275
|
+
|
|
1276
|
+
logger_name = model_logger_service_name.identifier()
|
|
1277
|
+
job_url = f"{url.JOB_URL_PREFIX}/{database_name}/{schema_name}/{logger_name}"
|
|
1278
|
+
snowflake_url = url.get_snowflake_url(session=self._session, url_path=job_url)
|
|
1279
|
+
logger.info(
|
|
1280
|
+
f"To monitor the progress of the model logging job, head to the job monitoring page {snowflake_url}"
|
|
1281
|
+
)
|
|
1282
|
+
|
|
1283
|
+
async_job.result() # type: ignore[attr-defined]
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""Utility functions for model parameter validation and resolution."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional, Sequence
|
|
4
|
+
|
|
5
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
|
6
|
+
from snowflake.ml._internal.utils import sql_identifier
|
|
7
|
+
from snowflake.ml.model._signatures import core
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def validate_params(
|
|
11
|
+
params: Optional[dict[str, Any]],
|
|
12
|
+
signature_params: Optional[Sequence[core.BaseParamSpec]],
|
|
13
|
+
) -> None:
|
|
14
|
+
"""Validate user-provided params against signature params.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
params: User-provided parameter dictionary (runtime values).
|
|
18
|
+
signature_params: Parameter specifications from the model signature.
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
SnowflakeMLException: If params are provided but signature has no params,
|
|
22
|
+
or if unknown params are provided, or if param types are invalid,
|
|
23
|
+
or if duplicate params are provided with different cases.
|
|
24
|
+
"""
|
|
25
|
+
# Params provided but signature has no params defined
|
|
26
|
+
if params and not signature_params:
|
|
27
|
+
raise exceptions.SnowflakeMLException(
|
|
28
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
|
29
|
+
original_exception=ValueError(
|
|
30
|
+
f"Parameters were provided ({sorted(params.keys())}), "
|
|
31
|
+
"but this method does not accept any parameters."
|
|
32
|
+
),
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if not signature_params or not params:
|
|
36
|
+
return
|
|
37
|
+
|
|
38
|
+
# Case-insensitive lookup: normalized_name -> param_spec
|
|
39
|
+
param_spec_lookup = {ps.name.upper(): ps for ps in signature_params}
|
|
40
|
+
|
|
41
|
+
# Check for duplicate params with different cases (e.g., "temperature" and "TEMPERATURE")
|
|
42
|
+
normalized_names = [name.upper() for name in params]
|
|
43
|
+
if len(normalized_names) != len(set(normalized_names)):
|
|
44
|
+
# Find the duplicate params to raise an error
|
|
45
|
+
param_seen: dict[str, list[str]] = {}
|
|
46
|
+
for param_name in params:
|
|
47
|
+
param_seen.setdefault(param_name.upper(), []).append(param_name)
|
|
48
|
+
duplicate_param_names = [param_names for param_names in param_seen.values() if len(param_names) > 1]
|
|
49
|
+
raise exceptions.SnowflakeMLException(
|
|
50
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
|
51
|
+
original_exception=ValueError(
|
|
52
|
+
f"Duplicate parameter(s) provided with different cases: {duplicate_param_names}. "
|
|
53
|
+
"Parameter names are case-insensitive."
|
|
54
|
+
),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Validate user-provided params exist (case-insensitive)
|
|
58
|
+
invalid_params = [name for name in params if name.upper() not in param_spec_lookup]
|
|
59
|
+
if invalid_params:
|
|
60
|
+
raise exceptions.SnowflakeMLException(
|
|
61
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
|
62
|
+
original_exception=ValueError(
|
|
63
|
+
f"Unknown parameter(s): {sorted(invalid_params)}. "
|
|
64
|
+
f"Valid parameters are: {sorted(ps.name for ps in signature_params)}"
|
|
65
|
+
),
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Validate types for each provided param
|
|
69
|
+
for param_name, default_value in params.items():
|
|
70
|
+
param_spec = param_spec_lookup[param_name.upper()]
|
|
71
|
+
if isinstance(param_spec, core.ParamSpec):
|
|
72
|
+
core.ParamSpec._validate_default_value(param_spec.dtype, default_value, param_spec.shape)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def resolve_params(
|
|
76
|
+
params: Optional[dict[str, Any]],
|
|
77
|
+
signature_params: Sequence[core.BaseParamSpec],
|
|
78
|
+
) -> list[tuple[sql_identifier.SqlIdentifier, Any]]:
|
|
79
|
+
"""Resolve final method parameters by applying user-provided params over signature defaults.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
params: User-provided parameter dictionary (runtime values).
|
|
83
|
+
signature_params: Parameter specifications from the model signature.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
List of tuples (SqlIdentifier, value) for method invocation.
|
|
87
|
+
"""
|
|
88
|
+
# Case-insensitive lookup: normalized_name -> param_spec
|
|
89
|
+
param_spec_lookup = {ps.name.upper(): ps for ps in signature_params}
|
|
90
|
+
|
|
91
|
+
# Start with defaults from signature
|
|
92
|
+
final_params: dict[str, Any] = {}
|
|
93
|
+
for param_spec in signature_params:
|
|
94
|
+
if hasattr(param_spec, "default_value"):
|
|
95
|
+
final_params[param_spec.name] = param_spec.default_value
|
|
96
|
+
|
|
97
|
+
# Override with provided runtime parameters (using signature's original param names)
|
|
98
|
+
if params:
|
|
99
|
+
for param_name, override_value in params.items():
|
|
100
|
+
canonical_name = param_spec_lookup[param_name.upper()].name
|
|
101
|
+
final_params[canonical_name] = override_value
|
|
102
|
+
|
|
103
|
+
return [(sql_identifier.SqlIdentifier(param_name), param_value) for param_name, param_value in final_params.items()]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def validate_and_resolve_params(
|
|
107
|
+
params: Optional[dict[str, Any]],
|
|
108
|
+
signature_params: Optional[Sequence[core.BaseParamSpec]],
|
|
109
|
+
) -> Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]]:
|
|
110
|
+
"""Validate user-provided params against signature params and return method parameters.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
params: User-provided parameter dictionary (runtime values).
|
|
114
|
+
signature_params: Parameter specifications from the model signature.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
List of tuples (SqlIdentifier, value) for method invocation, or None if no params.
|
|
118
|
+
"""
|
|
119
|
+
validate_params(params, signature_params)
|
|
120
|
+
|
|
121
|
+
if not signature_params:
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
return resolve_params(params, signature_params)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
+
import base64
|
|
1
2
|
import dataclasses
|
|
2
|
-
import
|
|
3
|
-
import hashlib
|
|
3
|
+
import json
|
|
4
4
|
import logging
|
|
5
5
|
import pathlib
|
|
6
6
|
import re
|
|
@@ -8,7 +8,9 @@ import tempfile
|
|
|
8
8
|
import threading
|
|
9
9
|
import time
|
|
10
10
|
import warnings
|
|
11
|
-
from typing import Any, Optional, Union, cast
|
|
11
|
+
from typing import Any, Optional, Sequence, Union, cast
|
|
12
|
+
|
|
13
|
+
from pydantic import TypeAdapter
|
|
12
14
|
|
|
13
15
|
from snowflake import snowpark
|
|
14
16
|
from snowflake.ml._internal import file_utils, platform_capabilities as pc
|
|
@@ -16,8 +18,10 @@ from snowflake.ml._internal.utils import identifier, service_logger, sql_identif
|
|
|
16
18
|
from snowflake.ml.jobs import job
|
|
17
19
|
from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
|
|
18
20
|
from snowflake.ml.model._client.model import batch_inference_specs
|
|
21
|
+
from snowflake.ml.model._client.ops import deployment_step, param_utils
|
|
19
22
|
from snowflake.ml.model._client.service import model_deployment_spec
|
|
20
23
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
|
24
|
+
from snowflake.ml.model._signatures import core
|
|
21
25
|
from snowflake.snowpark import async_job, exceptions, row, session
|
|
22
26
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
23
27
|
|
|
@@ -25,32 +29,12 @@ module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY
|
|
|
25
29
|
module_logger.propagate = False
|
|
26
30
|
|
|
27
31
|
|
|
28
|
-
class DeploymentStep(enum.Enum):
|
|
29
|
-
MODEL_BUILD = ("model-build", "model_build_")
|
|
30
|
-
MODEL_INFERENCE = ("model-inference", None)
|
|
31
|
-
MODEL_LOGGING = ("model-logging", "model_logging_")
|
|
32
|
-
|
|
33
|
-
def __init__(self, container_name: str, service_name_prefix: Optional[str]) -> None:
|
|
34
|
-
self._container_name = container_name
|
|
35
|
-
self._service_name_prefix = service_name_prefix
|
|
36
|
-
|
|
37
|
-
@property
|
|
38
|
-
def container_name(self) -> str:
|
|
39
|
-
"""Get the container name for the deployment step."""
|
|
40
|
-
return self._container_name
|
|
41
|
-
|
|
42
|
-
@property
|
|
43
|
-
def service_name_prefix(self) -> Optional[str]:
|
|
44
|
-
"""Get the service name prefix for the deployment step."""
|
|
45
|
-
return self._service_name_prefix
|
|
46
|
-
|
|
47
|
-
|
|
48
32
|
@dataclasses.dataclass
|
|
49
33
|
class ServiceLogInfo:
|
|
50
34
|
database_name: Optional[sql_identifier.SqlIdentifier]
|
|
51
35
|
schema_name: Optional[sql_identifier.SqlIdentifier]
|
|
52
36
|
service_name: sql_identifier.SqlIdentifier
|
|
53
|
-
deployment_step: DeploymentStep
|
|
37
|
+
deployment_step: deployment_step.DeploymentStep
|
|
54
38
|
instance_id: str = "0"
|
|
55
39
|
log_color: service_logger.LogColor = service_logger.LogColor.GREY
|
|
56
40
|
|
|
@@ -353,13 +337,16 @@ class ServiceOperator:
|
|
|
353
337
|
if is_enable_image_build:
|
|
354
338
|
# stream service logs in a thread
|
|
355
339
|
model_build_service_name = sql_identifier.SqlIdentifier(
|
|
356
|
-
|
|
340
|
+
deployment_step.get_service_id_from_deployment_step(
|
|
341
|
+
query_id,
|
|
342
|
+
deployment_step.DeploymentStep.MODEL_BUILD,
|
|
343
|
+
)
|
|
357
344
|
)
|
|
358
345
|
model_build_service = ServiceLogInfo(
|
|
359
346
|
database_name=service_database_name,
|
|
360
347
|
schema_name=service_schema_name,
|
|
361
348
|
service_name=model_build_service_name,
|
|
362
|
-
deployment_step=DeploymentStep.MODEL_BUILD,
|
|
349
|
+
deployment_step=deployment_step.DeploymentStep.MODEL_BUILD,
|
|
363
350
|
log_color=service_logger.LogColor.GREEN,
|
|
364
351
|
)
|
|
365
352
|
|
|
@@ -367,21 +354,23 @@ class ServiceOperator:
|
|
|
367
354
|
database_name=service_database_name,
|
|
368
355
|
schema_name=service_schema_name,
|
|
369
356
|
service_name=service_name,
|
|
370
|
-
deployment_step=DeploymentStep.MODEL_INFERENCE,
|
|
357
|
+
deployment_step=deployment_step.DeploymentStep.MODEL_INFERENCE,
|
|
371
358
|
log_color=service_logger.LogColor.BLUE,
|
|
372
359
|
)
|
|
373
360
|
|
|
374
361
|
model_logger_service: Optional[ServiceLogInfo] = None
|
|
375
362
|
if hf_model_args:
|
|
376
363
|
model_logger_service_name = sql_identifier.SqlIdentifier(
|
|
377
|
-
|
|
364
|
+
deployment_step.get_service_id_from_deployment_step(
|
|
365
|
+
query_id, deployment_step.DeploymentStep.MODEL_LOGGING
|
|
366
|
+
)
|
|
378
367
|
)
|
|
379
368
|
|
|
380
369
|
model_logger_service = ServiceLogInfo(
|
|
381
370
|
database_name=service_database_name,
|
|
382
371
|
schema_name=service_schema_name,
|
|
383
372
|
service_name=model_logger_service_name,
|
|
384
|
-
deployment_step=DeploymentStep.MODEL_LOGGING,
|
|
373
|
+
deployment_step=deployment_step.DeploymentStep.MODEL_LOGGING,
|
|
385
374
|
log_color=service_logger.LogColor.ORANGE,
|
|
386
375
|
)
|
|
387
376
|
|
|
@@ -536,7 +525,7 @@ class ServiceOperator:
|
|
|
536
525
|
service = service_log_meta.service
|
|
537
526
|
# check if using an existing model build image
|
|
538
527
|
if (
|
|
539
|
-
service.deployment_step == DeploymentStep.MODEL_BUILD
|
|
528
|
+
service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD
|
|
540
529
|
and not force_rebuild
|
|
541
530
|
and service_log_meta.is_model_logger_service_done
|
|
542
531
|
and not service_log_meta.is_model_build_service_done
|
|
@@ -582,31 +571,26 @@ class ServiceOperator:
|
|
|
582
571
|
if (service_status != service_sql.ServiceStatus.RUNNING) or (service_status != service_log_meta.service_status):
|
|
583
572
|
service_log_meta.service_status = service_status
|
|
584
573
|
|
|
585
|
-
if service.deployment_step == DeploymentStep.MODEL_BUILD:
|
|
574
|
+
if service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD:
|
|
586
575
|
module_logger.info(
|
|
587
576
|
f"Image build service {service.display_service_name} is "
|
|
588
577
|
f"{service_log_meta.service_status.value}."
|
|
589
578
|
)
|
|
590
|
-
elif service.deployment_step == DeploymentStep.MODEL_INFERENCE:
|
|
579
|
+
elif service.deployment_step == deployment_step.DeploymentStep.MODEL_INFERENCE:
|
|
591
580
|
module_logger.info(
|
|
592
581
|
f"Inference service {service.display_service_name} is {service_log_meta.service_status.value}."
|
|
593
582
|
)
|
|
594
|
-
elif service.deployment_step == DeploymentStep.MODEL_LOGGING:
|
|
583
|
+
elif service.deployment_step == deployment_step.DeploymentStep.MODEL_LOGGING:
|
|
595
584
|
module_logger.info(
|
|
596
585
|
f"Model logger service {service.display_service_name} is "
|
|
597
586
|
f"{service_log_meta.service_status.value}."
|
|
598
587
|
)
|
|
599
588
|
for status in statuses:
|
|
600
589
|
if status.instance_id is not None:
|
|
601
|
-
instance_status, container_status = None, None
|
|
602
|
-
if status.instance_status is not None:
|
|
603
|
-
instance_status = status.instance_status.value
|
|
604
|
-
if status.container_status is not None:
|
|
605
|
-
container_status = status.container_status.value
|
|
606
590
|
module_logger.info(
|
|
607
591
|
f"Instance[{status.instance_id}]: "
|
|
608
|
-
f"instance status: {instance_status}, "
|
|
609
|
-
f"container status: {container_status}, "
|
|
592
|
+
f"instance status: {status.instance_status}, "
|
|
593
|
+
f"container status: {status.container_status}, "
|
|
610
594
|
f"message: {status.message}"
|
|
611
595
|
)
|
|
612
596
|
time.sleep(5)
|
|
@@ -627,7 +611,7 @@ class ServiceOperator:
|
|
|
627
611
|
if service_status == service_sql.ServiceStatus.DONE:
|
|
628
612
|
# check if model logger service is done
|
|
629
613
|
# and transition the service log metadata to the model image build service
|
|
630
|
-
if service.deployment_step == DeploymentStep.MODEL_LOGGING:
|
|
614
|
+
if service.deployment_step == deployment_step.DeploymentStep.MODEL_LOGGING:
|
|
631
615
|
if model_build_service:
|
|
632
616
|
# building the inference image, transition to the model build service
|
|
633
617
|
service_log_meta.transition_service_log_metadata(
|
|
@@ -648,7 +632,7 @@ class ServiceOperator:
|
|
|
648
632
|
)
|
|
649
633
|
# check if model build service is done
|
|
650
634
|
# and transition the service log metadata to the model inference service
|
|
651
|
-
elif service.deployment_step == DeploymentStep.MODEL_BUILD:
|
|
635
|
+
elif service.deployment_step == deployment_step.DeploymentStep.MODEL_BUILD:
|
|
652
636
|
service_log_meta.transition_service_log_metadata(
|
|
653
637
|
model_inference_service,
|
|
654
638
|
f"Image build service {service.display_service_name} complete.",
|
|
@@ -656,7 +640,7 @@ class ServiceOperator:
|
|
|
656
640
|
is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
|
|
657
641
|
operation_id=operation_id,
|
|
658
642
|
)
|
|
659
|
-
elif service.deployment_step == DeploymentStep.MODEL_INFERENCE:
|
|
643
|
+
elif service.deployment_step == deployment_step.DeploymentStep.MODEL_INFERENCE:
|
|
660
644
|
module_logger.info(f"Inference service {service.display_service_name} is deployed.")
|
|
661
645
|
else:
|
|
662
646
|
module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.")
|
|
@@ -916,19 +900,6 @@ class ServiceOperator:
|
|
|
916
900
|
|
|
917
901
|
time.sleep(2) # Poll every 2 seconds
|
|
918
902
|
|
|
919
|
-
@staticmethod
|
|
920
|
-
def _get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
|
|
921
|
-
"""Get the service ID through the server-side logic."""
|
|
922
|
-
uuid = query_id.replace("-", "")
|
|
923
|
-
big_int = int(uuid, 16)
|
|
924
|
-
md5_hash = hashlib.md5(str(big_int).encode()).hexdigest()
|
|
925
|
-
identifier = md5_hash[:8]
|
|
926
|
-
service_name_prefix = deployment_step.service_name_prefix
|
|
927
|
-
if service_name_prefix is None:
|
|
928
|
-
# raise an exception if the service name prefix is None
|
|
929
|
-
raise ValueError(f"Service name prefix is {service_name_prefix} for deployment step {deployment_step}.")
|
|
930
|
-
return (service_name_prefix + identifier).upper()
|
|
931
|
-
|
|
932
903
|
def _check_if_service_exists(
|
|
933
904
|
self,
|
|
934
905
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
|
@@ -959,6 +930,38 @@ class ServiceOperator:
|
|
|
959
930
|
except exceptions.SnowparkSQLException:
|
|
960
931
|
return False
|
|
961
932
|
|
|
933
|
+
@staticmethod
|
|
934
|
+
def _encode_params(params: Optional[dict[str, Any]]) -> Optional[str]:
|
|
935
|
+
"""Encode params dictionary to a base64 string.
|
|
936
|
+
|
|
937
|
+
Args:
|
|
938
|
+
params: Optional dictionary of model inference parameters.
|
|
939
|
+
|
|
940
|
+
Returns:
|
|
941
|
+
Base64 encoded JSON string of the params, or None if input is None.
|
|
942
|
+
"""
|
|
943
|
+
if params is None:
|
|
944
|
+
return None
|
|
945
|
+
return base64.b64encode(json.dumps(params).encode("utf-8")).decode("utf-8")
|
|
946
|
+
|
|
947
|
+
@staticmethod
|
|
948
|
+
def _encode_column_handling(
|
|
949
|
+
column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
|
|
950
|
+
) -> Optional[str]:
|
|
951
|
+
"""Validate and encode column_handling to a base64 string.
|
|
952
|
+
|
|
953
|
+
Args:
|
|
954
|
+
column_handling: Optional dictionary mapping column names to file encoding options.
|
|
955
|
+
|
|
956
|
+
Returns:
|
|
957
|
+
Base64 encoded JSON string of the column handling options, or None if input is None.
|
|
958
|
+
"""
|
|
959
|
+
if column_handling is None:
|
|
960
|
+
return None
|
|
961
|
+
adapter = TypeAdapter(dict[str, batch_inference_specs.ColumnHandlingOptions])
|
|
962
|
+
validated_input = adapter.validate_python(column_handling)
|
|
963
|
+
return base64.b64encode(adapter.dump_json(validated_input)).decode("utf-8")
|
|
964
|
+
|
|
962
965
|
def invoke_batch_job_method(
|
|
963
966
|
self,
|
|
964
967
|
*,
|
|
@@ -971,6 +974,9 @@ class ServiceOperator:
|
|
|
971
974
|
image_repo_name: Optional[str],
|
|
972
975
|
input_stage_location: str,
|
|
973
976
|
input_file_pattern: str,
|
|
977
|
+
column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
|
|
978
|
+
params: Optional[dict[str, Any]],
|
|
979
|
+
signature_params: Optional[Sequence[core.BaseParamSpec]],
|
|
974
980
|
output_stage_location: str,
|
|
975
981
|
completion_filename: str,
|
|
976
982
|
force_rebuild: bool,
|
|
@@ -981,7 +987,13 @@ class ServiceOperator:
|
|
|
981
987
|
gpu_requests: Optional[str],
|
|
982
988
|
replicas: Optional[int],
|
|
983
989
|
statement_params: Optional[dict[str, Any]] = None,
|
|
990
|
+
inference_engine_args: Optional[InferenceEngineArgs] = None,
|
|
984
991
|
) -> job.MLJob[Any]:
|
|
992
|
+
# Validate and encode params
|
|
993
|
+
param_utils.validate_params(params, signature_params)
|
|
994
|
+
params_encoded = self._encode_params(params)
|
|
995
|
+
column_handling_encoded = self._encode_column_handling(column_handling)
|
|
996
|
+
|
|
985
997
|
database_name = self._database_name
|
|
986
998
|
schema_name = self._schema_name
|
|
987
999
|
|
|
@@ -1007,6 +1019,8 @@ class ServiceOperator:
|
|
|
1007
1019
|
max_batch_rows=max_batch_rows,
|
|
1008
1020
|
input_stage_location=input_stage_location,
|
|
1009
1021
|
input_file_pattern=input_file_pattern,
|
|
1022
|
+
column_handling=column_handling_encoded,
|
|
1023
|
+
params=params_encoded,
|
|
1010
1024
|
output_stage_location=output_stage_location,
|
|
1011
1025
|
completion_filename=completion_filename,
|
|
1012
1026
|
function_name=function_name,
|
|
@@ -1017,11 +1031,17 @@ class ServiceOperator:
|
|
|
1017
1031
|
replicas=replicas,
|
|
1018
1032
|
)
|
|
1019
1033
|
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1034
|
+
if inference_engine_args:
|
|
1035
|
+
self._model_deployment_spec.add_inference_engine_spec(
|
|
1036
|
+
inference_engine=inference_engine_args.inference_engine,
|
|
1037
|
+
inference_engine_args=inference_engine_args.inference_engine_args_override,
|
|
1038
|
+
)
|
|
1039
|
+
else:
|
|
1040
|
+
self._model_deployment_spec.add_image_build_spec(
|
|
1041
|
+
image_build_compute_pool_name=compute_pool_name,
|
|
1042
|
+
fully_qualified_image_repo_name=self._get_image_repo_fqn(image_repo_name, database_name, schema_name),
|
|
1043
|
+
force_rebuild=force_rebuild,
|
|
1044
|
+
)
|
|
1025
1045
|
|
|
1026
1046
|
spec_yaml_str_or_path = self._model_deployment_spec.save()
|
|
1027
1047
|
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
from snowflake.ml.model._client.service import model_deployment_spec_schema
|
|
6
|
+
|
|
7
|
+
BaseModel.model_config["protected_namespaces"] = ()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ModelName(BaseModel):
|
|
11
|
+
model_name: str
|
|
12
|
+
version_name: str
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ModelSpec(BaseModel):
|
|
16
|
+
name: ModelName
|
|
17
|
+
hf_model: Optional[model_deployment_spec_schema.HuggingFaceModel] = None
|
|
18
|
+
log_model_args: Optional[model_deployment_spec_schema.LogModelArgs] = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ImportModelSpec(BaseModel):
|
|
22
|
+
compute_pool: str
|
|
23
|
+
models: list[ModelSpec]
|
|
@@ -195,6 +195,7 @@ class ModelDeploymentSpec:
|
|
|
195
195
|
|
|
196
196
|
def add_job_spec(
|
|
197
197
|
self,
|
|
198
|
+
*,
|
|
198
199
|
job_name: sql_identifier.SqlIdentifier,
|
|
199
200
|
inference_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
200
201
|
function_name: str,
|
|
@@ -202,6 +203,8 @@ class ModelDeploymentSpec:
|
|
|
202
203
|
output_stage_location: str,
|
|
203
204
|
completion_filename: str,
|
|
204
205
|
input_file_pattern: str,
|
|
206
|
+
column_handling: Optional[str] = None,
|
|
207
|
+
params: Optional[str] = None,
|
|
205
208
|
warehouse: sql_identifier.SqlIdentifier,
|
|
206
209
|
job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
207
210
|
job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
@@ -217,14 +220,16 @@ class ModelDeploymentSpec:
|
|
|
217
220
|
Args:
|
|
218
221
|
job_name: Name of the job.
|
|
219
222
|
inference_compute_pool_name: Compute pool for inference.
|
|
220
|
-
warehouse: Warehouse for the job.
|
|
221
223
|
function_name: Function name.
|
|
222
224
|
input_stage_location: Stage location for input data.
|
|
223
225
|
output_stage_location: Stage location for output data.
|
|
226
|
+
completion_filename: Name of completion file (default: "completion.txt").
|
|
227
|
+
input_file_pattern: Pattern for input files (optional).
|
|
228
|
+
column_handling: Column handling mode for input data.
|
|
229
|
+
params: Additional parameters for the job.
|
|
230
|
+
warehouse: Warehouse for the job.
|
|
224
231
|
job_database_name: Database name for the job.
|
|
225
232
|
job_schema_name: Schema name for the job.
|
|
226
|
-
input_file_pattern: Pattern for input files (optional).
|
|
227
|
-
completion_filename: Name of completion file (default: "completion.txt").
|
|
228
233
|
cpu: CPU requirement.
|
|
229
234
|
memory: Memory requirement.
|
|
230
235
|
gpu: GPU requirement.
|
|
@@ -259,7 +264,10 @@ class ModelDeploymentSpec:
|
|
|
259
264
|
warehouse=warehouse.identifier() if warehouse else None,
|
|
260
265
|
function_name=function_name,
|
|
261
266
|
input=model_deployment_spec_schema.Input(
|
|
262
|
-
input_stage_location=input_stage_location,
|
|
267
|
+
input_stage_location=input_stage_location,
|
|
268
|
+
input_file_pattern=input_file_pattern,
|
|
269
|
+
column_handling=column_handling,
|
|
270
|
+
params=params,
|
|
263
271
|
),
|
|
264
272
|
output=model_deployment_spec_schema.Output(
|
|
265
273
|
output_stage_location=output_stage_location,
|
|
@@ -355,7 +363,7 @@ class ModelDeploymentSpec:
|
|
|
355
363
|
inference_engine: inference_engine_module.InferenceEngine,
|
|
356
364
|
inference_engine_args: Optional[list[str]] = None,
|
|
357
365
|
) -> "ModelDeploymentSpec":
|
|
358
|
-
"""Add inference engine specification. This must be called after self.add_service_spec().
|
|
366
|
+
"""Add inference engine specification. This must be called after self.add_service_spec() or self.add_job_spec().
|
|
359
367
|
|
|
360
368
|
Args:
|
|
361
369
|
inference_engine: Inference engine.
|
|
@@ -368,9 +376,10 @@ class ModelDeploymentSpec:
|
|
|
368
376
|
ValueError: If inference engine specification is called before add_service_spec().
|
|
369
377
|
ValueError: If the argument does not have a '--' prefix.
|
|
370
378
|
"""
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
379
|
+
if self._service is None and self._job is None:
|
|
380
|
+
raise ValueError(
|
|
381
|
+
"Inference engine specification must be called after add_service_spec() or add_job_spec()."
|
|
382
|
+
)
|
|
374
383
|
|
|
375
384
|
if inference_engine_args is None:
|
|
376
385
|
inference_engine_args = []
|
|
@@ -423,11 +432,17 @@ class ModelDeploymentSpec:
|
|
|
423
432
|
|
|
424
433
|
inference_engine_args = filtered_args
|
|
425
434
|
|
|
426
|
-
|
|
435
|
+
inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec(
|
|
427
436
|
# convert to string to be saved in the deployment spec
|
|
428
437
|
inference_engine_name=inference_engine.value,
|
|
429
438
|
inference_engine_args=inference_engine_args,
|
|
430
439
|
)
|
|
440
|
+
|
|
441
|
+
if self._service:
|
|
442
|
+
self._service.inference_engine_spec = inference_engine_spec
|
|
443
|
+
elif self._job:
|
|
444
|
+
self._job.inference_engine_spec = inference_engine_spec
|
|
445
|
+
|
|
431
446
|
return self
|
|
432
447
|
|
|
433
448
|
def save(self) -> str:
|
|
@@ -39,6 +39,8 @@ class Service(BaseModel):
|
|
|
39
39
|
class Input(BaseModel):
|
|
40
40
|
input_stage_location: str
|
|
41
41
|
input_file_pattern: str
|
|
42
|
+
column_handling: Optional[str] = None
|
|
43
|
+
params: Optional[str] = None
|
|
42
44
|
|
|
43
45
|
|
|
44
46
|
class Output(BaseModel):
|
|
@@ -59,6 +61,7 @@ class Job(BaseModel):
|
|
|
59
61
|
input: Input
|
|
60
62
|
output: Output
|
|
61
63
|
replicas: Optional[int] = None
|
|
64
|
+
inference_engine_spec: Optional[InferenceEngineSpec] = None
|
|
62
65
|
|
|
63
66
|
|
|
64
67
|
class LogModelArgs(BaseModel):
|
|
@@ -74,6 +77,7 @@ class HuggingFaceModel(BaseModel):
|
|
|
74
77
|
task: Optional[str] = None
|
|
75
78
|
tokenizer: Optional[str] = None
|
|
76
79
|
token: Optional[str] = None
|
|
80
|
+
token_secret_object: Optional[str] = None
|
|
77
81
|
trust_remote_code: Optional[bool] = False
|
|
78
82
|
revision: Optional[str] = None
|
|
79
83
|
hf_model_kwargs: Optional[str] = "{}"
|