snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. snowflake/ml/_internal/utils/identifier.py +1 -1
  2. snowflake/ml/_internal/utils/mixins.py +61 -0
  3. snowflake/ml/jobs/_utils/constants.py +1 -1
  4. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  5. snowflake/ml/jobs/_utils/payload_utils.py +6 -5
  6. snowflake/ml/jobs/_utils/query_helper.py +9 -0
  7. snowflake/ml/jobs/_utils/spec_utils.py +6 -4
  8. snowflake/ml/jobs/decorators.py +18 -25
  9. snowflake/ml/jobs/job.py +179 -58
  10. snowflake/ml/jobs/manager.py +194 -145
  11. snowflake/ml/model/_client/ops/model_ops.py +12 -3
  12. snowflake/ml/model/_client/ops/service_ops.py +4 -2
  13. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
  14. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
  15. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  16. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  17. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  18. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  19. snowflake/ml/model/target_platform.py +11 -0
  20. snowflake/ml/model/task.py +9 -0
  21. snowflake/ml/model/type_hints.py +5 -13
  22. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  23. snowflake/ml/registry/_manager/model_manager.py +30 -15
  24. snowflake/ml/registry/registry.py +119 -42
  25. snowflake/ml/version.py +1 -1
  26. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +52 -16
  27. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +30 -26
  28. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
  29. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
  30. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py CHANGED
@@ -1,3 +1,4 @@
1
+ import json
1
2
  import logging
2
3
  import os
3
4
  import time
@@ -7,9 +8,11 @@ from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overlo
7
8
  import yaml
8
9
 
9
10
  from snowflake import snowpark
11
+ from snowflake.connector import errors
10
12
  from snowflake.ml._internal import telemetry
11
13
  from snowflake.ml._internal.utils import identifier
12
- from snowflake.ml.jobs._utils import constants, interop_utils, types
14
+ from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
15
+ from snowflake.ml.jobs._utils import constants, interop_utils, query_helper, types
13
16
  from snowflake.snowpark import Row, context as sp_context
14
17
  from snowflake.snowpark.exceptions import SnowparkSQLException
15
18
 
@@ -21,7 +24,7 @@ T = TypeVar("T")
21
24
  logger = logging.getLogger(__name__)
22
25
 
23
26
 
