snowflake-ml-python 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. snowflake/ml/_internal/utils/url.py +42 -0
  2. snowflake/ml/jobs/__init__.py +2 -0
  3. snowflake/ml/jobs/_utils/constants.py +2 -0
  4. snowflake/ml/jobs/_utils/payload_utils.py +38 -18
  5. snowflake/ml/jobs/_utils/query_helper.py +8 -1
  6. snowflake/ml/jobs/_utils/runtime_env_utils.py +58 -4
  7. snowflake/ml/jobs/_utils/spec_utils.py +0 -31
  8. snowflake/ml/jobs/_utils/stage_utils.py +2 -2
  9. snowflake/ml/jobs/_utils/types.py +22 -2
  10. snowflake/ml/jobs/job_definition.py +232 -0
  11. snowflake/ml/jobs/manager.py +16 -177
  12. snowflake/ml/lineage/lineage_node.py +1 -1
  13. snowflake/ml/model/__init__.py +6 -0
  14. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  15. snowflake/ml/model/_client/model/model_version_impl.py +109 -32
  16. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  17. snowflake/ml/model/_client/ops/model_ops.py +45 -2
  18. snowflake/ml/model/_client/ops/param_utils.py +124 -0
  19. snowflake/ml/model/_client/ops/service_ops.py +81 -61
  20. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +24 -9
  22. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +4 -0
  23. snowflake/ml/model/_client/sql/model_version.py +30 -6
  24. snowflake/ml/model/_client/sql/service.py +30 -29
  25. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  27. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
  28. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
  29. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
  30. snowflake/ml/model/_model_composer/model_method/model_method.py +62 -2
  31. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  32. snowflake/ml/model/_packager/model_handlers/huggingface.py +54 -10
  33. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +52 -16
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  36. snowflake/ml/model/_packager/model_packager.py +1 -1
  37. snowflake/ml/model/_signatures/core.py +85 -0
  38. snowflake/ml/model/_signatures/utils.py +55 -0
  39. snowflake/ml/model/code_path.py +104 -0
  40. snowflake/ml/model/custom_model.py +55 -13
  41. snowflake/ml/model/model_signature.py +13 -1
  42. snowflake/ml/model/openai_signatures.py +97 -0
  43. snowflake/ml/model/type_hints.py +2 -0
  44. snowflake/ml/registry/_manager/model_manager.py +230 -15
  45. snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
  46. snowflake/ml/registry/registry.py +4 -4
  47. snowflake/ml/version.py +1 -1
  48. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/METADATA +95 -1
  49. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/RECORD +52 -46
  50. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/WHEEL +0 -0
  51. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/licenses/LICENSE.txt +0 -0
  52. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,42 @@
1
+ from urllib.parse import urlunparse
2
+
3
+ from snowflake import snowpark as snowpark
4
+
5
+ JOB_URL_PREFIX = "#/compute/job/"
6
+ SERVICE_URL_PREFIX = "#/compute/service/"
7
+
8
+
9
+ def get_snowflake_url(
10
+ session: snowpark.Session,
11
+ url_path: str,
12
+ params: str = "",
13
+ query: str = "",
14
+ fragment: str = "",
15
+ ) -> str:
16
+ """Construct a Snowflake URL from session connection details and URL components.
17
+
18
+ Args:
19
+ session: The Snowpark session containing connection details.
20
+ url_path: The path component of the URL (e.g., "/compute/job/123").
21
+ params: Optional parameters for the URL (RFC 1808). Defaults to "".
22
+ query: Optional query string for the URL. Defaults to "".
23
+ fragment: Optional fragment identifier for the URL (e.g., "#section"). Defaults to "".
24
+
25
+ Returns:
26
+ A fully constructed Snowflake URL string with scheme, host, and specified components.
27
+ """
28
+ scheme = "https"
29
+ if hasattr(session.connection, "scheme"):
30
+ scheme = session.connection.scheme
31
+ host = session.connection.host
32
+
33
+ return urlunparse(
34
+ (
35
+ scheme,
36
+ host,
37
+ url_path,
38
+ params,
39
+ query,
40
+ fragment,
41
+ )
42
+ )
@@ -2,6 +2,7 @@ from snowflake.ml.jobs._interop.exception_utils import install_exception_display
2
2
  from snowflake.ml.jobs._utils.types import JOB_STATUS
