snowflake-ml-python 1.8.5__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. snowflake/ml/_internal/telemetry.py +6 -9
  2. snowflake/ml/_internal/utils/connection_params.py +196 -0
  3. snowflake/ml/_internal/utils/identifier.py +1 -1
  4. snowflake/ml/_internal/utils/mixins.py +61 -0
  5. snowflake/ml/jobs/__init__.py +2 -0
  6. snowflake/ml/jobs/_utils/constants.py +3 -2
  7. snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
  8. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  9. snowflake/ml/jobs/_utils/payload_utils.py +89 -40
  10. snowflake/ml/jobs/_utils/query_helper.py +9 -0
  11. snowflake/ml/jobs/_utils/scripts/constants.py +19 -3
  12. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -26
  13. snowflake/ml/jobs/_utils/spec_utils.py +29 -5
  14. snowflake/ml/jobs/_utils/stage_utils.py +119 -0
  15. snowflake/ml/jobs/_utils/types.py +5 -1
  16. snowflake/ml/jobs/decorators.py +20 -28
  17. snowflake/ml/jobs/job.py +197 -61
  18. snowflake/ml/jobs/manager.py +253 -121
  19. snowflake/ml/model/_client/model/model_impl.py +58 -0
  20. snowflake/ml/model/_client/model/model_version_impl.py +90 -0
  21. snowflake/ml/model/_client/ops/model_ops.py +18 -6
  22. snowflake/ml/model/_client/ops/service_ops.py +23 -6
  23. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
  24. snowflake/ml/model/_client/sql/service.py +68 -20
  25. snowflake/ml/model/_client/sql/stage.py +5 -2
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
  27. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  28. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  29. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
  30. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  31. snowflake/ml/model/_signatures/core.py +24 -0
  32. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  33. snowflake/ml/model/target_platform.py +11 -0
  34. snowflake/ml/model/task.py +9 -0
  35. snowflake/ml/model/type_hints.py +5 -13
  36. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  37. snowflake/ml/monitoring/explain_visualize.py +2 -2
  38. snowflake/ml/monitoring/model_monitor.py +0 -4
  39. snowflake/ml/registry/_manager/model_manager.py +30 -15
  40. snowflake/ml/registry/registry.py +144 -47
  41. snowflake/ml/utils/connection_params.py +1 -1
  42. snowflake/ml/utils/html_utils.py +263 -0
  43. snowflake/ml/version.py +1 -1
  44. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +64 -19
  45. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +48 -41
  46. snowflake/ml/monitoring/model_monitor_version.py +0 -1
  47. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
  48. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
  49. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py CHANGED
@@ -1,3 +1,4 @@
1
+ import json
1
2
  import logging
2
3
  import os
3
4
  import time
@@ -7,21 +8,23 @@ from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overlo
7
8
  import yaml
8
9
 
9
10
  from snowflake import snowpark
11
+ from snowflake.connector import errors
10
12
  from snowflake.ml._internal import telemetry
11
13
  from snowflake.ml._internal.utils import identifier
12
- from snowflake.ml.jobs._utils import constants, interop_utils, types
14
+ from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
15
+ from snowflake.ml.jobs._utils import constants, interop_utils, query_helper, types
13
16
  from snowflake.snowpark import Row, context as sp_context
14
17
  from snowflake.snowpark.exceptions import SnowparkSQLException
15
18
 
16
19
  _PROJECT = "MLJob"
17
- TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
20
+ TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR"}
18
21
 
19
22
  T = TypeVar("T")
20
23
 
21
24
  logger = logging.getLogger(__name__)
22
25
 
23
26
 
