snowflake-ml-python 1.8.5__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +6 -9
- snowflake/ml/_internal/utils/connection_params.py +196 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +61 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +3 -2
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +89 -40
- snowflake/ml/jobs/_utils/query_helper.py +9 -0
- snowflake/ml/jobs/_utils/scripts/constants.py +19 -3
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -26
- snowflake/ml/jobs/_utils/spec_utils.py +29 -5
- snowflake/ml/jobs/_utils/stage_utils.py +119 -0
- snowflake/ml/jobs/_utils/types.py +5 -1
- snowflake/ml/jobs/decorators.py +20 -28
- snowflake/ml/jobs/job.py +197 -61
- snowflake/ml/jobs/manager.py +253 -121
- snowflake/ml/model/_client/model/model_impl.py +58 -0
- snowflake/ml/model/_client/model/model_version_impl.py +90 -0
- snowflake/ml/model/_client/ops/model_ops.py +18 -6
- snowflake/ml/model/_client/ops/service_ops.py +23 -6
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
- snowflake/ml/model/_client/sql/service.py +68 -20
- snowflake/ml/model/_client/sql/stage.py +5 -2
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_signatures/core.py +24 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/monitoring/explain_visualize.py +2 -2
- snowflake/ml/monitoring/model_monitor.py +0 -4
- snowflake/ml/registry/_manager/model_manager.py +30 -15
- snowflake/ml/registry/registry.py +144 -47
- snowflake/ml/utils/connection_params.py +1 -1
- snowflake/ml/utils/html_utils.py +263 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +64 -19
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +48 -41
- snowflake/ml/monitoring/model_monitor_version.py +0 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import json
|
1
2
|
import logging
|
2
3
|
import os
|
3
4
|
import time
|
@@ -7,21 +8,23 @@ from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overlo
|
|
7
8
|
import yaml
|
8
9
|
|
9
10
|
from snowflake import snowpark
|
11
|
+
from snowflake.connector import errors
|
10
12
|
from snowflake.ml._internal import telemetry
|
11
13
|
from snowflake.ml._internal.utils import identifier
|
12
|
-
from snowflake.ml.
|
14
|
+
from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
|
15
|
+
from snowflake.ml.jobs._utils import constants, interop_utils, query_helper, types
|
13
16
|
from snowflake.snowpark import Row, context as sp_context
|
14
17
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
15
18
|
|
16
19
|
_PROJECT = "MLJob"
|
17
|
-
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
|
20
|
+
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR"}
|
18
21
|
|
19
22
|
T = TypeVar("T")
|
20
23
|
|
21
24
|
logger = logging.getLogger(__name__)
|
22
25
|
|
23
26
|
|
24
|
-
class MLJob(Generic[T]):
|
27
|
+
class MLJob(Generic[T], SerializableSessionMixin):
|
25
28
|
def __init__(
|
26
29
|
self,
|
27
30
|
id: str,
|
@@ -67,7 +70,8 @@ class MLJob(Generic[T]):
|
|
67
70
|
def _compute_pool(self) -> str:
|
68
71
|
"""Get the job's compute pool name."""
|
69
72
|
row = _get_service_info(self._session, self.id)
|
70
|
-
|
73
|
+
compute_pool = row[query_helper.get_attribute_map(self._session, {"compute_pool": 5})["compute_pool"]]
|
74
|
+
return cast(str, compute_pool)
|
71
75
|
|
72
76
|
@property
|
73
77
|
def _service_spec(self) -> dict[str, Any]:
|
@@ -181,16 +185,20 @@ class MLJob(Generic[T]):
|
|
181
185
|
while (status := self.status) not in TERMINAL_JOB_STATUSES:
|
182
186
|
if status == "PENDING" and not warning_shown:
|
183
187
|
pool_info = _get_compute_pool_info(self._session, self._compute_pool)
|
184
|
-
|
188
|
+
requested_attributes = {"max_nodes": 3, "active_nodes": 9}
|
189
|
+
if (
|
190
|
+
pool_info[requested_attributes["max_nodes"]] - pool_info[requested_attributes["active_nodes"]]
|
191
|
+
) < self.min_instances:
|
185
192
|
logger.warning(
|
186
|
-
f
|
187
|
-
"
|
193
|
+
f'Compute pool busy ({pool_info[requested_attributes["active_nodes"]]}'
|
194
|
+
f'/{pool_info[requested_attributes["max_nodes"]]} nodes in use, '
|
195
|
+
f"{self.min_instances} nodes required). Job execution may be delayed."
|
188
196
|
)
|
189
197
|
warning_shown = True
|
190
198
|
if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
|
191
199
|
raise TimeoutError(f"Job {self.name} did not complete within {elapsed} seconds")
|
192
200
|
time.sleep(delay)
|
193
|
-
delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
201
|
+
delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
194
202
|
return self.status
|
195
203
|
|
196
204
|
@snowpark._internal.utils.private_preview(version="1.8.2")
|
@@ -220,27 +228,46 @@ class MLJob(Generic[T]):
|
|
220
228
|
return cast(T, self._result.result)
|
221
229
|
raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception
|
222
230
|
|
231
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
232
|
+
def cancel(self) -> None:
|
233
|
+
"""
|
234
|
+
Cancel the running job.
|
235
|
+
Raises:
|
236
|
+
RuntimeError: If cancellation fails. # noqa: DAR401
|
237
|
+
"""
|
238
|
+
try:
|
239
|
+
self._session.sql(f"CALL {self.id}!spcs_cancel_job()").collect()
|
240
|
+
logger.debug(f"Cancellation requested for job {self.id}")
|
241
|
+
except SnowparkSQLException as e:
|
242
|
+
raise RuntimeError(f"Failed to cancel job {self.id}: {e.message}") from e
|
243
|
+
|
223
244
|
|
224
245
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
|
225
246
|
def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
|
226
247
|
"""Retrieve job or job instance execution status."""
|
227
248
|
if instance_id is not None:
|
228
249
|
# Get specific instance status
|
229
|
-
rows = session.
|
230
|
-
|
231
|
-
|
232
|
-
|
250
|
+
rows = session._conn.run_query(
|
251
|
+
"SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=[job_id], _force_qmark_paramstyle=True
|
252
|
+
)
|
253
|
+
request_attributes = query_helper.get_attribute_map(session, {"status": 5, "instance_id": 4})
|
254
|
+
if isinstance(rows, dict) and "data" in rows:
|
255
|
+
for row in rows["data"]:
|
256
|
+
if row[request_attributes["instance_id"]] == str(instance_id):
|
257
|
+
return cast(types.JOB_STATUS, row[request_attributes["status"]])
|
233
258
|
raise ValueError(f"Instance {instance_id} not found in job {job_id}")
|
234
259
|
else:
|
235
260
|
row = _get_service_info(session, job_id)
|
236
|
-
|
261
|
+
request_attributes = query_helper.get_attribute_map(session, {"status": 1})
|
262
|
+
return cast(types.JOB_STATUS, row[request_attributes["status"]])
|
237
263
|
|
238
264
|
|
239
265
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
240
266
|
def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
|
241
267
|
"""Retrieve job execution service spec."""
|
242
268
|
row = _get_service_info(session, job_id)
|
243
|
-
|
269
|
+
requested_attributes = query_helper.get_attribute_map(session, {"spec": 6})
|
270
|
+
return cast(dict[str, Any], yaml.safe_load(row[requested_attributes["spec"]]))
|
244
271
|
|
245
272
|
|
246
273
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
|
@@ -262,6 +289,7 @@ def _get_logs(
|
|
262
289
|
|
263
290
|
Raises:
|
264
291
|
RuntimeError: if failed to get head instance_id
|
292
|
+
SnowparkSQLException: if there is an error retrieving logs from SPCS interface.
|
265
293
|
"""
|
266
294
|
# If instance_id is not specified, try to get the head instance ID
|
267
295
|
if instance_id is None:
|
@@ -279,30 +307,59 @@ def _get_logs(
|
|
279
307
|
if limit > 0:
|
280
308
|
params.append(limit)
|
281
309
|
try:
|
282
|
-
|
310
|
+
data = session._conn.run_query(
|
283
311
|
f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
|
284
312
|
params=params,
|
285
|
-
|
286
|
-
|
287
|
-
if
|
313
|
+
_force_qmark_paramstyle=True,
|
314
|
+
)
|
315
|
+
if isinstance(data, dict) and "data" in data:
|
316
|
+
full_log = str(data["data"][0][0])
|
317
|
+
# pass type check
|
318
|
+
else:
|
319
|
+
full_log = ""
|
320
|
+
except errors.ProgrammingError as e:
|
321
|
+
if "Container Status: PENDING" in str(e):
|
288
322
|
logger.warning("Waiting for container to start. Logs will be shown when available.")
|
289
323
|
return ""
|
290
324
|
else:
|
291
|
-
#
|
292
|
-
#
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
325
|
+
# Fallback plan:
|
326
|
+
# 1. Try SPCS Interface (doesn't require event table permission)
|
327
|
+
# 2. If the interface call fails, query Event Table (requires permission)
|
328
|
+
logger.debug("falling back to SPCS Interface for logs")
|
329
|
+
try:
|
330
|
+
logs = _get_logs_spcs(
|
331
|
+
session,
|
332
|
+
job_id,
|
333
|
+
limit=limit,
|
334
|
+
instance_id=instance_id if instance_id else 0,
|
335
|
+
container_name=constants.DEFAULT_CONTAINER_NAME,
|
302
336
|
)
|
303
|
-
|
304
|
-
|
305
|
-
|
337
|
+
full_log = os.linesep.join(row[0] for row in logs)
|
338
|
+
|
339
|
+
except SnowparkSQLException as spcs_error:
|
340
|
+
if spcs_error.sql_error_code == 2143:
|
341
|
+
logger.debug("persistent logs may not be enabled, falling back to event table")
|
342
|
+
else:
|
343
|
+
# If SPCS Interface fails for any other reason,
|
344
|
+
# for example, incorrect argument format,raise the error directly
|
345
|
+
raise
|
346
|
+
# event table accepts job name, not fully qualified name
|
347
|
+
db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
|
348
|
+
db = db or session.get_current_database()
|
349
|
+
schema = schema or session.get_current_schema()
|
350
|
+
event_table_logs = _get_service_log_from_event_table(
|
351
|
+
session,
|
352
|
+
name,
|
353
|
+
database=db,
|
354
|
+
schema=schema,
|
355
|
+
instance_id=instance_id if instance_id else 0,
|
356
|
+
limit=limit,
|
357
|
+
)
|
358
|
+
if len(event_table_logs) == 0:
|
359
|
+
raise RuntimeError(
|
360
|
+
"No logs were found. Please verify that the database, schema, and job ID are correct."
|
361
|
+
)
|
362
|
+
full_log = os.linesep.join(json.loads(row[0]) for row in event_table_logs)
|
306
363
|
|
307
364
|
# If verbose is True, return the complete log
|
308
365
|
if verbose:
|
@@ -338,47 +395,72 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
338
395
|
|
339
396
|
Raises:
|
340
397
|
RuntimeError: If the instances died or if some instances disappeared.
|
341
|
-
|
342
398
|
"""
|
399
|
+
|
343
400
|
try:
|
344
|
-
|
345
|
-
except
|
401
|
+
target_instances = _get_target_instances(session, job_id)
|
402
|
+
except errors.ProgrammingError:
|
403
|
+
# service may be deleted
|
404
|
+
raise RuntimeError("Couldn’t retrieve service information")
|
405
|
+
|
406
|
+
if target_instances == 1:
|
407
|
+
return 0
|
408
|
+
|
409
|
+
try:
|
410
|
+
rows = session._conn.run_query(
|
411
|
+
"SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,), _force_qmark_paramstyle=True
|
412
|
+
)
|
413
|
+
except errors.ProgrammingError:
|
346
414
|
# service may be deleted
|
347
415
|
raise RuntimeError("Couldn’t retrieve instances")
|
348
|
-
|
416
|
+
|
417
|
+
if not rows or not isinstance(rows, dict) or not rows.get("data"):
|
349
418
|
return None
|
350
|
-
|
419
|
+
|
420
|
+
if target_instances > len(rows["data"]):
|
351
421
|
raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
|
352
422
|
|
423
|
+
requested_attributes = query_helper.get_attribute_map(session, {"start_time": 8, "instance_id": 4})
|
353
424
|
# Sort by start_time first, then by instance_id
|
354
425
|
try:
|
355
|
-
sorted_instances = sorted(
|
426
|
+
sorted_instances = sorted(
|
427
|
+
rows["data"],
|
428
|
+
key=lambda x: (x[requested_attributes["start_time"]], int(x[requested_attributes["instance_id"]])),
|
429
|
+
)
|
356
430
|
except TypeError:
|
357
431
|
raise RuntimeError("Job instance information unavailable.")
|
358
432
|
head_instance = sorted_instances[0]
|
359
|
-
if not head_instance["start_time"]:
|
433
|
+
if not head_instance[requested_attributes["start_time"]]:
|
360
434
|
# If head instance hasn't started yet, return None
|
361
435
|
return None
|
362
436
|
try:
|
363
|
-
return int(head_instance["instance_id"])
|
437
|
+
return int(head_instance[requested_attributes["instance_id"]])
|
364
438
|
except (ValueError, TypeError):
|
365
439
|
return 0
|
366
440
|
|
367
441
|
|
368
442
|
def _get_service_log_from_event_table(
|
369
|
-
session: snowpark.Session,
|
370
|
-
|
443
|
+
session: snowpark.Session,
|
444
|
+
name: str,
|
445
|
+
database: Optional[str] = None,
|
446
|
+
schema: Optional[str] = None,
|
447
|
+
instance_id: Optional[int] = None,
|
448
|
+
limit: int = -1,
|
449
|
+
) -> Any:
|
371
450
|
params: list[Any] = [
|
372
|
-
database,
|
373
|
-
schema,
|
374
451
|
name,
|
375
452
|
]
|
376
453
|
query = [
|
377
454
|
"SELECT VALUE FROM snowflake.telemetry.events_view",
|
378
|
-
'WHERE RESOURCE_ATTRIBUTES:"snow.
|
379
|
-
'AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?',
|
380
|
-
'AND RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
|
455
|
+
'WHERE RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
|
381
456
|
]
|
457
|
+
if database:
|
458
|
+
query.append('AND RESOURCE_ATTRIBUTES:"snow.database.name" = ?')
|
459
|
+
params.append(database)
|
460
|
+
|
461
|
+
if schema:
|
462
|
+
query.append('AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?')
|
463
|
+
params.append(schema)
|
382
464
|
|
383
465
|
if instance_id:
|
384
466
|
query.append('AND RESOURCE_ATTRIBUTES:"snow.service.container.instance" = ?')
|
@@ -391,20 +473,23 @@ def _get_service_log_from_event_table(
|
|
391
473
|
if limit > 0:
|
392
474
|
query.append("LIMIT ?")
|
393
475
|
params.append(limit)
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
return rows
|
476
|
+
rows = session._conn.run_query(
|
477
|
+
"\n".join(line for line in query if line), params=params, _force_qmark_paramstyle=True
|
478
|
+
)
|
479
|
+
if not rows or not isinstance(rows, dict) or not rows.get("data"):
|
480
|
+
return []
|
481
|
+
return rows["data"]
|
400
482
|
|
401
483
|
|
402
|
-
def _get_service_info(session: snowpark.Session, job_id: str) ->
|
403
|
-
|
404
|
-
|
484
|
+
def _get_service_info(session: snowpark.Session, job_id: str) -> Any:
|
485
|
+
row = session._conn.run_query("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,), _force_qmark_paramstyle=True)
|
486
|
+
# pass the type check
|
487
|
+
if not row or not isinstance(row, dict) or not row.get("data"):
|
488
|
+
raise errors.ProgrammingError("failed to retrieve service information")
|
489
|
+
return row["data"][0]
|
405
490
|
|
406
491
|
|
407
|
-
def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) ->
|
492
|
+
def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Any:
|
408
493
|
"""
|
409
494
|
Check if the compute pool has enough available instances.
|
410
495
|
|
@@ -413,13 +498,64 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
|
|
413
498
|
compute_pool (str): The name of the compute pool.
|
414
499
|
|
415
500
|
Returns:
|
416
|
-
|
501
|
+
Any: The compute pool information.
|
502
|
+
|
503
|
+
Raises:
|
504
|
+
ValueError: If the compute pool is not found.
|
417
505
|
"""
|
418
|
-
|
419
|
-
|
506
|
+
try:
|
507
|
+
compute_pool_info = session._conn.run_query(
|
508
|
+
"SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,), _force_qmark_paramstyle=True
|
509
|
+
)
|
510
|
+
# pass the type check
|
511
|
+
if not compute_pool_info or not isinstance(compute_pool_info, dict) or not compute_pool_info.get("data"):
|
512
|
+
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
513
|
+
return compute_pool_info["data"][0]
|
514
|
+
except ValueError as e:
|
515
|
+
if "not enough values to unpack" in str(e):
|
516
|
+
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
517
|
+
raise
|
420
518
|
|
421
519
|
|
422
520
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
423
521
|
def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
|
424
522
|
row = _get_service_info(session, job_id)
|
425
|
-
|
523
|
+
requested_attributes = query_helper.get_attribute_map(session, {"target_instances": 9})
|
524
|
+
return int(row[requested_attributes["target_instances"]])
|
525
|
+
|
526
|
+
|
527
|
+
def _get_logs_spcs(
|
528
|
+
session: snowpark.Session,
|
529
|
+
fully_qualified_name: str,
|
530
|
+
limit: int = -1,
|
531
|
+
instance_id: Optional[int] = None,
|
532
|
+
container_name: Optional[str] = None,
|
533
|
+
start_time: Optional[str] = None,
|
534
|
+
end_time: Optional[str] = None,
|
535
|
+
) -> list[Row]:
|
536
|
+
query = [
|
537
|
+
f"SELECT LOG FROM table({fully_qualified_name}!spcs_get_logs(",
|
538
|
+
]
|
539
|
+
conditions_params = []
|
540
|
+
if start_time:
|
541
|
+
conditions_params.append(f"start_time => TO_TIMESTAMP_LTZ('{start_time}')")
|
542
|
+
if end_time:
|
543
|
+
conditions_params.append(f"end_time => TO_TIMESTAMP_LTZ('{end_time}')")
|
544
|
+
if len(conditions_params) > 0:
|
545
|
+
query.append(", ".join(conditions_params))
|
546
|
+
|
547
|
+
query.append("))")
|
548
|
+
|
549
|
+
query_params = []
|
550
|
+
if instance_id is not None:
|
551
|
+
query_params.append(f"INSTANCE_ID = {instance_id}")
|
552
|
+
if container_name:
|
553
|
+
query_params.append(f"CONTAINER_NAME = '{container_name}'")
|
554
|
+
if len(query_params) > 0:
|
555
|
+
query.append("WHERE " + " AND ".join(query_params))
|
556
|
+
|
557
|
+
query.append("ORDER BY TIMESTAMP ASC")
|
558
|
+
if limit > 0:
|
559
|
+
query.append(f" LIMIT {limit};")
|
560
|
+
rows = session.sql("\n".join(query)).collect()
|
561
|
+
return rows
|