snowflake-ml-python 1.14.0__py3-none-any.whl → 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +9 -7
- snowflake/ml/_internal/utils/connection_params.py +5 -3
- snowflake/ml/_internal/utils/jwt_generator.py +3 -2
- snowflake/ml/_internal/utils/temp_file_utils.py +1 -2
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +16 -3
- snowflake/ml/experiment/_entities/__init__.py +2 -1
- snowflake/ml/experiment/_entities/run.py +0 -15
- snowflake/ml/experiment/_entities/run_metadata.py +3 -51
- snowflake/ml/experiment/experiment_tracking.py +8 -8
- snowflake/ml/model/__init__.py +2 -6
- snowflake/ml/model/_client/model/batch_inference_specs.py +0 -4
- snowflake/ml/model/_client/model/inference_engine_utils.py +55 -0
- snowflake/ml/model/_client/model/model_version_impl.py +25 -62
- snowflake/ml/model/_client/ops/service_ops.py +18 -2
- snowflake/ml/model/_client/sql/service.py +29 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +4 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -5
- snowflake/ml/model/_packager/model_packager.py +4 -3
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +0 -1
- snowflake/ml/model/_signatures/utils.py +0 -21
- snowflake/ml/model/models/huggingface_pipeline.py +56 -21
- snowflake/ml/registry/_manager/model_manager.py +1 -1
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/METADATA +42 -35
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/RECORD +30 -29
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import logging
|
|
2
3
|
from contextlib import contextmanager
|
|
3
4
|
from typing import Any, Optional
|
|
4
5
|
|
|
5
|
-
from absl import logging
|
|
6
6
|
from packaging import version
|
|
7
7
|
|
|
8
8
|
from snowflake.ml import version as snowml_version
|
|
@@ -13,6 +13,8 @@ from snowflake.snowpark import (
|
|
|
13
13
|
session as snowpark_session,
|
|
14
14
|
)
|
|
15
15
|
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
16
18
|
LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
|
|
17
19
|
INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
|
|
18
20
|
|
|
@@ -60,12 +62,12 @@ class PlatformCapabilities:
|
|
|
60
62
|
@classmethod # type: ignore[arg-type]
|
|
61
63
|
@contextmanager
|
|
62
64
|
def mock_features(cls, features: dict[str, Any] = _dummy_features) -> None: # type: ignore[misc]
|
|
63
|
-
|
|
65
|
+
logger.debug(f"Setting mock features: {features}")
|
|
64
66
|
cls.set_mock_features(features)
|
|
65
67
|
try:
|
|
66
68
|
yield
|
|
67
69
|
finally:
|
|
68
|
-
|
|
70
|
+
logger.debug(f"Clearing mock features: {features}")
|
|
69
71
|
cls.clear_mock_features()
|
|
70
72
|
|
|
71
73
|
def is_inlined_deployment_spec_enabled(self) -> bool:
|
|
@@ -98,7 +100,7 @@ class PlatformCapabilities:
|
|
|
98
100
|
error_code=error_codes.INTERNAL_SNOWML_ERROR, original_exception=RuntimeError(message)
|
|
99
101
|
)
|
|
100
102
|
except snowpark_exceptions.SnowparkSQLException as e:
|
|
101
|
-
|
|
103
|
+
logger.debug(f"Failed to retrieve platform capabilities: {e}")
|
|
102
104
|
# This can happen is server side is older than 9.2. That is fine.
|
|
103
105
|
return {}
|
|
104
106
|
|
|
@@ -144,7 +146,7 @@ class PlatformCapabilities:
|
|
|
144
146
|
|
|
145
147
|
value = self.features.get(feature_name)
|
|
146
148
|
if value is None:
|
|
147
|
-
|
|
149
|
+
logger.debug(f"Feature {feature_name} not found, returning large version number")
|
|
148
150
|
return large_version
|
|
149
151
|
|
|
150
152
|
try:
|
|
@@ -152,7 +154,7 @@ class PlatformCapabilities:
|
|
|
152
154
|
version_str = str(value)
|
|
153
155
|
return version.Version(version_str)
|
|
154
156
|
except (version.InvalidVersion, ValueError, TypeError) as e:
|
|
155
|
-
|
|
157
|
+
logger.debug(
|
|
156
158
|
f"Failed to parse version from feature {feature_name} with value '{value}': {e}. "
|
|
157
159
|
f"Returning large version number"
|
|
158
160
|
)
|
|
@@ -171,7 +173,7 @@ class PlatformCapabilities:
|
|
|
171
173
|
feature_version = self._get_version_feature(feature_name)
|
|
172
174
|
|
|
173
175
|
result = current_version >= feature_version
|
|
174
|
-
|
|
176
|
+
logger.debug(
|
|
175
177
|
f"Version comparison for feature {feature_name}: "
|
|
176
178
|
f"current={current_version}, feature={feature_version}, enabled={result}"
|
|
177
179
|
)
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
import configparser
|
|
2
|
+
import logging
|
|
2
3
|
import os
|
|
3
4
|
from typing import Optional, Union
|
|
4
5
|
|
|
5
|
-
from absl import logging
|
|
6
6
|
from cryptography.hazmat import backends
|
|
7
7
|
from cryptography.hazmat.primitives import serialization
|
|
8
8
|
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
9
11
|
_DEFAULT_CONNECTION_FILE = "~/.snowsql/config"
|
|
10
12
|
|
|
11
13
|
|
|
@@ -106,7 +108,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
|
|
|
106
108
|
"""Loads the dictionary from snowsql config file."""
|
|
107
109
|
snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE)
|
|
108
110
|
if not os.path.exists(snowsql_config_file):
|
|
109
|
-
|
|
111
|
+
logger.error(f"Connection name given but snowsql config file is not found at: {snowsql_config_file}")
|
|
110
112
|
raise Exception("Snowflake SnowSQL config not found.")
|
|
111
113
|
|
|
112
114
|
config = configparser.ConfigParser(inline_comment_prefixes="#")
|
|
@@ -122,7 +124,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
|
|
|
122
124
|
# See https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
|
|
123
125
|
connection_name = "connections"
|
|
124
126
|
|
|
125
|
-
|
|
127
|
+
logger.info(f"Reading {snowsql_config_file} for connection parameters defined as {connection_name}")
|
|
126
128
|
config.read(snowsql_config_file)
|
|
127
129
|
conn_params = dict(config[connection_name])
|
|
128
130
|
# Remap names to appropriate args in Python Connector API
|
|
@@ -110,15 +110,16 @@ class JWTGenerator:
|
|
|
110
110
|
}
|
|
111
111
|
|
|
112
112
|
# Regenerate the actual token
|
|
113
|
-
token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
|
|
113
|
+
token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM) # type: ignore[arg-type]
|
|
114
114
|
# If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string instead of a string.
|
|
115
115
|
# If the token is a byte string, convert it to a string.
|
|
116
116
|
if isinstance(token, bytes):
|
|
117
117
|
token = token.decode("utf-8")
|
|
118
118
|
self.token = token
|
|
119
|
+
public_key = self.private_key.public_key()
|
|
119
120
|
logger.info(
|
|
120
121
|
"Generated a JWT with the following payload: %s",
|
|
121
|
-
jwt.decode(self.token, key=
|
|
122
|
+
jwt.decode(self.token, key=public_key, algorithms=[JWTGenerator.ALGORITHM]), # type: ignore[arg-type]
|
|
122
123
|
)
|
|
123
124
|
|
|
124
125
|
return token
|
|
@@ -76,17 +76,30 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
76
76
|
self._session, f"ALTER EXPERIMENT {experiment_fqn} DROP RUN {run_name}"
|
|
77
77
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
78
78
|
|
|
79
|
-
def
|
|
79
|
+
def modify_run_add_metrics(
|
|
80
80
|
self,
|
|
81
81
|
*,
|
|
82
82
|
experiment_name: sql_identifier.SqlIdentifier,
|
|
83
83
|
run_name: sql_identifier.SqlIdentifier,
|
|
84
|
-
|
|
84
|
+
metrics: str,
|
|
85
85
|
) -> None:
|
|
86
86
|
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
87
87
|
query_result_checker.SqlResultValidator(
|
|
88
88
|
self._session,
|
|
89
|
-
f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name}
|
|
89
|
+
f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD METRICS=$${metrics}$$",
|
|
90
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
91
|
+
|
|
92
|
+
def modify_run_add_params(
|
|
93
|
+
self,
|
|
94
|
+
*,
|
|
95
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
|
96
|
+
run_name: sql_identifier.SqlIdentifier,
|
|
97
|
+
params: str,
|
|
98
|
+
) -> None:
|
|
99
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
100
|
+
query_result_checker.SqlResultValidator(
|
|
101
|
+
self._session,
|
|
102
|
+
f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD PARAMETERS=$${params}$$",
|
|
90
103
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
91
104
|
|
|
92
105
|
def put_artifact(
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from snowflake.ml.experiment._entities.experiment import Experiment
|
|
2
2
|
from snowflake.ml.experiment._entities.run import Run
|
|
3
|
+
from snowflake.ml.experiment._entities.run_metadata import Metric, Param
|
|
3
4
|
|
|
4
|
-
__all__ = ["Experiment", "Run"]
|
|
5
|
+
__all__ = ["Experiment", "Run", "Metric", "Param"]
|
|
@@ -1,11 +1,8 @@
|
|
|
1
|
-
import json
|
|
2
1
|
import types
|
|
3
2
|
from typing import TYPE_CHECKING, Optional
|
|
4
3
|
|
|
5
4
|
from snowflake.ml._internal.utils import sql_identifier
|
|
6
5
|
from snowflake.ml.experiment import _experiment_info as experiment_info
|
|
7
|
-
from snowflake.ml.experiment._client import experiment_tracking_sql_client
|
|
8
|
-
from snowflake.ml.experiment._entities import run_metadata
|
|
9
6
|
|
|
10
7
|
if TYPE_CHECKING:
|
|
11
8
|
from snowflake.ml.experiment import experiment_tracking
|
|
@@ -41,18 +38,6 @@ class Run:
|
|
|
41
38
|
if self._experiment_tracking._run is self:
|
|
42
39
|
self._experiment_tracking.end_run()
|
|
43
40
|
|
|
44
|
-
def _get_metadata(
|
|
45
|
-
self,
|
|
46
|
-
) -> run_metadata.RunMetadata:
|
|
47
|
-
runs = self._experiment_tracking._sql_client.show_runs_in_experiment(
|
|
48
|
-
experiment_name=self.experiment_name, like=str(self.name)
|
|
49
|
-
)
|
|
50
|
-
if not runs:
|
|
51
|
-
raise RuntimeError(f"Run {self.name} not found in experiment {self.experiment_name}.")
|
|
52
|
-
return run_metadata.RunMetadata.from_dict(
|
|
53
|
-
json.loads(runs[0][experiment_tracking_sql_client.ExperimentTrackingSQLClient.RUN_METADATA_COL_NAME])
|
|
54
|
-
)
|
|
55
|
-
|
|
56
41
|
def _get_experiment_info(self) -> experiment_info.ExperimentInfo:
|
|
57
42
|
return experiment_info.ExperimentInfo(
|
|
58
43
|
fully_qualified_name=self._experiment_tracking._sql_client.fully_qualified_object_name(
|
|
@@ -1,12 +1,4 @@
|
|
|
1
1
|
import dataclasses
|
|
2
|
-
import enum
|
|
3
|
-
import typing
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class RunStatus(str, enum.Enum):
|
|
7
|
-
UNKNOWN = "UNKNOWN"
|
|
8
|
-
RUNNING = "RUNNING"
|
|
9
|
-
FINISHED = "FINISHED"
|
|
10
2
|
|
|
11
3
|
|
|
12
4
|
@dataclasses.dataclass
|
|
@@ -15,54 +7,14 @@ class Metric:
|
|
|
15
7
|
value: float
|
|
16
8
|
step: int
|
|
17
9
|
|
|
10
|
+
def to_dict(self) -> dict: # type: ignore[type-arg]
|
|
11
|
+
return dataclasses.asdict(self)
|
|
12
|
+
|
|
18
13
|
|
|
19
14
|
@dataclasses.dataclass
|
|
20
15
|
class Param:
|
|
21
16
|
name: str
|
|
22
17
|
value: str
|
|
23
18
|
|
|
24
|
-
|
|
25
|
-
@dataclasses.dataclass
|
|
26
|
-
class RunMetadata:
|
|
27
|
-
status: RunStatus
|
|
28
|
-
metrics: list[Metric]
|
|
29
|
-
parameters: list[Param]
|
|
30
|
-
|
|
31
|
-
@classmethod
|
|
32
|
-
def from_dict(
|
|
33
|
-
cls,
|
|
34
|
-
metadata: dict, # type: ignore[type-arg]
|
|
35
|
-
) -> "RunMetadata":
|
|
36
|
-
return RunMetadata(
|
|
37
|
-
status=RunStatus(metadata.get("status", RunStatus.UNKNOWN.value)),
|
|
38
|
-
metrics=[Metric(**m) for m in metadata.get("metrics", [])],
|
|
39
|
-
parameters=[Param(**p) for p in metadata.get("parameters", [])],
|
|
40
|
-
)
|
|
41
|
-
|
|
42
19
|
def to_dict(self) -> dict: # type: ignore[type-arg]
|
|
43
20
|
return dataclasses.asdict(self)
|
|
44
|
-
|
|
45
|
-
def set_metric(
|
|
46
|
-
self,
|
|
47
|
-
key: str,
|
|
48
|
-
value: float,
|
|
49
|
-
step: int,
|
|
50
|
-
) -> None:
|
|
51
|
-
for metric in self.metrics:
|
|
52
|
-
if metric.name == key and metric.step == step:
|
|
53
|
-
metric.value = value
|
|
54
|
-
break
|
|
55
|
-
else:
|
|
56
|
-
self.metrics.append(Metric(name=key, value=value, step=step))
|
|
57
|
-
|
|
58
|
-
def set_param(
|
|
59
|
-
self,
|
|
60
|
-
key: str,
|
|
61
|
-
value: typing.Any,
|
|
62
|
-
) -> None:
|
|
63
|
-
for parameter in self.parameters:
|
|
64
|
-
if parameter.name == key:
|
|
65
|
-
parameter.value = str(value)
|
|
66
|
-
break
|
|
67
|
-
else:
|
|
68
|
-
self.parameters.append(Param(name=key, value=str(value)))
|
|
@@ -261,13 +261,13 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
261
261
|
step: The step of the metrics. Defaults to 0.
|
|
262
262
|
"""
|
|
263
263
|
run = self._get_or_start_run()
|
|
264
|
-
|
|
264
|
+
metrics_list = []
|
|
265
265
|
for key, value in metrics.items():
|
|
266
|
-
|
|
267
|
-
self._sql_client.
|
|
266
|
+
metrics_list.append(entities.Metric(key, value, step))
|
|
267
|
+
self._sql_client.modify_run_add_metrics(
|
|
268
268
|
experiment_name=run.experiment_name,
|
|
269
269
|
run_name=run.name,
|
|
270
|
-
|
|
270
|
+
metrics=json.dumps([metric.to_dict() for metric in metrics_list]),
|
|
271
271
|
)
|
|
272
272
|
|
|
273
273
|
def log_param(
|
|
@@ -296,13 +296,13 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
296
296
|
to string.
|
|
297
297
|
"""
|
|
298
298
|
run = self._get_or_start_run()
|
|
299
|
-
|
|
299
|
+
params_list = []
|
|
300
300
|
for key, value in params.items():
|
|
301
|
-
|
|
302
|
-
self._sql_client.
|
|
301
|
+
params_list.append(entities.Param(key, str(value)))
|
|
302
|
+
self._sql_client.modify_run_add_params(
|
|
303
303
|
experiment_name=run.experiment_name,
|
|
304
304
|
run_name=run.name,
|
|
305
|
-
|
|
305
|
+
params=json.dumps([param.to_dict() for param in params_list]),
|
|
306
306
|
)
|
|
307
307
|
|
|
308
308
|
def log_artifact(
|
snowflake/ml/model/__init__.py
CHANGED
|
@@ -1,10 +1,6 @@
|
|
|
1
|
-
from snowflake.ml.model._client.model.batch_inference_specs import
|
|
2
|
-
InputSpec,
|
|
3
|
-
JobSpec,
|
|
4
|
-
OutputSpec,
|
|
5
|
-
)
|
|
1
|
+
from snowflake.ml.model._client.model.batch_inference_specs import JobSpec, OutputSpec
|
|
6
2
|
from snowflake.ml.model._client.model.model_impl import Model
|
|
7
3
|
from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
|
|
8
4
|
from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
|
|
9
5
|
|
|
10
|
-
__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "
|
|
6
|
+
__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "JobSpec", "OutputSpec"]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
|
+
|
|
3
|
+
from snowflake.ml.model._client.ops import service_ops
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _get_inference_engine_args(
|
|
7
|
+
experimental_options: Optional[dict[str, Any]],
|
|
8
|
+
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
9
|
+
|
|
10
|
+
if not experimental_options:
|
|
11
|
+
return None
|
|
12
|
+
|
|
13
|
+
if "inference_engine" not in experimental_options:
|
|
14
|
+
raise ValueError("inference_engine is required in experimental_options")
|
|
15
|
+
|
|
16
|
+
return service_ops.InferenceEngineArgs(
|
|
17
|
+
inference_engine=experimental_options["inference_engine"],
|
|
18
|
+
inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _enrich_inference_engine_args(
|
|
23
|
+
inference_engine_args: service_ops.InferenceEngineArgs,
|
|
24
|
+
gpu_requests: Optional[Union[str, int]] = None,
|
|
25
|
+
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
26
|
+
"""Enrich inference engine args with model path and tensor parallelism settings.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
inference_engine_args: The original inference engine args
|
|
30
|
+
gpu_requests: The number of GPUs requested
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Enriched inference engine args
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
ValueError: Invalid gpu_requests
|
|
37
|
+
"""
|
|
38
|
+
if inference_engine_args.inference_engine_args_override is None:
|
|
39
|
+
inference_engine_args.inference_engine_args_override = []
|
|
40
|
+
|
|
41
|
+
gpu_count = None
|
|
42
|
+
|
|
43
|
+
# Set tensor-parallelism if gpu_requests is specified
|
|
44
|
+
if gpu_requests is not None:
|
|
45
|
+
# assert gpu_requests is a string or an integer before casting to int
|
|
46
|
+
try:
|
|
47
|
+
gpu_count = int(gpu_requests)
|
|
48
|
+
if gpu_count > 0:
|
|
49
|
+
inference_engine_args.inference_engine_args_override.append(f"--tensor-parallel-size={gpu_count}")
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError(f"GPU count must be greater than 0, got {gpu_count}")
|
|
52
|
+
except ValueError:
|
|
53
|
+
raise ValueError(f"Invalid gpu_requests: {gpu_requests} with type {type(gpu_requests).__name__}")
|
|
54
|
+
|
|
55
|
+
return inference_engine_args
|
|
@@ -12,7 +12,10 @@ from snowflake.ml._internal import telemetry
|
|
|
12
12
|
from snowflake.ml._internal.utils import sql_identifier
|
|
13
13
|
from snowflake.ml.lineage import lineage_node
|
|
14
14
|
from snowflake.ml.model import task, type_hints
|
|
15
|
-
from snowflake.ml.model._client.model import
|
|
15
|
+
from snowflake.ml.model._client.model import (
|
|
16
|
+
batch_inference_specs,
|
|
17
|
+
inference_engine_utils,
|
|
18
|
+
)
|
|
16
19
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
|
17
20
|
from snowflake.ml.model._model_composer import model_composer
|
|
18
21
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
@@ -22,6 +25,7 @@ from snowflake.snowpark import Session, async_job, dataframe
|
|
|
22
25
|
_TELEMETRY_PROJECT = "MLOps"
|
|
23
26
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
|
24
27
|
_BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
|
|
28
|
+
_BATCH_INFERENCE_TEMPORARY_FOLDER = "_temporary"
|
|
25
29
|
|
|
26
30
|
|
|
27
31
|
class ExportMode(enum.Enum):
|
|
@@ -553,7 +557,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
553
557
|
self,
|
|
554
558
|
*,
|
|
555
559
|
compute_pool: str,
|
|
556
|
-
input_spec:
|
|
560
|
+
input_spec: dataframe.DataFrame,
|
|
557
561
|
output_spec: batch_inference_specs.OutputSpec,
|
|
558
562
|
job_spec: Optional[batch_inference_specs.JobSpec] = None,
|
|
559
563
|
) -> jobs.MLJob[Any]:
|
|
@@ -569,6 +573,18 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
569
573
|
if warehouse is None:
|
|
570
574
|
raise ValueError("Warehouse is not set. Please set the warehouse field in the JobSpec.")
|
|
571
575
|
|
|
576
|
+
# use a temporary folder in the output stage to store the intermediate output from the dataframe
|
|
577
|
+
output_stage_location = output_spec.stage_location
|
|
578
|
+
if not output_stage_location.endswith("/"):
|
|
579
|
+
output_stage_location += "/"
|
|
580
|
+
input_stage_location = f"{output_stage_location}{_BATCH_INFERENCE_TEMPORARY_FOLDER}/"
|
|
581
|
+
|
|
582
|
+
try:
|
|
583
|
+
input_spec.write.copy_into_location(location=input_stage_location, file_format_type="parquet", header=True)
|
|
584
|
+
# todo: be specific about the type of errors to provide better error messages.
|
|
585
|
+
except Exception as e:
|
|
586
|
+
raise RuntimeError(f"Failed to process input_spec: {e}")
|
|
587
|
+
|
|
572
588
|
if job_spec.job_name is None:
|
|
573
589
|
# Same as the MLJob ID generation logic with a different prefix
|
|
574
590
|
job_name = f"{_BATCH_INFERENCE_JOB_ID_PREFIX}{str(uuid.uuid4()).replace('-', '_').upper()}"
|
|
@@ -592,9 +608,9 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
592
608
|
job_name=job_name,
|
|
593
609
|
replicas=job_spec.replicas,
|
|
594
610
|
# input and output
|
|
595
|
-
input_stage_location=
|
|
611
|
+
input_stage_location=input_stage_location,
|
|
596
612
|
input_file_pattern="*",
|
|
597
|
-
output_stage_location=
|
|
613
|
+
output_stage_location=output_stage_location,
|
|
598
614
|
completion_filename="_SUCCESS",
|
|
599
615
|
# misc
|
|
600
616
|
statement_params=statement_params,
|
|
@@ -768,60 +784,6 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
768
784
|
version_name=sql_identifier.SqlIdentifier(version),
|
|
769
785
|
)
|
|
770
786
|
|
|
771
|
-
def _get_inference_engine_args(
|
|
772
|
-
self, experimental_options: Optional[dict[str, Any]]
|
|
773
|
-
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
774
|
-
|
|
775
|
-
if not experimental_options:
|
|
776
|
-
return None
|
|
777
|
-
|
|
778
|
-
if "inference_engine" not in experimental_options:
|
|
779
|
-
raise ValueError("inference_engine is required in experimental_options")
|
|
780
|
-
|
|
781
|
-
return service_ops.InferenceEngineArgs(
|
|
782
|
-
inference_engine=experimental_options["inference_engine"],
|
|
783
|
-
inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
|
|
784
|
-
)
|
|
785
|
-
|
|
786
|
-
def _enrich_inference_engine_args(
|
|
787
|
-
self,
|
|
788
|
-
inference_engine_args: service_ops.InferenceEngineArgs,
|
|
789
|
-
gpu_requests: Optional[Union[str, int]] = None,
|
|
790
|
-
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
791
|
-
"""Enrich inference engine args with tensor parallelism settings.
|
|
792
|
-
|
|
793
|
-
Args:
|
|
794
|
-
inference_engine_args: The original inference engine args
|
|
795
|
-
gpu_requests: The number of GPUs requested
|
|
796
|
-
|
|
797
|
-
Returns:
|
|
798
|
-
Enriched inference engine args
|
|
799
|
-
|
|
800
|
-
Raises:
|
|
801
|
-
ValueError: Invalid gpu_requests
|
|
802
|
-
"""
|
|
803
|
-
if inference_engine_args.inference_engine_args_override is None:
|
|
804
|
-
inference_engine_args.inference_engine_args_override = []
|
|
805
|
-
|
|
806
|
-
gpu_count = None
|
|
807
|
-
|
|
808
|
-
# Set tensor-parallelism if gpu_requests is specified
|
|
809
|
-
if gpu_requests is not None:
|
|
810
|
-
# assert gpu_requests is a string or an integer before casting to int
|
|
811
|
-
if isinstance(gpu_requests, str) or isinstance(gpu_requests, int):
|
|
812
|
-
try:
|
|
813
|
-
gpu_count = int(gpu_requests)
|
|
814
|
-
except ValueError:
|
|
815
|
-
raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
|
|
816
|
-
|
|
817
|
-
if gpu_count is not None:
|
|
818
|
-
if gpu_count > 0:
|
|
819
|
-
inference_engine_args.inference_engine_args_override.append(f"--tensor-parallel-size={gpu_count}")
|
|
820
|
-
else:
|
|
821
|
-
raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
|
|
822
|
-
|
|
823
|
-
return inference_engine_args
|
|
824
|
-
|
|
825
787
|
def _check_huggingface_text_generation_model(
|
|
826
788
|
self,
|
|
827
789
|
statement_params: Optional[dict[str, Any]] = None,
|
|
@@ -1101,13 +1063,14 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1101
1063
|
if experimental_options:
|
|
1102
1064
|
self._check_huggingface_text_generation_model(statement_params)
|
|
1103
1065
|
|
|
1104
|
-
inference_engine_args
|
|
1105
|
-
experimental_options
|
|
1106
|
-
)
|
|
1066
|
+
inference_engine_args = inference_engine_utils._get_inference_engine_args(experimental_options)
|
|
1107
1067
|
|
|
1108
1068
|
# Enrich inference engine args if inference engine is specified
|
|
1109
1069
|
if inference_engine_args is not None:
|
|
1110
|
-
inference_engine_args =
|
|
1070
|
+
inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
|
|
1071
|
+
inference_engine_args,
|
|
1072
|
+
gpu_requests,
|
|
1073
|
+
)
|
|
1111
1074
|
|
|
1112
1075
|
from snowflake.ml.model import event_handler
|
|
1113
1076
|
from snowflake.snowpark import exceptions
|
|
@@ -155,7 +155,8 @@ class ServiceOperator:
|
|
|
155
155
|
database_name=database_name,
|
|
156
156
|
schema_name=schema_name,
|
|
157
157
|
)
|
|
158
|
-
|
|
158
|
+
self._use_inlined_deployment_spec = pc.PlatformCapabilities.get_instance().is_inlined_deployment_spec_enabled()
|
|
159
|
+
if self._use_inlined_deployment_spec:
|
|
159
160
|
self._workspace = None
|
|
160
161
|
self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec()
|
|
161
162
|
else:
|
|
@@ -264,7 +265,14 @@ class ServiceOperator:
|
|
|
264
265
|
self._model_deployment_spec.add_hf_logger_spec(
|
|
265
266
|
hf_model_name=hf_model_args.hf_model_name,
|
|
266
267
|
hf_task=hf_model_args.hf_task,
|
|
267
|
-
hf_token=
|
|
268
|
+
hf_token=(
|
|
269
|
+
# when using inlined deployment spec, we need to use QMARK_RESERVED_TOKEN
|
|
270
|
+
# to avoid revealing the token while calling the SYSTEM$DEPLOY_MODEL function
|
|
271
|
+
# noop if using file-based deployment spec or token is not provided
|
|
272
|
+
service_sql.QMARK_RESERVED_TOKEN
|
|
273
|
+
if hf_model_args.hf_token and self._use_inlined_deployment_spec
|
|
274
|
+
else hf_model_args.hf_token
|
|
275
|
+
),
|
|
268
276
|
hf_tokenizer=hf_model_args.hf_tokenizer,
|
|
269
277
|
hf_revision=hf_model_args.hf_revision,
|
|
270
278
|
hf_trust_remote_code=hf_model_args.hf_trust_remote_code,
|
|
@@ -320,6 +328,14 @@ class ServiceOperator:
|
|
|
320
328
|
model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
|
|
321
329
|
),
|
|
322
330
|
model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
|
|
331
|
+
query_params=(
|
|
332
|
+
# when using inlined deployment spec, we need to add the token to the query params
|
|
333
|
+
# to avoid revealing the token while calling the SYSTEM$DEPLOY_MODEL function
|
|
334
|
+
# noop if using file-based deployment spec or token is not provided
|
|
335
|
+
[hf_model_args.hf_token]
|
|
336
|
+
if (self._use_inlined_deployment_spec and hf_model_args and hf_model_args.hf_token)
|
|
337
|
+
else []
|
|
338
|
+
),
|
|
323
339
|
statement_params=statement_params,
|
|
324
340
|
)
|
|
325
341
|
|
|
@@ -1,8 +1,9 @@
|
|
|
1
|
+
import contextlib
|
|
1
2
|
import dataclasses
|
|
2
3
|
import enum
|
|
3
4
|
import logging
|
|
4
5
|
import textwrap
|
|
5
|
-
from typing import Any, Optional
|
|
6
|
+
from typing import Any, Generator, Optional
|
|
6
7
|
|
|
7
8
|
from snowflake import snowpark
|
|
8
9
|
from snowflake.ml._internal.utils import (
|
|
@@ -17,6 +18,11 @@ from snowflake.snowpark._internal import utils as snowpark_utils
|
|
|
17
18
|
|
|
18
19
|
logger = logging.getLogger(__name__)
|
|
19
20
|
|
|
21
|
+
# Using this token instead of '?' to avoid escaping issues
|
|
22
|
+
# After quotes are escaped, we replace this token with '|| ? ||'
|
|
23
|
+
QMARK_RESERVED_TOKEN = "<QMARK_RESERVED_TOKEN>"
|
|
24
|
+
QMARK_PARAMETER_TOKEN = "'|| ? ||'"
|
|
25
|
+
|
|
20
26
|
|
|
21
27
|
class ServiceStatus(enum.Enum):
|
|
22
28
|
PENDING = "PENDING"
|
|
@@ -70,12 +76,26 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
70
76
|
CONTAINER_STATUS = "status"
|
|
71
77
|
MESSAGE = "message"
|
|
72
78
|
|
|
79
|
+
@contextlib.contextmanager
|
|
80
|
+
def _qmark_paramstyle(self) -> Generator[None, None, None]:
|
|
81
|
+
"""Context manager that temporarily changes paramstyle to qmark and restores original value on exit."""
|
|
82
|
+
if not hasattr(self._session, "_options"):
|
|
83
|
+
yield
|
|
84
|
+
else:
|
|
85
|
+
original_paramstyle = self._session._options["paramstyle"]
|
|
86
|
+
try:
|
|
87
|
+
self._session._options["paramstyle"] = "qmark"
|
|
88
|
+
yield
|
|
89
|
+
finally:
|
|
90
|
+
self._session._options["paramstyle"] = original_paramstyle
|
|
91
|
+
|
|
73
92
|
def deploy_model(
|
|
74
93
|
self,
|
|
75
94
|
*,
|
|
76
95
|
stage_path: Optional[str] = None,
|
|
77
96
|
model_deployment_spec_yaml_str: Optional[str] = None,
|
|
78
97
|
model_deployment_spec_file_rel_path: Optional[str] = None,
|
|
98
|
+
query_params: Optional[list[Any]] = None,
|
|
79
99
|
statement_params: Optional[dict[str, Any]] = None,
|
|
80
100
|
) -> tuple[str, snowpark.AsyncJob]:
|
|
81
101
|
assert model_deployment_spec_yaml_str or model_deployment_spec_file_rel_path
|
|
@@ -83,11 +103,18 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
83
103
|
model_deployment_spec_yaml_str = snowpark_utils.escape_single_quotes(
|
|
84
104
|
model_deployment_spec_yaml_str
|
|
85
105
|
) # type: ignore[no-untyped-call]
|
|
106
|
+
model_deployment_spec_yaml_str = model_deployment_spec_yaml_str.replace( # type: ignore[union-attr]
|
|
107
|
+
QMARK_RESERVED_TOKEN, QMARK_PARAMETER_TOKEN
|
|
108
|
+
)
|
|
86
109
|
logger.info(f"Deploying model with spec={model_deployment_spec_yaml_str}")
|
|
87
110
|
sql_str = f"CALL SYSTEM$DEPLOY_MODEL('{model_deployment_spec_yaml_str}')"
|
|
88
111
|
else:
|
|
89
112
|
sql_str = f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
|
|
90
|
-
|
|
113
|
+
with self._qmark_paramstyle():
|
|
114
|
+
async_job = self._session.sql(
|
|
115
|
+
sql_str,
|
|
116
|
+
params=query_params if query_params else None,
|
|
117
|
+
).collect(block=False, statement_params=statement_params)
|
|
91
118
|
assert isinstance(async_job, snowpark.AsyncJob)
|
|
92
119
|
return async_job.query_id, async_job
|
|
93
120
|
|