snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +61 -0
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +6 -5
- snowflake/ml/jobs/_utils/query_helper.py +9 -0
- snowflake/ml/jobs/_utils/spec_utils.py +6 -4
- snowflake/ml/jobs/decorators.py +18 -25
- snowflake/ml/jobs/job.py +179 -58
- snowflake/ml/jobs/manager.py +194 -145
- snowflake/ml/model/_client/ops/model_ops.py +12 -3
- snowflake/ml/model/_client/ops/service_ops.py +4 -2
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/registry/_manager/model_manager.py +30 -15
- snowflake/ml/registry/registry.py +119 -42
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +52 -16
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +30 -26
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import json
|
1
2
|
import logging
|
2
3
|
import os
|
3
4
|
import time
|
@@ -7,9 +8,11 @@ from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overlo
|
|
7
8
|
import yaml
|
8
9
|
|
9
10
|
from snowflake import snowpark
|
11
|
+
from snowflake.connector import errors
|
10
12
|
from snowflake.ml._internal import telemetry
|
11
13
|
from snowflake.ml._internal.utils import identifier
|
12
|
-
from snowflake.ml.
|
14
|
+
from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
|
15
|
+
from snowflake.ml.jobs._utils import constants, interop_utils, query_helper, types
|
13
16
|
from snowflake.snowpark import Row, context as sp_context
|
14
17
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
15
18
|
|
@@ -21,7 +24,7 @@ T = TypeVar("T")
|
|
21
24
|
logger = logging.getLogger(__name__)
|
22
25
|
|
23
26
|
|
24
|
-
class MLJob(Generic[T]):
|
27
|
+
class MLJob(Generic[T], SerializableSessionMixin):
|
25
28
|
def __init__(
|
26
29
|
self,
|
27
30
|
id: str,
|
@@ -67,7 +70,8 @@ class MLJob(Generic[T]):
|
|
67
70
|
def _compute_pool(self) -> str:
|
68
71
|
"""Get the job's compute pool name."""
|
69
72
|
row = _get_service_info(self._session, self.id)
|
70
|
-
|
73
|
+
compute_pool = row[query_helper.get_attribute_map(self._session, {"compute_pool": 5})["compute_pool"]]
|
74
|
+
return cast(str, compute_pool)
|
71
75
|
|
72
76
|
@property
|
73
77
|
def _service_spec(self) -> dict[str, Any]:
|
@@ -181,9 +185,13 @@ class MLJob(Generic[T]):
|
|
181
185
|
while (status := self.status) not in TERMINAL_JOB_STATUSES:
|
182
186
|
if status == "PENDING" and not warning_shown:
|
183
187
|
pool_info = _get_compute_pool_info(self._session, self._compute_pool)
|
184
|
-
|
188
|
+
requested_attributes = {"max_nodes": 3, "active_nodes": 9}
|
189
|
+
if (
|
190
|
+
pool_info[requested_attributes["max_nodes"]] - pool_info[requested_attributes["active_nodes"]]
|
191
|
+
) < self.min_instances:
|
185
192
|
logger.warning(
|
186
|
-
f
|
193
|
+
f'Compute pool busy ({pool_info[requested_attributes["active_nodes"]]}'
|
194
|
+
f'/{pool_info[requested_attributes["max_nodes"]]} nodes in use, '
|
187
195
|
f"{self.min_instances} nodes required). Job execution may be delayed."
|
188
196
|
)
|
189
197
|
warning_shown = True
|
@@ -220,27 +228,46 @@ class MLJob(Generic[T]):
|
|
220
228
|
return cast(T, self._result.result)
|
221
229
|
raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception
|
222
230
|
|
231
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
232
|
+
def cancel(self) -> None:
|
233
|
+
"""
|
234
|
+
Cancel the running job.
|
235
|
+
Raises:
|
236
|
+
RuntimeError: If cancellation fails. # noqa: DAR401
|
237
|
+
"""
|
238
|
+
try:
|
239
|
+
self._session.sql(f"CALL {self.id}!spcs_cancel_job()").collect()
|
240
|
+
logger.debug(f"Cancellation requested for job {self.id}")
|
241
|
+
except SnowparkSQLException as e:
|
242
|
+
raise RuntimeError(f"Failed to cancel job {self.id}: {e.message}") from e
|
243
|
+
|
223
244
|
|
224
245
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
|
225
246
|
def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
|
226
247
|
"""Retrieve job or job instance execution status."""
|
227
248
|
if instance_id is not None:
|
228
249
|
# Get specific instance status
|
229
|
-
rows = session.
|
230
|
-
|
231
|
-
|
232
|
-
|
250
|
+
rows = session._conn.run_query(
|
251
|
+
"SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=[job_id], _force_qmark_paramstyle=True
|
252
|
+
)
|
253
|
+
request_attributes = query_helper.get_attribute_map(session, {"status": 5, "instance_id": 4})
|
254
|
+
if isinstance(rows, dict) and "data" in rows:
|
255
|
+
for row in rows["data"]:
|
256
|
+
if row[request_attributes["instance_id"]] == str(instance_id):
|
257
|
+
return cast(types.JOB_STATUS, row[request_attributes["status"]])
|
233
258
|
raise ValueError(f"Instance {instance_id} not found in job {job_id}")
|
234
259
|
else:
|
235
260
|
row = _get_service_info(session, job_id)
|
236
|
-
|
261
|
+
request_attributes = query_helper.get_attribute_map(session, {"status": 1})
|
262
|
+
return cast(types.JOB_STATUS, row[request_attributes["status"]])
|
237
263
|
|
238
264
|
|
239
265
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
240
266
|
def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
|
241
267
|
"""Retrieve job execution service spec."""
|
242
268
|
row = _get_service_info(session, job_id)
|
243
|
-
|
269
|
+
requested_attributes = query_helper.get_attribute_map(session, {"spec": 6})
|
270
|
+
return cast(dict[str, Any], yaml.safe_load(row[requested_attributes["spec"]]))
|
244
271
|
|
245
272
|
|
246
273
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
|
@@ -262,6 +289,7 @@ def _get_logs(
|
|
262
289
|
|
263
290
|
Raises:
|
264
291
|
RuntimeError: if failed to get head instance_id
|
292
|
+
SnowparkSQLException: if there is an error retrieving logs from SPCS interface.
|
265
293
|
"""
|
266
294
|
# If instance_id is not specified, try to get the head instance ID
|
267
295
|
if instance_id is None:
|
@@ -279,30 +307,59 @@ def _get_logs(
|
|
279
307
|
if limit > 0:
|
280
308
|
params.append(limit)
|
281
309
|
try:
|
282
|
-
|
310
|
+
data = session._conn.run_query(
|
283
311
|
f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
|
284
312
|
params=params,
|
285
|
-
|
286
|
-
|
287
|
-
if
|
313
|
+
_force_qmark_paramstyle=True,
|
314
|
+
)
|
315
|
+
if isinstance(data, dict) and "data" in data:
|
316
|
+
full_log = str(data["data"][0][0])
|
317
|
+
# pass type check
|
318
|
+
else:
|
319
|
+
full_log = ""
|
320
|
+
except errors.ProgrammingError as e:
|
321
|
+
if "Container Status: PENDING" in str(e):
|
288
322
|
logger.warning("Waiting for container to start. Logs will be shown when available.")
|
289
323
|
return ""
|
290
324
|
else:
|
291
|
-
#
|
292
|
-
#
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
325
|
+
# Fallback plan:
|
326
|
+
# 1. Try SPCS Interface (doesn't require event table permission)
|
327
|
+
# 2. If the interface call fails, query Event Table (requires permission)
|
328
|
+
logger.debug("falling back to SPCS Interface for logs")
|
329
|
+
try:
|
330
|
+
logs = _get_logs_spcs(
|
331
|
+
session,
|
332
|
+
job_id,
|
333
|
+
limit=limit,
|
334
|
+
instance_id=instance_id if instance_id else 0,
|
335
|
+
container_name=constants.DEFAULT_CONTAINER_NAME,
|
302
336
|
)
|
303
|
-
|
304
|
-
|
305
|
-
|
337
|
+
full_log = os.linesep.join(row[0] for row in logs)
|
338
|
+
|
339
|
+
except SnowparkSQLException as spcs_error:
|
340
|
+
if spcs_error.sql_error_code == 2143:
|
341
|
+
logger.debug("persistent logs may not be enabled, falling back to event table")
|
342
|
+
else:
|
343
|
+
# If SPCS Interface fails for any other reason,
|
344
|
+
# for example, incorrect argument format,raise the error directly
|
345
|
+
raise
|
346
|
+
# event table accepts job name, not fully qualified name
|
347
|
+
db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
|
348
|
+
db = db or session.get_current_database()
|
349
|
+
schema = schema or session.get_current_schema()
|
350
|
+
event_table_logs = _get_service_log_from_event_table(
|
351
|
+
session,
|
352
|
+
name,
|
353
|
+
database=db,
|
354
|
+
schema=schema,
|
355
|
+
instance_id=instance_id if instance_id else 0,
|
356
|
+
limit=limit,
|
357
|
+
)
|
358
|
+
if len(event_table_logs) == 0:
|
359
|
+
raise RuntimeError(
|
360
|
+
"No logs were found. Please verify that the database, schema, and job ID are correct."
|
361
|
+
)
|
362
|
+
full_log = os.linesep.join(json.loads(row[0]) for row in event_table_logs)
|
306
363
|
|
307
364
|
# If verbose is True, return the complete log
|
308
365
|
if verbose:
|
@@ -340,52 +397,70 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
340
397
|
RuntimeError: If the instances died or if some instances disappeared.
|
341
398
|
"""
|
342
399
|
|
343
|
-
|
400
|
+
try:
|
401
|
+
target_instances = _get_target_instances(session, job_id)
|
402
|
+
except errors.ProgrammingError:
|
403
|
+
# service may be deleted
|
404
|
+
raise RuntimeError("Couldn’t retrieve service information")
|
344
405
|
|
345
406
|
if target_instances == 1:
|
346
407
|
return 0
|
347
408
|
|
348
409
|
try:
|
349
|
-
rows = session.
|
350
|
-
|
410
|
+
rows = session._conn.run_query(
|
411
|
+
"SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,), _force_qmark_paramstyle=True
|
412
|
+
)
|
413
|
+
except errors.ProgrammingError:
|
351
414
|
# service may be deleted
|
352
415
|
raise RuntimeError("Couldn’t retrieve instances")
|
353
416
|
|
354
|
-
if not rows:
|
417
|
+
if not rows or not isinstance(rows, dict) or not rows.get("data"):
|
355
418
|
return None
|
356
419
|
|
357
|
-
if target_instances > len(rows):
|
420
|
+
if target_instances > len(rows["data"]):
|
358
421
|
raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
|
359
422
|
|
423
|
+
requested_attributes = query_helper.get_attribute_map(session, {"start_time": 8, "instance_id": 4})
|
360
424
|
# Sort by start_time first, then by instance_id
|
361
425
|
try:
|
362
|
-
sorted_instances = sorted(
|
426
|
+
sorted_instances = sorted(
|
427
|
+
rows["data"],
|
428
|
+
key=lambda x: (x[requested_attributes["start_time"]], int(x[requested_attributes["instance_id"]])),
|
429
|
+
)
|
363
430
|
except TypeError:
|
364
431
|
raise RuntimeError("Job instance information unavailable.")
|
365
432
|
head_instance = sorted_instances[0]
|
366
|
-
if not head_instance["start_time"]:
|
433
|
+
if not head_instance[requested_attributes["start_time"]]:
|
367
434
|
# If head instance hasn't started yet, return None
|
368
435
|
return None
|
369
436
|
try:
|
370
|
-
return int(head_instance["instance_id"])
|
437
|
+
return int(head_instance[requested_attributes["instance_id"]])
|
371
438
|
except (ValueError, TypeError):
|
372
439
|
return 0
|
373
440
|
|
374
441
|
|
375
442
|
def _get_service_log_from_event_table(
|
376
|
-
session: snowpark.Session,
|
377
|
-
|
443
|
+
session: snowpark.Session,
|
444
|
+
name: str,
|
445
|
+
database: Optional[str] = None,
|
446
|
+
schema: Optional[str] = None,
|
447
|
+
instance_id: Optional[int] = None,
|
448
|
+
limit: int = -1,
|
449
|
+
) -> Any:
|
378
450
|
params: list[Any] = [
|
379
|
-
database,
|
380
|
-
schema,
|
381
451
|
name,
|
382
452
|
]
|
383
453
|
query = [
|
384
454
|
"SELECT VALUE FROM snowflake.telemetry.events_view",
|
385
|
-
'WHERE RESOURCE_ATTRIBUTES:"snow.
|
386
|
-
'AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?',
|
387
|
-
'AND RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
|
455
|
+
'WHERE RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
|
388
456
|
]
|
457
|
+
if database:
|
458
|
+
query.append('AND RESOURCE_ATTRIBUTES:"snow.database.name" = ?')
|
459
|
+
params.append(database)
|
460
|
+
|
461
|
+
if schema:
|
462
|
+
query.append('AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?')
|
463
|
+
params.append(schema)
|
389
464
|
|
390
465
|
if instance_id:
|
391
466
|
query.append('AND RESOURCE_ATTRIBUTES:"snow.service.container.instance" = ?')
|
@@ -398,20 +473,23 @@ def _get_service_log_from_event_table(
|
|
398
473
|
if limit > 0:
|
399
474
|
query.append("LIMIT ?")
|
400
475
|
params.append(limit)
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
return rows
|
476
|
+
rows = session._conn.run_query(
|
477
|
+
"\n".join(line for line in query if line), params=params, _force_qmark_paramstyle=True
|
478
|
+
)
|
479
|
+
if not rows or not isinstance(rows, dict) or not rows.get("data"):
|
480
|
+
return []
|
481
|
+
return rows["data"]
|
407
482
|
|
408
483
|
|
409
|
-
def _get_service_info(session: snowpark.Session, job_id: str) ->
|
410
|
-
|
411
|
-
|
484
|
+
def _get_service_info(session: snowpark.Session, job_id: str) -> Any:
|
485
|
+
row = session._conn.run_query("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,), _force_qmark_paramstyle=True)
|
486
|
+
# pass the type check
|
487
|
+
if not row or not isinstance(row, dict) or not row.get("data"):
|
488
|
+
raise errors.ProgrammingError("failed to retrieve service information")
|
489
|
+
return row["data"][0]
|
412
490
|
|
413
491
|
|
414
|
-
def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) ->
|
492
|
+
def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Any:
|
415
493
|
"""
|
416
494
|
Check if the compute pool has enough available instances.
|
417
495
|
|
@@ -420,14 +498,19 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
|
|
420
498
|
compute_pool (str): The name of the compute pool.
|
421
499
|
|
422
500
|
Returns:
|
423
|
-
|
501
|
+
Any: The compute pool information.
|
424
502
|
|
425
503
|
Raises:
|
426
504
|
ValueError: If the compute pool is not found.
|
427
505
|
"""
|
428
506
|
try:
|
429
|
-
|
430
|
-
|
507
|
+
compute_pool_info = session._conn.run_query(
|
508
|
+
"SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,), _force_qmark_paramstyle=True
|
509
|
+
)
|
510
|
+
# pass the type check
|
511
|
+
if not compute_pool_info or not isinstance(compute_pool_info, dict) or not compute_pool_info.get("data"):
|
512
|
+
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
513
|
+
return compute_pool_info["data"][0]
|
431
514
|
except ValueError as e:
|
432
515
|
if "not enough values to unpack" in str(e):
|
433
516
|
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
@@ -437,4 +520,42 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
|
|
437
520
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
438
521
|
def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
|
439
522
|
row = _get_service_info(session, job_id)
|
440
|
-
|
523
|
+
requested_attributes = query_helper.get_attribute_map(session, {"target_instances": 9})
|
524
|
+
return int(row[requested_attributes["target_instances"]])
|
525
|
+
|
526
|
+
|
527
|
+
def _get_logs_spcs(
|
528
|
+
session: snowpark.Session,
|
529
|
+
fully_qualified_name: str,
|
530
|
+
limit: int = -1,
|
531
|
+
instance_id: Optional[int] = None,
|
532
|
+
container_name: Optional[str] = None,
|
533
|
+
start_time: Optional[str] = None,
|
534
|
+
end_time: Optional[str] = None,
|
535
|
+
) -> list[Row]:
|
536
|
+
query = [
|
537
|
+
f"SELECT LOG FROM table({fully_qualified_name}!spcs_get_logs(",
|
538
|
+
]
|
539
|
+
conditions_params = []
|
540
|
+
if start_time:
|
541
|
+
conditions_params.append(f"start_time => TO_TIMESTAMP_LTZ('{start_time}')")
|
542
|
+
if end_time:
|
543
|
+
conditions_params.append(f"end_time => TO_TIMESTAMP_LTZ('{end_time}')")
|
544
|
+
if len(conditions_params) > 0:
|
545
|
+
query.append(", ".join(conditions_params))
|
546
|
+
|
547
|
+
query.append("))")
|
548
|
+
|
549
|
+
query_params = []
|
550
|
+
if instance_id is not None:
|
551
|
+
query_params.append(f"INSTANCE_ID = {instance_id}")
|
552
|
+
if container_name:
|
553
|
+
query_params.append(f"CONTAINER_NAME = '{container_name}'")
|
554
|
+
if len(query_params) > 0:
|
555
|
+
query.append("WHERE " + " AND ".join(query_params))
|
556
|
+
|
557
|
+
query.append("ORDER BY TIMESTAMP ASC")
|
558
|
+
if limit > 0:
|
559
|
+
query.append(f" LIMIT {limit};")
|
560
|
+
rows = session.sql("\n".join(query)).collect()
|
561
|
+
return rows
|