snowflake-ml-python 1.8.5__py3-none-any.whl → 1.8.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +6 -9
- snowflake/ml/_internal/utils/connection_params.py +196 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +3 -2
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/payload_utils.py +83 -35
- snowflake/ml/jobs/_utils/scripts/constants.py +19 -3
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -26
- snowflake/ml/jobs/_utils/spec_utils.py +23 -1
- snowflake/ml/jobs/_utils/stage_utils.py +119 -0
- snowflake/ml/jobs/_utils/types.py +5 -1
- snowflake/ml/jobs/decorators.py +6 -7
- snowflake/ml/jobs/job.py +24 -9
- snowflake/ml/jobs/manager.py +102 -19
- snowflake/ml/model/_client/model/model_impl.py +58 -0
- snowflake/ml/model/_client/model/model_version_impl.py +90 -0
- snowflake/ml/model/_client/ops/model_ops.py +6 -3
- snowflake/ml/model/_client/ops/service_ops.py +19 -4
- snowflake/ml/model/_client/sql/service.py +68 -20
- snowflake/ml/model/_client/sql/stage.py +5 -2
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
- snowflake/ml/model/_signatures/core.py +24 -0
- snowflake/ml/monitoring/explain_visualize.py +2 -2
- snowflake/ml/monitoring/model_monitor.py +0 -4
- snowflake/ml/registry/registry.py +34 -14
- snowflake/ml/utils/connection_params.py +1 -1
- snowflake/ml/utils/html_utils.py +263 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/METADATA +14 -5
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/RECORD +33 -30
- snowflake/ml/monitoring/model_monitor_version.py +0 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,119 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
from os import PathLike
|
4
|
+
from pathlib import Path, PurePath
|
5
|
+
from typing import Union
|
6
|
+
|
7
|
+
from snowflake.ml._internal.utils import identifier
|
8
|
+
|
9
|
+
PROTOCOL_NAME = "snow"
|
10
|
+
_SNOWURL_PATH_RE = re.compile(
|
11
|
+
rf"^(?:(?:{PROTOCOL_NAME}://)?"
|
12
|
+
r"(?<!@)(?P<domain>\w+)/"
|
13
|
+
rf"(?P<name>(?:{identifier._SF_IDENTIFIER}\.){{,2}}{identifier._SF_IDENTIFIER})/)?"
|
14
|
+
r"(?P<path>versions(?:/(?P<version>[^/]+)(?:/(?P<relpath>.*))?)?)$"
|
15
|
+
)
|
16
|
+
|
17
|
+
_STAGEF_PATH_RE = re.compile(r"^@(?P<stage>~|%?\w+)(?:/(?P<relpath>[\w\-./]*))?$")
|
18
|
+
|
19
|
+
|
20
|
+
class StagePath:
|
21
|
+
def __init__(self, path: str) -> None:
|
22
|
+
stage_match = _SNOWURL_PATH_RE.fullmatch(path) or _STAGEF_PATH_RE.fullmatch(path)
|
23
|
+
if not stage_match:
|
24
|
+
raise ValueError(f"{path} is not a valid stage path")
|
25
|
+
path = path.strip()
|
26
|
+
self._raw_path = path
|
27
|
+
relpath = stage_match.group("relpath")
|
28
|
+
start, _ = stage_match.span("relpath")
|
29
|
+
self._root = self._raw_path[0:start].rstrip("/") if relpath else self._raw_path.rstrip("/")
|
30
|
+
self._path = Path(relpath or "")
|
31
|
+
|
32
|
+
@property
|
33
|
+
def parent(self) -> "StagePath":
|
34
|
+
if self._path.parent == Path(""):
|
35
|
+
return StagePath(self._root)
|
36
|
+
else:
|
37
|
+
return StagePath(f"{self._root}/{self._path.parent}")
|
38
|
+
|
39
|
+
@property
|
40
|
+
def root(self) -> str:
|
41
|
+
return self._root
|
42
|
+
|
43
|
+
@property
|
44
|
+
def suffix(self) -> str:
|
45
|
+
return self._path.suffix
|
46
|
+
|
47
|
+
def _compose_path(self, path: Path) -> str:
|
48
|
+
# in pathlib, Path("") = "."
|
49
|
+
if path == Path(""):
|
50
|
+
return self.root
|
51
|
+
else:
|
52
|
+
return f"{self.root}/{path}"
|
53
|
+
|
54
|
+
def is_relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> bool:
|
55
|
+
stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
|
56
|
+
if stage_path.root == self.root:
|
57
|
+
return self._path.is_relative_to(stage_path._path)
|
58
|
+
else:
|
59
|
+
return False
|
60
|
+
|
61
|
+
def relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> PurePath:
|
62
|
+
stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
|
63
|
+
if self.root == stage_path.root:
|
64
|
+
return self._path.relative_to(stage_path._path)
|
65
|
+
raise ValueError(f"{self._raw_path} does not start with {stage_path._raw_path}")
|
66
|
+
|
67
|
+
def absolute(self) -> "StagePath":
|
68
|
+
return self
|
69
|
+
|
70
|
+
def as_posix(self) -> str:
|
71
|
+
return self._compose_path(self._path)
|
72
|
+
|
73
|
+
# TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
|
74
|
+
def exists(self) -> bool:
|
75
|
+
return True
|
76
|
+
|
77
|
+
# TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
|
78
|
+
def is_file(self) -> bool:
|
79
|
+
return True
|
80
|
+
|
81
|
+
# TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
|
82
|
+
def is_dir(self) -> bool:
|
83
|
+
return True
|
84
|
+
|
85
|
+
def is_absolute(self) -> bool:
|
86
|
+
return True
|
87
|
+
|
88
|
+
def __str__(self) -> str:
|
89
|
+
return self.as_posix()
|
90
|
+
|
91
|
+
def __eq__(self, other: object) -> bool:
|
92
|
+
if not isinstance(other, StagePath):
|
93
|
+
raise NotImplementedError
|
94
|
+
return bool(self.root == other.root and self._path == other._path)
|
95
|
+
|
96
|
+
def __fspath__(self) -> str:
|
97
|
+
return self._compose_path(self._path)
|
98
|
+
|
99
|
+
def joinpath(self, *args: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
|
100
|
+
path = self
|
101
|
+
for arg in args:
|
102
|
+
path = path._make_child(arg)
|
103
|
+
return path
|
104
|
+
|
105
|
+
def _make_child(self, path: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
|
106
|
+
stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
|
107
|
+
if self.root == stage_path.root:
|
108
|
+
child_path = self._path.joinpath(stage_path._path)
|
109
|
+
return StagePath(self._compose_path(child_path))
|
110
|
+
else:
|
111
|
+
return stage_path
|
112
|
+
|
113
|
+
|
114
|
+
def identify_stage_path(path: str) -> Union[StagePath, Path]:
|
115
|
+
try:
|
116
|
+
stage_path = StagePath(path)
|
117
|
+
except ValueError:
|
118
|
+
return Path(path)
|
119
|
+
return stage_path
|
@@ -2,18 +2,22 @@ from dataclasses import dataclass
|
|
2
2
|
from pathlib import PurePath
|
3
3
|
from typing import Literal, Optional, Union
|
4
4
|
|
5
|
+
from snowflake.ml.jobs._utils import stage_utils
|
6
|
+
|
5
7
|
JOB_STATUS = Literal[
|
6
8
|
"PENDING",
|
7
9
|
"RUNNING",
|
8
10
|
"FAILED",
|
9
11
|
"DONE",
|
12
|
+
"CANCELLING",
|
13
|
+
"CANCELLED",
|
10
14
|
"INTERNAL_ERROR",
|
11
15
|
]
|
12
16
|
|
13
17
|
|
14
18
|
@dataclass(frozen=True)
|
15
19
|
class PayloadEntrypoint:
|
16
|
-
file_path: PurePath
|
20
|
+
file_path: Union[PurePath, stage_utils.StagePath]
|
17
21
|
main_func: Optional[str]
|
18
22
|
|
19
23
|
|
snowflake/ml/jobs/decorators.py
CHANGED
@@ -7,7 +7,7 @@ from typing_extensions import ParamSpec
|
|
7
7
|
from snowflake import snowpark
|
8
8
|
from snowflake.ml._internal import telemetry
|
9
9
|
from snowflake.ml.jobs import job as jb, manager as jm
|
10
|
-
from snowflake.ml.jobs._utils import
|
10
|
+
from snowflake.ml.jobs._utils import payload_utils
|
11
11
|
|
12
12
|
_PROJECT = "MLJob"
|
13
13
|
|
@@ -25,7 +25,7 @@ def remote(
|
|
25
25
|
query_warehouse: Optional[str] = None,
|
26
26
|
env_vars: Optional[dict[str, str]] = None,
|
27
27
|
target_instances: int = 1,
|
28
|
-
min_instances: int =
|
28
|
+
min_instances: Optional[int] = None,
|
29
29
|
enable_metrics: bool = False,
|
30
30
|
database: Optional[str] = None,
|
31
31
|
schema: Optional[str] = None,
|
@@ -42,8 +42,8 @@ def remote(
|
|
42
42
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
43
43
|
env_vars: Environment variables to set in container
|
44
44
|
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
45
|
-
min_instances: The minimum number of nodes required to start the job. If none specified,
|
46
|
-
If set, the job will not start until the minimum number of nodes is available.
|
45
|
+
min_instances: The minimum number of nodes required to start the job. If none specified,
|
46
|
+
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
47
47
|
enable_metrics: Whether to enable metrics publishing for the job.
|
48
48
|
database: The database to use for the job.
|
49
49
|
schema: The schema to use for the job.
|
@@ -62,8 +62,7 @@ def remote(
|
|
62
62
|
|
63
63
|
@functools.wraps(func)
|
64
64
|
def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
|
65
|
-
payload =
|
66
|
-
setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
|
65
|
+
payload = payload_utils.create_function_payload(func, *args, **kwargs)
|
67
66
|
job = jm._submit_job(
|
68
67
|
source=payload,
|
69
68
|
stage_name=stage_name,
|
@@ -77,7 +76,7 @@ def remote(
|
|
77
76
|
enable_metrics=enable_metrics,
|
78
77
|
database=database,
|
79
78
|
schema=schema,
|
80
|
-
session=session,
|
79
|
+
session=payload.session or session,
|
81
80
|
)
|
82
81
|
assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
|
83
82
|
return job
|
snowflake/ml/jobs/job.py
CHANGED
@@ -14,7 +14,7 @@ from snowflake.snowpark import Row, context as sp_context
|
|
14
14
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
15
15
|
|
16
16
|
_PROJECT = "MLJob"
|
17
|
-
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
|
17
|
+
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR"}
|
18
18
|
|
19
19
|
T = TypeVar("T")
|
20
20
|
|
@@ -183,14 +183,14 @@ class MLJob(Generic[T]):
|
|
183
183
|
pool_info = _get_compute_pool_info(self._session, self._compute_pool)
|
184
184
|
if (pool_info.max_nodes - pool_info.active_nodes) < self.min_instances:
|
185
185
|
logger.warning(
|
186
|
-
f"Compute pool busy ({pool_info.active_nodes}/{pool_info.max_nodes} nodes in use
|
187
|
-
" Job execution may be delayed."
|
186
|
+
f"Compute pool busy ({pool_info.active_nodes}/{pool_info.max_nodes} nodes in use, "
|
187
|
+
f"{self.min_instances} nodes required). Job execution may be delayed."
|
188
188
|
)
|
189
189
|
warning_shown = True
|
190
190
|
if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
|
191
191
|
raise TimeoutError(f"Job {self.name} did not complete within {elapsed} seconds")
|
192
192
|
time.sleep(delay)
|
193
|
-
delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
193
|
+
delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
194
194
|
return self.status
|
195
195
|
|
196
196
|
@snowpark._internal.utils.private_preview(version="1.8.2")
|
@@ -338,16 +338,23 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
338
338
|
|
339
339
|
Raises:
|
340
340
|
RuntimeError: If the instances died or if some instances disappeared.
|
341
|
-
|
342
341
|
"""
|
342
|
+
|
343
|
+
target_instances = _get_target_instances(session, job_id)
|
344
|
+
|
345
|
+
if target_instances == 1:
|
346
|
+
return 0
|
347
|
+
|
343
348
|
try:
|
344
349
|
rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
345
350
|
except SnowparkSQLException:
|
346
351
|
# service may be deleted
|
347
352
|
raise RuntimeError("Couldn’t retrieve instances")
|
353
|
+
|
348
354
|
if not rows:
|
349
355
|
return None
|
350
|
-
|
356
|
+
|
357
|
+
if target_instances > len(rows):
|
351
358
|
raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
|
352
359
|
|
353
360
|
# Sort by start_time first, then by instance_id
|
@@ -414,12 +421,20 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
|
|
414
421
|
|
415
422
|
Returns:
|
416
423
|
Row: The compute pool information.
|
424
|
+
|
425
|
+
Raises:
|
426
|
+
ValueError: If the compute pool is not found.
|
417
427
|
"""
|
418
|
-
|
419
|
-
|
428
|
+
try:
|
429
|
+
(pool_info,) = session.sql("SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,)).collect()
|
430
|
+
return pool_info
|
431
|
+
except ValueError as e:
|
432
|
+
if "not enough values to unpack" in str(e):
|
433
|
+
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
434
|
+
raise
|
420
435
|
|
421
436
|
|
422
437
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
423
438
|
def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
|
424
439
|
row = _get_service_info(session, job_id)
|
425
|
-
return int(row["target_instances"])
|
440
|
+
return int(row["target_instances"])
|
snowflake/ml/jobs/manager.py
CHANGED
@@ -87,13 +87,15 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
87
87
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
88
88
|
def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Session] = None) -> None:
|
89
89
|
"""Delete a job service from the backend. Status and logs will be lost."""
|
90
|
-
if isinstance(job, jb.MLJob)
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
90
|
+
job = job if isinstance(job, jb.MLJob) else get_job(job, session=session)
|
91
|
+
session = job._session
|
92
|
+
try:
|
93
|
+
stage_path = job._stage_path
|
94
|
+
session.sql(f"REMOVE {stage_path}/").collect()
|
95
|
+
logger.info(f"Successfully cleaned up stage files for job {job.id} at {stage_path}")
|
96
|
+
except Exception as e:
|
97
|
+
logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
|
98
|
+
session.sql("DROP SERVICE IDENTIFIER(?)", params=(job.id,)).collect()
|
97
99
|
|
98
100
|
|
99
101
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
@@ -109,7 +111,7 @@ def submit_file(
|
|
109
111
|
query_warehouse: Optional[str] = None,
|
110
112
|
spec_overrides: Optional[dict[str, Any]] = None,
|
111
113
|
target_instances: int = 1,
|
112
|
-
min_instances: int =
|
114
|
+
min_instances: Optional[int] = None,
|
113
115
|
enable_metrics: bool = False,
|
114
116
|
database: Optional[str] = None,
|
115
117
|
schema: Optional[str] = None,
|
@@ -129,7 +131,8 @@ def submit_file(
|
|
129
131
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
130
132
|
spec_overrides: Custom service specification overrides to apply.
|
131
133
|
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
132
|
-
min_instances: The minimum number of nodes required to start the job. If none specified,
|
134
|
+
min_instances: The minimum number of nodes required to start the job. If none specified,
|
135
|
+
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
133
136
|
enable_metrics: Whether to enable metrics publishing for the job.
|
134
137
|
database: The database to use.
|
135
138
|
schema: The schema to use.
|
@@ -171,7 +174,7 @@ def submit_directory(
|
|
171
174
|
query_warehouse: Optional[str] = None,
|
172
175
|
spec_overrides: Optional[dict[str, Any]] = None,
|
173
176
|
target_instances: int = 1,
|
174
|
-
min_instances: int =
|
177
|
+
min_instances: Optional[int] = None,
|
175
178
|
enable_metrics: bool = False,
|
176
179
|
database: Optional[str] = None,
|
177
180
|
schema: Optional[str] = None,
|
@@ -192,7 +195,8 @@ def submit_directory(
|
|
192
195
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
193
196
|
spec_overrides: Custom service specification overrides to apply.
|
194
197
|
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
195
|
-
min_instances: The minimum number of nodes required to start the job. If none specified,
|
198
|
+
min_instances: The minimum number of nodes required to start the job. If none specified,
|
199
|
+
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
196
200
|
enable_metrics: Whether to enable metrics publishing for the job.
|
197
201
|
database: The database to use.
|
198
202
|
schema: The schema to use.
|
@@ -221,6 +225,72 @@ def submit_directory(
|
|
221
225
|
)
|
222
226
|
|
223
227
|
|
228
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
229
|
+
def submit_from_stage(
|
230
|
+
source: str,
|
231
|
+
compute_pool: str,
|
232
|
+
*,
|
233
|
+
entrypoint: str,
|
234
|
+
stage_name: str,
|
235
|
+
args: Optional[list[str]] = None,
|
236
|
+
env_vars: Optional[dict[str, str]] = None,
|
237
|
+
pip_requirements: Optional[list[str]] = None,
|
238
|
+
external_access_integrations: Optional[list[str]] = None,
|
239
|
+
query_warehouse: Optional[str] = None,
|
240
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
241
|
+
target_instances: int = 1,
|
242
|
+
min_instances: Optional[int] = None,
|
243
|
+
enable_metrics: bool = False,
|
244
|
+
database: Optional[str] = None,
|
245
|
+
schema: Optional[str] = None,
|
246
|
+
session: Optional[snowpark.Session] = None,
|
247
|
+
) -> jb.MLJob[None]:
|
248
|
+
"""
|
249
|
+
Submit a directory containing Python script(s) as a job to the compute pool.
|
250
|
+
|
251
|
+
Args:
|
252
|
+
source: a stage path or a stage containing the job payload.
|
253
|
+
compute_pool: The compute pool to use for the job.
|
254
|
+
entrypoint: a stage path containing the entry point script inside the source directory.
|
255
|
+
stage_name: The name of the stage where the job payload will be uploaded.
|
256
|
+
args: A list of arguments to pass to the job.
|
257
|
+
env_vars: Environment variables to set in container
|
258
|
+
pip_requirements: A list of pip requirements for the job.
|
259
|
+
external_access_integrations: A list of external access integrations.
|
260
|
+
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
261
|
+
spec_overrides: Custom service specification overrides to apply.
|
262
|
+
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
263
|
+
min_instances: The minimum number of nodes required to start the job. If none specified,
|
264
|
+
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
265
|
+
enable_metrics: Whether to enable metrics publishing for the job.
|
266
|
+
database: The database to use.
|
267
|
+
schema: The schema to use.
|
268
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
269
|
+
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
An object representing the submitted job.
|
273
|
+
"""
|
274
|
+
return _submit_job(
|
275
|
+
source=source,
|
276
|
+
entrypoint=entrypoint,
|
277
|
+
args=args,
|
278
|
+
compute_pool=compute_pool,
|
279
|
+
stage_name=stage_name,
|
280
|
+
env_vars=env_vars,
|
281
|
+
pip_requirements=pip_requirements,
|
282
|
+
external_access_integrations=external_access_integrations,
|
283
|
+
query_warehouse=query_warehouse,
|
284
|
+
spec_overrides=spec_overrides,
|
285
|
+
target_instances=target_instances,
|
286
|
+
min_instances=min_instances,
|
287
|
+
enable_metrics=enable_metrics,
|
288
|
+
database=database,
|
289
|
+
schema=schema,
|
290
|
+
session=session,
|
291
|
+
)
|
292
|
+
|
293
|
+
|
224
294
|
@overload
|
225
295
|
def _submit_job(
|
226
296
|
source: str,
|
@@ -235,7 +305,7 @@ def _submit_job(
|
|
235
305
|
query_warehouse: Optional[str] = None,
|
236
306
|
spec_overrides: Optional[dict[str, Any]] = None,
|
237
307
|
target_instances: int = 1,
|
238
|
-
min_instances: int =
|
308
|
+
min_instances: Optional[int] = None,
|
239
309
|
enable_metrics: bool = False,
|
240
310
|
database: Optional[str] = None,
|
241
311
|
schema: Optional[str] = None,
|
@@ -258,7 +328,7 @@ def _submit_job(
|
|
258
328
|
query_warehouse: Optional[str] = None,
|
259
329
|
spec_overrides: Optional[dict[str, Any]] = None,
|
260
330
|
target_instances: int = 1,
|
261
|
-
min_instances: int =
|
331
|
+
min_instances: Optional[int] = None,
|
262
332
|
enable_metrics: bool = False,
|
263
333
|
database: Optional[str] = None,
|
264
334
|
schema: Optional[str] = None,
|
@@ -292,7 +362,7 @@ def _submit_job(
|
|
292
362
|
query_warehouse: Optional[str] = None,
|
293
363
|
spec_overrides: Optional[dict[str, Any]] = None,
|
294
364
|
target_instances: int = 1,
|
295
|
-
min_instances: int =
|
365
|
+
min_instances: Optional[int] = None,
|
296
366
|
enable_metrics: bool = False,
|
297
367
|
database: Optional[str] = None,
|
298
368
|
schema: Optional[str] = None,
|
@@ -313,7 +383,8 @@ def _submit_job(
|
|
313
383
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
314
384
|
spec_overrides: Custom service specification overrides to apply.
|
315
385
|
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
316
|
-
min_instances: The minimum number of nodes required to start the job. If none specified,
|
386
|
+
min_instances: The minimum number of nodes required to start the job. If none specified,
|
387
|
+
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
317
388
|
enable_metrics: Whether to enable metrics publishing for the job.
|
318
389
|
database: The database to use.
|
319
390
|
schema: The schema to use.
|
@@ -328,13 +399,25 @@ def _submit_job(
|
|
328
399
|
"""
|
329
400
|
if database and not schema:
|
330
401
|
raise ValueError("Schema must be specified if database is specified.")
|
331
|
-
if target_instances < 1
|
332
|
-
raise ValueError("target_instances
|
333
|
-
|
334
|
-
|
402
|
+
if target_instances < 1:
|
403
|
+
raise ValueError("target_instances must be greater than 0.")
|
404
|
+
|
405
|
+
min_instances = target_instances if min_instances is None else min_instances
|
406
|
+
if not (0 < min_instances <= target_instances):
|
407
|
+
raise ValueError("min_instances must be greater than 0 and less than or equal to target_instances.")
|
335
408
|
|
336
409
|
session = session or get_active_session()
|
337
410
|
|
411
|
+
if min_instances > 1:
|
412
|
+
# Validate min_instances against compute pool max_nodes
|
413
|
+
pool_info = jb._get_compute_pool_info(session, compute_pool)
|
414
|
+
max_nodes = int(pool_info["max_nodes"])
|
415
|
+
if min_instances > max_nodes:
|
416
|
+
raise ValueError(
|
417
|
+
f"The requested min_instances ({min_instances}) exceeds the max_nodes ({max_nodes}) "
|
418
|
+
f"of compute pool '{compute_pool}'. Reduce min_instances or increase max_nodes."
|
419
|
+
)
|
420
|
+
|
338
421
|
# Validate database and schema identifiers on client side since
|
339
422
|
# SQL parser for EXECUTE JOB SERVICE seems to struggle with this
|
340
423
|
database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
|
@@ -426,3 +426,61 @@ class Model:
|
|
426
426
|
schema_name=new_schema or self._model_ops._model_client._schema_name,
|
427
427
|
)
|
428
428
|
self._model_name = new_model
|
429
|
+
|
430
|
+
def _repr_html_(self) -> str:
|
431
|
+
"""Generate an HTML representation of the model.
|
432
|
+
|
433
|
+
Returns:
|
434
|
+
str: HTML string containing formatted model details.
|
435
|
+
"""
|
436
|
+
from snowflake.ml.utils import html_utils
|
437
|
+
|
438
|
+
# Get default version
|
439
|
+
default_version = self.default.version_name
|
440
|
+
|
441
|
+
# Get versions info
|
442
|
+
try:
|
443
|
+
versions_df = self.show_versions()
|
444
|
+
versions_html = ""
|
445
|
+
|
446
|
+
for _, row in versions_df.iterrows():
|
447
|
+
versions_html += html_utils.create_version_item(
|
448
|
+
version_name=row["name"],
|
449
|
+
created_on=str(row["created_on"]),
|
450
|
+
comment=str(row.get("comment", "")),
|
451
|
+
is_default=bool(row["is_default_version"]),
|
452
|
+
)
|
453
|
+
except Exception:
|
454
|
+
versions_html = html_utils.create_error_message("Error retrieving versions")
|
455
|
+
|
456
|
+
# Get tags
|
457
|
+
try:
|
458
|
+
tags = self.show_tags()
|
459
|
+
if not tags:
|
460
|
+
tags_html = html_utils.create_error_message("No tags available")
|
461
|
+
else:
|
462
|
+
tags_html = ""
|
463
|
+
for tag_name, tag_value in tags.items():
|
464
|
+
tags_html += html_utils.create_tag_item(tag_name, tag_value)
|
465
|
+
except Exception:
|
466
|
+
tags_html = html_utils.create_error_message("Error retrieving tags")
|
467
|
+
|
468
|
+
# Create main content sections
|
469
|
+
main_info = html_utils.create_grid_section(
|
470
|
+
[
|
471
|
+
("Model Name", self.name),
|
472
|
+
("Full Name", self.fully_qualified_name),
|
473
|
+
("Description", self.description),
|
474
|
+
("Default Version", default_version),
|
475
|
+
]
|
476
|
+
)
|
477
|
+
|
478
|
+
versions_section = html_utils.create_section_header("Versions") + html_utils.create_content_section(
|
479
|
+
versions_html
|
480
|
+
)
|
481
|
+
|
482
|
+
tags_section = html_utils.create_section_header("Tags") + html_utils.create_content_section(tags_html)
|
483
|
+
|
484
|
+
content = main_info + versions_section + tags_section
|
485
|
+
|
486
|
+
return html_utils.create_base_container("Model Details", content)
|
@@ -38,6 +38,96 @@ class ModelVersion(lineage_node.LineageNode):
|
|
38
38
|
def __init__(self) -> None:
|
39
39
|
raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
|
40
40
|
|
41
|
+
def _repr_html_(self) -> str:
|
42
|
+
"""Generate an HTML representation of the model version.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
str: HTML string containing formatted model version details.
|
46
|
+
"""
|
47
|
+
from snowflake.ml.utils import html_utils
|
48
|
+
|
49
|
+
# Get task
|
50
|
+
try:
|
51
|
+
task = self.get_model_task().value
|
52
|
+
except Exception:
|
53
|
+
task = (
|
54
|
+
html_utils.create_error_message("Not available")
|
55
|
+
.replace('<em style="color: #888; font-style: italic;">', "")
|
56
|
+
.replace("</em>", "")
|
57
|
+
)
|
58
|
+
|
59
|
+
# Get functions info for display
|
60
|
+
try:
|
61
|
+
functions = self.show_functions()
|
62
|
+
if not functions:
|
63
|
+
functions_html = html_utils.create_error_message("No functions available")
|
64
|
+
else:
|
65
|
+
functions_list = []
|
66
|
+
for func in functions:
|
67
|
+
try:
|
68
|
+
sig_html = func["signature"]._repr_html_()
|
69
|
+
except Exception:
|
70
|
+
# Fallback to simple display if can't display signature
|
71
|
+
sig_html = f"<pre style='margin: 5px 0;'>{func['signature']}</pre>"
|
72
|
+
|
73
|
+
function_content = f"""
|
74
|
+
<div style="margin: 5px 0;">
|
75
|
+
<strong>Target Method:</strong> {func['target_method']}
|
76
|
+
</div>
|
77
|
+
<div style="margin: 5px 0;">
|
78
|
+
<strong>Function Type:</strong> {func.get('target_method_function_type', 'N/A')}
|
79
|
+
</div>
|
80
|
+
<div style="margin: 5px 0;">
|
81
|
+
<strong>Partitioned:</strong> {func.get('is_partitioned', False)}
|
82
|
+
</div>
|
83
|
+
<div style="margin: 10px 0;">
|
84
|
+
<strong>Signature:</strong>
|
85
|
+
{sig_html}
|
86
|
+
</div>
|
87
|
+
"""
|
88
|
+
|
89
|
+
functions_list.append(
|
90
|
+
html_utils.create_collapsible_section(
|
91
|
+
title=func["name"], content=function_content, open_by_default=False
|
92
|
+
)
|
93
|
+
)
|
94
|
+
functions_html = "".join(functions_list)
|
95
|
+
except Exception:
|
96
|
+
functions_html = html_utils.create_error_message("Error retrieving functions")
|
97
|
+
|
98
|
+
# Get metrics for display
|
99
|
+
try:
|
100
|
+
metrics = self.show_metrics()
|
101
|
+
if not metrics:
|
102
|
+
metrics_html = html_utils.create_error_message("No metrics available")
|
103
|
+
else:
|
104
|
+
metrics_html = ""
|
105
|
+
for metric_name, value in metrics.items():
|
106
|
+
metrics_html += html_utils.create_metric_item(metric_name, value)
|
107
|
+
except Exception:
|
108
|
+
metrics_html = html_utils.create_error_message("Error retrieving metrics")
|
109
|
+
|
110
|
+
# Create main content sections
|
111
|
+
main_info = html_utils.create_grid_section(
|
112
|
+
[
|
113
|
+
("Model Name", self.model_name),
|
114
|
+
("Version", f'<strong style="color: #28a745;">{self.version_name}</strong>'),
|
115
|
+
("Full Name", self.fully_qualified_model_name),
|
116
|
+
("Description", self.description),
|
117
|
+
("Task", task),
|
118
|
+
]
|
119
|
+
)
|
120
|
+
|
121
|
+
functions_section = html_utils.create_section_header("Functions") + html_utils.create_content_section(
|
122
|
+
functions_html
|
123
|
+
)
|
124
|
+
|
125
|
+
metrics_section = html_utils.create_section_header("Metrics") + html_utils.create_content_section(metrics_html)
|
126
|
+
|
127
|
+
content = main_info + functions_section + metrics_section
|
128
|
+
|
129
|
+
return html_utils.create_base_container("Model Version Details", content)
|
130
|
+
|
41
131
|
@classmethod
|
42
132
|
def _ref(
|
43
133
|
cls,
|
@@ -643,14 +643,17 @@ class ModelOperator:
|
|
643
643
|
# TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
|
644
644
|
fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
|
645
645
|
|
646
|
-
result = []
|
647
|
-
|
646
|
+
result: list[ServiceInfo] = []
|
648
647
|
for fully_qualified_service_name in fully_qualified_service_names:
|
649
648
|
ingress_url: Optional[str] = None
|
650
649
|
db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
|
651
|
-
|
650
|
+
statuses = self._service_client.get_service_container_statuses(
|
652
651
|
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
653
652
|
)
|
653
|
+
if len(statuses) == 0:
|
654
|
+
return result
|
655
|
+
|
656
|
+
service_status = statuses[0].service_status
|
654
657
|
for res_row in self._service_client.show_endpoints(
|
655
658
|
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
656
659
|
):
|