snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/identifier.py +1 -1
  5. snowflake/ml/_internal/utils/mixins.py +71 -0
  6. snowflake/ml/_internal/utils/service_logger.py +4 -2
  7. snowflake/ml/data/_internal/arrow_ingestor.py +11 -1
  8. snowflake/ml/data/data_connector.py +43 -2
  9. snowflake/ml/data/data_ingestor.py +8 -0
  10. snowflake/ml/data/torch_utils.py +1 -1
  11. snowflake/ml/dataset/dataset.py +3 -2
  12. snowflake/ml/dataset/dataset_reader.py +22 -6
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/experiment_tracking.py +319 -0
  20. snowflake/ml/jobs/_utils/constants.py +1 -1
  21. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +5 -3
  23. snowflake/ml/jobs/_utils/query_helper.py +20 -0
  24. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -1
  25. snowflake/ml/jobs/_utils/spec_utils.py +21 -4
  26. snowflake/ml/jobs/decorators.py +18 -25
  27. snowflake/ml/jobs/job.py +137 -37
  28. snowflake/ml/jobs/manager.py +228 -153
  29. snowflake/ml/lineage/lineage_node.py +2 -2
  30. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  31. snowflake/ml/model/_client/ops/model_ops.py +12 -3
  32. snowflake/ml/model/_client/ops/service_ops.py +324 -138
  33. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  34. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  35. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +55 -13
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  38. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  39. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  40. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  45. snowflake/ml/model/event_handler.py +117 -0
  46. snowflake/ml/model/model_signature.py +9 -9
  47. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  48. snowflake/ml/model/target_platform.py +11 -0
  49. snowflake/ml/model/task.py +9 -0
  50. snowflake/ml/model/type_hints.py +5 -13
  51. snowflake/ml/modeling/framework/base.py +1 -1
  52. snowflake/ml/modeling/metrics/classification.py +14 -14
  53. snowflake/ml/modeling/metrics/correlation.py +19 -8
  54. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  55. snowflake/ml/modeling/metrics/ranking.py +6 -6
  56. snowflake/ml/modeling/metrics/regression.py +9 -9
  57. snowflake/ml/monitoring/explain_visualize.py +12 -5
  58. snowflake/ml/registry/_manager/model_manager.py +47 -15
  59. snowflake/ml/registry/registry.py +109 -64
  60. snowflake/ml/version.py +1 -1
  61. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/METADATA +118 -18
  62. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/RECORD +65 -53
  63. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/WHEEL +0 -0
  64. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/licenses/LICENSE.txt +0 -0
  65. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import copy
2
2
  import functools
3
- from typing import Callable, Optional, TypeVar
3
+ from typing import Any, Callable, Optional, TypeVar
4
4
 
5
5
  from typing_extensions import ParamSpec
6
6
 