3
3
  from snowflake.ml.jobs.decorators import remote
4
4
  from snowflake.ml.jobs.job import MLJob
5
+ from snowflake.ml.jobs.job_definition import MLJobDefinition
5
6
  from snowflake.ml.jobs.manager import (
6
7
  delete_job,
7
8
  get_job,
@@ -24,4 +25,5 @@ __all__ = [
24
25
  "MLJob",
25
26
  "JOB_STATUS",
26
27
  "submit_from_stage",
28
+ "MLJobDefinition",
27
29
  ]
@@ -5,6 +5,7 @@ from snowflake.ml.jobs._utils.types import ComputeResources
5
5
  DEFAULT_CONTAINER_NAME = "main"
6
6
  MEMORY_VOLUME_NAME = "dshm"
7
7
  STAGE_VOLUME_NAME = "stage-volume"
8
+ DEFAULT_PYTHON_VERSION = "3.10"
8
9
 
9
10
  # Environment variables
10
11
  STAGE_MOUNT_PATH_ENV_VAR = "MLRS_STAGE_MOUNT_PATH"
@@ -30,6 +31,7 @@ DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
30
31
  DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
31
32
  DEFAULT_IMAGE_TAG = "1.8.0"
32
33
  DEFAULT_ENTRYPOINT_PATH = "func.py"
34
+ DEFAULT_PYTHON_VERSION = "3.10"
33
35
 
34
36
  # Percent of container memory to allocate for /dev/shm volume
35
37
  MEMORY_VOLUME_SIZE = 0.3
@@ -11,6 +11,7 @@ from importlib.abc import Traversable
11
11
  from pathlib import Path, PurePath
12
12
  from types import ModuleType
13
13
  from typing import IO, Any, Callable, Optional, Union, cast, get_args, get_origin
14
+ from uuid import uuid4
14
15
 
15
16
  import cloudpickle as cp
16
17
  from packaging import version
@@ -36,10 +37,15 @@ _SUPPORTED_ARG_TYPES = {str, int, float}
36
37
  _SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
37
38
  _ENTRYPOINT_FUNC_NAME = "func"
38
39
  _STARTUP_SCRIPT_PATH = PurePath("startup.sh")
40
+ JOB_ID_PREFIX = "MLJOB_"
39
41
 
40
42
 
41
43
  def _compress_and_upload_file(
42
- session: snowpark.Session, source_path: Path, stage_path: PurePath, import_path: Optional[str] = None
44
+ session: snowpark.Session,
45
+ source_path: Path,
46
+ stage_path: PurePath,
47
+ import_path: Optional[str] = None,
48
+ overwrite: bool = True,
43
49
  ) -> None:
44
50
  absolute_source_path = source_path.absolute()
45
51
  leading_path = absolute_source_path.as_posix()[: -len(import_path)] if import_path else None
@@ -49,11 +55,13 @@ def _compress_and_upload_file(
49
55
  cast(IO[bytes], stream),
50
56
  stage_path.joinpath(filename).as_posix(),
51
57
  auto_compress=False,
52
- overwrite=True,
58
+ overwrite=overwrite,
53
59
  )
54
60
 
55
61
 
56
- def _upload_directory(session: snowpark.Session, source_path: Path, payload_stage_path: PurePath) -> None:
62
+ def _upload_directory(
63
+ session: snowpark.Session, source_path: Path, payload_stage_path: PurePath, overwrite: bool = True
64
+ ) -> None:
57
65
  # Manually traverse the directory and upload each file, since Snowflake PUT
58
66
  # can't handle directories. Reduce the number of PUT operations by using
59
67
  # wildcard patterns to batch upload files with the same extension.
@@ -81,12 +89,14 @@ def _upload_directory(session: snowpark.Session, source_path: Path, payload_stag
81
89
  session.file.put(
82
90
  str(path),
83
91
  payload_stage_path.joinpath(path.parent.relative_to(source_path)).as_posix(),
84
- overwrite=True,
92
+ overwrite=overwrite,
85
93
  auto_compress=False,
86
94
  )
87
95
 
88
96
 
89
- def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_specs: types.PayloadSpec) -> None:
97
+ def upload_payloads(
98
+ session: snowpark.Session, stage_path: PurePath, *payload_specs: types.PayloadSpec, overwrite: bool = True
99
+ ) -> None:
90
100
  for spec in payload_specs:
91
101
  source_path = spec.source_path
92
102
  remote_relative_path = spec.remote_relative_path
@@ -109,6 +119,7 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp
109
119
  source_path,
110
120
  stage_path,
111
121
  remote_relative_path.as_posix() if remote_relative_path else None,
122
+ overwrite=overwrite,
112
123
  )