24
- class MLJob(Generic[T]):
27
+ class MLJob(Generic[T], SerializableSessionMixin):
25
28
  def __init__(
26
29
  self,
27
30
  id: str,
@@ -67,7 +70,8 @@ class MLJob(Generic[T]):
67
70
  def _compute_pool(self) -> str:
68
71
  """Get the job's compute pool name."""
69
72
  row = _get_service_info(self._session, self.id)
70
- return cast(str, row["compute_pool"])
73
+ compute_pool = row[query_helper.get_attribute_map(self._session, {"compute_pool": 5})["compute_pool"]]
74
+ return cast(str, compute_pool)
71
75
 
72
76
  @property
73
77
  def _service_spec(self) -> dict[str, Any]:
@@ -181,16 +185,20 @@ class MLJob(Generic[T]):
181
185
  while (status := self.status) not in TERMINAL_JOB_STATUSES:
182
186
  if status == "PENDING" and not warning_shown:
183
187
  pool_info = _get_compute_pool_info(self._session, self._compute_pool)
184
- if (pool_info.max_nodes - pool_info.active_nodes) < self.min_instances:
188
+ requested_attributes = {"max_nodes": 3, "active_nodes": 9}
189
+ if (
190
+ pool_info[requested_attributes["max_nodes"]] - pool_info[requested_attributes["active_nodes"]]
191
+ ) < self.min_instances:
185
192
  logger.warning(
186
- f"Compute pool busy ({pool_info.active_nodes}/{pool_info.max_nodes} nodes in use)."
187
- " Job execution may be delayed."
193
+ f'Compute pool busy ({pool_info[requested_attributes["active_nodes"]]}'
194
+ f'/{pool_info[requested_attributes["max_nodes"]]} nodes in use, '
195
+ f"{self.min_instances} nodes required). Job execution may be delayed."
188
196
  )
189
197
  warning_shown = True
190
198
  if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
191
199
  raise TimeoutError(f"Job {self.name} did not complete within {elapsed} seconds")
192
200
  time.sleep(delay)
193
- delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
201
+ delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
194
202
  return self.status
195
203
 
196
204
  @snowpark._internal.utils.private_preview(version="1.8.2")
@@ -220,27 +228,46 @@ class MLJob(Generic[T]):
220
228
  return cast(T, self._result.result)
221
229
  raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception
222
230
 
231
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
232
+ def cancel(self) -> None:
233
+ """
234
+ Cancel the running job.
235
+ Raises:
236
+ RuntimeError: If cancellation fails. # noqa: DAR401
237
+ """
238
+ try:
239
+ self._session.sql(f"CALL {self.id}!spcs_cancel_job()").collect()
240
+ logger.debug(f"Cancellation requested for job {self.id}")
241
+ except SnowparkSQLException as e:
242
+ raise RuntimeError(f"Failed to cancel job {self.id}: {e.message}") from e
243
+
223
244
 
224
245
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
225
246
  def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
226
247
  """Retrieve job or job instance execution status."""
227
248
  if instance_id is not None:
228
249
  # Get specific instance status
229
- rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
230
- for row in rows:
231
- if row["instance_id"] == str(instance_id):
232
- return cast(types.JOB_STATUS, row["status"])
250
+ rows = session._conn.run_query(
251
+ "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=[job_id], _force_qmark_paramstyle=True
252
+ )
253
+ request_attributes = query_helper.get_attribute_map(session, {"status": 5, "instance_id": 4})
254
+ if isinstance(rows, dict) and "data" in rows:
255
+ for row in rows["data"]:
256
+ if row[request_attributes["instance_id"]] == str(instance_id):
257
+ return cast(types.JOB_STATUS, row[request_attributes["status"]])
233
258
  raise ValueError(f"Instance {instance_id} not found in job {job_id}")
234
259
  else:
235
260
  row = _get_service_info(session, job_id)
236
- return cast(types.JOB_STATUS, row["status"])
261
+ request_attributes = query_helper.get_attribute_map(session, {"status": 1})
262
+ return cast(types.JOB_STATUS, row[request_attributes["status"]])
237
263
 
238
264
 
239
265
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
240
266
  def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
241
267
  """Retrieve job execution service spec."""
242
268
  row = _get_service_info(session, job_id)
243
- return cast(dict[str, Any], yaml.safe_load(row["spec"]))
269
+ requested_attributes = query_helper.get_attribute_map(session, {"spec": 6})
270
+ return cast(dict[str, Any], yaml.safe_load(row[requested_attributes["spec"]]))
244
271
 
245
272
 
246
273
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
@@ -262,6 +289,7 @@ def _get_logs(
262
289
 
263
290
  Raises:
264
291
  RuntimeError: if failed to get head instance_id
292
+ SnowparkSQLException: if there is an error retrieving logs from SPCS interface.
265
293
  """
266
294
  # If instance_id is not specified, try to get the head instance ID
267
295
  if instance_id is None:
@@ -279,30 +307,59 @@ def _get_logs(
279
307
  if limit > 0:
280
308
  params.append(limit)
281
309
  try:
282
- (row,) = session.sql(
310
+ data = session._conn.run_query(
283
311
  f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
284
312
  params=params,
285
- ).collect()
286
- except SnowparkSQLException as e:
287
- if "Container Status: PENDING" in e.message:
313
+ _force_qmark_paramstyle=True,
314
+ )
315
+ if isinstance(data, dict) and "data" in data:
316
+ full_log = str(data["data"][0][0])
317
+ # pass type check
318
+ else:
319
+ full_log = ""
320
+ except errors.ProgrammingError as e:
321
+ if "Container Status: PENDING" in str(e):
288
322
  logger.warning("Waiting for container to start. Logs will be shown when available.")
289
323
  return ""
290
324
  else:
291
- # event table accepts job name, not fully qualified name
292
- # cast is to resolve the type check error
293
- db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
294
- db = cast(str, db or session.get_current_database())
295
- schema = cast(str, schema or session.get_current_schema())
296
- logs = _get_service_log_from_event_table(
297
- session, db, schema, name, limit, instance_id if instance_id else None
298
- )
299
- if len(logs) == 0:
300
- raise RuntimeError(
301
- "No logs were found. Please verify that the database, schema, and job ID are correct."
325
+ # Fallback plan:
326
+ # 1. Try SPCS Interface (doesn't require event table permission)
327
+ # 2. If the interface call fails, query Event Table (requires permission)
328
+ logger.debug("falling back to SPCS Interface for logs")
329
+ try:
330
+ logs = _get_logs_spcs(
331
+ session,
332
+ job_id,
333
+ limit=limit,
334
+ instance_id=instance_id if instance_id else 0,
335
+ container_name=constants.DEFAULT_CONTAINER_NAME,
302
336
  )
303
- return os.linesep.join(row[0] for row in logs)
304
-
305
- full_log = str(row[0])
337
+ full_log = os.linesep.join(row[0] for row in logs)
338
+
339
+ except SnowparkSQLException as spcs_error:
340
+ if spcs_error.sql_error_code == 2143:
341
+ logger.debug("persistent logs may not be enabled, falling back to event table")
342
+ else:
343
+ # If SPCS Interface fails for any other reason,
344
+ # for example, incorrect argument format,raise the error directly
345
+ raise
346
+ # event table accepts job name, not fully qualified name
347
+ db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
348
+ db = db or session.get_current_database()
349
+ schema = schema or session.get_current_schema()
350
+ event_table_logs = _get_service_log_from_event_table(
351
+ session,
352
+ name,
353
+ database=db,
354
+ schema=schema,
355
+ instance_id=instance_id if instance_id else 0,
356
+ limit=limit,
357
+ )
358
+ if len(event_table_logs) == 0:
359
+ raise RuntimeError(
360
+ "No logs were found. Please verify that the database, schema, and job ID are correct."
361
+ )
362
+ full_log = os.linesep.join(json.loads(row[0]) for row in event_table_logs)
306
363
 
307
364
  # If verbose is True, return the complete log
308
365
  if verbose:
@@ -338,47 +395,72 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
338
395
 
339
396
  Raises:
340
397
  RuntimeError: If the instances died or if some instances disappeared.
341
-
342
398
  """
399
+
343
400
  try:
344
- rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
345
- except SnowparkSQLException:
401
+ target_instances = _get_target_instances(session, job_id)
402
+ except errors.ProgrammingError:
403
+ # service may be deleted
404
+ raise RuntimeError("Couldn’t retrieve service information")
405
+
406
+ if target_instances == 1:
407
+ return 0
408
+
409
+ try:
410
+ rows = session._conn.run_query(
411
+ "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,), _force_qmark_paramstyle=True
412
+ )
413
+ except errors.ProgrammingError:
346
414
  # service may be deleted
347
415
  raise RuntimeError("Couldn’t retrieve instances")
348
- if not rows:
416
+
417
+ if not rows or not isinstance(rows, dict) or not rows.get("data"):
349
418
  return None
350
- if _get_target_instances(session, job_id) > len(rows):
419
+
420
+ if target_instances > len(rows["data"]):
351
421
  raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
352
422
 
423
+ requested_attributes = query_helper.get_attribute_map(session, {"start_time": 8, "instance_id": 4})
353
424
  # Sort by start_time first, then by instance_id
354
425
  try:
355
- sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
426
+ sorted_instances = sorted(
427
+ rows["data"],
428
+ key=lambda x: (x[requested_attributes["start_time"]], int(x[requested_attributes["instance_id"]])),
429
+ )
356
430
  except TypeError:
357
431
  raise RuntimeError("Job instance information unavailable.")
358
432
  head_instance = sorted_instances[0]
359
- if not head_instance["start_time"]:
433
+ if not head_instance[requested_attributes["start_time"]]:
360
434
  # If head instance hasn't started yet, return None
361
435
  return None
362
436
  try:
363
- return int(head_instance["instance_id"])
437
+ return int(head_instance[requested_attributes["instance_id"]])
364
438
  except (ValueError, TypeError):
365
439
  return 0
366
440
 
367
441
 
368
442
  def _get_service_log_from_event_table(
369
- session: snowpark.Session, database: str, schema: str, name: str, limit: int, instance_id: Optional[int]
370
- ) -> list[Row]:
443
+ session: snowpark.Session,
444
+ name: str,
445
+ database: Optional[str] = None,
446
+ schema: Optional[str] = None,
447
+ instance_id: Optional[int] = None,
448
+ limit: int = -1,
449
+ ) -> Any:
371
450
  params: list[Any] = [
372
- database,
373
- schema,
374
451
  name,
375
452
  ]
376
453
  query = [
377
454
  "SELECT VALUE FROM snowflake.telemetry.events_view",
378
- 'WHERE RESOURCE_ATTRIBUTES:"snow.database.name" = ?',
379
- 'AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?',
380
- 'AND RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
455
+ 'WHERE RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
381
456
  ]
457
+ if database:
458
+ query.append('AND RESOURCE_ATTRIBUTES:"snow.database.name" = ?')
459
+ params.append(database)
460
+
461
+ if schema:
462
+ query.append('AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?')
463
+ params.append(schema)
382
464
 
383
465
  if instance_id:
384
466
  query.append('AND RESOURCE_ATTRIBUTES:"snow.service.container.instance" = ?')
@@ -391,20 +473,23 @@ def _get_service_log_from_event_table(
391
473
  if limit > 0:
392
474
  query.append("LIMIT ?")
393
475
  params.append(limit)
394
-
395
- rows = session.sql(
396
- "\n".join(line for line in query if line),
397
- params=params,
398
- ).collect()
399
- return rows
476
+ rows = session._conn.run_query(
477
+ "\n".join(line for line in query if line), params=params, _force_qmark_paramstyle=True
478
+ )
479
+ if not rows or not isinstance(rows, dict) or not rows.get("data"):
480
+ return []
481
+ return rows["data"]
400
482
 
401
483
 
402
- def _get_service_info(session: snowpark.Session, job_id: str) -> Row:
403
- (row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
404
- return row
484
+ def _get_service_info(session: snowpark.Session, job_id: str) -> Any:
485
+ row = session._conn.run_query("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,), _force_qmark_paramstyle=True)
486
+ # pass the type check
487
+ if not row or not isinstance(row, dict) or not row.get("data"):
488
+ raise errors.ProgrammingError("failed to retrieve service information")
489
+ return row["data"][0]
405
490
 
406
491
 
407
- def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
492
+ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Any:
408
493
  """
409
494
  Check if the compute pool has enough available instances.
410
495
 
@@ -413,13 +498,64 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
413
498
  compute_pool (str): The name of the compute pool.
414
499
 
415
500
  Returns:
416
- Row: The compute pool information.
501
+ Any: The compute pool information.
502
+
503
+ Raises:
504
+ ValueError: If the compute pool is not found.
417
505
  """
418
- (pool_info,) = session.sql("SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,)).collect()
419
- return pool_info
506
+ try:
507
+ compute_pool_info = session._conn.run_query(
508
+ "SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,), _force_qmark_paramstyle=True
509
+ )
510
+ # pass the type check
511
+ if not compute_pool_info or not isinstance(compute_pool_info, dict) or not compute_pool_info.get("data"):
512
+ raise ValueError(f"Compute pool '{compute_pool}' not found")
513
+ return compute_pool_info["data"][0]
514
+ except ValueError as e:
515
+ if "not enough values to unpack" in str(e):
516
+ raise ValueError(f"Compute pool '{compute_pool}' not found")
517
+ raise
420
518
 
421
519
 
422
520
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
423
521
  def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
424
522
  row = _get_service_info(session, job_id)
425
- return int(row["target_instances"]) if row["target_instances"] else 0
523
+ requested_attributes = query_helper.get_attribute_map(session, {"target_instances": 9})
524
+ return int(row[requested_attributes["target_instances"]])
525
+
526
+
527
+ def _get_logs_spcs(
528
+ session: snowpark.Session,
529
+ fully_qualified_name: str,
530
+ limit: int = -1,
531
+ instance_id: Optional[int] = None,
532
+ container_name: Optional[str] = None,
533
+ start_time: Optional[str] = None,
534
+ end_time: Optional[str] = None,
535
+ ) -> list[Row]:
536
+ query = [
537
+ f"SELECT LOG FROM table({fully_qualified_name}!spcs_get_logs(",
538
+ ]
539
+ conditions_params = []
540
+ if start_time:
541
+ conditions_params.append(f"start_time => TO_TIMESTAMP_LTZ('{start_time}')")
542
+ if end_time:
543
+ conditions_params.append(f"end_time => TO_TIMESTAMP_LTZ('{end_time}')")
544
+ if len(conditions_params) > 0:
545
+ query.append(", ".join(conditions_params))
546
+
547
+ query.append("))")
548
+
549
+ query_params = []
550
+ if instance_id is not None:
551
+ query_params.append(f"INSTANCE_ID = {instance_id}")
552
+ if container_name:
553
+ query_params.append(f"CONTAINER_NAME = '{container_name}'")
554
+ if len(query_params) > 0:
555
+ query.append("WHERE " + " AND ".join(query_params))
556
+
557
+ query.append("ORDER BY TIMESTAMP ASC")
558
+ if limit > 0:
559
+ query.append(f" LIMIT {limit};")
560
+ rows = session.sql("\n".join(query)).collect()
561
+ return rows