@@ -20,16 +20,11 @@ def remote(
20
20
  compute_pool: str,
21
21
  *,
22
22
  stage_name: str,
23
+ target_instances: int = 1,
23
24
  pip_requirements: Optional[list[str]] = None,
24
25
  external_access_integrations: Optional[list[str]] = None,
25
- query_warehouse: Optional[str] = None,
26
- env_vars: Optional[dict[str, str]] = None,
27
- target_instances: int = 1,
28
- min_instances: Optional[int] = None,
29
- enable_metrics: bool = False,
30
- database: Optional[str] = None,
31
- schema: Optional[str] = None,
32
26
  session: Optional[snowpark.Session] = None,
27
+ **kwargs: Any,
33
28
  ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
34
29
  """
35
30
  Submit a job to the compute pool.
@@ -37,17 +32,20 @@ def remote(
37
32
  Args:
38
33
  compute_pool: The compute pool to use for the job.
39
34
  stage_name: The name of the stage where the job payload will be uploaded.
35
+ target_instances: The number of nodes in the job. If none specified, create a single node job.
40
36
  pip_requirements: A list of pip requirements for the job.
41
37
  external_access_integrations: A list of external access integrations.
42
- query_warehouse: The query warehouse to use. Defaults to session warehouse.
43
- env_vars: Environment variables to set in container
44
- target_instances: The number of nodes in the job. If none specified, create a single node job.
45
- min_instances: The minimum number of nodes required to start the job. If none specified,
46
- defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
47
- enable_metrics: Whether to enable metrics publishing for the job.
48
- database: The database to use for the job.
49
- schema: The schema to use for the job.
50
38
  session: The Snowpark session to use. If none specified, uses active session.
39
+ kwargs: Additional keyword arguments. Supported arguments:
40
+ database (str): The database to use for the job.
41
+ schema (str): The schema to use for the job.
42
+ min_instances (int): The minimum number of nodes required to start the job.
43
+ If none specified, defaults to target_instances. If set, the job
44
+ will not start until the minimum number of nodes is available.
45
+ env_vars (dict): Environment variables to set in container.
46
+ enable_metrics (bool): Whether to enable metrics publishing for the job.
47
+ query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
48
+ spec_overrides (dict): A dictionary of overrides for the service spec.
51
49
 
52
50
  Returns:
53
51
  Decorator that dispatches invocations of the decorated function as remote jobs.
@@ -61,22 +59,17 @@ def remote(
61
59
  wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
62
60
 
63
61
  @functools.wraps(func)
64
- def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
65
- payload = payload_utils.create_function_payload(func, *args, **kwargs)
62
+ def wrapper(*_args: _Args.args, **_kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
63
+ payload = payload_utils.create_function_payload(func, *_args, **_kwargs)
66
64
  job = jm._submit_job(
67
65
  source=payload,
68
66
  stage_name=stage_name,
69
67
  compute_pool=compute_pool,
68
+ target_instances=target_instances,
70
69
  pip_requirements=pip_requirements,
71
70
  external_access_integrations=external_access_integrations,
72
- query_warehouse=query_warehouse,
73
- env_vars=env_vars,
74
- target_instances=target_instances,
75
- min_instances=min_instances,
76
- enable_metrics=enable_metrics,
77
- database=database,
78
- schema=schema,
79
71
  session=payload.session or session,
72
+ **kwargs,
80
73
  )
81
74
  assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
82
75
  return job
snowflake/ml/jobs/job.py CHANGED
@@ -1,3 +1,4 @@
1
+ import json
1
2
  import logging
2
3
  import os
3
4
  import time
@@ -9,7 +10,8 @@ import yaml
9
10
  from snowflake import snowpark
10
11
  from snowflake.ml._internal import telemetry
11
12
  from snowflake.ml._internal.utils import identifier
12
- from snowflake.ml.jobs._utils import constants, interop_utils, types
13
+ from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
14
+ from snowflake.ml.jobs._utils import constants, interop_utils, query_helper, types
13
15
  from snowflake.snowpark import Row, context as sp_context
14
16
  from snowflake.snowpark.exceptions import SnowparkSQLException
15
17
 
@@ -21,7 +23,7 @@ T = TypeVar("T")
21
23
  logger = logging.getLogger(__name__)
22
24
 
23
25
 
24
- class MLJob(Generic[T]):
26
+ class MLJob(Generic[T], SerializableSessionMixin):
25
27
  def __init__(
26
28
  self,
27
29
  id: str,
@@ -220,6 +222,19 @@ class MLJob(Generic[T]):
220
222
  return cast(T, self._result.result)
221
223
  raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception
222
224
 
225
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
226
+ def cancel(self) -> None:
227
+ """
228
+ Cancel the running job.
229
+ Raises:
230
+ RuntimeError: If cancellation fails. # noqa: DAR401
231
+ """
232
+ try:
233
+ self._session.sql(f"CALL {self.id}!spcs_cancel_job()").collect()
234
+ logger.debug(f"Cancellation requested for job {self.id}")
235
+ except SnowparkSQLException as e:
236
+ raise RuntimeError(f"Failed to cancel job {self.id}: {e.message}") from e
237
+
223
238
 
224
239
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
225
240
  def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
@@ -262,6 +277,7 @@ def _get_logs(
262
277
 
263
278
  Raises:
264
279
  RuntimeError: if failed to get head instance_id
280
+ SnowparkSQLException: if there is an error retrieving logs from SPCS interface.
265
281
  """
266
282
  # If instance_id is not specified, try to get the head instance ID
267
283
  if instance_id is None:
@@ -279,30 +295,55 @@ def _get_logs(
279
295
  if limit > 0:
280
296
  params.append(limit)
281
297
  try:
282
- (row,) = session.sql(
298
+ (row,) = query_helper.run_query(
299
+ session,
283
300
  f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
284
301
  params=params,
285
- ).collect()
302
+ )
303
+ full_log = str(row[0])
286
304
  except SnowparkSQLException as e:
287
305
  if "Container Status: PENDING" in e.message:
288
306
  logger.warning("Waiting for container to start. Logs will be shown when available.")
289
307
  return ""
290
308
  else:
291
- # event table accepts job name, not fully qualified name
292
- # cast is to resolve the type check error
293
- db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
294
- db = cast(str, db or session.get_current_database())
295
- schema = cast(str, schema or session.get_current_schema())
296
- logs = _get_service_log_from_event_table(
297
- session, db, schema, name, limit, instance_id if instance_id else None
298
- )
299
- if len(logs) == 0:
300
- raise RuntimeError(
301
- "No logs were found. Please verify that the database, schema, and job ID are correct."
309
+ # Fallback plan:
310
+ # 1. Try SPCS Interface (doesn't require event table permission)
311
+ # 2. If the interface call fails, query Event Table (requires permission)
312
+ logger.debug("falling back to SPCS Interface for logs")
313
+ try:
314
+ logs = _get_logs_spcs(
315
+ session,
316
+ job_id,
317
+ limit=limit,
318
+ instance_id=instance_id if instance_id else 0,
319
+ container_name=constants.DEFAULT_CONTAINER_NAME,
302
320
  )
303
- return os.linesep.join(row[0] for row in logs)
304
-
305
- full_log = str(row[0])
321
+ full_log = os.linesep.join(row[0] for row in logs)
322
+
323
+ except SnowparkSQLException as spcs_error:
324
+ if spcs_error.sql_error_code == 2143:
325
+ logger.debug("persistent logs may not be enabled, falling back to event table")
326
+ else:
327
+ # If SPCS Interface fails for any other reason,
328
+ # for example, incorrect argument format,raise the error directly
329
+ raise
330
+ # event table accepts job name, not fully qualified name
331
+ db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
332
+ db = db or session.get_current_database()
333
+ schema = schema or session.get_current_schema()
334
+ event_table_logs = _get_service_log_from_event_table(
335
+ session,
336
+ name,
337
+ database=db,
338
+ schema=schema,
339
+ instance_id=instance_id if instance_id else 0,
340
+ limit=limit,
341
+ )
342
+ if len(event_table_logs) == 0:
343
+ raise RuntimeError(
344
+ "No logs were found. Please verify that the database, schema, and job ID are correct."
345
+ )
346
+ full_log = os.linesep.join(json.loads(row[0]) for row in event_table_logs)
306
347
 
307
348
  # If verbose is True, return the complete log
308
349
  if verbose:
@@ -340,13 +381,21 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
340
381
  RuntimeError: If the instances died or if some instances disappeared.
341
382
  """
342
383
 
343
- target_instances = _get_target_instances(session, job_id)
384
+ try:
385
+ target_instances = _get_target_instances(session, job_id)
386
+ except SnowparkSQLException:
387
+ # service may be deleted
388
+ raise RuntimeError("Couldn’t retrieve service information")
344
389
 
345
390
  if target_instances == 1:
346
391
  return 0
347
392
 
348
393
  try:
349
- rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
394
+ rows = query_helper.run_query(
395
+ session,
396
+ "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)",
397
+ params=(job_id,),
398
+ )
350
399
  except SnowparkSQLException:
351
400
  # service may be deleted
352
401
  raise RuntimeError("Couldn’t retrieve instances")
@@ -373,19 +422,29 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
373
422
 
374
423
 
375
424
  def _get_service_log_from_event_table(
376
- session: snowpark.Session, database: str, schema: str, name: str, limit: int, instance_id: Optional[int]
425
+ session: snowpark.Session,
426
+ name: str,
427
+ database: Optional[str] = None,
428
+ schema: Optional[str] = None,
429
+ instance_id: Optional[int] = None,
430
+ limit: int = -1,
377
431
  ) -> list[Row]:
432
+ event_table_name = session.sql("SHOW PARAMETERS LIKE 'event_table' IN ACCOUNT").collect()[0]["value"]
433
+ query = [
434
+ "SELECT VALUE FROM IDENTIFIER(?)",
435
+ 'WHERE RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
436
+ ]
378
437
  params: list[Any] = [
379
- database,
380
- schema,
438
+ event_table_name,
381
439
  name,
382
440
  ]
383
- query = [
384
- "SELECT VALUE FROM snowflake.telemetry.events_view",
385
- 'WHERE RESOURCE_ATTRIBUTES:"snow.database.name" = ?',
386
- 'AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?',
387
- 'AND RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
388
- ]
441
+ if database:
442
+ query.append('AND RESOURCE_ATTRIBUTES:"snow.database.name" = ?')
443
+ params.append(database)
444
+
445
+ if schema:
446
+ query.append('AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?')
447
+ params.append(schema)
389
448
 
390
449
  if instance_id:
391
450
  query.append('AND RESOURCE_ATTRIBUTES:"snow.service.container.instance" = ?')
@@ -398,16 +457,18 @@ def _get_service_log_from_event_table(
398
457
  if limit > 0:
399
458
  query.append("LIMIT ?")
400
459
  params.append(limit)
401
-
402
- rows = session.sql(
460
+ # the wrap used in query_helper does not have return type.
461
+ # sticking a # type: ignore[no-any-return] is to pass type check
462
+ rows = query_helper.run_query(
463
+ session,
403
464
  "\n".join(line for line in query if line),
404
465
  params=params,
405
- ).collect()
406
- return rows
466
+ )
467
+ return rows # type: ignore[no-any-return]
407
468
 
408
469
 
409
- def _get_service_info(session: snowpark.Session, job_id: str) -> Row:
410
- (row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
470
+ def _get_service_info(session: snowpark.Session, job_id: str) -> Any:
471
+ (row,) = query_helper.run_query(session, "DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,))
411
472
  return row
412
473
 
413
474
 
@@ -426,8 +487,10 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
426
487
  ValueError: If the compute pool is not found.
427
488
  """
428
489
  try:
429
- (pool_info,) = session.sql("SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,)).collect()
430
- return pool_info
490
+ # the wrap used in query_helper does not have return type.
491
+ # sticking a # type: ignore[no-any-return] is to pass type check
492
+ (pool_info,) = query_helper.run_query(session, "SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,))
493
+ return pool_info # type: ignore[no-any-return]
431
494
  except ValueError as e:
432
495
  if "not enough values to unpack" in str(e):
433
496
  raise ValueError(f"Compute pool '{compute_pool}' not found")
@@ -438,3 +501,40 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
438
501
  def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
439
502
  row = _get_service_info(session, job_id)
440
503
  return int(row["target_instances"])
504
+
505
+
506
+ def _get_logs_spcs(
507
+ session: snowpark.Session,
508
+ fully_qualified_name: str,
509
+ limit: int = -1,
510
+ instance_id: Optional[int] = None,
511
+ container_name: Optional[str] = None,
512
+ start_time: Optional[str] = None,
513
+ end_time: Optional[str] = None,
514
+ ) -> list[Row]:
515
+ query = [
516
+ f"SELECT LOG FROM table({fully_qualified_name}!spcs_get_logs(",
517
+ ]
518
+ conditions_params = []
519
+ if start_time:
520
+ conditions_params.append(f"start_time => TO_TIMESTAMP_LTZ('{start_time}')")
521
+ if end_time:
522
+ conditions_params.append(f"end_time => TO_TIMESTAMP_LTZ('{end_time}')")
523
+ if len(conditions_params) > 0:
524
+ query.append(", ".join(conditions_params))
525
+
526
+ query.append("))")
527
+
528
+ query_params = []
529
+ if instance_id is not None:
530
+ query_params.append(f"INSTANCE_ID = {instance_id}")
531
+ if container_name:
532
+ query_params.append(f"CONTAINER_NAME = '{container_name}'")
533
+ if len(query_params) > 0:
534
+ query.append("WHERE " + " AND ".join(query_params))
535
+
536
+ query.append("ORDER BY TIMESTAMP ASC")
537
+ if limit > 0:
538
+ query.append(f" LIMIT {limit};")
539
+ rows = session.sql("\n".join(query)).collect()
540
+ return rows