113
124
  else:
114
125
  _upload_directory(session, source_path, payload_stage_path)
@@ -120,12 +131,13 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp
120
131
  source_path,
121
132
  stage_path,
122
133
  remote_relative_path.as_posix() if remote_relative_path else None,
134
+ overwrite=overwrite,
123
135
  )
124
136
  else:
125
137
  session.file.put(
126
138
  str(source_path.resolve()),
127
139
  payload_stage_path.as_posix(),
128
- overwrite=True,
140
+ overwrite=overwrite,
129
141
  auto_compress=False,
130
142
  )
131
143
 
@@ -455,7 +467,9 @@ class JobPayload:
455
467
  self.pip_requirements = pip_requirements
456
468
  self.imports = imports
457
469
 
458
- def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
470
+ def upload(
471
+ self, session: snowpark.Session, stage_path: Union[str, PurePath], overwrite: bool = False
472
+ ) -> types.UploadedPayload:
459
473
  # Prepare local variables
460
474
  stage_path = PurePath(stage_path) if isinstance(stage_path, str) else stage_path
461
475
  source = resolve_source(self.source)
@@ -482,7 +496,6 @@ class JobPayload:
482
496
 
483
497
  # Handle list entrypoints (custom commands like ["arctic_training"])
484
498
  if isinstance(entrypoint, (list, tuple)):
485
- payload_name = entrypoint[0] if entrypoint else None
486
499
  # For list entrypoints, still upload source if it's a path
487
500
  if isinstance(source, Path):
488
501
  upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
@@ -491,30 +504,24 @@ class JobPayload:
491
504
  python_entrypoint: list[Union[str, PurePath]] = list(entrypoint)
492
505
  else:
493
506
  # Standard file-based entrypoint handling
494
- payload_name = None
495
507
  if not isinstance(source, types.PayloadPath):
496
- if isinstance(source, function_payload_utils.FunctionPayload):
497
- payload_name = source.function.__name__
498
-
499
508
  source_code = generate_python_code(source, source_code_display=True)
500
509
  _ = session.file.put_stream(
501
510
  io.BytesIO(source_code.encode()),
502
511
  stage_location=app_stage_path.joinpath(entrypoint.file_path).as_posix(),
503
512
  auto_compress=False,
504
- overwrite=True,
513
+ overwrite=overwrite,
505
514
  )
506
515
  source = Path(entrypoint.file_path.parent)
507
516
 
508
517
  elif isinstance(source, stage_utils.StagePath):
509
- payload_name = entrypoint.file_path.stem
510
518
  # copy payload to stage
511
519
  if source == entrypoint.file_path:
512
520
  source = source.parent
513
- upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
521
+ upload_payloads(session, app_stage_path, types.PayloadSpec(source, None), overwrite=overwrite)
514
522
 
515
523
  elif isinstance(source, Path):
516
- payload_name = entrypoint.file_path.stem
517
- upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
524
+ upload_payloads(session, app_stage_path, types.PayloadSpec(source, None), overwrite=overwrite)
518
525
  if source.is_file():
519
526
  source = source.parent
520
527
 
@@ -565,7 +572,6 @@ class JobPayload:
565
572
  *python_entrypoint,
566
573
  ],
567
574
  env_vars=env_vars,
568
- payload_name=payload_name,
569
575
  )
570
576
 
571
577
 
