snowflake-ml-python 1.13.0__py3-none-any.whl → 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +9 -7
- snowflake/ml/_internal/utils/connection_params.py +5 -3
- snowflake/ml/_internal/utils/jwt_generator.py +3 -2
- snowflake/ml/_internal/utils/temp_file_utils.py +1 -2
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +16 -3
- snowflake/ml/experiment/_entities/__init__.py +2 -1
- snowflake/ml/experiment/_entities/run.py +0 -15
- snowflake/ml/experiment/_entities/run_metadata.py +3 -51
- snowflake/ml/experiment/experiment_tracking.py +8 -8
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +9 -7
- snowflake/ml/jobs/job.py +12 -4
- snowflake/ml/jobs/manager.py +34 -7
- snowflake/ml/lineage/lineage_node.py +0 -1
- snowflake/ml/model/__init__.py +2 -6
- snowflake/ml/model/_client/model/batch_inference_specs.py +0 -4
- snowflake/ml/model/_client/model/inference_engine_utils.py +55 -0
- snowflake/ml/model/_client/model/model_version_impl.py +25 -77
- snowflake/ml/model/_client/ops/model_ops.py +9 -2
- snowflake/ml/model/_client/ops/service_ops.py +82 -36
- snowflake/ml/model/_client/sql/service.py +29 -5
- snowflake/ml/model/_packager/model_handlers/_utils.py +4 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -5
- snowflake/ml/model/_packager/model_packager.py +4 -3
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +0 -1
- snowflake/ml/model/_signatures/utils.py +0 -21
- snowflake/ml/model/models/huggingface_pipeline.py +56 -21
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +47 -3
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
- snowflake/ml/monitoring/model_monitor.py +30 -0
- snowflake/ml/registry/_manager/model_manager.py +1 -1
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.13.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/METADATA +51 -34
- {snowflake_ml_python-1.13.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/RECORD +40 -39
- {snowflake_ml_python-1.13.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.13.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.13.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import logging
|
|
2
3
|
from contextlib import contextmanager
|
|
3
4
|
from typing import Any, Optional
|
|
4
5
|
|
|
5
|
-
from absl import logging
|
|
6
6
|
from packaging import version
|
|
7
7
|
|
|
8
8
|
from snowflake.ml import version as snowml_version
|
|
@@ -13,6 +13,8 @@ from snowflake.snowpark import (
|
|
|
13
13
|
session as snowpark_session,
|
|
14
14
|
)
|
|
15
15
|
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
16
18
|
LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
|
|
17
19
|
INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
|
|
18
20
|
|
|
@@ -60,12 +62,12 @@ class PlatformCapabilities:
|
|
|
60
62
|
@classmethod # type: ignore[arg-type]
|
|
61
63
|
@contextmanager
|
|
62
64
|
def mock_features(cls, features: dict[str, Any] = _dummy_features) -> None: # type: ignore[misc]
|
|
63
|
-
|
|
65
|
+
logger.debug(f"Setting mock features: {features}")
|
|
64
66
|
cls.set_mock_features(features)
|
|
65
67
|
try:
|
|
66
68
|
yield
|
|
67
69
|
finally:
|
|
68
|
-
|
|
70
|
+
logger.debug(f"Clearing mock features: {features}")
|
|
69
71
|
cls.clear_mock_features()
|
|
70
72
|
|
|
71
73
|
def is_inlined_deployment_spec_enabled(self) -> bool:
|
|
@@ -98,7 +100,7 @@ class PlatformCapabilities:
|
|
|
98
100
|
error_code=error_codes.INTERNAL_SNOWML_ERROR, original_exception=RuntimeError(message)
|
|
99
101
|
)
|
|
100
102
|
except snowpark_exceptions.SnowparkSQLException as e:
|
|
101
|
-
|
|
103
|
+
logger.debug(f"Failed to retrieve platform capabilities: {e}")
|
|
102
104
|
# This can happen is server side is older than 9.2. That is fine.
|
|
103
105
|
return {}
|
|
104
106
|
|
|
@@ -144,7 +146,7 @@ class PlatformCapabilities:
|
|
|
144
146
|
|
|
145
147
|
value = self.features.get(feature_name)
|
|
146
148
|
if value is None:
|
|
147
|
-
|
|
149
|
+
logger.debug(f"Feature {feature_name} not found, returning large version number")
|
|
148
150
|
return large_version
|
|
149
151
|
|
|
150
152
|
try:
|
|
@@ -152,7 +154,7 @@ class PlatformCapabilities:
|
|
|
152
154
|
version_str = str(value)
|
|
153
155
|
return version.Version(version_str)
|
|
154
156
|
except (version.InvalidVersion, ValueError, TypeError) as e:
|
|
155
|
-
|
|
157
|
+
logger.debug(
|
|
156
158
|
f"Failed to parse version from feature {feature_name} with value '{value}': {e}. "
|
|
157
159
|
f"Returning large version number"
|
|
158
160
|
)
|
|
@@ -171,7 +173,7 @@ class PlatformCapabilities:
|
|
|
171
173
|
feature_version = self._get_version_feature(feature_name)
|
|
172
174
|
|
|
173
175
|
result = current_version >= feature_version
|
|
174
|
-
|
|
176
|
+
logger.debug(
|
|
175
177
|
f"Version comparison for feature {feature_name}: "
|
|
176
178
|
f"current={current_version}, feature={feature_version}, enabled={result}"
|
|
177
179
|
)
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
import configparser
|
|
2
|
+
import logging
|
|
2
3
|
import os
|
|
3
4
|
from typing import Optional, Union
|
|
4
5
|
|
|
5
|
-
from absl import logging
|
|
6
6
|
from cryptography.hazmat import backends
|
|
7
7
|
from cryptography.hazmat.primitives import serialization
|
|
8
8
|
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
9
11
|
_DEFAULT_CONNECTION_FILE = "~/.snowsql/config"
|
|
10
12
|
|
|
11
13
|
|
|
@@ -106,7 +108,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
|
|
|
106
108
|
"""Loads the dictionary from snowsql config file."""
|
|
107
109
|
snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE)
|
|
108
110
|
if not os.path.exists(snowsql_config_file):
|
|
109
|
-
|
|
111
|
+
logger.error(f"Connection name given but snowsql config file is not found at: {snowsql_config_file}")
|
|
110
112
|
raise Exception("Snowflake SnowSQL config not found.")
|
|
111
113
|
|
|
112
114
|
config = configparser.ConfigParser(inline_comment_prefixes="#")
|
|
@@ -122,7 +124,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
|
|
|
122
124
|
# See https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
|
|
123
125
|
connection_name = "connections"
|
|
124
126
|
|
|
125
|
-
|
|
127
|
+
logger.info(f"Reading {snowsql_config_file} for connection parameters defined as {connection_name}")
|
|
126
128
|
config.read(snowsql_config_file)
|
|
127
129
|
conn_params = dict(config[connection_name])
|
|
128
130
|
# Remap names to appropriate args in Python Connector API
|
|
@@ -110,15 +110,16 @@ class JWTGenerator:
|
|
|
110
110
|
}
|
|
111
111
|
|
|
112
112
|
# Regenerate the actual token
|
|
113
|
-
token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
|
|
113
|
+
token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM) # type: ignore[arg-type]
|
|
114
114
|
# If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string instead of a string.
|
|
115
115
|
# If the token is a byte string, convert it to a string.
|
|
116
116
|
if isinstance(token, bytes):
|
|
117
117
|
token = token.decode("utf-8")
|
|
118
118
|
self.token = token
|
|
119
|
+
public_key = self.private_key.public_key()
|
|
119
120
|
logger.info(
|
|
120
121
|
"Generated a JWT with the following payload: %s",
|
|
121
|
-
jwt.decode(self.token, key=
|
|
122
|
+
jwt.decode(self.token, key=public_key, algorithms=[JWTGenerator.ALGORITHM]), # type: ignore[arg-type]
|
|
122
123
|
)
|
|
123
124
|
|
|
124
125
|
return token
|
|
@@ -76,17 +76,30 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
76
76
|
self._session, f"ALTER EXPERIMENT {experiment_fqn} DROP RUN {run_name}"
|
|
77
77
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
78
78
|
|
|
79
|
-
def
|
|
79
|
+
def modify_run_add_metrics(
|
|
80
80
|
self,
|
|
81
81
|
*,
|
|
82
82
|
experiment_name: sql_identifier.SqlIdentifier,
|
|
83
83
|
run_name: sql_identifier.SqlIdentifier,
|
|
84
|
-
|
|
84
|
+
metrics: str,
|
|
85
85
|
) -> None:
|
|
86
86
|
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
87
87
|
query_result_checker.SqlResultValidator(
|
|
88
88
|
self._session,
|
|
89
|
-
f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name}
|
|
89
|
+
f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD METRICS=$${metrics}$$",
|
|
90
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
91
|
+
|
|
92
|
+
def modify_run_add_params(
|
|
93
|
+
self,
|
|
94
|
+
*,
|
|
95
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
|
96
|
+
run_name: sql_identifier.SqlIdentifier,
|
|
97
|
+
params: str,
|
|
98
|
+
) -> None:
|
|
99
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
100
|
+
query_result_checker.SqlResultValidator(
|
|
101
|
+
self._session,
|
|
102
|
+
f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD PARAMETERS=$${params}$$",
|
|
90
103
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
91
104
|
|
|
92
105
|
def put_artifact(
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from snowflake.ml.experiment._entities.experiment import Experiment
|
|
2
2
|
from snowflake.ml.experiment._entities.run import Run
|
|
3
|
+
from snowflake.ml.experiment._entities.run_metadata import Metric, Param
|
|
3
4
|
|
|
4
|
-
__all__ = ["Experiment", "Run"]
|
|
5
|
+
__all__ = ["Experiment", "Run", "Metric", "Param"]
|
|
@@ -1,11 +1,8 @@
|
|
|
1
|
-
import json
|
|
2
1
|
import types
|
|
3
2
|
from typing import TYPE_CHECKING, Optional
|
|
4
3
|
|
|
5
4
|
from snowflake.ml._internal.utils import sql_identifier
|
|
6
5
|
from snowflake.ml.experiment import _experiment_info as experiment_info
|
|
7
|
-
from snowflake.ml.experiment._client import experiment_tracking_sql_client
|
|
8
|
-
from snowflake.ml.experiment._entities import run_metadata
|
|
9
6
|
|
|
10
7
|
if TYPE_CHECKING:
|
|
11
8
|
from snowflake.ml.experiment import experiment_tracking
|
|
@@ -41,18 +38,6 @@ class Run:
|
|
|
41
38
|
if self._experiment_tracking._run is self:
|
|
42
39
|
self._experiment_tracking.end_run()
|
|
43
40
|
|
|
44
|
-
def _get_metadata(
|
|
45
|
-
self,
|
|
46
|
-
) -> run_metadata.RunMetadata:
|
|
47
|
-
runs = self._experiment_tracking._sql_client.show_runs_in_experiment(
|
|
48
|
-
experiment_name=self.experiment_name, like=str(self.name)
|
|
49
|
-
)
|
|
50
|
-
if not runs:
|
|
51
|
-
raise RuntimeError(f"Run {self.name} not found in experiment {self.experiment_name}.")
|
|
52
|
-
return run_metadata.RunMetadata.from_dict(
|
|
53
|
-
json.loads(runs[0][experiment_tracking_sql_client.ExperimentTrackingSQLClient.RUN_METADATA_COL_NAME])
|
|
54
|
-
)
|
|
55
|
-
|
|
56
41
|
def _get_experiment_info(self) -> experiment_info.ExperimentInfo:
|
|
57
42
|
return experiment_info.ExperimentInfo(
|
|
58
43
|
fully_qualified_name=self._experiment_tracking._sql_client.fully_qualified_object_name(
|
|
@@ -1,12 +1,4 @@
|
|
|
1
1
|
import dataclasses
|
|
2
|
-
import enum
|
|
3
|
-
import typing
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class RunStatus(str, enum.Enum):
|
|
7
|
-
UNKNOWN = "UNKNOWN"
|
|
8
|
-
RUNNING = "RUNNING"
|
|
9
|
-
FINISHED = "FINISHED"
|
|
10
2
|
|
|
11
3
|
|
|
12
4
|
@dataclasses.dataclass
|
|
@@ -15,54 +7,14 @@ class Metric:
|
|
|
15
7
|
value: float
|
|
16
8
|
step: int
|
|
17
9
|
|
|
10
|
+
def to_dict(self) -> dict: # type: ignore[type-arg]
|
|
11
|
+
return dataclasses.asdict(self)
|
|
12
|
+
|
|
18
13
|
|
|
19
14
|
@dataclasses.dataclass
|
|
20
15
|
class Param:
|
|
21
16
|
name: str
|
|
22
17
|
value: str
|
|
23
18
|
|
|
24
|
-
|
|
25
|
-
@dataclasses.dataclass
|
|
26
|
-
class RunMetadata:
|
|
27
|
-
status: RunStatus
|
|
28
|
-
metrics: list[Metric]
|
|
29
|
-
parameters: list[Param]
|
|
30
|
-
|
|
31
|
-
@classmethod
|
|
32
|
-
def from_dict(
|
|
33
|
-
cls,
|
|
34
|
-
metadata: dict, # type: ignore[type-arg]
|
|
35
|
-
) -> "RunMetadata":
|
|
36
|
-
return RunMetadata(
|
|
37
|
-
status=RunStatus(metadata.get("status", RunStatus.UNKNOWN.value)),
|
|
38
|
-
metrics=[Metric(**m) for m in metadata.get("metrics", [])],
|
|
39
|
-
parameters=[Param(**p) for p in metadata.get("parameters", [])],
|
|
40
|
-
)
|
|
41
|
-
|
|
42
19
|
def to_dict(self) -> dict: # type: ignore[type-arg]
|
|
43
20
|
return dataclasses.asdict(self)
|
|
44
|
-
|
|
45
|
-
def set_metric(
|
|
46
|
-
self,
|
|
47
|
-
key: str,
|
|
48
|
-
value: float,
|
|
49
|
-
step: int,
|
|
50
|
-
) -> None:
|
|
51
|
-
for metric in self.metrics:
|
|
52
|
-
if metric.name == key and metric.step == step:
|
|
53
|
-
metric.value = value
|
|
54
|
-
break
|
|
55
|
-
else:
|
|
56
|
-
self.metrics.append(Metric(name=key, value=value, step=step))
|
|
57
|
-
|
|
58
|
-
def set_param(
|
|
59
|
-
self,
|
|
60
|
-
key: str,
|
|
61
|
-
value: typing.Any,
|
|
62
|
-
) -> None:
|
|
63
|
-
for parameter in self.parameters:
|
|
64
|
-
if parameter.name == key:
|
|
65
|
-
parameter.value = str(value)
|
|
66
|
-
break
|
|
67
|
-
else:
|
|
68
|
-
self.parameters.append(Param(name=key, value=str(value)))
|
|
@@ -261,13 +261,13 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
261
261
|
step: The step of the metrics. Defaults to 0.
|
|
262
262
|
"""
|
|
263
263
|
run = self._get_or_start_run()
|
|
264
|
-
|
|
264
|
+
metrics_list = []
|
|
265
265
|
for key, value in metrics.items():
|
|
266
|
-
|
|
267
|
-
self._sql_client.
|
|
266
|
+
metrics_list.append(entities.Metric(key, value, step))
|
|
267
|
+
self._sql_client.modify_run_add_metrics(
|
|
268
268
|
experiment_name=run.experiment_name,
|
|
269
269
|
run_name=run.name,
|
|
270
|
-
|
|
270
|
+
metrics=json.dumps([metric.to_dict() for metric in metrics_list]),
|
|
271
271
|
)
|
|
272
272
|
|
|
273
273
|
def log_param(
|
|
@@ -296,13 +296,13 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
296
296
|
to string.
|
|
297
297
|
"""
|
|
298
298
|
run = self._get_or_start_run()
|
|
299
|
-
|
|
299
|
+
params_list = []
|
|
300
300
|
for key, value in params.items():
|
|
301
|
-
|
|
302
|
-
self._sql_client.
|
|
301
|
+
params_list.append(entities.Param(key, str(value)))
|
|
302
|
+
self._sql_client.modify_run_add_params(
|
|
303
303
|
experiment_name=run.experiment_name,
|
|
304
304
|
run_name=run.name,
|
|
305
|
-
|
|
305
|
+
params=json.dumps([param.to_dict() for param in params_list]),
|
|
306
306
|
)
|
|
307
307
|
|
|
308
308
|
def log_artifact(
|
|
@@ -25,7 +25,7 @@ RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_STAGE_SUBPATH}/mljob_result.pkl"
|
|
|
25
25
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
|
26
26
|
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
|
27
27
|
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
|
28
|
-
DEFAULT_IMAGE_TAG = "1.
|
|
28
|
+
DEFAULT_IMAGE_TAG = "1.8.0"
|
|
29
29
|
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
|
30
30
|
|
|
31
31
|
# Percent of container memory to allocate for /dev/shm volume
|
|
@@ -234,12 +234,6 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
234
234
|
if payload_dir and payload_dir not in sys.path:
|
|
235
235
|
sys.path.insert(0, payload_dir)
|
|
236
236
|
|
|
237
|
-
# Create a Snowpark session before running the script
|
|
238
|
-
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
|
|
239
|
-
config = SnowflakeLoginOptions()
|
|
240
|
-
config["client_session_keep_alive"] = "True"
|
|
241
|
-
session = Session.builder.configs(config).create() # noqa: F841
|
|
242
|
-
|
|
243
237
|
try:
|
|
244
238
|
|
|
245
239
|
if main_func:
|
|
@@ -266,7 +260,6 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
266
260
|
finally:
|
|
267
261
|
# Restore original sys.argv
|
|
268
262
|
sys.argv = original_argv
|
|
269
|
-
session.close()
|
|
270
263
|
|
|
271
264
|
|
|
272
265
|
def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> ExecutionResult:
|
|
@@ -297,6 +290,12 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
297
290
|
except ModuleNotFoundError:
|
|
298
291
|
warnings.warn("Ray is not installed, skipping Ray initialization", ImportWarning, stacklevel=1)
|
|
299
292
|
|
|
293
|
+
# Create a Snowpark session before starting
|
|
294
|
+
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
|
|
295
|
+
config = SnowflakeLoginOptions()
|
|
296
|
+
config["client_session_keep_alive"] = "True"
|
|
297
|
+
session = Session.builder.configs(config).create() # noqa: F841
|
|
298
|
+
|
|
300
299
|
try:
|
|
301
300
|
# Wait for minimum required instances if specified
|
|
302
301
|
min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
|
|
@@ -352,6 +351,9 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
352
351
|
f"Failed to serialize JSON result to {result_json_path}: {json_exc}", RuntimeWarning, stacklevel=1
|
|
353
352
|
)
|
|
354
353
|
|
|
354
|
+
# Close the session after serializing the result
|
|
355
|
+
session.close()
|
|
356
|
+
|
|
355
357
|
|
|
356
358
|
if __name__ == "__main__":
|
|
357
359
|
# Parse command line arguments
|
snowflake/ml/jobs/job.py
CHANGED
|
@@ -83,6 +83,8 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
83
83
|
def _container_spec(self) -> dict[str, Any]:
|
|
84
84
|
"""Get the job's main container spec."""
|
|
85
85
|
containers = self._service_spec["spec"]["containers"]
|
|
86
|
+
if len(containers) == 1:
|
|
87
|
+
return cast(dict[str, Any], containers[0])
|
|
86
88
|
try:
|
|
87
89
|
container_spec = next(c for c in containers if c["name"] == constants.DEFAULT_CONTAINER_NAME)
|
|
88
90
|
except StopIteration:
|
|
@@ -163,7 +165,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
163
165
|
Returns:
|
|
164
166
|
The job's execution logs.
|
|
165
167
|
"""
|
|
166
|
-
logs = _get_logs(self._session, self.id, limit, instance_id, verbose)
|
|
168
|
+
logs = _get_logs(self._session, self.id, limit, instance_id, self._container_spec["name"], verbose)
|
|
167
169
|
assert isinstance(logs, str) # mypy
|
|
168
170
|
if as_list:
|
|
169
171
|
return logs.splitlines()
|
|
@@ -281,7 +283,12 @@ def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
|
|
|
281
283
|
|
|
282
284
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
|
|
283
285
|
def _get_logs(
|
|
284
|
-
session: snowpark.Session,
|
|
286
|
+
session: snowpark.Session,
|
|
287
|
+
job_id: str,
|
|
288
|
+
limit: int = -1,
|
|
289
|
+
instance_id: Optional[int] = None,
|
|
290
|
+
container_name: str = constants.DEFAULT_CONTAINER_NAME,
|
|
291
|
+
verbose: bool = True,
|
|
285
292
|
) -> str:
|
|
286
293
|
"""
|
|
287
294
|
Retrieve the job's execution logs.
|
|
@@ -291,6 +298,7 @@ def _get_logs(
|
|
|
291
298
|
limit: The maximum number of lines to return. Negative values are treated as no limit.
|
|
292
299
|
session: The Snowpark session to use. If none specified, uses active session.
|
|
293
300
|
instance_id: Optional instance ID to get logs from a specific instance.
|
|
301
|
+
container_name: The container name to get logs from a specific container.
|
|
294
302
|
verbose: Whether to return the full log or just the portion between START and END messages.
|
|
295
303
|
|
|
296
304
|
Returns:
|
|
@@ -311,7 +319,7 @@ def _get_logs(
|
|
|
311
319
|
params: list[Any] = [
|
|
312
320
|
job_id,
|
|
313
321
|
0 if instance_id is None else instance_id,
|
|
314
|
-
|
|
322
|
+
container_name,
|
|
315
323
|
]
|
|
316
324
|
if limit > 0:
|
|
317
325
|
params.append(limit)
|
|
@@ -337,7 +345,7 @@ def _get_logs(
|
|
|
337
345
|
job_id,
|
|
338
346
|
limit=limit,
|
|
339
347
|
instance_id=instance_id if instance_id else 0,
|
|
340
|
-
container_name=
|
|
348
|
+
container_name=container_name,
|
|
341
349
|
)
|
|
342
350
|
full_log = os.linesep.join(row[0] for row in logs)
|
|
343
351
|
|
snowflake/ml/jobs/manager.py
CHANGED
|
@@ -232,6 +232,7 @@ def submit_file(
|
|
|
232
232
|
enable_metrics (bool): Whether to enable metrics publishing for the job.
|
|
233
233
|
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
|
234
234
|
spec_overrides (dict): A dictionary of overrides for the service spec.
|
|
235
|
+
imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job.
|
|
235
236
|
|
|
236
237
|
Returns:
|
|
237
238
|
An object representing the submitted job.
|
|
@@ -286,6 +287,7 @@ def submit_directory(
|
|
|
286
287
|
enable_metrics (bool): Whether to enable metrics publishing for the job.
|
|
287
288
|
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
|
288
289
|
spec_overrides (dict): A dictionary of overrides for the service spec.
|
|
290
|
+
imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job.
|
|
289
291
|
|
|
290
292
|
Returns:
|
|
291
293
|
An object representing the submitted job.
|
|
@@ -341,6 +343,7 @@ def submit_from_stage(
|
|
|
341
343
|
enable_metrics (bool): Whether to enable metrics publishing for the job.
|
|
342
344
|
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
|
343
345
|
spec_overrides (dict): A dictionary of overrides for the service spec.
|
|
346
|
+
imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job.
|
|
344
347
|
|
|
345
348
|
Returns:
|
|
346
349
|
An object representing the submitted job.
|
|
@@ -404,6 +407,8 @@ def _submit_job(
|
|
|
404
407
|
"num_instances", # deprecated
|
|
405
408
|
"target_instances",
|
|
406
409
|
"min_instances",
|
|
410
|
+
"enable_metrics",
|
|
411
|
+
"query_warehouse",
|
|
407
412
|
],
|
|
408
413
|
)
|
|
409
414
|
def _submit_job(
|
|
@@ -447,6 +452,13 @@ def _submit_job(
|
|
|
447
452
|
)
|
|
448
453
|
target_instances = max(target_instances, kwargs.pop("num_instances"))
|
|
449
454
|
|
|
455
|
+
imports = None
|
|
456
|
+
if "additional_payloads" in kwargs:
|
|
457
|
+
logger.warning(
|
|
458
|
+
"'additional_payloads' is deprecated and will be removed in a future release. Use 'imports' instead."
|
|
459
|
+
)
|
|
460
|
+
imports = kwargs.pop("additional_payloads")
|
|
461
|
+
|
|
450
462
|
# Use kwargs for less common optional parameters
|
|
451
463
|
database = kwargs.pop("database", None)
|
|
452
464
|
schema = kwargs.pop("schema", None)
|
|
@@ -457,10 +469,7 @@ def _submit_job(
|
|
|
457
469
|
spec_overrides = kwargs.pop("spec_overrides", None)
|
|
458
470
|
enable_metrics = kwargs.pop("enable_metrics", True)
|
|
459
471
|
query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
if additional_payloads:
|
|
463
|
-
logger.warning("'additional_payloads' is in private preview since 1.9.1. Do not use it in production.")
|
|
472
|
+
imports = kwargs.pop("imports", None) or imports
|
|
464
473
|
|
|
465
474
|
# Warn if there are unknown kwargs
|
|
466
475
|
if kwargs:
|
|
@@ -492,7 +501,7 @@ def _submit_job(
|
|
|
492
501
|
try:
|
|
493
502
|
# Upload payload
|
|
494
503
|
uploaded_payload = payload_utils.JobPayload(
|
|
495
|
-
source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=
|
|
504
|
+
source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=imports
|
|
496
505
|
).upload(session, stage_path)
|
|
497
506
|
except snowpark.exceptions.SnowparkSQLException as e:
|
|
498
507
|
if e.sql_error_code == 90106:
|
|
@@ -501,6 +510,22 @@ def _submit_job(
|
|
|
501
510
|
)
|
|
502
511
|
raise
|
|
503
512
|
|
|
513
|
+
# FIXME: Temporary patches, remove this after v1 is deprecated
|
|
514
|
+
if target_instances > 1:
|
|
515
|
+
default_spec_overrides = {
|
|
516
|
+
"spec": {
|
|
517
|
+
"endpoints": [
|
|
518
|
+
{"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"},
|
|
519
|
+
]
|
|
520
|
+
},
|
|
521
|
+
}
|
|
522
|
+
if spec_overrides:
|
|
523
|
+
spec_overrides = spec_utils.merge_patch(
|
|
524
|
+
default_spec_overrides, spec_overrides, display_name="spec_overrides"
|
|
525
|
+
)
|
|
526
|
+
else:
|
|
527
|
+
spec_overrides = default_spec_overrides
|
|
528
|
+
|
|
504
529
|
if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled():
|
|
505
530
|
# Add default env vars (extracted from spec_utils.generate_service_spec)
|
|
506
531
|
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
|
|
@@ -668,8 +693,10 @@ def _ensure_session(session: Optional[snowpark.Session]) -> snowpark.Session:
|
|
|
668
693
|
session = session or get_active_session()
|
|
669
694
|
except snowpark.exceptions.SnowparkSessionException as e:
|
|
670
695
|
if "More than one active session" in e.message:
|
|
671
|
-
raise RuntimeError(
|
|
696
|
+
raise RuntimeError(
|
|
697
|
+
"More than one active session is found. Please specify the session explicitly as a parameter"
|
|
698
|
+
) from None
|
|
672
699
|
if "No default Session is found" in e.message:
|
|
673
|
-
raise RuntimeError("Please create a session
|
|
700
|
+
raise RuntimeError("No active session is found. Please create a session") from None
|
|
674
701
|
raise
|
|
675
702
|
return session
|
|
@@ -83,7 +83,6 @@ class LineageNode(mixins.SerializableSessionMixin):
|
|
|
83
83
|
raise NotImplementedError()
|
|
84
84
|
|
|
85
85
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
|
86
|
-
@snowpark._internal.utils.private_preview(version="1.5.3")
|
|
87
86
|
def lineage(
|
|
88
87
|
self,
|
|
89
88
|
direction: Literal["upstream", "downstream"] = "downstream",
|
snowflake/ml/model/__init__.py
CHANGED
|
@@ -1,10 +1,6 @@
|
|
|
1
|
-
from snowflake.ml.model._client.model.batch_inference_specs import
|
|
2
|
-
InputSpec,
|
|
3
|
-
JobSpec,
|
|
4
|
-
OutputSpec,
|
|
5
|
-
)
|
|
1
|
+
from snowflake.ml.model._client.model.batch_inference_specs import JobSpec, OutputSpec
|
|
6
2
|
from snowflake.ml.model._client.model.model_impl import Model
|
|
7
3
|
from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
|
|
8
4
|
from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
|
|
9
5
|
|
|
10
|
-
__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "
|
|
6
|
+
__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "JobSpec", "OutputSpec"]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
|
+
|
|
3
|
+
from snowflake.ml.model._client.ops import service_ops
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _get_inference_engine_args(
|
|
7
|
+
experimental_options: Optional[dict[str, Any]],
|
|
8
|
+
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
9
|
+
|
|
10
|
+
if not experimental_options:
|
|
11
|
+
return None
|
|
12
|
+
|
|
13
|
+
if "inference_engine" not in experimental_options:
|
|
14
|
+
raise ValueError("inference_engine is required in experimental_options")
|
|
15
|
+
|
|
16
|
+
return service_ops.InferenceEngineArgs(
|
|
17
|
+
inference_engine=experimental_options["inference_engine"],
|
|
18
|
+
inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _enrich_inference_engine_args(
|
|
23
|
+
inference_engine_args: service_ops.InferenceEngineArgs,
|
|
24
|
+
gpu_requests: Optional[Union[str, int]] = None,
|
|
25
|
+
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
26
|
+
"""Enrich inference engine args with model path and tensor parallelism settings.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
inference_engine_args: The original inference engine args
|
|
30
|
+
gpu_requests: The number of GPUs requested
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Enriched inference engine args
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
ValueError: Invalid gpu_requests
|
|
37
|
+
"""
|
|
38
|
+
if inference_engine_args.inference_engine_args_override is None:
|
|
39
|
+
inference_engine_args.inference_engine_args_override = []
|
|
40
|
+
|
|
41
|
+
gpu_count = None
|
|
42
|
+
|
|
43
|
+
# Set tensor-parallelism if gpu_requests is specified
|
|
44
|
+
if gpu_requests is not None:
|
|
45
|
+
# assert gpu_requests is a string or an integer before casting to int
|
|
46
|
+
try:
|
|
47
|
+
gpu_count = int(gpu_requests)
|
|
48
|
+
if gpu_count > 0:
|
|
49
|
+
inference_engine_args.inference_engine_args_override.append(f"--tensor-parallel-size={gpu_count}")
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError(f"GPU count must be greater than 0, got {gpu_count}")
|
|
52
|
+
except ValueError:
|
|
53
|
+
raise ValueError(f"Invalid gpu_requests: {gpu_requests} with type {type(gpu_requests).__name__}")
|
|
54
|
+
|
|
55
|
+
return inference_engine_args
|