snowflake-ml-python 1.8.4__py3-none-any.whl → 1.8.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +42 -16
- snowflake/ml/_internal/utils/connection_params.py +196 -0
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +12 -2
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/interop_utils.py +1 -1
- snowflake/ml/jobs/_utils/payload_utils.py +95 -39
- snowflake/ml/jobs/_utils/scripts/constants.py +22 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +67 -2
- snowflake/ml/jobs/_utils/spec_utils.py +30 -6
- snowflake/ml/jobs/_utils/stage_utils.py +119 -0
- snowflake/ml/jobs/_utils/types.py +5 -1
- snowflake/ml/jobs/decorators.py +10 -7
- snowflake/ml/jobs/job.py +176 -28
- snowflake/ml/jobs/manager.py +119 -26
- snowflake/ml/model/_client/model/model_impl.py +58 -0
- snowflake/ml/model/_client/model/model_version_impl.py +90 -0
- snowflake/ml/model/_client/ops/model_ops.py +6 -3
- snowflake/ml/model/_client/ops/service_ops.py +24 -7
- snowflake/ml/model/_client/service/model_deployment_spec.py +11 -0
- snowflake/ml/model/_client/sql/model_version.py +1 -1
- snowflake/ml/model/_client/sql/service.py +73 -28
- snowflake/ml/model/_client/sql/stage.py +5 -2
- snowflake/ml/model/_model_composer/model_composer.py +3 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -2
- snowflake/ml/model/_signatures/core.py +24 -0
- snowflake/ml/monitoring/explain_visualize.py +160 -22
- snowflake/ml/monitoring/model_monitor.py +0 -4
- snowflake/ml/registry/registry.py +34 -14
- snowflake/ml/utils/connection_params.py +9 -3
- snowflake/ml/utils/html_utils.py +263 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/METADATA +40 -13
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/RECORD +40 -37
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/model_monitor_version.py +0 -1
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
1
3
|
import time
|
2
4
|
from functools import cached_property
|
3
5
|
from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
|
@@ -12,10 +14,12 @@ from snowflake.snowpark import Row, context as sp_context
|
|
12
14
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
13
15
|
|
14
16
|
_PROJECT = "MLJob"
|
15
|
-
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
|
17
|
+
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR"}
|
16
18
|
|
17
19
|
T = TypeVar("T")
|
18
20
|
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
19
23
|
|
20
24
|
class MLJob(Generic[T]):
|
21
25
|
def __init__(
|
@@ -36,8 +40,15 @@ class MLJob(Generic[T]):
|
|
36
40
|
return identifier.parse_schema_level_object_identifier(self.id)[-1]
|
37
41
|
|
38
42
|
@cached_property
|
39
|
-
def
|
40
|
-
return
|
43
|
+
def target_instances(self) -> int:
|
44
|
+
return _get_target_instances(self._session, self.id)
|
45
|
+
|
46
|
+
@cached_property
|
47
|
+
def min_instances(self) -> int:
|
48
|
+
try:
|
49
|
+
return int(self._container_spec["env"].get(constants.MIN_INSTANCES_ENV_VAR, 1))
|
50
|
+
except TypeError:
|
51
|
+
return 1
|
41
52
|
|
42
53
|
@property
|
43
54
|
def id(self) -> str:
|
@@ -52,6 +63,12 @@ class MLJob(Generic[T]):
|
|
52
63
|
self._status = _get_status(self._session, self.id)
|
53
64
|
return self._status
|
54
65
|
|
66
|
+
@cached_property
|
67
|
+
def _compute_pool(self) -> str:
|
68
|
+
"""Get the job's compute pool name."""
|
69
|
+
row = _get_service_info(self._session, self.id)
|
70
|
+
return cast(str, row["compute_pool"])
|
71
|
+
|
55
72
|
@property
|
56
73
|
def _service_spec(self) -> dict[str, Any]:
|
57
74
|
"""Get the job's service spec."""
|
@@ -82,15 +99,34 @@ class MLJob(Generic[T]):
|
|
82
99
|
return f"{self._stage_path}/{result_path}"
|
83
100
|
|
84
101
|
@overload
|
85
|
-
def get_logs(
|
102
|
+
def get_logs(
|
103
|
+
self,
|
104
|
+
limit: int = -1,
|
105
|
+
instance_id: Optional[int] = None,
|
106
|
+
*,
|
107
|
+
as_list: Literal[True],
|
108
|
+
verbose: bool = constants.DEFAULT_VERBOSE_LOG,
|
109
|
+
) -> list[str]:
|
86
110
|
...
|
87
111
|
|
88
112
|
@overload
|
89
|
-
def get_logs(
|
113
|
+
def get_logs(
|
114
|
+
self,
|
115
|
+
limit: int = -1,
|
116
|
+
instance_id: Optional[int] = None,
|
117
|
+
*,
|
118
|
+
as_list: Literal[False] = False,
|
119
|
+
verbose: bool = constants.DEFAULT_VERBOSE_LOG,
|
120
|
+
) -> str:
|
90
121
|
...
|
91
122
|
|
92
123
|
def get_logs(
|
93
|
-
self,
|
124
|
+
self,
|
125
|
+
limit: int = -1,
|
126
|
+
instance_id: Optional[int] = None,
|
127
|
+
*,
|
128
|
+
as_list: bool = False,
|
129
|
+
verbose: bool = constants.DEFAULT_VERBOSE_LOG,
|
94
130
|
) -> Union[str, list[str]]:
|
95
131
|
"""
|
96
132
|
Return the job's execution logs.
|
@@ -100,17 +136,20 @@ class MLJob(Generic[T]):
|
|
100
136
|
instance_id: Optional instance ID to get logs from a specific instance.
|
101
137
|
If not provided, returns logs from the head node.
|
102
138
|
as_list: If True, returns logs as a list of lines. Otherwise, returns logs as a single string.
|
139
|
+
verbose: Whether to return the full log or just the user log.
|
103
140
|
|
104
141
|
Returns:
|
105
142
|
The job's execution logs.
|
106
143
|
"""
|
107
|
-
logs = _get_logs(self._session, self.id, limit, instance_id)
|
144
|
+
logs = _get_logs(self._session, self.id, limit, instance_id, verbose)
|
108
145
|
assert isinstance(logs, str) # mypy
|
109
146
|
if as_list:
|
110
147
|
return logs.splitlines()
|
111
148
|
return logs
|
112
149
|
|
113
|
-
def show_logs(
|
150
|
+
def show_logs(
|
151
|
+
self, limit: int = -1, instance_id: Optional[int] = None, verbose: bool = constants.DEFAULT_VERBOSE_LOG
|
152
|
+
) -> None:
|
114
153
|
"""
|
115
154
|
Display the job's execution logs.
|
116
155
|
|
@@ -118,8 +157,9 @@ class MLJob(Generic[T]):
|
|
118
157
|
limit: The maximum number of lines to display. Negative values are treated as no limit.
|
119
158
|
instance_id: Optional instance ID to get logs from a specific instance.
|
120
159
|
If not provided, displays logs from the head node.
|
160
|
+
verbose: Whether to return the full log or just the user log.
|
121
161
|
"""
|
122
|
-
print(self.get_logs(limit, instance_id, as_list=False)) # noqa: T201: we need to print here.
|
162
|
+
print(self.get_logs(limit, instance_id, as_list=False, verbose=verbose)) # noqa: T201: we need to print here.
|
123
163
|
|
124
164
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
|
125
165
|
def wait(self, timeout: float = -1) -> types.JOB_STATUS:
|
@@ -137,11 +177,20 @@ class MLJob(Generic[T]):
|
|
137
177
|
"""
|
138
178
|
delay = constants.JOB_POLL_INITIAL_DELAY_SECONDS # Start with 100ms delay
|
139
179
|
start_time = time.monotonic()
|
140
|
-
|
180
|
+
warning_shown = False
|
181
|
+
while (status := self.status) not in TERMINAL_JOB_STATUSES:
|
182
|
+
if status == "PENDING" and not warning_shown:
|
183
|
+
pool_info = _get_compute_pool_info(self._session, self._compute_pool)
|
184
|
+
if (pool_info.max_nodes - pool_info.active_nodes) < self.min_instances:
|
185
|
+
logger.warning(
|
186
|
+
f"Compute pool busy ({pool_info.active_nodes}/{pool_info.max_nodes} nodes in use, "
|
187
|
+
f"{self.min_instances} nodes required). Job execution may be delayed."
|
188
|
+
)
|
189
|
+
warning_shown = True
|
141
190
|
if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
|
142
191
|
raise TimeoutError(f"Job {self.name} did not complete within {elapsed} seconds")
|
143
192
|
time.sleep(delay)
|
144
|
-
delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
193
|
+
delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
145
194
|
return self.status
|
146
195
|
|
147
196
|
@snowpark._internal.utils.private_preview(version="1.8.2")
|
@@ -195,7 +244,9 @@ def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
|
|
195
244
|
|
196
245
|
|
197
246
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
|
198
|
-
def _get_logs(
|
247
|
+
def _get_logs(
|
248
|
+
session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None, verbose: bool = True
|
249
|
+
) -> str:
|
199
250
|
"""
|
200
251
|
Retrieve the job's execution logs.
|
201
252
|
|
@@ -204,24 +255,20 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
|
|
204
255
|
limit: The maximum number of lines to return. Negative values are treated as no limit.
|
205
256
|
session: The Snowpark session to use. If none specified, uses active session.
|
206
257
|
instance_id: Optional instance ID to get logs from a specific instance.
|
258
|
+
verbose: Whether to return the full log or just the portion between START and END messages.
|
207
259
|
|
208
260
|
Returns:
|
209
261
|
The job's execution logs.
|
210
262
|
|
211
263
|
Raises:
|
212
|
-
SnowparkSQLException: if the container is pending
|
213
264
|
RuntimeError: if failed to get head instance_id
|
214
|
-
|
215
265
|
"""
|
216
266
|
# If instance_id is not specified, try to get the head instance ID
|
217
267
|
if instance_id is None:
|
218
268
|
try:
|
219
269
|
instance_id = _get_head_instance_id(session, job_id)
|
220
270
|
except RuntimeError:
|
221
|
-
|
222
|
-
"Failed to retrieve job logs. "
|
223
|
-
"Logs may be inaccessible due to job expiration and can be retrieved from Event Table instead."
|
224
|
-
)
|
271
|
+
instance_id = None
|
225
272
|
|
226
273
|
# Assemble params: [job_id, instance_id, container_name, (optional) limit]
|
227
274
|
params: list[Any] = [
|
@@ -231,7 +278,6 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
|
|
231
278
|
]
|
232
279
|
if limit > 0:
|
233
280
|
params.append(limit)
|
234
|
-
|
235
281
|
try:
|
236
282
|
(row,) = session.sql(
|
237
283
|
f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
|
@@ -239,9 +285,43 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
|
|
239
285
|
).collect()
|
240
286
|
except SnowparkSQLException as e:
|
241
287
|
if "Container Status: PENDING" in e.message:
|
242
|
-
|
243
|
-
|
244
|
-
|
288
|
+
logger.warning("Waiting for container to start. Logs will be shown when available.")
|
289
|
+
return ""
|
290
|
+
else:
|
291
|
+
# event table accepts job name, not fully qualified name
|
292
|
+
# cast is to resolve the type check error
|
293
|
+
db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
|
294
|
+
db = cast(str, db or session.get_current_database())
|
295
|
+
schema = cast(str, schema or session.get_current_schema())
|
296
|
+
logs = _get_service_log_from_event_table(
|
297
|
+
session, db, schema, name, limit, instance_id if instance_id else None
|
298
|
+
)
|
299
|
+
if len(logs) == 0:
|
300
|
+
raise RuntimeError(
|
301
|
+
"No logs were found. Please verify that the database, schema, and job ID are correct."
|
302
|
+
)
|
303
|
+
return os.linesep.join(row[0] for row in logs)
|
304
|
+
|
305
|
+
full_log = str(row[0])
|
306
|
+
|
307
|
+
# If verbose is True, return the complete log
|
308
|
+
if verbose:
|
309
|
+
return full_log
|
310
|
+
|
311
|
+
# Otherwise, extract only the portion between LOG_START_MSG and LOG_END_MSG
|
312
|
+
start_idx = full_log.find(constants.LOG_START_MSG)
|
313
|
+
if start_idx != -1:
|
314
|
+
start_idx += len(constants.LOG_START_MSG)
|
315
|
+
else:
|
316
|
+
# If start message not found, start from the beginning
|
317
|
+
start_idx = 0
|
318
|
+
|
319
|
+
end_idx = full_log.find(constants.LOG_END_MSG, start_idx)
|
320
|
+
if end_idx == -1:
|
321
|
+
# If end message not found, return everything after start
|
322
|
+
end_idx = len(full_log)
|
323
|
+
|
324
|
+
return full_log[start_idx:end_idx].strip()
|
245
325
|
|
246
326
|
|
247
327
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
@@ -256,13 +336,25 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
256
336
|
Returns:
|
257
337
|
Optional[int]: The head instance ID of the job, or None if the head instance has not started yet.
|
258
338
|
|
259
|
-
|
339
|
+
Raises:
|
260
340
|
RuntimeError: If the instances died or if some instances disappeared.
|
261
341
|
"""
|
262
|
-
|
342
|
+
|
343
|
+
target_instances = _get_target_instances(session, job_id)
|
344
|
+
|
345
|
+
if target_instances == 1:
|
346
|
+
return 0
|
347
|
+
|
348
|
+
try:
|
349
|
+
rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
350
|
+
except SnowparkSQLException:
|
351
|
+
# service may be deleted
|
352
|
+
raise RuntimeError("Couldn’t retrieve instances")
|
353
|
+
|
263
354
|
if not rows:
|
264
355
|
return None
|
265
|
-
|
356
|
+
|
357
|
+
if target_instances > len(rows):
|
266
358
|
raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
|
267
359
|
|
268
360
|
# Sort by start_time first, then by instance_id
|
@@ -270,7 +362,6 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
270
362
|
sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
271
363
|
except TypeError:
|
272
364
|
raise RuntimeError("Job instance information unavailable.")
|
273
|
-
|
274
365
|
head_instance = sorted_instances[0]
|
275
366
|
if not head_instance["start_time"]:
|
276
367
|
# If head instance hasn't started yet, return None
|
@@ -281,12 +372,69 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
281
372
|
return 0
|
282
373
|
|
283
374
|
|
375
|
+
def _get_service_log_from_event_table(
|
376
|
+
session: snowpark.Session, database: str, schema: str, name: str, limit: int, instance_id: Optional[int]
|
377
|
+
) -> list[Row]:
|
378
|
+
params: list[Any] = [
|
379
|
+
database,
|
380
|
+
schema,
|
381
|
+
name,
|
382
|
+
]
|
383
|
+
query = [
|
384
|
+
"SELECT VALUE FROM snowflake.telemetry.events_view",
|
385
|
+
'WHERE RESOURCE_ATTRIBUTES:"snow.database.name" = ?',
|
386
|
+
'AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?',
|
387
|
+
'AND RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
|
388
|
+
]
|
389
|
+
|
390
|
+
if instance_id:
|
391
|
+
query.append('AND RESOURCE_ATTRIBUTES:"snow.service.container.instance" = ?')
|
392
|
+
params.append(instance_id)
|
393
|
+
|
394
|
+
query.append("AND RECORD_TYPE = 'LOG'")
|
395
|
+
# sort by TIMESTAMP; although OBSERVED_TIMESTAMP is for log, it is NONE currently when record_type is log
|
396
|
+
query.append("ORDER BY TIMESTAMP")
|
397
|
+
|
398
|
+
if limit > 0:
|
399
|
+
query.append("LIMIT ?")
|
400
|
+
params.append(limit)
|
401
|
+
|
402
|
+
rows = session.sql(
|
403
|
+
"\n".join(line for line in query if line),
|
404
|
+
params=params,
|
405
|
+
).collect()
|
406
|
+
return rows
|
407
|
+
|
408
|
+
|
284
409
|
def _get_service_info(session: snowpark.Session, job_id: str) -> Row:
|
285
410
|
(row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
286
411
|
return row
|
287
412
|
|
288
413
|
|
414
|
+
def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
|
415
|
+
"""
|
416
|
+
Check if the compute pool has enough available instances.
|
417
|
+
|
418
|
+
Args:
|
419
|
+
session (Session): The Snowpark session to use.
|
420
|
+
compute_pool (str): The name of the compute pool.
|
421
|
+
|
422
|
+
Returns:
|
423
|
+
Row: The compute pool information.
|
424
|
+
|
425
|
+
Raises:
|
426
|
+
ValueError: If the compute pool is not found.
|
427
|
+
"""
|
428
|
+
try:
|
429
|
+
(pool_info,) = session.sql("SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,)).collect()
|
430
|
+
return pool_info
|
431
|
+
except ValueError as e:
|
432
|
+
if "not enough values to unpack" in str(e):
|
433
|
+
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
434
|
+
raise
|
435
|
+
|
436
|
+
|
289
437
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
290
|
-
def
|
438
|
+
def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
|
291
439
|
row = _get_service_info(session, job_id)
|
292
|
-
return int(row["target_instances"])
|
440
|
+
return int(row["target_instances"])
|
snowflake/ml/jobs/manager.py
CHANGED
@@ -87,13 +87,15 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
87
87
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
88
88
|
def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Session] = None) -> None:
|
89
89
|
"""Delete a job service from the backend. Status and logs will be lost."""
|
90
|
-
if isinstance(job, jb.MLJob)
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
90
|
+
job = job if isinstance(job, jb.MLJob) else get_job(job, session=session)
|
91
|
+
session = job._session
|
92
|
+
try:
|
93
|
+
stage_path = job._stage_path
|
94
|
+
session.sql(f"REMOVE {stage_path}/").collect()
|
95
|
+
logger.info(f"Successfully cleaned up stage files for job {job.id} at {stage_path}")
|
96
|
+
except Exception as e:
|
97
|
+
logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
|
98
|
+
session.sql("DROP SERVICE IDENTIFIER(?)", params=(job.id,)).collect()
|
97
99
|
|
98
100
|
|
99
101
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
@@ -108,7 +110,8 @@ def submit_file(
|
|
108
110
|
external_access_integrations: Optional[list[str]] = None,
|
109
111
|
query_warehouse: Optional[str] = None,
|
110
112
|
spec_overrides: Optional[dict[str, Any]] = None,
|
111
|
-
|
113
|
+
target_instances: int = 1,
|
114
|
+
min_instances: Optional[int] = None,
|
112
115
|
enable_metrics: bool = False,
|
113
116
|
database: Optional[str] = None,
|
114
117
|
schema: Optional[str] = None,
|
@@ -127,7 +130,9 @@ def submit_file(
|
|
127
130
|
external_access_integrations: A list of external access integrations.
|
128
131
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
129
132
|
spec_overrides: Custom service specification overrides to apply.
|
130
|
-
|
133
|
+
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
134
|
+
min_instances: The minimum number of nodes required to start the job. If none specified,
|
135
|
+
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
131
136
|
enable_metrics: Whether to enable metrics publishing for the job.
|
132
137
|
database: The database to use.
|
133
138
|
schema: The schema to use.
|
@@ -146,7 +151,8 @@ def submit_file(
|
|
146
151
|
external_access_integrations=external_access_integrations,
|
147
152
|
query_warehouse=query_warehouse,
|
148
153
|
spec_overrides=spec_overrides,
|
149
|
-
|
154
|
+
target_instances=target_instances,
|
155
|
+
min_instances=min_instances,
|
150
156
|
enable_metrics=enable_metrics,
|
151
157
|
database=database,
|
152
158
|
schema=schema,
|
@@ -167,7 +173,8 @@ def submit_directory(
|
|
167
173
|
external_access_integrations: Optional[list[str]] = None,
|
168
174
|
query_warehouse: Optional[str] = None,
|
169
175
|
spec_overrides: Optional[dict[str, Any]] = None,
|
170
|
-
|
176
|
+
target_instances: int = 1,
|
177
|
+
min_instances: Optional[int] = None,
|
171
178
|
enable_metrics: bool = False,
|
172
179
|
database: Optional[str] = None,
|
173
180
|
schema: Optional[str] = None,
|
@@ -187,7 +194,9 @@ def submit_directory(
|
|
187
194
|
external_access_integrations: A list of external access integrations.
|
188
195
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
189
196
|
spec_overrides: Custom service specification overrides to apply.
|
190
|
-
|
197
|
+
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
198
|
+
min_instances: The minimum number of nodes required to start the job. If none specified,
|
199
|
+
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
191
200
|
enable_metrics: Whether to enable metrics publishing for the job.
|
192
201
|
database: The database to use.
|
193
202
|
schema: The schema to use.
|
@@ -207,7 +216,74 @@ def submit_directory(
|
|
207
216
|
external_access_integrations=external_access_integrations,
|
208
217
|
query_warehouse=query_warehouse,
|
209
218
|
spec_overrides=spec_overrides,
|
210
|
-
|
219
|
+
target_instances=target_instances,
|
220
|
+
min_instances=min_instances,
|
221
|
+
enable_metrics=enable_metrics,
|
222
|
+
database=database,
|
223
|
+
schema=schema,
|
224
|
+
session=session,
|
225
|
+
)
|
226
|
+
|
227
|
+
|
228
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
229
|
+
def submit_from_stage(
|
230
|
+
source: str,
|
231
|
+
compute_pool: str,
|
232
|
+
*,
|
233
|
+
entrypoint: str,
|
234
|
+
stage_name: str,
|
235
|
+
args: Optional[list[str]] = None,
|
236
|
+
env_vars: Optional[dict[str, str]] = None,
|
237
|
+
pip_requirements: Optional[list[str]] = None,
|
238
|
+
external_access_integrations: Optional[list[str]] = None,
|
239
|
+
query_warehouse: Optional[str] = None,
|
240
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
241
|
+
target_instances: int = 1,
|
242
|
+
min_instances: Optional[int] = None,
|
243
|
+
enable_metrics: bool = False,
|
244
|
+
database: Optional[str] = None,
|
245
|
+
schema: Optional[str] = None,
|
246
|
+
session: Optional[snowpark.Session] = None,
|
247
|
+
) -> jb.MLJob[None]:
|
248
|
+
"""
|
249
|
+
Submit a directory containing Python script(s) as a job to the compute pool.
|
250
|
+
|
251
|
+
Args:
|
252
|
+
source: a stage path or a stage containing the job payload.
|
253
|
+
compute_pool: The compute pool to use for the job.
|
254
|
+
entrypoint: a stage path containing the entry point script inside the source directory.
|
255
|
+
stage_name: The name of the stage where the job payload will be uploaded.
|
256
|
+
args: A list of arguments to pass to the job.
|
257
|
+
env_vars: Environment variables to set in container
|
258
|
+
pip_requirements: A list of pip requirements for the job.
|
259
|
+
external_access_integrations: A list of external access integrations.
|
260
|
+
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
261
|
+
spec_overrides: Custom service specification overrides to apply.
|
262
|
+
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
263
|
+
min_instances: The minimum number of nodes required to start the job. If none specified,
|
264
|
+
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
265
|
+
enable_metrics: Whether to enable metrics publishing for the job.
|
266
|
+
database: The database to use.
|
267
|
+
schema: The schema to use.
|
268
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
269
|
+
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
An object representing the submitted job.
|
273
|
+
"""
|
274
|
+
return _submit_job(
|
275
|
+
source=source,
|
276
|
+
entrypoint=entrypoint,
|
277
|
+
args=args,
|
278
|
+
compute_pool=compute_pool,
|
279
|
+
stage_name=stage_name,
|
280
|
+
env_vars=env_vars,
|
281
|
+
pip_requirements=pip_requirements,
|
282
|
+
external_access_integrations=external_access_integrations,
|
283
|
+
query_warehouse=query_warehouse,
|
284
|
+
spec_overrides=spec_overrides,
|
285
|
+
target_instances=target_instances,
|
286
|
+
min_instances=min_instances,
|
211
287
|
enable_metrics=enable_metrics,
|
212
288
|
database=database,
|
213
289
|
schema=schema,
|
@@ -228,7 +304,8 @@ def _submit_job(
|
|
228
304
|
external_access_integrations: Optional[list[str]] = None,
|
229
305
|
query_warehouse: Optional[str] = None,
|
230
306
|
spec_overrides: Optional[dict[str, Any]] = None,
|
231
|
-
|
307
|
+
target_instances: int = 1,
|
308
|
+
min_instances: Optional[int] = None,
|
232
309
|
enable_metrics: bool = False,
|
233
310
|
database: Optional[str] = None,
|
234
311
|
schema: Optional[str] = None,
|
@@ -250,7 +327,8 @@ def _submit_job(
|
|
250
327
|
external_access_integrations: Optional[list[str]] = None,
|
251
328
|
query_warehouse: Optional[str] = None,
|
252
329
|
spec_overrides: Optional[dict[str, Any]] = None,
|
253
|
-
|
330
|
+
target_instances: int = 1,
|
331
|
+
min_instances: Optional[int] = None,
|
254
332
|
enable_metrics: bool = False,
|
255
333
|
database: Optional[str] = None,
|
256
334
|
schema: Optional[str] = None,
|
@@ -267,7 +345,7 @@ def _submit_job(
|
|
267
345
|
# TODO: Log lengths of args, env_vars, and spec_overrides values
|
268
346
|
"pip_requirements",
|
269
347
|
"external_access_integrations",
|
270
|
-
"
|
348
|
+
"target_instances",
|
271
349
|
"enable_metrics",
|
272
350
|
],
|
273
351
|
)
|
@@ -283,7 +361,8 @@ def _submit_job(
|
|
283
361
|
external_access_integrations: Optional[list[str]] = None,
|
284
362
|
query_warehouse: Optional[str] = None,
|
285
363
|
spec_overrides: Optional[dict[str, Any]] = None,
|
286
|
-
|
364
|
+
target_instances: int = 1,
|
365
|
+
min_instances: Optional[int] = None,
|
287
366
|
enable_metrics: bool = False,
|
288
367
|
database: Optional[str] = None,
|
289
368
|
schema: Optional[str] = None,
|
@@ -303,7 +382,9 @@ def _submit_job(
|
|
303
382
|
external_access_integrations: A list of external access integrations.
|
304
383
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
305
384
|
spec_overrides: Custom service specification overrides to apply.
|
306
|
-
|
385
|
+
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
386
|
+
min_instances: The minimum number of nodes required to start the job. If none specified,
|
387
|
+
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
307
388
|
enable_metrics: Whether to enable metrics publishing for the job.
|
308
389
|
database: The database to use.
|
309
390
|
schema: The schema to use.
|
@@ -316,16 +397,27 @@ def _submit_job(
|
|
316
397
|
RuntimeError: If required Snowflake features are not enabled.
|
317
398
|
ValueError: If database or schema value(s) are invalid
|
318
399
|
"""
|
319
|
-
# Display warning about PrPr parameters
|
320
|
-
if num_instances is not None:
|
321
|
-
logger.warning(
|
322
|
-
"_submit_job() parameter 'num_instances' is in private preview since 1.8.2. Do not use it in production.",
|
323
|
-
)
|
324
400
|
if database and not schema:
|
325
401
|
raise ValueError("Schema must be specified if database is specified.")
|
402
|
+
if target_instances < 1:
|
403
|
+
raise ValueError("target_instances must be greater than 0.")
|
404
|
+
|
405
|
+
min_instances = target_instances if min_instances is None else min_instances
|
406
|
+
if not (0 < min_instances <= target_instances):
|
407
|
+
raise ValueError("min_instances must be greater than 0 and less than or equal to target_instances.")
|
326
408
|
|
327
409
|
session = session or get_active_session()
|
328
410
|
|
411
|
+
if min_instances > 1:
|
412
|
+
# Validate min_instances against compute pool max_nodes
|
413
|
+
pool_info = jb._get_compute_pool_info(session, compute_pool)
|
414
|
+
max_nodes = int(pool_info["max_nodes"])
|
415
|
+
if min_instances > max_nodes:
|
416
|
+
raise ValueError(
|
417
|
+
f"The requested min_instances ({min_instances}) exceeds the max_nodes ({max_nodes}) "
|
418
|
+
f"of compute pool '{compute_pool}'. Reduce min_instances or increase max_nodes."
|
419
|
+
)
|
420
|
+
|
329
421
|
# Validate database and schema identifiers on client side since
|
330
422
|
# SQL parser for EXECUTE JOB SERVICE seems to struggle with this
|
331
423
|
database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
|
@@ -350,7 +442,8 @@ def _submit_job(
|
|
350
442
|
compute_pool=compute_pool,
|
351
443
|
payload=uploaded_payload,
|
352
444
|
args=args,
|
353
|
-
|
445
|
+
target_instances=target_instances,
|
446
|
+
min_instances=min_instances,
|
354
447
|
enable_metrics=enable_metrics,
|
355
448
|
)
|
356
449
|
spec_overrides = spec_utils.generate_spec_overrides(
|
@@ -381,9 +474,9 @@ def _submit_job(
|
|
381
474
|
if query_warehouse:
|
382
475
|
query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
|
383
476
|
params.append(query_warehouse)
|
384
|
-
if
|
477
|
+
if target_instances > 1:
|
385
478
|
query.append("REPLICAS = ?")
|
386
|
-
params.append(
|
479
|
+
params.append(target_instances)
|
387
480
|
|
388
481
|
# Submit job
|
389
482
|
query_text = "\n".join(line for line in query if line)
|
@@ -426,3 +426,61 @@ class Model:
|
|
426
426
|
schema_name=new_schema or self._model_ops._model_client._schema_name,
|
427
427
|
)
|
428
428
|
self._model_name = new_model
|
429
|
+
|
430
|
+
def _repr_html_(self) -> str:
|
431
|
+
"""Generate an HTML representation of the model.
|
432
|
+
|
433
|
+
Returns:
|
434
|
+
str: HTML string containing formatted model details.
|
435
|
+
"""
|
436
|
+
from snowflake.ml.utils import html_utils
|
437
|
+
|
438
|
+
# Get default version
|
439
|
+
default_version = self.default.version_name
|
440
|
+
|
441
|
+
# Get versions info
|
442
|
+
try:
|
443
|
+
versions_df = self.show_versions()
|
444
|
+
versions_html = ""
|
445
|
+
|
446
|
+
for _, row in versions_df.iterrows():
|
447
|
+
versions_html += html_utils.create_version_item(
|
448
|
+
version_name=row["name"],
|
449
|
+
created_on=str(row["created_on"]),
|
450
|
+
comment=str(row.get("comment", "")),
|
451
|
+
is_default=bool(row["is_default_version"]),
|
452
|
+
)
|
453
|
+
except Exception:
|
454
|
+
versions_html = html_utils.create_error_message("Error retrieving versions")
|
455
|
+
|
456
|
+
# Get tags
|
457
|
+
try:
|
458
|
+
tags = self.show_tags()
|
459
|
+
if not tags:
|
460
|
+
tags_html = html_utils.create_error_message("No tags available")
|
461
|
+
else:
|
462
|
+
tags_html = ""
|
463
|
+
for tag_name, tag_value in tags.items():
|
464
|
+
tags_html += html_utils.create_tag_item(tag_name, tag_value)
|
465
|
+
except Exception:
|
466
|
+
tags_html = html_utils.create_error_message("Error retrieving tags")
|
467
|
+
|
468
|
+
# Create main content sections
|
469
|
+
main_info = html_utils.create_grid_section(
|
470
|
+
[
|
471
|
+
("Model Name", self.name),
|
472
|
+
("Full Name", self.fully_qualified_name),
|
473
|
+
("Description", self.description),
|
474
|
+
("Default Version", default_version),
|
475
|
+
]
|
476
|
+
)
|
477
|
+
|
478
|
+
versions_section = html_utils.create_section_header("Versions") + html_utils.create_content_section(
|
479
|
+
versions_html
|
480
|
+
)
|
481
|
+
|
482
|
+
tags_section = html_utils.create_section_header("Tags") + html_utils.create_content_section(tags_html)
|
483
|
+
|
484
|
+
content = main_info + versions_section + tags_section
|
485
|
+
|
486
|
+
return html_utils.create_base_container("Model Details", content)
|