@@ -759,3 +765,17 @@ def create_function_payload(
759
765
  payload = function_payload_utils.FunctionPayload(func, session, session_argument, *bound.args, **bound.kwargs)
760
766
 
761
767
  return payload
768
+
769
+
770
+ def get_payload_name(source: Union[str, Callable[..., Any]], entrypoint: Optional[Union[str, list[str]]] = None) -> str:
771
+
772
+ if entrypoint and isinstance(entrypoint, (list, tuple)):
773
+ return entrypoint[0]
774
+ elif entrypoint and isinstance(entrypoint, str):
775
+ return f"{PurePath(entrypoint).stem}"
776
+ elif source and not callable(source):
777
+ return f"{PurePath(source).stem}"
778
+ elif isinstance(source, function_payload_utils.FunctionPayload):
779
+ return f"{source.function.__name__}"
780
+ else:
781
+ return f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
@@ -14,10 +14,17 @@ def result_set_to_rows(session: snowpark.Session, result: dict[str, Any]) -> lis
14
14
 
15
15
 
16
16
  @snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
17
- def run_query(session: snowpark.Session, query_text: str, params: Optional[Sequence[Any]] = None) -> list[Row]:
17
+ def run_query(
18
+ session: snowpark.Session,
19
+ query_text: str,
20
+ params: Optional[Sequence[Any]] = None,
21
+ statement_params: Optional[dict[str, Any]] = None,
22
+ ) -> list[Row]:
18
23
  kwargs: dict[str, Any] = {"query": query_text, "params": params}
19
24
  if not is_in_stored_procedure(): # type: ignore[no-untyped-call]
20
25
  kwargs["_force_qmark_paramstyle"] = True
26
+ if statement_params:
27
+ kwargs["_statement_params"] = statement_params
21
28
  result = session._conn.run_query(**kwargs)
22
29
  if not isinstance(result, dict) or "data" not in result:
23
30
  raise ValueError(f"Unprocessable result: {result}")
@@ -1,8 +1,13 @@
1
- from typing import Any, Optional, Union
1
+ import datetime
2
+ import logging
3
+ from typing import Any, Literal, Optional, Union
2
4
 
3
5
  from packaging.version import Version
4
6
  from pydantic import BaseModel, Field, RootModel, field_validator
5
7
 
8
+ from snowflake import snowpark
9
+ from snowflake.ml.jobs._utils import constants, query_helper
10
+
6
11
 
7
12
  class SpcsContainerRuntime(BaseModel):
8
13
  python_version: Version = Field(alias="pythonVersion")
@@ -27,6 +32,8 @@ class SpcsContainerRuntime(BaseModel):
27
32
 
28
33
  class RuntimeEnvironmentEntry(BaseModel):
29
34
  spcs_container_runtime: Optional[SpcsContainerRuntime] = Field(alias="spcsContainerRuntime", default=None)
35
+ created_on: datetime.datetime = Field(alias="createdOn")
36
+ id: Optional[str] = Field(alias="id")
30
37
 
31
38
  class Config:
32
39
  extra = "allow"
@@ -57,7 +64,54 @@ class RuntimeEnvironmentsDict(RootModel[dict[str, RuntimeEnvironmentEntry]]):
57
64
  # Filter out any key whose value is not a dict
58
65
  return {key: value for key, value in data.items() if isinstance(value, dict)}
59
66
 
60
- def get_spcs_container_runtimes(self) -> list[SpcsContainerRuntime]:
61
- return [
62
- entry.spcs_container_runtime for entry in self.root.values() if entry.spcs_container_runtime is not None
67
+ def get_spcs_container_runtimes(
68
+ self,
69
+ *,
70
+ hardware_type: Optional[str] = None,
71
+ python_version: Optional[Version] = None,
72
+ ) -> list[SpcsContainerRuntime]:
73
+ # TODO(SNOW-2682000): parse version from NRE in a safer way, like relying on the label,id or image tag.
74
+ entries: list[RuntimeEnvironmentEntry] = [
75
+ entry
76
+ for entry in self.root.values()
77
+ if entry.spcs_container_runtime is not None
78
+ and (hardware_type is None or entry.spcs_container_runtime.hardware_type.lower() == hardware_type.lower())
79
+ and (
80
+ python_version is None
81
+ or (
82
+ entry.spcs_container_runtime.python_version.major == python_version.major
83
+ and entry.spcs_container_runtime.python_version.minor == python_version.minor
84
+ )
85
+ )
63
86
  ]
87
+ entries.sort(key=lambda e: e.created_on, reverse=True)
88
+
89
+ return [entry.spcs_container_runtime for entry in entries if entry.spcs_container_runtime is not None]
90
+
91
+
92
+ def _extract_image_tag(image_url: str) -> Optional[str]:
93
+ image_tag = image_url.rsplit(":", 1)[-1]
94
+ return image_tag
95
+
96
+
97
+ def find_runtime_image(
98
+ session: snowpark.Session, target_hardware: Literal["CPU", "GPU"], target_python_version: Optional[str] = None
99
+ ) -> Optional[str]:
100
+ python_version = (
101
+ Version(target_python_version) if target_python_version else Version(constants.DEFAULT_PYTHON_VERSION)
102
+ )
103
+ rows = query_helper.run_query(session, "CALL SYSTEM$NOTEBOOKS_FIND_LABELED_RUNTIMES()")
104
+ if not rows:
105
+ return None
106
+ try:
107
+ runtime_envs = RuntimeEnvironmentsDict.model_validate_json(rows[0][0])
108
+ spcs_container_runtimes = runtime_envs.get_spcs_container_runtimes(
109
+ hardware_type=target_hardware,
110
+ python_version=python_version,
111
+ )
112
+ except Exception as e:
113
+ logging.warning(f"Failed to parse runtime image name from {rows[0][0]}, error: {e}")
114
+ return None
115
+
116
+ selected_runtime = spcs_container_runtimes[0] if spcs_container_runtimes else None
117
+ return selected_runtime.runtime_container_image if selected_runtime else None
@@ -1,11 +1,6 @@
1
- import logging
2
- import sys
3
- from typing import Literal, Optional
4
-
5
1
  from snowflake import snowpark
6
2
  from snowflake.ml._internal.utils import snowflake_env
7
3
  from snowflake.ml.jobs._utils import constants, query_helper, types
8
- from snowflake.ml.jobs._utils.runtime_env_utils import RuntimeEnvironmentsDict
9
4
 
10
5
 
11
6
  def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
@@ -25,29 +20,3 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
25
20
  constants.COMMON_INSTANCE_FAMILIES.get(instance_family)
26
21
  or constants.CLOUD_INSTANCE_FAMILIES[cloud][instance_family]
27
22
  )
28
-
29
-
30
- def _get_runtime_image(session: snowpark.Session, target_hardware: Literal["CPU", "GPU"]) -> Optional[str]:
31
- rows = query_helper.run_query(session, "CALL SYSTEM$NOTEBOOKS_FIND_LABELED_RUNTIMES()")
32
- if not rows:
33
- return None
34
- try:
35
- runtime_envs = RuntimeEnvironmentsDict.model_validate_json(rows[0][0])
36
- spcs_container_runtimes = runtime_envs.get_spcs_container_runtimes()
37
- except Exception as e:
38
- logging.warning(f"Failed to parse runtime image name from {rows[0][0]}, error: {e}")
39
- return None
40
-
41
- selected_runtime = next(
42
- (
43
- runtime
44
- for runtime in spcs_container_runtimes
45
- if (
46
- runtime.hardware_type.lower() == target_hardware.lower()
47
- and runtime.python_version.major == sys.version_info.major
48
- and runtime.python_version.minor == sys.version_info.minor
49
- )
50
- ),
51
- None,
52
- )
53
- return selected_runtime.runtime_container_image if selected_runtime else None
@@ -52,7 +52,7 @@ class StagePath:
52
52
  if self._path.parent == Path(""):
53
53
  return StagePath(self._root)
54
54
  else:
55
- return StagePath(f"{self._root}/{self._path.parent}")
55
+ return StagePath(f"{self._root}/{self._path.parent.as_posix()}")
56
56
 
57
57
  @property
58
58
  def root(self) -> str:
@@ -67,7 +67,7 @@ class StagePath:
67
67
  if path == Path(""):
68
68
  return self.root
69
69
  else:
70
- return f"{self.root}/{path}"
70
+ return f"{self.root}/{path.as_posix()}"
71
71
 
72
72
  def is_relative_to(self, *other: Union[str, os.PathLike[str]]) -> bool:
73
73
  if not other:
@@ -1,7 +1,7 @@
1
1
  import os
2
2
  from dataclasses import dataclass, field
3
3
  from pathlib import PurePath
4
- from typing import Literal, Optional, Protocol, Union, runtime_checkable
4
+ from typing import Any, Literal, Optional, Protocol, Union, runtime_checkable
5
5
 
6
6
  from typing_extensions import Self
7
7
 
@@ -103,7 +103,6 @@ class UploadedPayload:
103
103
  stage_path: PurePath
104
104
  entrypoint: list[Union[str, PurePath]]
105
105
  env_vars: dict[str, str] = field(default_factory=dict)
106
- payload_name: Optional[str] = None
107
106
 
108
107
 
109
108
  @dataclass(frozen=True)
@@ -128,3 +127,24 @@ class ServiceInfo:
128
127
  status: str
129
128
  compute_pool: str
130
129
  target_instances: int
130
+
131
+
132
+ @dataclass
133
+ class JobOptions:
134
+ external_access_integrations: Optional[list[str]] = None
135
+ query_warehouse: Optional[str] = None
136
+ target_instances: Optional[int] = None
137
+ min_instances: Optional[int] = None
138
+ use_async: Optional[bool] = True
139
+ generate_suffix: Optional[bool] = True
140
+
141
+
142
+ @dataclass
143
+ class SpecOptions:
144
+ stage_path: str
145
+ args: Optional[list[str]] = None
146
+ env_vars: Optional[dict[str, str]] = None
147
+ enable_metrics: Optional[bool] = None
148
+ spec_overrides: Optional[dict[str, Any]] = None
149
+ runtime: Optional[str] = None
150
+ enable_stage_mount_v2: Optional[bool] = True
@@ -0,0 +1,232 @@
1
+ import dataclasses
2
+ import json
3
+ import logging
4
+ import os
5
+ import sys
6
+ from pathlib import PurePath, PurePosixPath
7
+ from typing import Any, Callable, Generic, Optional, TypeVar, Union
8
+ from uuid import uuid4
9
+
10
+ from typing_extensions import ParamSpec
11
+
12
+ from snowflake import snowpark
13
+ from snowflake.ml._internal import telemetry
14
+ from snowflake.ml._internal.utils import identifier
15
+ from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
16
+ from snowflake.ml.jobs import job as jb
17
+ from snowflake.ml.jobs._utils import (
18
+ constants,
19
+ feature_flags,
20
+ payload_utils,
21
+ query_helper,
22
+ types,
23
+ )
24
+ from snowflake.snowpark import context as sp_context
25
+ from snowflake.snowpark.exceptions import SnowparkSQLException
26
+
27
+ _Args = ParamSpec("_Args")
28
+ _ReturnValue = TypeVar("_ReturnValue")
29
+ JOB_ID_PREFIX = "MLJOB_"
30
+ _PROJECT = "MLJob"
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
35
+ def __init__(
36
+ self,
37
+ job_options: types.JobOptions,
38
+ spec_options: types.SpecOptions,
39
+ stage_name: str,
40
+ compute_pool: str,
41
+ name: str,
42
+ entrypoint_args: list[Any],
43
+ database: Optional[str] = None,
44
+ schema: Optional[str] = None,
45
+ session: Optional[snowpark.Session] = None,
46
+ ) -> None:
47
+ self.stage_name = stage_name
48
+ self.job_options = job_options
49
+ self.spec_options = spec_options
50
+ self.compute_pool = compute_pool
51
+ self.session = session or sp_context.get_active_session()
52
+ self.database = database or self.session.get_current_database()
53
+ self.schema = schema or self.session.get_current_schema()
54
+ self.job_definition_id = identifier.get_schema_level_object_identifier(self.database, self.schema, name)
55
+ self.entrypoint_args = entrypoint_args
56
+
57
+ def delete(self) -> None:
58
+ if self.stage_name:
59
+ try:
60
+ self.session.sql(f"REMOVE {self.stage_name}/").collect()
61
+ logger.debug(f"Successfully cleaned up stage files for job definition {self.stage_name}")
62
+ except Exception as e:
63
+ logger.warning(f"Failed to clean up stage files for job definition {self.stage_name}: {e}")
64
+
65
+ def _prepare_arguments(self, *args: _Args.args, **kwargs: _Args.kwargs) -> list[Any]:
66
+ # TODO: Add ArgProtocol and respective logics
67
+ return [arg for arg in args]
68
+
69
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
70
+ def __call__(self, *args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
71
+ statement_params = telemetry.get_statement_params(_PROJECT)
72
+ statement_params = telemetry.add_statement_params_custom_tags(
73
+ statement_params,
74
+ custom_tags={
75
+ "job_definition_id": self.job_definition_id,
76
+ },
77
+ )
78
+ args_list = self._prepare_arguments(*args, **kwargs)
79
+ query = self.to_sql(job_args=args_list, use_async=True)
80
+ job_id = query_helper.run_query(self.session, query, statement_params=statement_params)[0][0]
81
+ return jb.MLJob[_ReturnValue](job_id, session=self.session)
82
+
83
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
84
+ def to_sql(self, *, job_args: Optional[list[Any]] = None, use_async: bool = False) -> str:
85
+ # Combine the entrypoint_args and job_args for use in the query
86
+ combined_args = (self.entrypoint_args or []) + (job_args or [])
87
+ spec_options = dataclasses.replace(self.spec_options, args=combined_args)
88
+ # Uppercase option keys to match the expected SYSTEM$EXECUTE_ML_JOB parameter format
89
+ spec_options_dict = {k.upper(): v for k, v in dataclasses.asdict(spec_options).items()}
90
+ job_options = dataclasses.replace(self.job_options, use_async=use_async)
91
+ # Uppercase option keys to match the expected SYSTEM$EXECUTE_ML_JOB parameter format
92
+ job_options_dict = {k.upper(): v for k, v in dataclasses.asdict(job_options).items()}
93
+ job_options_dict["ASYNC"] = job_options_dict.pop("USE_ASYNC")
94
+ params = [
95
+ self.job_definition_id + ("_" if self.job_options.generate_suffix else ""),
96
+ self.compute_pool,
97
+ json.dumps(spec_options_dict),
98
+ json.dumps(job_options_dict),
99
+ ]
100
+ query_template = "CALL SYSTEM$EXECUTE_ML_JOB(%s, %s, %s, %s)"
101
+ sql = self.session._conn._cursor._preprocess_pyformat_query(query_template, params)
102
+ return sql
103
+
104
+ @classmethod
105
+ @telemetry.send_api_usage_telemetry(
106
+ project=_PROJECT,
107
+ func_params_to_log=[
108
+ "pip_requirements",
109
+ "external_access_integrations",
110
+ "target_instances",
111
+ "min_instances",
112
+ "enable_metrics",
113
+ "query_warehouse",
114
+ "runtime_environment",
115
+ ],
116
+ )
117
+ def register(
118
+ cls,
119
+ source: Union[str, Callable[_Args, _ReturnValue]],
120
+ compute_pool: str,
121
+ stage_name: str,
122
+ session: Optional[snowpark.Session] = None,
123
+ entrypoint: Optional[Union[str, list[str]]] = None,
124
+ target_instances: int = 1,
125
+ generate_suffix: bool = True,
126
+ **kwargs: Any,
127
+ ) -> "MLJobDefinition[_Args, _ReturnValue]":
128
+ # Use kwargs for less common optional parameters
129
+ database = kwargs.pop("database", None)
130
+ schema = kwargs.pop("schema", None)
131
+ min_instances = kwargs.pop("min_instances", target_instances)
132
+ pip_requirements = kwargs.pop("pip_requirements", None)
133
+ external_access_integrations = kwargs.pop("external_access_integrations", None)
134
+ env_vars = kwargs.pop("env_vars", None)
135
+ spec_overrides = kwargs.pop("spec_overrides", None)
136
+ enable_metrics = kwargs.pop("enable_metrics", True)
137
+ session = session or sp_context.get_active_session()
138
+ query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
139
+ imports = kwargs.pop("imports", None)
140
+ runtime_environment = kwargs.pop(
141
+ "runtime_environment", os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR, None)
142
+ )
143
+ overwrite = kwargs.pop("overwrite", False)
144
+ name = kwargs.pop("name", None)
145
+ # Warn if there are unknown kwargs
146
+ if kwargs:
147
+ logger.warning(f"Ignoring unknown kwargs: {kwargs.keys()}")
148
+
149
+ # Validate parameters
150
+ if database and not schema:
151
+ raise ValueError("Schema must be specified if database is specified.")
152
+ if target_instances < 1:
153
+ raise ValueError("target_instances must be greater than 0.")
154
+ if not (0 < min_instances <= target_instances):
155
+ raise ValueError("min_instances must be greater than 0 and less than or equal to target_instances.")
156
+ if min_instances > 1:
157
+ # Validate min_instances against compute pool max_nodes
158
+ pool_info = jb._get_compute_pool_info(session, compute_pool)
159
+ max_nodes = int(pool_info["max_nodes"])
160
+ if min_instances > max_nodes:
161
+ raise ValueError(
162
+ f"The requested min_instances ({min_instances}) exceeds the max_nodes ({max_nodes}) "
163
+ f"of compute pool '{compute_pool}'. Reduce min_instances or increase max_nodes."
164
+ )
165
+
166
+ if name:
167
+ parsed_database, parsed_schema, parsed_name = identifier.parse_schema_level_object_identifier(name)
168
+ database = parsed_database or database
169
+ schema = parsed_schema or schema
170
+ name = parsed_name
171
+ else:
172
+ name = payload_utils.get_payload_name(source, entrypoint)
173
+
174
+ # The logical identifier for this job definition (used in the stage path)
175
+ # is the resolved object name, not the fully qualified identifier.
176
+ job_definition_id = name if not generate_suffix else name + _generate_suffix()
177
+ stage_path_parts = identifier.parse_snowflake_stage_path(stage_name.lstrip("@"))
178
+ stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
179
+ stage_path = PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_definition_id}")
180
+
181
+ try:
182
+ # Upload payload
183
+ uploaded_payload = payload_utils.JobPayload(
184
+ source, entrypoint=entrypoint, pip_requirements=pip_requirements, imports=imports
185
+ ).upload(session, stage_path, overwrite)
186
+ except SnowparkSQLException as e:
187
+ if e.sql_error_code == 90106:
188
+ raise RuntimeError(
189
+ "Please specify a schema, either in the session context or as a parameter in the job submission"
190
+ )
191
+ raise
192
+
193
+ if runtime_environment is None and feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled(default=True):
194
+ # Pass a JSON object for runtime versions so it serializes as nested JSON in options
195
+ runtime_environment = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
196
+
197
+ combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
198
+ entrypoint_args = [v.as_posix() if isinstance(v, PurePath) else v for v in uploaded_payload.entrypoint]
199
+ spec_options = types.SpecOptions(
200
+ stage_path=stage_path.as_posix(),
201
+ # the args will be set at runtime
202
+ args=None,
203
+ env_vars=combined_env_vars,
204
+ enable_metrics=enable_metrics,
205
+ spec_overrides=spec_overrides,
206
+ runtime=runtime_environment if runtime_environment else None,
207
+ enable_stage_mount_v2=feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(default=True),
208
+ )
209
+
210
+ job_options = types.JobOptions(
211
+ external_access_integrations=external_access_integrations,
212
+ query_warehouse=query_warehouse,
213
+ target_instances=target_instances,
214
+ min_instances=min_instances,
215
+ generate_suffix=generate_suffix,
216
+ )
217
+
218
+ return cls(
219
+ stage_name=stage_path.as_posix(),
220
+ spec_options=spec_options,
221
+ job_options=job_options,
222
+ compute_pool=compute_pool,
223
+ entrypoint_args=entrypoint_args,
224
+ session=session,
225
+ database=database,
226
+ schema=schema,
227
+ name=name,
228
+ )
229
+
230
+
231
+ def _generate_suffix() -> str:
232
+ return str(uuid4().hex)[:8]