24
- class MLJob(Generic[T]):
27
+ class MLJob(Generic[T], SerializableSessionMixin):
25
28
  def __init__(
26
29
  self,
27
30
  id: str,
@@ -67,7 +70,8 @@ class MLJob(Generic[T]):
67
70
  def _compute_pool(self) -> str:
68
71
  """Get the job's compute pool name."""
69
72
  row = _get_service_info(self._session, self.id)
70
- return cast(str, row["compute_pool"])
73
+ compute_pool = row[query_helper.get_attribute_map(self._session, {"compute_pool": 5})["compute_pool"]]
74
+ return cast(str, compute_pool)
71
75
 
72
76
  @property
73
77
  def _service_spec(self) -> dict[str, Any]:
@@ -181,9 +185,13 @@ class MLJob(Generic[T]):
181
185
  while (status := self.status) not in TERMINAL_JOB_STATUSES:
182
186
  if status == "PENDING" and not warning_shown:
183
187
  pool_info = _get_compute_pool_info(self._session, self._compute_pool)
184
- if (pool_info.max_nodes - pool_info.active_nodes) < self.min_instances:
188
+ requested_attributes = {"max_nodes": 3, "active_nodes": 9}
189
+ if (
190
+ pool_info[requested_attributes["max_nodes"]] - pool_info[requested_attributes["active_nodes"]]
191
+ ) < self.min_instances:
185
192
  logger.warning(
186
- f"Compute pool busy ({pool_info.active_nodes}/{pool_info.max_nodes} nodes in use, "
193
+ f'Compute pool busy ({pool_info[requested_attributes["active_nodes"]]}'
194
+ f'/{pool_info[requested_attributes["max_nodes"]]} nodes in use, '
187
195
  f"{self.min_instances} nodes required). Job execution may be delayed."
188
196
  )
189
197
  warning_shown = True
@@ -220,27 +228,46 @@ class MLJob(Generic[T]):
220
228
  return cast(T, self._result.result)
221
229
  raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception
222
230
 
231
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
232
+ def cancel(self) -> None:
233
+ """
234
+ Cancel the running job.
235
+ Raises:
236
+ RuntimeError: If cancellation fails. # noqa: DAR401
237
+ """
238
+ try:
239
+ self._session.sql(f"CALL {self.id}!spcs_cancel_job()").collect()
240
+ logger.debug(f"Cancellation requested for job {self.id}")
241
+ except SnowparkSQLException as e:
242
+ raise RuntimeError(f"Failed to cancel job {self.id}: {e.message}") from e
243
+
223
244
 
224
245
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
225
246
  def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
226
247
  """Retrieve job or job instance execution status."""
227
248
  if instance_id is not None:
228
249
  # Get specific instance status
229
- rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
230
- for row in rows:
231
- if row["instance_id"] == str(instance_id):
232
- return cast(types.JOB_STATUS, row["status"])
250
+ rows = session._conn.run_query(
251
+ "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=[job_id], _force_qmark_paramstyle=True
252
+ )
253
+ request_attributes = query_helper.get_attribute_map(session, {"status": 5, "instance_id": 4})
254
+ if isinstance(rows, dict) and "data" in rows:
255
+ for row in rows["data"]:
256
+ if row[request_attributes["instance_id"]] == str(instance_id):
257
+ return cast(types.JOB_STATUS, row[request_attributes["status"]])
233
258
  raise ValueError(f"Instance {instance_id} not found in job {job_id}")
234
259
  else:
235
260
  row = _get_service_info(session, job_id)
236
- return cast(types.JOB_STATUS, row["status"])
261
+ request_attributes = query_helper.get_attribute_map(session, {"status": 1})
262
+ return cast(types.JOB_STATUS, row[request_attributes["status"]])
237
263
 
238
264
 
239
265
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
240
266
  def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
241
267
  """Retrieve job execution service spec."""
242
268
  row = _get_service_info(session, job_id)
243
- return cast(dict[str, Any], yaml.safe_load(row["spec"]))
269
+ requested_attributes = query_helper.get_attribute_map(session, {"spec": 6})
270
+ return cast(dict[str, Any], yaml.safe_load(row[requested_attributes["spec"]]))
244
271
 
245
272
 
246
273
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
@@ -262,6 +289,7 @@ def _get_logs(
262
289
 
263
290
  Raises:
264
291
  RuntimeError: if failed to get head instance_id
292
+ SnowparkSQLException: if there is an error retrieving logs from SPCS interface.
265
293
  """
266
294
  # If instance_id is not specified, try to get the head instance ID
267
295
  if instance_id is None:
@@ -279,30 +307,59 @@ def _get_logs(
279
307
  if limit > 0:
280
308
  params.append(limit)
281
309
  try:
282
- (row,) = session.sql(
310
+ data = session._conn.run_query(
283
311
  f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
284
312
  params=params,
285
- ).collect()
286
- except SnowparkSQLException as e:
287
- if "Container Status: PENDING" in e.message:
313
+ _force_qmark_paramstyle=True,
314
+ )
315
+ if isinstance(data, dict) and "data" in data:
316
+ full_log = str(data["data"][0][0])
317
+ # pass type check
318
+ else:
319
+ full_log = ""
320
+ except errors.ProgrammingError as e:
321
+ if "Container Status: PENDING" in str(e):
288
322
  logger.warning("Waiting for container to start. Logs will be shown when available.")
289
323
  return ""
290
324
  else:
291
- # event table accepts job name, not fully qualified name
292
- # cast is to resolve the type check error
293
- db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
294
- db = cast(str, db or session.get_current_database())
295
- schema = cast(str, schema or session.get_current_schema())
296
- logs = _get_service_log_from_event_table(
297
- session, db, schema, name, limit, instance_id if instance_id else None
298
- )
299
- if len(logs) == 0:
300
- raise RuntimeError(
301
- "No logs were found. Please verify that the database, schema, and job ID are correct."
325
+ # Fallback plan:
326
+ # 1. Try SPCS Interface (doesn't require event table permission)
327
+ # 2. If the interface call fails, query Event Table (requires permission)
328
+ logger.debug("falling back to SPCS Interface for logs")
329
+ try:
330
+ logs = _get_logs_spcs(
331
+ session,
332
+ job_id,
333
+ limit=limit,
334
+ instance_id=instance_id if instance_id else 0,
335
+ container_name=constants.DEFAULT_CONTAINER_NAME,
302
336
  )
303
- return os.linesep.join(row[0] for row in logs)
304
-
305
- full_log = str(row[0])
337
+ full_log = os.linesep.join(row[0] for row in logs)
338
+
339
+ except SnowparkSQLException as spcs_error:
340
+ if spcs_error.sql_error_code == 2143:
341
+ logger.debug("persistent logs may not be enabled, falling back to event table")
342
+ else:
343
+ # If SPCS Interface fails for any other reason,
344
+ # for example, incorrect argument format,raise the error directly
345
+ raise
346
+ # event table accepts job name, not fully qualified name
347
+ db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
348
+ db = db or session.get_current_database()
349
+ schema = schema or session.get_current_schema()
350
+ event_table_logs = _get_service_log_from_event_table(
351
+ session,
352
+ name,
353
+ database=db,
354
+ schema=schema,
355
+ instance_id=instance_id if instance_id else 0,
356
+ limit=limit,
357
+ )
358
+ if len(event_table_logs) == 0:
359
+ raise RuntimeError(
360
+ "No logs were found. Please verify that the database, schema, and job ID are correct."
361
+ )
362
+ full_log = os.linesep.join(json.loads(row[0]) for row in event_table_logs)
306
363
 
307
364
  # If verbose is True, return the complete log
308
365
  if verbose:
@@ -340,52 +397,70 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
340
397
  RuntimeError: If the instances died or if some instances disappeared.
341
398
  """
342
399
 
343
- target_instances = _get_target_instances(session, job_id)
400
+ try:
401
+ target_instances = _get_target_instances(session, job_id)
402
+ except errors.ProgrammingError:
403
+ # service may be deleted
404
+ raise RuntimeError("Couldn’t retrieve service information")
344
405
 
345
406
  if target_instances == 1:
346
407
  return 0
347
408
 
348
409
  try:
349
- rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
350
- except SnowparkSQLException:
410
+ rows = session._conn.run_query(
411
+ "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,), _force_qmark_paramstyle=True
412
+ )
413
+ except errors.ProgrammingError:
351
414
  # service may be deleted
352
415
  raise RuntimeError("Couldn’t retrieve instances")
353
416
 
354
- if not rows:
417
+ if not rows or not isinstance(rows, dict) or not rows.get("data"):
355
418
  return None
356
419
 
357
- if target_instances > len(rows):
420
+ if target_instances > len(rows["data"]):
358
421
  raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
359
422
 
423
+ requested_attributes = query_helper.get_attribute_map(session, {"start_time": 8, "instance_id": 4})
360
424
  # Sort by start_time first, then by instance_id
361
425
  try:
362
- sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
426
+ sorted_instances = sorted(
427
+ rows["data"],
428
+ key=lambda x: (x[requested_attributes["start_time"]], int(x[requested_attributes["instance_id"]])),
429
+ )
363
430
  except TypeError:
364
431
  raise RuntimeError("Job instance information unavailable.")
365
432
  head_instance = sorted_instances[0]
366
- if not head_instance["start_time"]:
433
+ if not head_instance[requested_attributes["start_time"]]:
367
434
  # If head instance hasn't started yet, return None
368
435
  return None
369
436
  try:
370
- return int(head_instance["instance_id"])
437
+ return int(head_instance[requested_attributes["instance_id"]])
371
438
  except (ValueError, TypeError):
372
439
  return 0
373
440
 
374
441
 
375
442
  def _get_service_log_from_event_table(
376
- session: snowpark.Session, database: str, schema: str, name: str, limit: int, instance_id: Optional[int]
377
- ) -> list[Row]:
443
+ session: snowpark.Session,
444
+ name: str,
445
+ database: Optional[str] = None,
446
+ schema: Optional[str] = None,
447
+ instance_id: Optional[int] = None,
448
+ limit: int = -1,
449
+ ) -> Any:
378
450
  params: list[Any] = [
379
- database,
380
- schema,
381
451
  name,
382
452
  ]
383
453
  query = [
384
454
  "SELECT VALUE FROM snowflake.telemetry.events_view",
385
- 'WHERE RESOURCE_ATTRIBUTES:"snow.database.name" = ?',
386
- 'AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?',
387
- 'AND RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
455
+ 'WHERE RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
388
456
  ]
457
+ if database:
458
+ query.append('AND RESOURCE_ATTRIBUTES:"snow.database.name" = ?')
459
+ params.append(database)
460
+
461
+ if schema:
462
+ query.append('AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?')
463
+ params.append(schema)
389
464
 
390
465
  if instance_id:
391
466
  query.append('AND RESOURCE_ATTRIBUTES:"snow.service.container.instance" = ?')
@@ -398,20 +473,23 @@ def _get_service_log_from_event_table(
398
473
  if limit > 0:
399
474
  query.append("LIMIT ?")
400
475
  params.append(limit)
401
-
402
- rows = session.sql(
403
- "\n".join(line for line in query if line),
404
- params=params,
405
- ).collect()
406
- return rows
476
+ rows = session._conn.run_query(
477
+ "\n".join(line for line in query if line), params=params, _force_qmark_paramstyle=True
478
+ )
479
+ if not rows or not isinstance(rows, dict) or not rows.get("data"):
480
+ return []
481
+ return rows["data"]
407
482
 
408
483
 
409
- def _get_service_info(session: snowpark.Session, job_id: str) -> Row:
410
- (row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
411
- return row
484
+ def _get_service_info(session: snowpark.Session, job_id: str) -> Any:
485
+ row = session._conn.run_query("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,), _force_qmark_paramstyle=True)
486
+ # pass the type check
487
+ if not row or not isinstance(row, dict) or not row.get("data"):
488
+ raise errors.ProgrammingError("failed to retrieve service information")
489
+ return row["data"][0]
412
490
 
413
491
 
414
- def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
492
+ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Any:
415
493
  """
416
494
  Check if the compute pool has enough available instances.
417
495
 
@@ -420,14 +498,19 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
420
498
  compute_pool (str): The name of the compute pool.
421
499
 
422
500
  Returns:
423
- Row: The compute pool information.
501
+ Any: The compute pool information.
424
502
 
425
503
  Raises:
426
504
  ValueError: If the compute pool is not found.
427
505
  """
428
506
  try:
429
- (pool_info,) = session.sql("SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,)).collect()
430
- return pool_info
507
+ compute_pool_info = session._conn.run_query(
508
+ "SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,), _force_qmark_paramstyle=True
509
+ )
510
+ # pass the type check
511
+ if not compute_pool_info or not isinstance(compute_pool_info, dict) or not compute_pool_info.get("data"):
512
+ raise ValueError(f"Compute pool '{compute_pool}' not found")
513
+ return compute_pool_info["data"][0]
431
514
  except ValueError as e:
432
515
  if "not enough values to unpack" in str(e):
433
516
  raise ValueError(f"Compute pool '{compute_pool}' not found")
@@ -437,4 +520,42 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
437
520
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
438
521
  def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
439
522
  row = _get_service_info(session, job_id)
440
- return int(row["target_instances"])
523
+ requested_attributes = query_helper.get_attribute_map(session, {"target_instances": 9})
524
+ return int(row[requested_attributes["target_instances"]])
525
+
526
+
527
+ def _get_logs_spcs(
528
+ session: snowpark.Session,
529
+ fully_qualified_name: str,
530
+ limit: int = -1,
531
+ instance_id: Optional[int] = None,
532
+ container_name: Optional[str] = None,
533
+ start_time: Optional[str] = None,
534
+ end_time: Optional[str] = None,
535
+ ) -> list[Row]:
536
+ query = [
537
+ f"SELECT LOG FROM table({fully_qualified_name}!spcs_get_logs(",
538
+ ]
539
+ conditions_params = []
540
+ if start_time:
541
+ conditions_params.append(f"start_time => TO_TIMESTAMP_LTZ('{start_time}')")
542
+ if end_time:
543
+ conditions_params.append(f"end_time => TO_TIMESTAMP_LTZ('{end_time}')")
544
+ if len(conditions_params) > 0:
545
+ query.append(", ".join(conditions_params))
546
+
547
+ query.append("))")
548
+
549
+ query_params = []
550
+ if instance_id is not None:
551
+ query_params.append(f"INSTANCE_ID = {instance_id}")
552
+ if container_name:
553
+ query_params.append(f"CONTAINER_NAME = '{container_name}'")
554
+ if len(query_params) > 0:
555
+ query.append("WHERE " + " AND ".join(query_params))
556
+
557
+ query.append("ORDER BY TIMESTAMP ASC")
558
+ if limit > 0:
559
+ query.append(f" LIMIT {limit};")
560
+ rows = session.sql("\n".join(query)).collect()
561
+ return rows