snowflake-ml-python 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/url.py +42 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +2 -0
- snowflake/ml/jobs/_utils/payload_utils.py +38 -18
- snowflake/ml/jobs/_utils/query_helper.py +8 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +58 -4
- snowflake/ml/jobs/_utils/spec_utils.py +0 -31
- snowflake/ml/jobs/_utils/stage_utils.py +2 -2
- snowflake/ml/jobs/_utils/types.py +22 -2
- snowflake/ml/jobs/job_definition.py +232 -0
- snowflake/ml/jobs/manager.py +16 -177
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
- snowflake/ml/model/_client/model/model_version_impl.py +109 -32
- snowflake/ml/model/_client/ops/deployment_step.py +36 -0
- snowflake/ml/model/_client/ops/model_ops.py +45 -2
- snowflake/ml/model/_client/ops/param_utils.py +124 -0
- snowflake/ml/model/_client/ops/service_ops.py +81 -61
- snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +24 -9
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +4 -0
- snowflake/ml/model/_client/sql/model_version.py +30 -6
- snowflake/ml/model/_client/sql/service.py +30 -29
- snowflake/ml/model/_model_composer/model_composer.py +1 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/model_method.py +62 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
- snowflake/ml/model/_packager/model_handlers/huggingface.py +54 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +52 -16
- snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
- snowflake/ml/model/_packager/model_packager.py +1 -1
- snowflake/ml/model/_signatures/core.py +85 -0
- snowflake/ml/model/_signatures/utils.py +55 -0
- snowflake/ml/model/code_path.py +104 -0
- snowflake/ml/model/custom_model.py +55 -13
- snowflake/ml/model/model_signature.py +13 -1
- snowflake/ml/model/openai_signatures.py +97 -0
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/registry/_manager/model_manager.py +230 -15
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
- snowflake/ml/registry/registry.py +4 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/METADATA +95 -1
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/RECORD +52 -46
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from urllib.parse import urlunparse
|
|
2
|
+
|
|
3
|
+
from snowflake import snowpark as snowpark
|
|
4
|
+
|
|
5
|
+
JOB_URL_PREFIX = "#/compute/job/"
|
|
6
|
+
SERVICE_URL_PREFIX = "#/compute/service/"
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_snowflake_url(
|
|
10
|
+
session: snowpark.Session,
|
|
11
|
+
url_path: str,
|
|
12
|
+
params: str = "",
|
|
13
|
+
query: str = "",
|
|
14
|
+
fragment: str = "",
|
|
15
|
+
) -> str:
|
|
16
|
+
"""Construct a Snowflake URL from session connection details and URL components.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
session: The Snowpark session containing connection details.
|
|
20
|
+
url_path: The path component of the URL (e.g., "/compute/job/123").
|
|
21
|
+
params: Optional parameters for the URL (RFC 1808). Defaults to "".
|
|
22
|
+
query: Optional query string for the URL. Defaults to "".
|
|
23
|
+
fragment: Optional fragment identifier for the URL (e.g., "#section"). Defaults to "".
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
A fully constructed Snowflake URL string with scheme, host, and specified components.
|
|
27
|
+
"""
|
|
28
|
+
scheme = "https"
|
|
29
|
+
if hasattr(session.connection, "scheme"):
|
|
30
|
+
scheme = session.connection.scheme
|
|
31
|
+
host = session.connection.host
|
|
32
|
+
|
|
33
|
+
return urlunparse(
|
|
34
|
+
(
|
|
35
|
+
scheme,
|
|
36
|
+
host,
|
|
37
|
+
url_path,
|
|
38
|
+
params,
|
|
39
|
+
query,
|
|
40
|
+
fragment,
|
|
41
|
+
)
|
|
42
|
+
)
|
snowflake/ml/jobs/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ from snowflake.ml.jobs._interop.exception_utils import install_exception_display
|
|
|
2
2
|
from snowflake.ml.jobs._utils.types import JOB_STATUS
|
|
3
3
|
from snowflake.ml.jobs.decorators import remote
|
|
4
4
|
from snowflake.ml.jobs.job import MLJob
|
|
5
|
+
from snowflake.ml.jobs.job_definition import MLJobDefinition
|
|
5
6
|
from snowflake.ml.jobs.manager import (
|
|
6
7
|
delete_job,
|
|
7
8
|
get_job,
|
|
@@ -24,4 +25,5 @@ __all__ = [
|
|
|
24
25
|
"MLJob",
|
|
25
26
|
"JOB_STATUS",
|
|
26
27
|
"submit_from_stage",
|
|
28
|
+
"MLJobDefinition",
|
|
27
29
|
]
|
|
@@ -5,6 +5,7 @@ from snowflake.ml.jobs._utils.types import ComputeResources
|
|
|
5
5
|
DEFAULT_CONTAINER_NAME = "main"
|
|
6
6
|
MEMORY_VOLUME_NAME = "dshm"
|
|
7
7
|
STAGE_VOLUME_NAME = "stage-volume"
|
|
8
|
+
DEFAULT_PYTHON_VERSION = "3.10"
|
|
8
9
|
|
|
9
10
|
# Environment variables
|
|
10
11
|
STAGE_MOUNT_PATH_ENV_VAR = "MLRS_STAGE_MOUNT_PATH"
|
|
@@ -30,6 +31,7 @@ DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
|
|
30
31
|
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
|
31
32
|
DEFAULT_IMAGE_TAG = "1.8.0"
|
|
32
33
|
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
|
34
|
+
DEFAULT_PYTHON_VERSION = "3.10"
|
|
33
35
|
|
|
34
36
|
# Percent of container memory to allocate for /dev/shm volume
|
|
35
37
|
MEMORY_VOLUME_SIZE = 0.3
|
|
@@ -11,6 +11,7 @@ from importlib.abc import Traversable
|
|
|
11
11
|
from pathlib import Path, PurePath
|
|
12
12
|
from types import ModuleType
|
|
13
13
|
from typing import IO, Any, Callable, Optional, Union, cast, get_args, get_origin
|
|
14
|
+
from uuid import uuid4
|
|
14
15
|
|
|
15
16
|
import cloudpickle as cp
|
|
16
17
|
from packaging import version
|
|
@@ -36,10 +37,15 @@ _SUPPORTED_ARG_TYPES = {str, int, float}
|
|
|
36
37
|
_SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
|
|
37
38
|
_ENTRYPOINT_FUNC_NAME = "func"
|
|
38
39
|
_STARTUP_SCRIPT_PATH = PurePath("startup.sh")
|
|
40
|
+
JOB_ID_PREFIX = "MLJOB_"
|
|
39
41
|
|
|
40
42
|
|
|
41
43
|
def _compress_and_upload_file(
|
|
42
|
-
session: snowpark.Session,
|
|
44
|
+
session: snowpark.Session,
|
|
45
|
+
source_path: Path,
|
|
46
|
+
stage_path: PurePath,
|
|
47
|
+
import_path: Optional[str] = None,
|
|
48
|
+
overwrite: bool = True,
|
|
43
49
|
) -> None:
|
|
44
50
|
absolute_source_path = source_path.absolute()
|
|
45
51
|
leading_path = absolute_source_path.as_posix()[: -len(import_path)] if import_path else None
|
|
@@ -49,11 +55,13 @@ def _compress_and_upload_file(
|
|
|
49
55
|
cast(IO[bytes], stream),
|
|
50
56
|
stage_path.joinpath(filename).as_posix(),
|
|
51
57
|
auto_compress=False,
|
|
52
|
-
overwrite=
|
|
58
|
+
overwrite=overwrite,
|
|
53
59
|
)
|
|
54
60
|
|
|
55
61
|
|
|
56
|
-
def _upload_directory(
|
|
62
|
+
def _upload_directory(
|
|
63
|
+
session: snowpark.Session, source_path: Path, payload_stage_path: PurePath, overwrite: bool = True
|
|
64
|
+
) -> None:
|
|
57
65
|
# Manually traverse the directory and upload each file, since Snowflake PUT
|
|
58
66
|
# can't handle directories. Reduce the number of PUT operations by using
|
|
59
67
|
# wildcard patterns to batch upload files with the same extension.
|
|
@@ -81,12 +89,14 @@ def _upload_directory(session: snowpark.Session, source_path: Path, payload_stag
|
|
|
81
89
|
session.file.put(
|
|
82
90
|
str(path),
|
|
83
91
|
payload_stage_path.joinpath(path.parent.relative_to(source_path)).as_posix(),
|
|
84
|
-
overwrite=
|
|
92
|
+
overwrite=overwrite,
|
|
85
93
|
auto_compress=False,
|
|
86
94
|
)
|
|
87
95
|
|
|
88
96
|
|
|
89
|
-
def upload_payloads(
|
|
97
|
+
def upload_payloads(
|
|
98
|
+
session: snowpark.Session, stage_path: PurePath, *payload_specs: types.PayloadSpec, overwrite: bool = True
|
|
99
|
+
) -> None:
|
|
90
100
|
for spec in payload_specs:
|
|
91
101
|
source_path = spec.source_path
|
|
92
102
|
remote_relative_path = spec.remote_relative_path
|
|
@@ -109,6 +119,7 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp
|
|
|
109
119
|
source_path,
|
|
110
120
|
stage_path,
|
|
111
121
|
remote_relative_path.as_posix() if remote_relative_path else None,
|
|
122
|
+
overwrite=overwrite,
|
|
112
123
|
)
|
|
113
124
|
else:
|
|
114
125
|
_upload_directory(session, source_path, payload_stage_path)
|
|
@@ -120,12 +131,13 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp
|
|
|
120
131
|
source_path,
|
|
121
132
|
stage_path,
|
|
122
133
|
remote_relative_path.as_posix() if remote_relative_path else None,
|
|
134
|
+
overwrite=overwrite,
|
|
123
135
|
)
|
|
124
136
|
else:
|
|
125
137
|
session.file.put(
|
|
126
138
|
str(source_path.resolve()),
|
|
127
139
|
payload_stage_path.as_posix(),
|
|
128
|
-
overwrite=
|
|
140
|
+
overwrite=overwrite,
|
|
129
141
|
auto_compress=False,
|
|
130
142
|
)
|
|
131
143
|
|
|
@@ -455,7 +467,9 @@ class JobPayload:
|
|
|
455
467
|
self.pip_requirements = pip_requirements
|
|
456
468
|
self.imports = imports
|
|
457
469
|
|
|
458
|
-
def upload(
|
|
470
|
+
def upload(
|
|
471
|
+
self, session: snowpark.Session, stage_path: Union[str, PurePath], overwrite: bool = False
|
|
472
|
+
) -> types.UploadedPayload:
|
|
459
473
|
# Prepare local variables
|
|
460
474
|
stage_path = PurePath(stage_path) if isinstance(stage_path, str) else stage_path
|
|
461
475
|
source = resolve_source(self.source)
|
|
@@ -482,7 +496,6 @@ class JobPayload:
|
|
|
482
496
|
|
|
483
497
|
# Handle list entrypoints (custom commands like ["arctic_training"])
|
|
484
498
|
if isinstance(entrypoint, (list, tuple)):
|
|
485
|
-
payload_name = entrypoint[0] if entrypoint else None
|
|
486
499
|
# For list entrypoints, still upload source if it's a path
|
|
487
500
|
if isinstance(source, Path):
|
|
488
501
|
upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
|
|
@@ -491,30 +504,24 @@ class JobPayload:
|
|
|
491
504
|
python_entrypoint: list[Union[str, PurePath]] = list(entrypoint)
|
|
492
505
|
else:
|
|
493
506
|
# Standard file-based entrypoint handling
|
|
494
|
-
payload_name = None
|
|
495
507
|
if not isinstance(source, types.PayloadPath):
|
|
496
|
-
if isinstance(source, function_payload_utils.FunctionPayload):
|
|
497
|
-
payload_name = source.function.__name__
|
|
498
|
-
|
|
499
508
|
source_code = generate_python_code(source, source_code_display=True)
|
|
500
509
|
_ = session.file.put_stream(
|
|
501
510
|
io.BytesIO(source_code.encode()),
|
|
502
511
|
stage_location=app_stage_path.joinpath(entrypoint.file_path).as_posix(),
|
|
503
512
|
auto_compress=False,
|
|
504
|
-
overwrite=
|
|
513
|
+
overwrite=overwrite,
|
|
505
514
|
)
|
|
506
515
|
source = Path(entrypoint.file_path.parent)
|
|
507
516
|
|
|
508
517
|
elif isinstance(source, stage_utils.StagePath):
|
|
509
|
-
payload_name = entrypoint.file_path.stem
|
|
510
518
|
# copy payload to stage
|
|
511
519
|
if source == entrypoint.file_path:
|
|
512
520
|
source = source.parent
|
|
513
|
-
upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
|
|
521
|
+
upload_payloads(session, app_stage_path, types.PayloadSpec(source, None), overwrite=overwrite)
|
|
514
522
|
|
|
515
523
|
elif isinstance(source, Path):
|
|
516
|
-
|
|
517
|
-
upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
|
|
524
|
+
upload_payloads(session, app_stage_path, types.PayloadSpec(source, None), overwrite=overwrite)
|
|
518
525
|
if source.is_file():
|
|
519
526
|
source = source.parent
|
|
520
527
|
|
|
@@ -565,7 +572,6 @@ class JobPayload:
|
|
|
565
572
|
*python_entrypoint,
|
|
566
573
|
],
|
|
567
574
|
env_vars=env_vars,
|
|
568
|
-
payload_name=payload_name,
|
|
569
575
|
)
|
|
570
576
|
|
|
571
577
|
|
|
@@ -759,3 +765,17 @@ def create_function_payload(
|
|
|
759
765
|
payload = function_payload_utils.FunctionPayload(func, session, session_argument, *bound.args, **bound.kwargs)
|
|
760
766
|
|
|
761
767
|
return payload
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def get_payload_name(source: Union[str, Callable[..., Any]], entrypoint: Optional[Union[str, list[str]]] = None) -> str:
|
|
771
|
+
|
|
772
|
+
if entrypoint and isinstance(entrypoint, (list, tuple)):
|
|
773
|
+
return entrypoint[0]
|
|
774
|
+
elif entrypoint and isinstance(entrypoint, str):
|
|
775
|
+
return f"{PurePath(entrypoint).stem}"
|
|
776
|
+
elif source and not callable(source):
|
|
777
|
+
return f"{PurePath(source).stem}"
|
|
778
|
+
elif isinstance(source, function_payload_utils.FunctionPayload):
|
|
779
|
+
return f"{source.function.__name__}"
|
|
780
|
+
else:
|
|
781
|
+
return f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
|
|
@@ -14,10 +14,17 @@ def result_set_to_rows(session: snowpark.Session, result: dict[str, Any]) -> lis
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
@snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
|
|
17
|
-
def run_query(
|
|
17
|
+
def run_query(
|
|
18
|
+
session: snowpark.Session,
|
|
19
|
+
query_text: str,
|
|
20
|
+
params: Optional[Sequence[Any]] = None,
|
|
21
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
22
|
+
) -> list[Row]:
|
|
18
23
|
kwargs: dict[str, Any] = {"query": query_text, "params": params}
|
|
19
24
|
if not is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
|
20
25
|
kwargs["_force_qmark_paramstyle"] = True
|
|
26
|
+
if statement_params:
|
|
27
|
+
kwargs["_statement_params"] = statement_params
|
|
21
28
|
result = session._conn.run_query(**kwargs)
|
|
22
29
|
if not isinstance(result, dict) or "data" not in result:
|
|
23
30
|
raise ValueError(f"Unprocessable result: {result}")
|
|
@@ -1,8 +1,13 @@
|
|
|
1
|
-
|
|
1
|
+
import datetime
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Any, Literal, Optional, Union
|
|
2
4
|
|
|
3
5
|
from packaging.version import Version
|
|
4
6
|
from pydantic import BaseModel, Field, RootModel, field_validator
|
|
5
7
|
|
|
8
|
+
from snowflake import snowpark
|
|
9
|
+
from snowflake.ml.jobs._utils import constants, query_helper
|
|
10
|
+
|
|
6
11
|
|
|
7
12
|
class SpcsContainerRuntime(BaseModel):
|
|
8
13
|
python_version: Version = Field(alias="pythonVersion")
|
|
@@ -27,6 +32,8 @@ class SpcsContainerRuntime(BaseModel):
|
|
|
27
32
|
|
|
28
33
|
class RuntimeEnvironmentEntry(BaseModel):
|
|
29
34
|
spcs_container_runtime: Optional[SpcsContainerRuntime] = Field(alias="spcsContainerRuntime", default=None)
|
|
35
|
+
created_on: datetime.datetime = Field(alias="createdOn")
|
|
36
|
+
id: Optional[str] = Field(alias="id")
|
|
30
37
|
|
|
31
38
|
class Config:
|
|
32
39
|
extra = "allow"
|
|
@@ -57,7 +64,54 @@ class RuntimeEnvironmentsDict(RootModel[dict[str, RuntimeEnvironmentEntry]]):
|
|
|
57
64
|
# Filter out any key whose value is not a dict
|
|
58
65
|
return {key: value for key, value in data.items() if isinstance(value, dict)}
|
|
59
66
|
|
|
60
|
-
def get_spcs_container_runtimes(
|
|
61
|
-
|
|
62
|
-
|
|
67
|
+
def get_spcs_container_runtimes(
|
|
68
|
+
self,
|
|
69
|
+
*,
|
|
70
|
+
hardware_type: Optional[str] = None,
|
|
71
|
+
python_version: Optional[Version] = None,
|
|
72
|
+
) -> list[SpcsContainerRuntime]:
|
|
73
|
+
# TODO(SNOW-2682000): parse version from NRE in a safer way, like relying on the label,id or image tag.
|
|
74
|
+
entries: list[RuntimeEnvironmentEntry] = [
|
|
75
|
+
entry
|
|
76
|
+
for entry in self.root.values()
|
|
77
|
+
if entry.spcs_container_runtime is not None
|
|
78
|
+
and (hardware_type is None or entry.spcs_container_runtime.hardware_type.lower() == hardware_type.lower())
|
|
79
|
+
and (
|
|
80
|
+
python_version is None
|
|
81
|
+
or (
|
|
82
|
+
entry.spcs_container_runtime.python_version.major == python_version.major
|
|
83
|
+
and entry.spcs_container_runtime.python_version.minor == python_version.minor
|
|
84
|
+
)
|
|
85
|
+
)
|
|
63
86
|
]
|
|
87
|
+
entries.sort(key=lambda e: e.created_on, reverse=True)
|
|
88
|
+
|
|
89
|
+
return [entry.spcs_container_runtime for entry in entries if entry.spcs_container_runtime is not None]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _extract_image_tag(image_url: str) -> Optional[str]:
|
|
93
|
+
image_tag = image_url.rsplit(":", 1)[-1]
|
|
94
|
+
return image_tag
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def find_runtime_image(
|
|
98
|
+
session: snowpark.Session, target_hardware: Literal["CPU", "GPU"], target_python_version: Optional[str] = None
|
|
99
|
+
) -> Optional[str]:
|
|
100
|
+
python_version = (
|
|
101
|
+
Version(target_python_version) if target_python_version else Version(constants.DEFAULT_PYTHON_VERSION)
|
|
102
|
+
)
|
|
103
|
+
rows = query_helper.run_query(session, "CALL SYSTEM$NOTEBOOKS_FIND_LABELED_RUNTIMES()")
|
|
104
|
+
if not rows:
|
|
105
|
+
return None
|
|
106
|
+
try:
|
|
107
|
+
runtime_envs = RuntimeEnvironmentsDict.model_validate_json(rows[0][0])
|
|
108
|
+
spcs_container_runtimes = runtime_envs.get_spcs_container_runtimes(
|
|
109
|
+
hardware_type=target_hardware,
|
|
110
|
+
python_version=python_version,
|
|
111
|
+
)
|
|
112
|
+
except Exception as e:
|
|
113
|
+
logging.warning(f"Failed to parse runtime image name from {rows[0][0]}, error: {e}")
|
|
114
|
+
return None
|
|
115
|
+
|
|
116
|
+
selected_runtime = spcs_container_runtimes[0] if spcs_container_runtimes else None
|
|
117
|
+
return selected_runtime.runtime_container_image if selected_runtime else None
|
|
@@ -1,11 +1,6 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import sys
|
|
3
|
-
from typing import Literal, Optional
|
|
4
|
-
|
|
5
1
|
from snowflake import snowpark
|
|
6
2
|
from snowflake.ml._internal.utils import snowflake_env
|
|
7
3
|
from snowflake.ml.jobs._utils import constants, query_helper, types
|
|
8
|
-
from snowflake.ml.jobs._utils.runtime_env_utils import RuntimeEnvironmentsDict
|
|
9
4
|
|
|
10
5
|
|
|
11
6
|
def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
|
|
@@ -25,29 +20,3 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
|
|
|
25
20
|
constants.COMMON_INSTANCE_FAMILIES.get(instance_family)
|
|
26
21
|
or constants.CLOUD_INSTANCE_FAMILIES[cloud][instance_family]
|
|
27
22
|
)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def _get_runtime_image(session: snowpark.Session, target_hardware: Literal["CPU", "GPU"]) -> Optional[str]:
|
|
31
|
-
rows = query_helper.run_query(session, "CALL SYSTEM$NOTEBOOKS_FIND_LABELED_RUNTIMES()")
|
|
32
|
-
if not rows:
|
|
33
|
-
return None
|
|
34
|
-
try:
|
|
35
|
-
runtime_envs = RuntimeEnvironmentsDict.model_validate_json(rows[0][0])
|
|
36
|
-
spcs_container_runtimes = runtime_envs.get_spcs_container_runtimes()
|
|
37
|
-
except Exception as e:
|
|
38
|
-
logging.warning(f"Failed to parse runtime image name from {rows[0][0]}, error: {e}")
|
|
39
|
-
return None
|
|
40
|
-
|
|
41
|
-
selected_runtime = next(
|
|
42
|
-
(
|
|
43
|
-
runtime
|
|
44
|
-
for runtime in spcs_container_runtimes
|
|
45
|
-
if (
|
|
46
|
-
runtime.hardware_type.lower() == target_hardware.lower()
|
|
47
|
-
and runtime.python_version.major == sys.version_info.major
|
|
48
|
-
and runtime.python_version.minor == sys.version_info.minor
|
|
49
|
-
)
|
|
50
|
-
),
|
|
51
|
-
None,
|
|
52
|
-
)
|
|
53
|
-
return selected_runtime.runtime_container_image if selected_runtime else None
|
|
@@ -52,7 +52,7 @@ class StagePath:
|
|
|
52
52
|
if self._path.parent == Path(""):
|
|
53
53
|
return StagePath(self._root)
|
|
54
54
|
else:
|
|
55
|
-
return StagePath(f"{self._root}/{self._path.parent}")
|
|
55
|
+
return StagePath(f"{self._root}/{self._path.parent.as_posix()}")
|
|
56
56
|
|
|
57
57
|
@property
|
|
58
58
|
def root(self) -> str:
|
|
@@ -67,7 +67,7 @@ class StagePath:
|
|
|
67
67
|
if path == Path(""):
|
|
68
68
|
return self.root
|
|
69
69
|
else:
|
|
70
|
-
return f"{self.root}/{path}"
|
|
70
|
+
return f"{self.root}/{path.as_posix()}"
|
|
71
71
|
|
|
72
72
|
def is_relative_to(self, *other: Union[str, os.PathLike[str]]) -> bool:
|
|
73
73
|
if not other:
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
3
|
from pathlib import PurePath
|
|
4
|
-
from typing import Literal, Optional, Protocol, Union, runtime_checkable
|
|
4
|
+
from typing import Any, Literal, Optional, Protocol, Union, runtime_checkable
|
|
5
5
|
|
|
6
6
|
from typing_extensions import Self
|
|
7
7
|
|
|
@@ -103,7 +103,6 @@ class UploadedPayload:
|
|
|
103
103
|
stage_path: PurePath
|
|
104
104
|
entrypoint: list[Union[str, PurePath]]
|
|
105
105
|
env_vars: dict[str, str] = field(default_factory=dict)
|
|
106
|
-
payload_name: Optional[str] = None
|
|
107
106
|
|
|
108
107
|
|
|
109
108
|
@dataclass(frozen=True)
|
|
@@ -128,3 +127,24 @@ class ServiceInfo:
|
|
|
128
127
|
status: str
|
|
129
128
|
compute_pool: str
|
|
130
129
|
target_instances: int
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@dataclass
|
|
133
|
+
class JobOptions:
|
|
134
|
+
external_access_integrations: Optional[list[str]] = None
|
|
135
|
+
query_warehouse: Optional[str] = None
|
|
136
|
+
target_instances: Optional[int] = None
|
|
137
|
+
min_instances: Optional[int] = None
|
|
138
|
+
use_async: Optional[bool] = True
|
|
139
|
+
generate_suffix: Optional[bool] = True
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@dataclass
|
|
143
|
+
class SpecOptions:
|
|
144
|
+
stage_path: str
|
|
145
|
+
args: Optional[list[str]] = None
|
|
146
|
+
env_vars: Optional[dict[str, str]] = None
|
|
147
|
+
enable_metrics: Optional[bool] = None
|
|
148
|
+
spec_overrides: Optional[dict[str, Any]] = None
|
|
149
|
+
runtime: Optional[str] = None
|
|
150
|
+
enable_stage_mount_v2: Optional[bool] = True
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
from pathlib import PurePath, PurePosixPath
|
|
7
|
+
from typing import Any, Callable, Generic, Optional, TypeVar, Union
|
|
8
|
+
from uuid import uuid4
|
|
9
|
+
|
|
10
|
+
from typing_extensions import ParamSpec
|
|
11
|
+
|
|
12
|
+
from snowflake import snowpark
|
|
13
|
+
from snowflake.ml._internal import telemetry
|
|
14
|
+
from snowflake.ml._internal.utils import identifier
|
|
15
|
+
from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
|
|
16
|
+
from snowflake.ml.jobs import job as jb
|
|
17
|
+
from snowflake.ml.jobs._utils import (
|
|
18
|
+
constants,
|
|
19
|
+
feature_flags,
|
|
20
|
+
payload_utils,
|
|
21
|
+
query_helper,
|
|
22
|
+
types,
|
|
23
|
+
)
|
|
24
|
+
from snowflake.snowpark import context as sp_context
|
|
25
|
+
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
26
|
+
|
|
27
|
+
_Args = ParamSpec("_Args")
|
|
28
|
+
_ReturnValue = TypeVar("_ReturnValue")
|
|
29
|
+
JOB_ID_PREFIX = "MLJOB_"
|
|
30
|
+
_PROJECT = "MLJob"
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class MLJobDefinition(Generic[_Args, _ReturnValue], SerializableSessionMixin):
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
job_options: types.JobOptions,
|
|
38
|
+
spec_options: types.SpecOptions,
|
|
39
|
+
stage_name: str,
|
|
40
|
+
compute_pool: str,
|
|
41
|
+
name: str,
|
|
42
|
+
entrypoint_args: list[Any],
|
|
43
|
+
database: Optional[str] = None,
|
|
44
|
+
schema: Optional[str] = None,
|
|
45
|
+
session: Optional[snowpark.Session] = None,
|
|
46
|
+
) -> None:
|
|
47
|
+
self.stage_name = stage_name
|
|
48
|
+
self.job_options = job_options
|
|
49
|
+
self.spec_options = spec_options
|
|
50
|
+
self.compute_pool = compute_pool
|
|
51
|
+
self.session = session or sp_context.get_active_session()
|
|
52
|
+
self.database = database or self.session.get_current_database()
|
|
53
|
+
self.schema = schema or self.session.get_current_schema()
|
|
54
|
+
self.job_definition_id = identifier.get_schema_level_object_identifier(self.database, self.schema, name)
|
|
55
|
+
self.entrypoint_args = entrypoint_args
|
|
56
|
+
|
|
57
|
+
def delete(self) -> None:
|
|
58
|
+
if self.stage_name:
|
|
59
|
+
try:
|
|
60
|
+
self.session.sql(f"REMOVE {self.stage_name}/").collect()
|
|
61
|
+
logger.debug(f"Successfully cleaned up stage files for job definition {self.stage_name}")
|
|
62
|
+
except Exception as e:
|
|
63
|
+
logger.warning(f"Failed to clean up stage files for job definition {self.stage_name}: {e}")
|
|
64
|
+
|
|
65
|
+
def _prepare_arguments(self, *args: _Args.args, **kwargs: _Args.kwargs) -> list[Any]:
|
|
66
|
+
# TODO: Add ArgProtocol and respective logics
|
|
67
|
+
return [arg for arg in args]
|
|
68
|
+
|
|
69
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
|
70
|
+
def __call__(self, *args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
|
|
71
|
+
statement_params = telemetry.get_statement_params(_PROJECT)
|
|
72
|
+
statement_params = telemetry.add_statement_params_custom_tags(
|
|
73
|
+
statement_params,
|
|
74
|
+
custom_tags={
|
|
75
|
+
"job_definition_id": self.job_definition_id,
|
|
76
|
+
},
|
|
77
|
+
)
|
|
78
|
+
args_list = self._prepare_arguments(*args, **kwargs)
|
|
79
|
+
query = self.to_sql(job_args=args_list, use_async=True)
|
|
80
|
+
job_id = query_helper.run_query(self.session, query, statement_params=statement_params)[0][0]
|
|
81
|
+
return jb.MLJob[_ReturnValue](job_id, session=self.session)
|
|
82
|
+
|
|
83
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
|
84
|
+
def to_sql(self, *, job_args: Optional[list[Any]] = None, use_async: bool = False) -> str:
|
|
85
|
+
# Combine the entrypoint_args and job_args for use in the query
|
|
86
|
+
combined_args = (self.entrypoint_args or []) + (job_args or [])
|
|
87
|
+
spec_options = dataclasses.replace(self.spec_options, args=combined_args)
|
|
88
|
+
# Uppercase option keys to match the expected SYSTEM$EXECUTE_ML_JOB parameter format
|
|
89
|
+
spec_options_dict = {k.upper(): v for k, v in dataclasses.asdict(spec_options).items()}
|
|
90
|
+
job_options = dataclasses.replace(self.job_options, use_async=use_async)
|
|
91
|
+
# Uppercase option keys to match the expected SYSTEM$EXECUTE_ML_JOB parameter format
|
|
92
|
+
job_options_dict = {k.upper(): v for k, v in dataclasses.asdict(job_options).items()}
|
|
93
|
+
job_options_dict["ASYNC"] = job_options_dict.pop("USE_ASYNC")
|
|
94
|
+
params = [
|
|
95
|
+
self.job_definition_id + ("_" if self.job_options.generate_suffix else ""),
|
|
96
|
+
self.compute_pool,
|
|
97
|
+
json.dumps(spec_options_dict),
|
|
98
|
+
json.dumps(job_options_dict),
|
|
99
|
+
]
|
|
100
|
+
query_template = "CALL SYSTEM$EXECUTE_ML_JOB(%s, %s, %s, %s)"
|
|
101
|
+
sql = self.session._conn._cursor._preprocess_pyformat_query(query_template, params)
|
|
102
|
+
return sql
|
|
103
|
+
|
|
104
|
+
@classmethod
|
|
105
|
+
@telemetry.send_api_usage_telemetry(
|
|
106
|
+
project=_PROJECT,
|
|
107
|
+
func_params_to_log=[
|
|
108
|
+
"pip_requirements",
|
|
109
|
+
"external_access_integrations",
|
|
110
|
+
"target_instances",
|
|
111
|
+
"min_instances",
|
|
112
|
+
"enable_metrics",
|
|
113
|
+
"query_warehouse",
|
|
114
|
+
"runtime_environment",
|
|
115
|
+
],
|
|
116
|
+
)
|
|
117
|
+
def register(
|
|
118
|
+
cls,
|
|
119
|
+
source: Union[str, Callable[_Args, _ReturnValue]],
|
|
120
|
+
compute_pool: str,
|
|
121
|
+
stage_name: str,
|
|
122
|
+
session: Optional[snowpark.Session] = None,
|
|
123
|
+
entrypoint: Optional[Union[str, list[str]]] = None,
|
|
124
|
+
target_instances: int = 1,
|
|
125
|
+
generate_suffix: bool = True,
|
|
126
|
+
**kwargs: Any,
|
|
127
|
+
) -> "MLJobDefinition[_Args, _ReturnValue]":
|
|
128
|
+
# Use kwargs for less common optional parameters
|
|
129
|
+
database = kwargs.pop("database", None)
|
|
130
|
+
schema = kwargs.pop("schema", None)
|
|
131
|
+
min_instances = kwargs.pop("min_instances", target_instances)
|
|
132
|
+
pip_requirements = kwargs.pop("pip_requirements", None)
|
|
133
|
+
external_access_integrations = kwargs.pop("external_access_integrations", None)
|
|
134
|
+
env_vars = kwargs.pop("env_vars", None)
|
|
135
|
+
spec_overrides = kwargs.pop("spec_overrides", None)
|
|
136
|
+
enable_metrics = kwargs.pop("enable_metrics", True)
|
|
137
|
+
session = session or sp_context.get_active_session()
|
|
138
|
+
query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
|
|
139
|
+
imports = kwargs.pop("imports", None)
|
|
140
|
+
runtime_environment = kwargs.pop(
|
|
141
|
+
"runtime_environment", os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR, None)
|
|
142
|
+
)
|
|
143
|
+
overwrite = kwargs.pop("overwrite", False)
|
|
144
|
+
name = kwargs.pop("name", None)
|
|
145
|
+
# Warn if there are unknown kwargs
|
|
146
|
+
if kwargs:
|
|
147
|
+
logger.warning(f"Ignoring unknown kwargs: {kwargs.keys()}")
|
|
148
|
+
|
|
149
|
+
# Validate parameters
|
|
150
|
+
if database and not schema:
|
|
151
|
+
raise ValueError("Schema must be specified if database is specified.")
|
|
152
|
+
if target_instances < 1:
|
|
153
|
+
raise ValueError("target_instances must be greater than 0.")
|
|
154
|
+
if not (0 < min_instances <= target_instances):
|
|
155
|
+
raise ValueError("min_instances must be greater than 0 and less than or equal to target_instances.")
|
|
156
|
+
if min_instances > 1:
|
|
157
|
+
# Validate min_instances against compute pool max_nodes
|
|
158
|
+
pool_info = jb._get_compute_pool_info(session, compute_pool)
|
|
159
|
+
max_nodes = int(pool_info["max_nodes"])
|
|
160
|
+
if min_instances > max_nodes:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"The requested min_instances ({min_instances}) exceeds the max_nodes ({max_nodes}) "
|
|
163
|
+
f"of compute pool '{compute_pool}'. Reduce min_instances or increase max_nodes."
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
if name:
|
|
167
|
+
parsed_database, parsed_schema, parsed_name = identifier.parse_schema_level_object_identifier(name)
|
|
168
|
+
database = parsed_database or database
|
|
169
|
+
schema = parsed_schema or schema
|
|
170
|
+
name = parsed_name
|
|
171
|
+
else:
|
|
172
|
+
name = payload_utils.get_payload_name(source, entrypoint)
|
|
173
|
+
|
|
174
|
+
# The logical identifier for this job definition (used in the stage path)
|
|
175
|
+
# is the resolved object name, not the fully qualified identifier.
|
|
176
|
+
job_definition_id = name if not generate_suffix else name + _generate_suffix()
|
|
177
|
+
stage_path_parts = identifier.parse_snowflake_stage_path(stage_name.lstrip("@"))
|
|
178
|
+
stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
|
|
179
|
+
stage_path = PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_definition_id}")
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
# Upload payload
|
|
183
|
+
uploaded_payload = payload_utils.JobPayload(
|
|
184
|
+
source, entrypoint=entrypoint, pip_requirements=pip_requirements, imports=imports
|
|
185
|
+
).upload(session, stage_path, overwrite)
|
|
186
|
+
except SnowparkSQLException as e:
|
|
187
|
+
if e.sql_error_code == 90106:
|
|
188
|
+
raise RuntimeError(
|
|
189
|
+
"Please specify a schema, either in the session context or as a parameter in the job submission"
|
|
190
|
+
)
|
|
191
|
+
raise
|
|
192
|
+
|
|
193
|
+
if runtime_environment is None and feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled(default=True):
|
|
194
|
+
# Pass a JSON object for runtime versions so it serializes as nested JSON in options
|
|
195
|
+
runtime_environment = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
|
|
196
|
+
|
|
197
|
+
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
|
|
198
|
+
entrypoint_args = [v.as_posix() if isinstance(v, PurePath) else v for v in uploaded_payload.entrypoint]
|
|
199
|
+
spec_options = types.SpecOptions(
|
|
200
|
+
stage_path=stage_path.as_posix(),
|
|
201
|
+
# the args will be set at runtime
|
|
202
|
+
args=None,
|
|
203
|
+
env_vars=combined_env_vars,
|
|
204
|
+
enable_metrics=enable_metrics,
|
|
205
|
+
spec_overrides=spec_overrides,
|
|
206
|
+
runtime=runtime_environment if runtime_environment else None,
|
|
207
|
+
enable_stage_mount_v2=feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(default=True),
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
job_options = types.JobOptions(
|
|
211
|
+
external_access_integrations=external_access_integrations,
|
|
212
|
+
query_warehouse=query_warehouse,
|
|
213
|
+
target_instances=target_instances,
|
|
214
|
+
min_instances=min_instances,
|
|
215
|
+
generate_suffix=generate_suffix,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
return cls(
|
|
219
|
+
stage_name=stage_path.as_posix(),
|
|
220
|
+
spec_options=spec_options,
|
|
221
|
+
job_options=job_options,
|
|
222
|
+
compute_pool=compute_pool,
|
|
223
|
+
entrypoint_args=entrypoint_args,
|
|
224
|
+
session=session,
|
|
225
|
+
database=database,
|
|
226
|
+
schema=schema,
|
|
227
|
+
name=name,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _generate_suffix() -> str:
|
|
232
|
+
return str(uuid4().hex)[:8]
|