snowflake-ml-python 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +44 -3
- snowflake/ml/_internal/platform_capabilities.py +52 -2
- snowflake/ml/_internal/type_utils.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +54 -42
- snowflake/ml/_internal/utils/service_logger.py +105 -3
- snowflake/ml/data/_internal/arrow_ingestor.py +15 -2
- snowflake/ml/data/data_connector.py +13 -2
- snowflake/ml/data/data_ingestor.py +8 -0
- snowflake/ml/data/torch_utils.py +1 -1
- snowflake/ml/dataset/dataset.py +2 -1
- snowflake/ml/dataset/dataset_reader.py +14 -4
- snowflake/ml/experiment/__init__.py +3 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
- snowflake/ml/experiment/_entities/__init__.py +4 -0
- snowflake/ml/experiment/_entities/experiment.py +10 -0
- snowflake/ml/experiment/_entities/run.py +62 -0
- snowflake/ml/experiment/_entities/run_metadata.py +68 -0
- snowflake/ml/experiment/_experiment_info.py +63 -0
- snowflake/ml/experiment/callback.py +121 -0
- snowflake/ml/experiment/experiment_tracking.py +319 -0
- snowflake/ml/jobs/_utils/constants.py +15 -4
- snowflake/ml/jobs/_utils/payload_utils.py +156 -54
- snowflake/ml/jobs/_utils/query_helper.py +16 -5
- snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +130 -23
- snowflake/ml/jobs/_utils/spec_utils.py +23 -8
- snowflake/ml/jobs/_utils/stage_utils.py +30 -14
- snowflake/ml/jobs/_utils/types.py +64 -4
- snowflake/ml/jobs/job.py +70 -75
- snowflake/ml/jobs/manager.py +59 -31
- snowflake/ml/lineage/lineage_node.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +16 -4
- snowflake/ml/model/_client/ops/service_ops.py +336 -137
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -1
- snowflake/ml/model/_client/sql/service.py +1 -38
- snowflake/ml/model/_model_composer/model_composer.py +6 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +17 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
- snowflake/ml/model/_signatures/pandas_handler.py +3 -0
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/event_handler.py +117 -0
- snowflake/ml/model/model_signature.py +11 -9
- snowflake/ml/model/models/huggingface_pipeline.py +170 -1
- snowflake/ml/modeling/framework/base.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +14 -14
- snowflake/ml/modeling/metrics/correlation.py +19 -8
- snowflake/ml/modeling/metrics/ranking.py +6 -6
- snowflake/ml/modeling/metrics/regression.py +9 -9
- snowflake/ml/monitoring/explain_visualize.py +12 -5
- snowflake/ml/registry/_manager/model_manager.py +32 -15
- snowflake/ml/registry/registry.py +48 -80
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +107 -5
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +62 -52
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
|
@@ -12,12 +12,14 @@ from snowflake.ml.jobs._utils import constants, query_helper, types
|
|
|
12
12
|
def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
|
|
13
13
|
"""Extract resource information for the specified compute pool"""
|
|
14
14
|
# Get the instance family
|
|
15
|
-
rows =
|
|
16
|
-
|
|
15
|
+
rows = query_helper.run_query(
|
|
16
|
+
session,
|
|
17
|
+
"show compute pools like ?",
|
|
18
|
+
params=[compute_pool],
|
|
19
|
+
)
|
|
20
|
+
if not rows:
|
|
17
21
|
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
|
18
|
-
|
|
19
|
-
compute_pool_info = rows["data"]
|
|
20
|
-
instance_family: str = compute_pool_info[0][requested_attributes["instance_family"]]
|
|
22
|
+
instance_family: str = rows[0]["instance_family"]
|
|
21
23
|
cloud = snowflake_env.get_current_cloud(session, default=snowflake_env.SnowflakeCloudType.AWS)
|
|
22
24
|
|
|
23
25
|
return (
|
|
@@ -179,10 +181,10 @@ def generate_service_spec(
|
|
|
179
181
|
# TODO: Add hooks for endpoints for integration with TensorBoard etc
|
|
180
182
|
|
|
181
183
|
env_vars = {
|
|
182
|
-
constants.PAYLOAD_DIR_ENV_VAR:
|
|
184
|
+
constants.PAYLOAD_DIR_ENV_VAR: constants.APP_MOUNT_PATH,
|
|
183
185
|
constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
|
|
184
186
|
}
|
|
185
|
-
endpoints = []
|
|
187
|
+
endpoints: list[dict[str, Any]] = []
|
|
186
188
|
|
|
187
189
|
if target_instances > 1:
|
|
188
190
|
# Update environment variables for multi-node job
|
|
@@ -191,7 +193,7 @@ def generate_service_spec(
|
|
|
191
193
|
env_vars[constants.MIN_INSTANCES_ENV_VAR] = str(min_instances)
|
|
192
194
|
|
|
193
195
|
# Define Ray endpoints for intra-service instance communication
|
|
194
|
-
ray_endpoints = [
|
|
196
|
+
ray_endpoints: list[dict[str, Any]] = [
|
|
195
197
|
{"name": "ray-client-server-endpoint", "port": 10001, "protocol": "TCP"},
|
|
196
198
|
{"name": "ray-gcs-endpoint", "port": 12001, "protocol": "TCP"},
|
|
197
199
|
{"name": "ray-dashboard-grpc-endpoint", "port": 12002, "protocol": "TCP"},
|
|
@@ -234,6 +236,19 @@ def generate_service_spec(
|
|
|
234
236
|
],
|
|
235
237
|
"volumes": volumes,
|
|
236
238
|
}
|
|
239
|
+
|
|
240
|
+
if target_instances > 1:
|
|
241
|
+
spec_dict.update(
|
|
242
|
+
{
|
|
243
|
+
"resourceManagement": {
|
|
244
|
+
"controlPolicy": {
|
|
245
|
+
"startupOrder": {
|
|
246
|
+
"type": "FirstInstance",
|
|
247
|
+
},
|
|
248
|
+
},
|
|
249
|
+
},
|
|
250
|
+
}
|
|
251
|
+
)
|
|
237
252
|
if endpoints:
|
|
238
253
|
spec_dict["endpoints"] = endpoints
|
|
239
254
|
if metrics:
|
|
@@ -14,7 +14,10 @@ _SNOWURL_PATH_RE = re.compile(
|
|
|
14
14
|
r"(?P<path>versions(?:/(?P<version>[^/]+)(?:/(?P<relpath>.*))?)?)$"
|
|
15
15
|
)
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
# Break long regex into two main parts
|
|
18
|
+
_STAGE_PATTERN = rf"~|%?(?:(?:{identifier._SF_IDENTIFIER}\.?){{,2}}{identifier._SF_IDENTIFIER})"
|
|
19
|
+
_RELPATH_PATTERN = r"[\w\-./]*"
|
|
20
|
+
_STAGEF_PATH_RE = re.compile(rf"^@(?P<stage>{_STAGE_PATTERN})(?:/(?P<relpath>{_RELPATH_PATTERN}))?$")
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
class StagePath:
|
|
@@ -29,6 +32,14 @@ class StagePath:
|
|
|
29
32
|
self._root = self._raw_path[0:start].rstrip("/") if relpath else self._raw_path.rstrip("/")
|
|
30
33
|
self._path = Path(relpath or "")
|
|
31
34
|
|
|
35
|
+
@property
|
|
36
|
+
def parts(self) -> tuple[str, ...]:
|
|
37
|
+
return self._path.parts
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def name(self) -> str:
|
|
41
|
+
return self._path.name
|
|
42
|
+
|
|
32
43
|
@property
|
|
33
44
|
def parent(self) -> "StagePath":
|
|
34
45
|
if self._path.parent == Path(""):
|
|
@@ -51,18 +62,28 @@ class StagePath:
|
|
|
51
62
|
else:
|
|
52
63
|
return f"{self.root}/{path}"
|
|
53
64
|
|
|
54
|
-
def is_relative_to(self,
|
|
65
|
+
def is_relative_to(self, *other: Union[str, os.PathLike[str]]) -> bool:
|
|
66
|
+
if not other:
|
|
67
|
+
raise TypeError("is_relative_to() requires at least one argument")
|
|
68
|
+
# For now, we only support a single argument, like pathlib.Path in Python < 3.12
|
|
69
|
+
path = other[0]
|
|
55
70
|
stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
|
|
56
71
|
if stage_path.root == self.root:
|
|
57
72
|
return self._path.is_relative_to(stage_path._path)
|
|
58
73
|
else:
|
|
59
74
|
return False
|
|
60
75
|
|
|
61
|
-
def relative_to(self,
|
|
76
|
+
def relative_to(self, *other: Union[str, os.PathLike[str]]) -> PurePath:
|
|
77
|
+
if not other:
|
|
78
|
+
raise TypeError("relative_to() requires at least one argument")
|
|
79
|
+
if not self.is_relative_to(*other):
|
|
80
|
+
raise ValueError(f"{other} does not start with {self._raw_path}")
|
|
81
|
+
path = other[0]
|
|
62
82
|
stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
|
|
63
83
|
if self.root == stage_path.root:
|
|
64
84
|
return self._path.relative_to(stage_path._path)
|
|
65
|
-
|
|
85
|
+
else:
|
|
86
|
+
raise ValueError(f"{self._raw_path} does not start with {stage_path._raw_path}")
|
|
66
87
|
|
|
67
88
|
def absolute(self) -> "StagePath":
|
|
68
89
|
return self
|
|
@@ -88,6 +109,9 @@ class StagePath:
|
|
|
88
109
|
def __str__(self) -> str:
|
|
89
110
|
return self.as_posix()
|
|
90
111
|
|
|
112
|
+
def __repr__(self) -> str:
|
|
113
|
+
return f"StagePath('{self.as_posix()}')"
|
|
114
|
+
|
|
91
115
|
def __eq__(self, other: object) -> bool:
|
|
92
116
|
if not isinstance(other, StagePath):
|
|
93
117
|
raise NotImplementedError
|
|
@@ -96,24 +120,16 @@ class StagePath:
|
|
|
96
120
|
def __fspath__(self) -> str:
|
|
97
121
|
return self._compose_path(self._path)
|
|
98
122
|
|
|
99
|
-
def joinpath(self, *args: Union[str, PathLike[str]
|
|
123
|
+
def joinpath(self, *args: Union[str, PathLike[str]]) -> "StagePath":
|
|
100
124
|
path = self
|
|
101
125
|
for arg in args:
|
|
102
126
|
path = path._make_child(arg)
|
|
103
127
|
return path
|
|
104
128
|
|
|
105
|
-
def _make_child(self, path: Union[str, PathLike[str]
|
|
129
|
+
def _make_child(self, path: Union[str, PathLike[str]]) -> "StagePath":
|
|
106
130
|
stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
|
|
107
131
|
if self.root == stage_path.root:
|
|
108
132
|
child_path = self._path.joinpath(stage_path._path)
|
|
109
133
|
return StagePath(self._compose_path(child_path))
|
|
110
134
|
else:
|
|
111
135
|
return stage_path
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def identify_stage_path(path: str) -> Union[StagePath, Path]:
|
|
115
|
-
try:
|
|
116
|
-
stage_path = StagePath(path)
|
|
117
|
-
except ValueError:
|
|
118
|
-
return Path(path)
|
|
119
|
-
return stage_path
|
|
@@ -1,8 +1,7 @@
|
|
|
1
|
+
import os
|
|
1
2
|
from dataclasses import dataclass
|
|
2
3
|
from pathlib import PurePath
|
|
3
|
-
from typing import Literal, Optional, Union
|
|
4
|
-
|
|
5
|
-
from snowflake.ml.jobs._utils import stage_utils
|
|
4
|
+
from typing import Iterator, Literal, Optional, Protocol, Union, runtime_checkable
|
|
6
5
|
|
|
7
6
|
JOB_STATUS = Literal[
|
|
8
7
|
"PENDING",
|
|
@@ -15,9 +14,70 @@ JOB_STATUS = Literal[
|
|
|
15
14
|
]
|
|
16
15
|
|
|
17
16
|
|
|
17
|
+
@runtime_checkable
|
|
18
|
+
class PayloadPath(Protocol):
|
|
19
|
+
"""A protocol for path-like objects used in this module, covering methods from pathlib.Path and StagePath."""
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def name(self) -> str:
|
|
23
|
+
...
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def suffix(self) -> str:
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def parent(self) -> "PayloadPath":
|
|
31
|
+
...
|
|
32
|
+
|
|
33
|
+
def exists(self) -> bool:
|
|
34
|
+
...
|
|
35
|
+
|
|
36
|
+
def is_file(self) -> bool:
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
def is_absolute(self) -> bool:
|
|
40
|
+
...
|
|
41
|
+
|
|
42
|
+
def absolute(self) -> "PayloadPath":
|
|
43
|
+
...
|
|
44
|
+
|
|
45
|
+
def joinpath(self, *other: Union[str, os.PathLike[str]]) -> "PayloadPath":
|
|
46
|
+
...
|
|
47
|
+
|
|
48
|
+
def as_posix(self) -> str:
|
|
49
|
+
...
|
|
50
|
+
|
|
51
|
+
def is_relative_to(self, *other: Union[str, os.PathLike[str]]) -> bool:
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
def relative_to(self, *other: Union[str, os.PathLike[str]]) -> PurePath:
|
|
55
|
+
...
|
|
56
|
+
|
|
57
|
+
def __fspath__(self) -> str:
|
|
58
|
+
...
|
|
59
|
+
|
|
60
|
+
def __str__(self) -> str:
|
|
61
|
+
...
|
|
62
|
+
|
|
63
|
+
def __repr__(self) -> str:
|
|
64
|
+
...
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass
|
|
68
|
+
class PayloadSpec:
|
|
69
|
+
"""Represents a payload item to be uploaded."""
|
|
70
|
+
|
|
71
|
+
source_path: PayloadPath
|
|
72
|
+
remote_relative_path: Optional[PurePath] = None
|
|
73
|
+
|
|
74
|
+
def __iter__(self) -> Iterator[Union[PayloadPath, Optional[PurePath]]]:
|
|
75
|
+
return iter((self.source_path, self.remote_relative_path))
|
|
76
|
+
|
|
77
|
+
|
|
18
78
|
@dataclass(frozen=True)
|
|
19
79
|
class PayloadEntrypoint:
|
|
20
|
-
file_path:
|
|
80
|
+
file_path: PayloadPath
|
|
21
81
|
main_func: Optional[str]
|
|
22
82
|
|
|
23
83
|
|
snowflake/ml/jobs/job.py
CHANGED
|
@@ -3,12 +3,12 @@ import logging
|
|
|
3
3
|
import os
|
|
4
4
|
import time
|
|
5
5
|
from functools import cached_property
|
|
6
|
+
from pathlib import Path
|
|
6
7
|
from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
|
|
7
8
|
|
|
8
9
|
import yaml
|
|
9
10
|
|
|
10
11
|
from snowflake import snowpark
|
|
11
|
-
from snowflake.connector import errors
|
|
12
12
|
from snowflake.ml._internal import telemetry
|
|
13
13
|
from snowflake.ml._internal.utils import identifier
|
|
14
14
|
from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
|
|
@@ -70,8 +70,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
70
70
|
def _compute_pool(self) -> str:
|
|
71
71
|
"""Get the job's compute pool name."""
|
|
72
72
|
row = _get_service_info(self._session, self.id)
|
|
73
|
-
|
|
74
|
-
return cast(str, compute_pool)
|
|
73
|
+
return cast(str, row["compute_pool"])
|
|
75
74
|
|
|
76
75
|
@property
|
|
77
76
|
def _service_spec(self) -> dict[str, Any]:
|
|
@@ -97,10 +96,24 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
97
96
|
@property
|
|
98
97
|
def _result_path(self) -> str:
|
|
99
98
|
"""Get the job's result file location."""
|
|
100
|
-
|
|
101
|
-
if
|
|
99
|
+
result_path_str = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
|
|
100
|
+
if result_path_str is None:
|
|
102
101
|
raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
|
|
103
|
-
|
|
102
|
+
volume_mounts = self._container_spec["volumeMounts"]
|
|
103
|
+
stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
|
|
104
|
+
|
|
105
|
+
result_path = Path(result_path_str)
|
|
106
|
+
stage_mount = Path(stage_mount_str)
|
|
107
|
+
try:
|
|
108
|
+
relative_path = result_path.relative_to(stage_mount)
|
|
109
|
+
except ValueError:
|
|
110
|
+
if result_path.is_absolute():
|
|
111
|
+
raise ValueError(
|
|
112
|
+
f"Result path {result_path} is absolute, but should be relative to stage mount {stage_mount}"
|
|
113
|
+
)
|
|
114
|
+
relative_path = result_path
|
|
115
|
+
|
|
116
|
+
return f"{self._stage_path}/{relative_path.as_posix()}"
|
|
104
117
|
|
|
105
118
|
@overload
|
|
106
119
|
def get_logs(
|
|
@@ -183,20 +196,17 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
183
196
|
start_time = time.monotonic()
|
|
184
197
|
warning_shown = False
|
|
185
198
|
while (status := self.status) not in TERMINAL_JOB_STATUSES:
|
|
186
|
-
|
|
199
|
+
elapsed = time.monotonic() - start_time
|
|
200
|
+
if elapsed >= timeout >= 0:
|
|
201
|
+
raise TimeoutError(f"Job {self.name} did not complete within {timeout} seconds")
|
|
202
|
+
elif status == "PENDING" and not warning_shown and elapsed >= 2: # Only show warning after 2s
|
|
187
203
|
pool_info = _get_compute_pool_info(self._session, self._compute_pool)
|
|
188
|
-
|
|
189
|
-
if (
|
|
190
|
-
pool_info[requested_attributes["max_nodes"]] - pool_info[requested_attributes["active_nodes"]]
|
|
191
|
-
) < self.min_instances:
|
|
204
|
+
if (pool_info.max_nodes - pool_info.active_nodes) < self.min_instances:
|
|
192
205
|
logger.warning(
|
|
193
|
-
f
|
|
194
|
-
f'/{pool_info[requested_attributes["max_nodes"]]} nodes in use, '
|
|
206
|
+
f"Compute pool busy ({pool_info.active_nodes}/{pool_info.max_nodes} nodes in use, "
|
|
195
207
|
f"{self.min_instances} nodes required). Job execution may be delayed."
|
|
196
208
|
)
|
|
197
209
|
warning_shown = True
|
|
198
|
-
if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
|
|
199
|
-
raise TimeoutError(f"Job {self.name} did not complete within {elapsed} seconds")
|
|
200
210
|
time.sleep(delay)
|
|
201
211
|
delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
|
202
212
|
return self.status
|
|
@@ -247,27 +257,21 @@ def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[in
|
|
|
247
257
|
"""Retrieve job or job instance execution status."""
|
|
248
258
|
if instance_id is not None:
|
|
249
259
|
# Get specific instance status
|
|
250
|
-
rows = session.
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
if isinstance(rows, dict) and "data" in rows:
|
|
255
|
-
for row in rows["data"]:
|
|
256
|
-
if row[request_attributes["instance_id"]] == str(instance_id):
|
|
257
|
-
return cast(types.JOB_STATUS, row[request_attributes["status"]])
|
|
260
|
+
rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
|
261
|
+
for row in rows:
|
|
262
|
+
if row["instance_id"] == str(instance_id):
|
|
263
|
+
return cast(types.JOB_STATUS, row["status"])
|
|
258
264
|
raise ValueError(f"Instance {instance_id} not found in job {job_id}")
|
|
259
265
|
else:
|
|
260
266
|
row = _get_service_info(session, job_id)
|
|
261
|
-
|
|
262
|
-
return cast(types.JOB_STATUS, row[request_attributes["status"]])
|
|
267
|
+
return cast(types.JOB_STATUS, row["status"])
|
|
263
268
|
|
|
264
269
|
|
|
265
270
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
|
266
271
|
def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
|
|
267
272
|
"""Retrieve job execution service spec."""
|
|
268
273
|
row = _get_service_info(session, job_id)
|
|
269
|
-
|
|
270
|
-
return cast(dict[str, Any], yaml.safe_load(row[requested_attributes["spec"]]))
|
|
274
|
+
return cast(dict[str, Any], yaml.safe_load(row["spec"]))
|
|
271
275
|
|
|
272
276
|
|
|
273
277
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
|
|
@@ -307,18 +311,14 @@ def _get_logs(
|
|
|
307
311
|
if limit > 0:
|
|
308
312
|
params.append(limit)
|
|
309
313
|
try:
|
|
310
|
-
|
|
314
|
+
(row,) = query_helper.run_query(
|
|
315
|
+
session,
|
|
311
316
|
f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
|
|
312
317
|
params=params,
|
|
313
|
-
_force_qmark_paramstyle=True,
|
|
314
318
|
)
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
else:
|
|
319
|
-
full_log = ""
|
|
320
|
-
except errors.ProgrammingError as e:
|
|
321
|
-
if "Container Status: PENDING" in str(e):
|
|
319
|
+
full_log = str(row[0])
|
|
320
|
+
except SnowparkSQLException as e:
|
|
321
|
+
if "Container Status: PENDING" in e.message:
|
|
322
322
|
logger.warning("Waiting for container to start. Logs will be shown when available.")
|
|
323
323
|
return ""
|
|
324
324
|
else:
|
|
@@ -399,7 +399,7 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
|
399
399
|
|
|
400
400
|
try:
|
|
401
401
|
target_instances = _get_target_instances(session, job_id)
|
|
402
|
-
except
|
|
402
|
+
except SnowparkSQLException:
|
|
403
403
|
# service may be deleted
|
|
404
404
|
raise RuntimeError("Couldn’t retrieve service information")
|
|
405
405
|
|
|
@@ -407,34 +407,32 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
|
407
407
|
return 0
|
|
408
408
|
|
|
409
409
|
try:
|
|
410
|
-
rows =
|
|
411
|
-
|
|
410
|
+
rows = query_helper.run_query(
|
|
411
|
+
session,
|
|
412
|
+
"SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)",
|
|
413
|
+
params=(job_id,),
|
|
412
414
|
)
|
|
413
|
-
except
|
|
415
|
+
except SnowparkSQLException:
|
|
414
416
|
# service may be deleted
|
|
415
417
|
raise RuntimeError("Couldn’t retrieve instances")
|
|
416
418
|
|
|
417
|
-
if not rows
|
|
419
|
+
if not rows:
|
|
418
420
|
return None
|
|
419
421
|
|
|
420
|
-
if target_instances > len(rows
|
|
422
|
+
if target_instances > len(rows):
|
|
421
423
|
raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
|
|
422
424
|
|
|
423
|
-
requested_attributes = query_helper.get_attribute_map(session, {"start_time": 8, "instance_id": 4})
|
|
424
425
|
# Sort by start_time first, then by instance_id
|
|
425
426
|
try:
|
|
426
|
-
sorted_instances = sorted(
|
|
427
|
-
rows["data"],
|
|
428
|
-
key=lambda x: (x[requested_attributes["start_time"]], int(x[requested_attributes["instance_id"]])),
|
|
429
|
-
)
|
|
427
|
+
sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
|
430
428
|
except TypeError:
|
|
431
429
|
raise RuntimeError("Job instance information unavailable.")
|
|
432
430
|
head_instance = sorted_instances[0]
|
|
433
|
-
if not head_instance[
|
|
431
|
+
if not head_instance["start_time"]:
|
|
434
432
|
# If head instance hasn't started yet, return None
|
|
435
433
|
return None
|
|
436
434
|
try:
|
|
437
|
-
return int(head_instance[
|
|
435
|
+
return int(head_instance["instance_id"])
|
|
438
436
|
except (ValueError, TypeError):
|
|
439
437
|
return 0
|
|
440
438
|
|
|
@@ -446,14 +444,16 @@ def _get_service_log_from_event_table(
|
|
|
446
444
|
schema: Optional[str] = None,
|
|
447
445
|
instance_id: Optional[int] = None,
|
|
448
446
|
limit: int = -1,
|
|
449
|
-
) ->
|
|
450
|
-
|
|
451
|
-
name,
|
|
452
|
-
]
|
|
447
|
+
) -> list[Row]:
|
|
448
|
+
event_table_name = session.sql("SHOW PARAMETERS LIKE 'event_table' IN ACCOUNT").collect()[0]["value"]
|
|
453
449
|
query = [
|
|
454
|
-
"SELECT VALUE FROM
|
|
450
|
+
"SELECT VALUE FROM IDENTIFIER(?)",
|
|
455
451
|
'WHERE RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
|
|
456
452
|
]
|
|
453
|
+
params: list[Any] = [
|
|
454
|
+
event_table_name,
|
|
455
|
+
name,
|
|
456
|
+
]
|
|
457
457
|
if database:
|
|
458
458
|
query.append('AND RESOURCE_ATTRIBUTES:"snow.database.name" = ?')
|
|
459
459
|
params.append(database)
|
|
@@ -473,23 +473,22 @@ def _get_service_log_from_event_table(
|
|
|
473
473
|
if limit > 0:
|
|
474
474
|
query.append("LIMIT ?")
|
|
475
475
|
params.append(limit)
|
|
476
|
-
|
|
477
|
-
|
|
476
|
+
# the wrap used in query_helper does not have return type.
|
|
477
|
+
# sticking a # type: ignore[no-any-return] is to pass type check
|
|
478
|
+
rows = query_helper.run_query(
|
|
479
|
+
session,
|
|
480
|
+
"\n".join(line for line in query if line),
|
|
481
|
+
params=params,
|
|
478
482
|
)
|
|
479
|
-
|
|
480
|
-
return []
|
|
481
|
-
return rows["data"]
|
|
483
|
+
return rows # type: ignore[no-any-return]
|
|
482
484
|
|
|
483
485
|
|
|
484
486
|
def _get_service_info(session: snowpark.Session, job_id: str) -> Any:
|
|
485
|
-
row =
|
|
486
|
-
|
|
487
|
-
if not row or not isinstance(row, dict) or not row.get("data"):
|
|
488
|
-
raise errors.ProgrammingError("failed to retrieve service information")
|
|
489
|
-
return row["data"][0]
|
|
487
|
+
(row,) = query_helper.run_query(session, "DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,))
|
|
488
|
+
return row
|
|
490
489
|
|
|
491
490
|
|
|
492
|
-
def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) ->
|
|
491
|
+
def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
|
|
493
492
|
"""
|
|
494
493
|
Check if the compute pool has enough available instances.
|
|
495
494
|
|
|
@@ -498,19 +497,16 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Any:
|
|
|
498
497
|
compute_pool (str): The name of the compute pool.
|
|
499
498
|
|
|
500
499
|
Returns:
|
|
501
|
-
|
|
500
|
+
Row: The compute pool information.
|
|
502
501
|
|
|
503
502
|
Raises:
|
|
504
503
|
ValueError: If the compute pool is not found.
|
|
505
504
|
"""
|
|
506
505
|
try:
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
)
|
|
510
|
-
#
|
|
511
|
-
if not compute_pool_info or not isinstance(compute_pool_info, dict) or not compute_pool_info.get("data"):
|
|
512
|
-
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
|
513
|
-
return compute_pool_info["data"][0]
|
|
506
|
+
# the wrap used in query_helper does not have return type.
|
|
507
|
+
# sticking a # type: ignore[no-any-return] is to pass type check
|
|
508
|
+
(pool_info,) = query_helper.run_query(session, "SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,))
|
|
509
|
+
return pool_info # type: ignore[no-any-return]
|
|
514
510
|
except ValueError as e:
|
|
515
511
|
if "not enough values to unpack" in str(e):
|
|
516
512
|
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
|
@@ -520,8 +516,7 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Any:
|
|
|
520
516
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
|
521
517
|
def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
|
|
522
518
|
row = _get_service_info(session, job_id)
|
|
523
|
-
|
|
524
|
-
return int(row[requested_attributes["target_instances"]])
|
|
519
|
+
return int(row["target_instances"])
|
|
525
520
|
|
|
526
521
|
|
|
527
522
|
def _get_logs_spcs(
|
snowflake/ml/jobs/manager.py
CHANGED
|
@@ -8,7 +8,6 @@ import pandas as pd
|
|
|
8
8
|
import yaml
|
|
9
9
|
|
|
10
10
|
from snowflake import snowpark
|
|
11
|
-
from snowflake.connector import errors
|
|
12
11
|
from snowflake.ml._internal import telemetry
|
|
13
12
|
from snowflake.ml._internal.utils import identifier
|
|
14
13
|
from snowflake.ml.jobs import job as jb
|
|
@@ -169,8 +168,8 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
|
169
168
|
job = jb.MLJob[Any](job_id, session=session)
|
|
170
169
|
_ = job._service_spec
|
|
171
170
|
return job
|
|
172
|
-
except
|
|
173
|
-
if "does not exist" in
|
|
171
|
+
except SnowparkSQLException as e:
|
|
172
|
+
if "does not exist" in e.message:
|
|
174
173
|
raise ValueError(f"Job does not exist: {job_id}") from e
|
|
175
174
|
raise
|
|
176
175
|
|
|
@@ -186,7 +185,7 @@ def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Sessio
|
|
|
186
185
|
logger.info(f"Successfully cleaned up stage files for job {job.id} at {stage_path}")
|
|
187
186
|
except Exception as e:
|
|
188
187
|
logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
|
|
189
|
-
|
|
188
|
+
query_helper.run_query(session, "DROP SERVICE IDENTIFIER(?)", params=(job.id,))
|
|
190
189
|
|
|
191
190
|
|
|
192
191
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
|
@@ -426,12 +425,18 @@ def _submit_job(
|
|
|
426
425
|
An object representing the submitted job.
|
|
427
426
|
|
|
428
427
|
Raises:
|
|
429
|
-
RuntimeError: If required Snowflake features are not enabled.
|
|
430
428
|
ValueError: If database or schema value(s) are invalid
|
|
431
|
-
|
|
429
|
+
SnowparkSQLException: If there is an error submitting the job.
|
|
432
430
|
"""
|
|
433
431
|
session = session or get_active_session()
|
|
434
432
|
|
|
433
|
+
# Check for deprecated args
|
|
434
|
+
if "num_instances" in kwargs:
|
|
435
|
+
logger.warning(
|
|
436
|
+
"'num_instances' is deprecated and will be removed in a future release. Use 'target_instances' instead."
|
|
437
|
+
)
|
|
438
|
+
target_instances = max(target_instances, kwargs.pop("num_instances"))
|
|
439
|
+
|
|
435
440
|
# Use kwargs for less common optional parameters
|
|
436
441
|
database = kwargs.pop("database", None)
|
|
437
442
|
schema = kwargs.pop("schema", None)
|
|
@@ -442,13 +447,10 @@ def _submit_job(
|
|
|
442
447
|
spec_overrides = kwargs.pop("spec_overrides", None)
|
|
443
448
|
enable_metrics = kwargs.pop("enable_metrics", True)
|
|
444
449
|
query_warehouse = kwargs.pop("query_warehouse", None)
|
|
450
|
+
additional_payloads = kwargs.pop("additional_payloads", None)
|
|
445
451
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
logger.warning(
|
|
449
|
-
"'num_instances' is deprecated and will be removed in a future release. Use 'target_instances' instead."
|
|
450
|
-
)
|
|
451
|
-
target_instances = max(target_instances, kwargs.pop("num_instances"))
|
|
452
|
+
if additional_payloads:
|
|
453
|
+
logger.warning("'additional_payloads' is in private preview since 1.9.1. Do not use it in production.")
|
|
452
454
|
|
|
453
455
|
# Warn if there are unknown kwargs
|
|
454
456
|
if kwargs:
|
|
@@ -464,8 +466,7 @@ def _submit_job(
|
|
|
464
466
|
if min_instances > 1:
|
|
465
467
|
# Validate min_instances against compute pool max_nodes
|
|
466
468
|
pool_info = jb._get_compute_pool_info(session, compute_pool)
|
|
467
|
-
|
|
468
|
-
max_nodes = int(pool_info[requested_attributes["max_nodes"]])
|
|
469
|
+
max_nodes = int(pool_info["max_nodes"])
|
|
469
470
|
if min_instances > max_nodes:
|
|
470
471
|
raise ValueError(
|
|
471
472
|
f"The requested min_instances ({min_instances}) exceeds the max_nodes ({max_nodes}) "
|
|
@@ -480,9 +481,7 @@ def _submit_job(
|
|
|
480
481
|
|
|
481
482
|
# Upload payload
|
|
482
483
|
uploaded_payload = payload_utils.JobPayload(
|
|
483
|
-
source,
|
|
484
|
-
entrypoint=entrypoint,
|
|
485
|
-
pip_requirements=pip_requirements,
|
|
484
|
+
source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=additional_payloads
|
|
486
485
|
).upload(session, stage_path)
|
|
487
486
|
|
|
488
487
|
# Generate service spec
|
|
@@ -502,7 +501,48 @@ def _submit_job(
|
|
|
502
501
|
if spec_overrides:
|
|
503
502
|
spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
|
|
504
503
|
|
|
505
|
-
|
|
504
|
+
query_text, params = _generate_submission_query(
|
|
505
|
+
spec, external_access_integrations, query_warehouse, target_instances, session, compute_pool, job_id
|
|
506
|
+
)
|
|
507
|
+
try:
|
|
508
|
+
_ = query_helper.run_query(session, query_text, params=params)
|
|
509
|
+
except SnowparkSQLException as e:
|
|
510
|
+
if "Invalid spec: unknown option 'resourceManagement' for 'spec'." in e.message:
|
|
511
|
+
logger.warning("Dropping 'resourceManagement' from spec because control policy is not enabled.")
|
|
512
|
+
spec["spec"].pop("resourceManagement", None)
|
|
513
|
+
query_text, params = _generate_submission_query(
|
|
514
|
+
spec, external_access_integrations, query_warehouse, target_instances, session, compute_pool, job_id
|
|
515
|
+
)
|
|
516
|
+
_ = query_helper.run_query(session, query_text, params=params)
|
|
517
|
+
else:
|
|
518
|
+
raise
|
|
519
|
+
return get_job(job_id, session=session)
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def _generate_submission_query(
|
|
523
|
+
spec: dict[str, Any],
|
|
524
|
+
external_access_integrations: list[str],
|
|
525
|
+
query_warehouse: Optional[str],
|
|
526
|
+
target_instances: int,
|
|
527
|
+
session: snowpark.Session,
|
|
528
|
+
compute_pool: str,
|
|
529
|
+
job_id: str,
|
|
530
|
+
) -> tuple[str, list[Any]]:
|
|
531
|
+
"""
|
|
532
|
+
Generate the SQL query for job submission.
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
spec: The service spec for the job.
|
|
536
|
+
external_access_integrations: The external access integrations for the job.
|
|
537
|
+
query_warehouse: The query warehouse for the job.
|
|
538
|
+
target_instances: The number of instances for the job.
|
|
539
|
+
session: The Snowpark session to use.
|
|
540
|
+
compute_pool: The compute pool to use for the job.
|
|
541
|
+
job_id: The ID of the job.
|
|
542
|
+
|
|
543
|
+
Returns:
|
|
544
|
+
A tuple containing the SQL query text and the parameters for the query.
|
|
545
|
+
"""
|
|
506
546
|
query_template = textwrap.dedent(
|
|
507
547
|
"""\
|
|
508
548
|
EXECUTE JOB SERVICE
|
|
@@ -526,17 +566,5 @@ def _submit_job(
|
|
|
526
566
|
if target_instances > 1:
|
|
527
567
|
query.append("REPLICAS = ?")
|
|
528
568
|
params.append(target_instances)
|
|
529
|
-
|
|
530
|
-
# Submit job
|
|
531
569
|
query_text = "\n".join(line for line in query if line)
|
|
532
|
-
|
|
533
|
-
try:
|
|
534
|
-
_ = session._conn.run_query(query_text, params=params, _force_qmark_paramstyle=True)
|
|
535
|
-
except errors.ProgrammingError as e:
|
|
536
|
-
if "invalid property 'ASYNC'" in str(e):
|
|
537
|
-
raise RuntimeError(
|
|
538
|
-
"SPCS Async Jobs not enabled. Set parameter `ENABLE_SNOWSERVICES_ASYNC_JOBS = TRUE` to enable."
|
|
539
|
-
) from e
|
|
540
|
-
raise
|
|
541
|
-
|
|
542
|
-
return get_job(job_id, session=session)
|
|
570
|
+
return query_text, params
|