snowflake-ml-python 1.9.2__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/service_logger.py +31 -17
- snowflake/ml/experiment/callback/keras.py +63 -0
- snowflake/ml/experiment/callback/lightgbm.py +59 -0
- snowflake/ml/experiment/callback/xgboost.py +67 -0
- snowflake/ml/experiment/utils.py +14 -0
- snowflake/ml/jobs/_utils/__init__.py +0 -0
- snowflake/ml/jobs/_utils/constants.py +4 -1
- snowflake/ml/jobs/_utils/payload_utils.py +55 -21
- snowflake/ml/jobs/_utils/query_helper.py +5 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +2 -2
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -5
- snowflake/ml/jobs/_utils/spec_utils.py +41 -8
- snowflake/ml/jobs/_utils/stage_utils.py +22 -9
- snowflake/ml/jobs/_utils/types.py +5 -7
- snowflake/ml/jobs/job.py +1 -1
- snowflake/ml/jobs/manager.py +1 -13
- snowflake/ml/model/_client/model/model_version_impl.py +219 -55
- snowflake/ml/model/_client/ops/service_ops.py +230 -30
- snowflake/ml/model/_client/service/model_deployment_spec.py +103 -27
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +11 -5
- snowflake/ml/model/_model_composer/model_composer.py +1 -70
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
- snowflake/ml/model/event_handler.py +87 -18
- snowflake/ml/model/inference_engine.py +5 -0
- snowflake/ml/model/models/huggingface_pipeline.py +74 -51
- snowflake/ml/model/type_hints.py +26 -1
- snowflake/ml/registry/_manager/model_manager.py +37 -70
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +294 -0
- snowflake/ml/registry/registry.py +0 -19
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/METADATA +523 -491
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/RECORD +36 -29
- snowflake/ml/experiment/callback.py +0 -121
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/top_level.txt +0 -0
|
@@ -9,6 +9,15 @@ from typing import Optional
|
|
|
9
9
|
|
|
10
10
|
import platformdirs
|
|
11
11
|
|
|
12
|
+
# Module-level logger for operational messages that should appear on console
|
|
13
|
+
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
14
|
+
stdout_handler.setFormatter(logging.Formatter("%(message)s"))
|
|
15
|
+
|
|
16
|
+
console_logger = logging.getLogger(__name__)
|
|
17
|
+
console_logger.addHandler(stdout_handler)
|
|
18
|
+
console_logger.setLevel(logging.INFO)
|
|
19
|
+
console_logger.propagate = False
|
|
20
|
+
|
|
12
21
|
|
|
13
22
|
class LogColor(enum.Enum):
|
|
14
23
|
GREY = "\x1b[38;20m"
|
|
@@ -109,42 +118,36 @@ def _get_or_create_parent_logger(operation_id: str) -> logging.Logger:
|
|
|
109
118
|
"""Get or create a parent logger with FileHandler for the operation."""
|
|
110
119
|
parent_logger_name = f"snowflake_ml_operation_{operation_id}"
|
|
111
120
|
parent_logger = logging.getLogger(parent_logger_name)
|
|
121
|
+
parent_logger.setLevel(logging.DEBUG)
|
|
122
|
+
parent_logger.propagate = False
|
|
112
123
|
|
|
113
|
-
# Only add handler if it doesn't exist yet
|
|
114
124
|
if not parent_logger.handlers:
|
|
115
125
|
log_file_path = _get_log_file_path(operation_id)
|
|
116
126
|
|
|
117
127
|
if log_file_path:
|
|
118
|
-
# Successfully found a writable location
|
|
119
128
|
try:
|
|
120
129
|
file_handler = logging.FileHandler(log_file_path)
|
|
121
130
|
file_handler.setFormatter(logging.Formatter("%(name)s [%(asctime)s] [%(levelname)s] %(message)s"))
|
|
122
131
|
parent_logger.addHandler(file_handler)
|
|
123
|
-
parent_logger.setLevel(logging.DEBUG)
|
|
124
|
-
parent_logger.propagate = False # Don't propagate to root logger
|
|
125
132
|
|
|
126
|
-
|
|
127
|
-
parent_logger.warning(f"Operation logs saved to: {log_file_path}")
|
|
133
|
+
console_logger.info(f"create_service logs saved to: {log_file_path}")
|
|
128
134
|
except OSError as e:
|
|
129
|
-
|
|
130
|
-
# Fall back to console-only logging
|
|
131
|
-
parent_logger.setLevel(logging.DEBUG)
|
|
132
|
-
parent_logger.propagate = False
|
|
133
|
-
parent_logger.warning(f"Could not create log file at {log_file_path}: {e}. Using console-only logging.")
|
|
135
|
+
console_logger.warning(f"Could not create log file at {log_file_path}: {e}.")
|
|
134
136
|
else:
|
|
135
137
|
# No writable location found, use console-only logging
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
138
|
+
console_logger.warning("No writable location found for create_service log file.")
|
|
139
|
+
|
|
140
|
+
if logging.getLogger().level > logging.INFO:
|
|
141
|
+
console_logger.info(
|
|
142
|
+
"To see logs in console, set log level to INFO: logging.getLogger().setLevel(logging.INFO)"
|
|
143
|
+
)
|
|
139
144
|
|
|
140
145
|
return parent_logger
|
|
141
146
|
|
|
142
147
|
|
|
143
148
|
def get_logger(logger_name: str, info_color: LogColor, operation_id: Optional[str] = None) -> logging.Logger:
|
|
144
149
|
logger = logging.getLogger(logger_name)
|
|
145
|
-
|
|
146
|
-
handler.setFormatter(CustomFormatter(info_color))
|
|
147
|
-
logger.addHandler(handler)
|
|
150
|
+
root_logger = logging.getLogger()
|
|
148
151
|
|
|
149
152
|
# If operation_id provided, set up parent logger with file handler
|
|
150
153
|
if operation_id:
|
|
@@ -152,6 +155,17 @@ def get_logger(logger_name: str, info_color: LogColor, operation_id: Optional[st
|
|
|
152
155
|
logger.parent = parent_logger
|
|
153
156
|
logger.propagate = True
|
|
154
157
|
|
|
158
|
+
if root_logger.level <= logging.INFO:
|
|
159
|
+
handler = logging.StreamHandler(sys.stdout)
|
|
160
|
+
handler.setFormatter(CustomFormatter(info_color))
|
|
161
|
+
logger.addHandler(handler)
|
|
162
|
+
else:
|
|
163
|
+
# No operation_id - add console handler only if user wants verbose logging
|
|
164
|
+
if root_logger.level <= logging.INFO and not logger.handlers:
|
|
165
|
+
handler = logging.StreamHandler(sys.stdout)
|
|
166
|
+
handler.setFormatter(CustomFormatter(info_color))
|
|
167
|
+
logger.addHandler(handler)
|
|
168
|
+
|
|
155
169
|
return logger
|
|
156
170
|
|
|
157
171
|
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
3
|
+
from warnings import warn
|
|
4
|
+
|
|
5
|
+
import keras
|
|
6
|
+
|
|
7
|
+
from snowflake.ml.experiment import utils
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
|
|
11
|
+
from snowflake.ml.model.model_signature import ModelSignature
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SnowflakeKerasCallback(keras.callbacks.Callback):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
experiment_tracking: "ExperimentTracking",
|
|
18
|
+
log_model: bool = True,
|
|
19
|
+
log_metrics: bool = True,
|
|
20
|
+
log_params: bool = True,
|
|
21
|
+
log_every_n_epochs: int = 1,
|
|
22
|
+
model_name: Optional[str] = None,
|
|
23
|
+
model_signature: Optional["ModelSignature"] = None,
|
|
24
|
+
) -> None:
|
|
25
|
+
self._experiment_tracking = experiment_tracking
|
|
26
|
+
self.log_model = log_model
|
|
27
|
+
self.log_metrics = log_metrics
|
|
28
|
+
self.log_params = log_params
|
|
29
|
+
if log_every_n_epochs < 1:
|
|
30
|
+
raise ValueError("`log_every_n_epochs` must be positive.")
|
|
31
|
+
self.log_every_n_epochs = log_every_n_epochs
|
|
32
|
+
self.model_name = model_name
|
|
33
|
+
self.model_signature = model_signature
|
|
34
|
+
|
|
35
|
+
def on_train_begin(self, logs: Optional[dict[str, Any]] = None) -> None:
|
|
36
|
+
if self.log_params:
|
|
37
|
+
params = json.loads(self.model.to_json())
|
|
38
|
+
self._experiment_tracking.log_params(utils.flatten_nested_params(params))
|
|
39
|
+
|
|
40
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict[str, Any]] = None) -> None:
|
|
41
|
+
if self.log_metrics and logs and epoch % self.log_every_n_epochs == 0:
|
|
42
|
+
for key, value in logs.items():
|
|
43
|
+
try:
|
|
44
|
+
value = float(value)
|
|
45
|
+
except Exception:
|
|
46
|
+
pass
|
|
47
|
+
else:
|
|
48
|
+
self._experiment_tracking.log_metric(key=key, value=value, step=epoch)
|
|
49
|
+
|
|
50
|
+
def on_train_end(self, logs: Optional[dict[str, Any]] = None) -> None:
|
|
51
|
+
if self.log_model:
|
|
52
|
+
if not self.model_signature:
|
|
53
|
+
warn(
|
|
54
|
+
"Model will not be logged because model signature is missing. "
|
|
55
|
+
"To autolog the model, please specify `model_signature` when constructing SnowflakeKerasCallback."
|
|
56
|
+
)
|
|
57
|
+
return
|
|
58
|
+
model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
|
|
59
|
+
self._experiment_tracking.log_model( # type: ignore[call-arg]
|
|
60
|
+
model=self.model,
|
|
61
|
+
model_name=model_name,
|
|
62
|
+
signatures={"predict": self.model_signature},
|
|
63
|
+
)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Optional
|
|
2
|
+
from warnings import warn
|
|
3
|
+
|
|
4
|
+
import lightgbm as lgb
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
|
|
8
|
+
from snowflake.ml.model.model_signature import ModelSignature
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
experiment_tracking: "ExperimentTracking",
|
|
15
|
+
log_model: bool = True,
|
|
16
|
+
log_metrics: bool = True,
|
|
17
|
+
log_params: bool = True,
|
|
18
|
+
log_every_n_epochs: int = 1,
|
|
19
|
+
model_name: Optional[str] = None,
|
|
20
|
+
model_signature: Optional["ModelSignature"] = None,
|
|
21
|
+
) -> None:
|
|
22
|
+
self._experiment_tracking = experiment_tracking
|
|
23
|
+
self.log_model = log_model
|
|
24
|
+
self.log_metrics = log_metrics
|
|
25
|
+
self.log_params = log_params
|
|
26
|
+
if log_every_n_epochs < 1:
|
|
27
|
+
raise ValueError("`log_every_n_epochs` must be positive.")
|
|
28
|
+
self.log_every_n_epochs = log_every_n_epochs
|
|
29
|
+
self.model_name = model_name
|
|
30
|
+
self.model_signature = model_signature
|
|
31
|
+
|
|
32
|
+
super().__init__(eval_result={})
|
|
33
|
+
|
|
34
|
+
def __call__(self, env: lgb.callback.CallbackEnv) -> None:
|
|
35
|
+
if self.log_params:
|
|
36
|
+
if env.iteration == env.begin_iteration: # Log params only at the first iteration
|
|
37
|
+
self._experiment_tracking.log_params(env.params)
|
|
38
|
+
|
|
39
|
+
if self.log_metrics and env.iteration % self.log_every_n_epochs == 0:
|
|
40
|
+
super().__call__(env)
|
|
41
|
+
for dataset_name, metrics in self.eval_result.items():
|
|
42
|
+
for metric_name, log in metrics.items():
|
|
43
|
+
metric_key = dataset_name + ":" + metric_name
|
|
44
|
+
self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=env.iteration)
|
|
45
|
+
|
|
46
|
+
if self.log_model:
|
|
47
|
+
if env.iteration == env.end_iteration - 1: # Log model only at the last iteration
|
|
48
|
+
if self.model_signature:
|
|
49
|
+
model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
|
|
50
|
+
self._experiment_tracking.log_model( # type: ignore[call-arg]
|
|
51
|
+
model=env.model,
|
|
52
|
+
model_name=model_name,
|
|
53
|
+
signatures={"predict": self.model_signature},
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
warn(
|
|
57
|
+
"Model will not be logged because model signature is missing. To autolog the model, "
|
|
58
|
+
"please specify `model_signature` when constructing SnowflakeLightgbmCallback."
|
|
59
|
+
)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
3
|
+
from warnings import warn
|
|
4
|
+
|
|
5
|
+
import xgboost as xgb
|
|
6
|
+
|
|
7
|
+
from snowflake.ml.experiment import utils
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
|
|
11
|
+
from snowflake.ml.model.model_signature import ModelSignature
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
experiment_tracking: "ExperimentTracking",
|
|
18
|
+
log_model: bool = True,
|
|
19
|
+
log_metrics: bool = True,
|
|
20
|
+
log_params: bool = True,
|
|
21
|
+
log_every_n_epochs: int = 1,
|
|
22
|
+
model_name: Optional[str] = None,
|
|
23
|
+
model_signature: Optional["ModelSignature"] = None,
|
|
24
|
+
) -> None:
|
|
25
|
+
self._experiment_tracking = experiment_tracking
|
|
26
|
+
self.log_model = log_model
|
|
27
|
+
self.log_metrics = log_metrics
|
|
28
|
+
self.log_params = log_params
|
|
29
|
+
if log_every_n_epochs < 1:
|
|
30
|
+
raise ValueError("`log_every_n_epochs` must be positive.")
|
|
31
|
+
self.log_every_n_epochs = log_every_n_epochs
|
|
32
|
+
self.model_name = model_name
|
|
33
|
+
self.model_signature = model_signature
|
|
34
|
+
|
|
35
|
+
def before_training(self, model: xgb.Booster) -> xgb.Booster:
|
|
36
|
+
if self.log_params:
|
|
37
|
+
params = json.loads(model.save_config())
|
|
38
|
+
self._experiment_tracking.log_params(utils.flatten_nested_params(params))
|
|
39
|
+
|
|
40
|
+
return model
|
|
41
|
+
|
|
42
|
+
def after_iteration(self, model: Any, epoch: int, evals_log: dict[str, dict[str, Any]]) -> bool:
|
|
43
|
+
if self.log_metrics and epoch % self.log_every_n_epochs == 0:
|
|
44
|
+
for dataset_name, metrics in evals_log.items():
|
|
45
|
+
for metric_name, log in metrics.items():
|
|
46
|
+
metric_key = dataset_name + ":" + metric_name
|
|
47
|
+
self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=epoch)
|
|
48
|
+
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
def after_training(self, model: xgb.Booster) -> xgb.Booster:
|
|
52
|
+
if self.log_model:
|
|
53
|
+
if not self.model_signature:
|
|
54
|
+
warn(
|
|
55
|
+
"Model will not be logged because model signature is missing. "
|
|
56
|
+
"To autolog the model, please specify `model_signature` when constructing SnowflakeXgboostCallback."
|
|
57
|
+
)
|
|
58
|
+
return model
|
|
59
|
+
|
|
60
|
+
model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
|
|
61
|
+
self._experiment_tracking.log_model( # type: ignore[call-arg]
|
|
62
|
+
model=model,
|
|
63
|
+
model_name=model_name,
|
|
64
|
+
signatures={"predict": self.model_signature},
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return model
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from typing import Any, Union
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def flatten_nested_params(params: Union[list[Any], dict[str, Any]], prefix: str = "") -> dict[str, Any]:
|
|
5
|
+
flat_params = {}
|
|
6
|
+
items = params.items() if isinstance(params, dict) else enumerate(params)
|
|
7
|
+
for key, value in items:
|
|
8
|
+
key = str(key).replace(".", "_") # Replace dots in keys to avoid collisions involving nested keys
|
|
9
|
+
new_prefix = f"{prefix}.{key}" if prefix else key
|
|
10
|
+
if isinstance(value, (dict, list)):
|
|
11
|
+
flat_params.update(flatten_nested_params(value, new_prefix))
|
|
12
|
+
else:
|
|
13
|
+
flat_params[new_prefix] = value
|
|
14
|
+
return flat_params
|
|
File without changes
|
|
@@ -28,7 +28,7 @@ OUTPUT_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{OUTPUT_STAGE_SUBPATH}"
|
|
|
28
28
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
|
29
29
|
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
|
30
30
|
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
|
31
|
-
DEFAULT_IMAGE_TAG = "1.
|
|
31
|
+
DEFAULT_IMAGE_TAG = "1.6.2"
|
|
32
32
|
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
|
33
33
|
|
|
34
34
|
# Percent of container memory to allocate for /dev/shm volume
|
|
@@ -98,3 +98,6 @@ CLOUD_INSTANCE_FAMILIES = {
|
|
|
98
98
|
SnowflakeCloudType.AWS: AWS_INSTANCE_FAMILIES,
|
|
99
99
|
SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES,
|
|
100
100
|
}
|
|
101
|
+
|
|
102
|
+
# runtime version environment variable
|
|
103
|
+
ENABLE_IMAGE_VERSION_ENV_VAR = "MLRS_ENABLE_RUNTIME_VERSIONS"
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import functools
|
|
2
|
+
import importlib
|
|
2
3
|
import inspect
|
|
3
4
|
import io
|
|
4
5
|
import itertools
|
|
@@ -7,6 +8,7 @@ import logging
|
|
|
7
8
|
import pickle
|
|
8
9
|
import sys
|
|
9
10
|
import textwrap
|
|
11
|
+
from importlib.abc import Traversable
|
|
10
12
|
from pathlib import Path, PurePath
|
|
11
13
|
from typing import Any, Callable, Optional, Union, cast, get_args, get_origin
|
|
12
14
|
|
|
@@ -63,6 +65,13 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
63
65
|
|
|
64
66
|
##### Set up Python environment #####
|
|
65
67
|
export PYTHONPATH=/opt/env/site-packages/
|
|
68
|
+
MLRS_SYSTEM_REQUIREMENTS_FILE=${{MLRS_SYSTEM_REQUIREMENTS_FILE:-"${{SYSTEM_DIR}}/requirements.txt"}}
|
|
69
|
+
|
|
70
|
+
if [ -f "${{MLRS_SYSTEM_REQUIREMENTS_FILE}}" ]; then
|
|
71
|
+
echo "Installing packages from $MLRS_SYSTEM_REQUIREMENTS_FILE"
|
|
72
|
+
pip install -r $MLRS_SYSTEM_REQUIREMENTS_FILE
|
|
73
|
+
fi
|
|
74
|
+
|
|
66
75
|
MLRS_REQUIREMENTS_FILE=${{MLRS_REQUIREMENTS_FILE:-"requirements.txt"}}
|
|
67
76
|
if [ -f "${{MLRS_REQUIREMENTS_FILE}}" ]; then
|
|
68
77
|
# TODO: Prevent collisions with MLRS packages using virtualenvs
|
|
@@ -255,11 +264,24 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp
|
|
|
255
264
|
# Manually traverse the directory and upload each file, since Snowflake PUT
|
|
256
265
|
# can't handle directories. Reduce the number of PUT operations by using
|
|
257
266
|
# wildcard patterns to batch upload files with the same extension.
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
267
|
+
upload_path_patterns = set()
|
|
268
|
+
for p in source_path.resolve().rglob("*"):
|
|
269
|
+
if p.is_dir():
|
|
270
|
+
continue
|
|
271
|
+
if p.name.startswith("."):
|
|
272
|
+
# Hidden files: use .* pattern for batch upload
|
|
273
|
+
if p.suffix:
|
|
274
|
+
upload_path_patterns.add(p.parent.joinpath(f".*{p.suffix}"))
|
|
275
|
+
else:
|
|
276
|
+
upload_path_patterns.add(p.parent.joinpath(".*"))
|
|
277
|
+
else:
|
|
278
|
+
# Regular files: use * pattern for batch upload
|
|
279
|
+
if p.suffix:
|
|
280
|
+
upload_path_patterns.add(p.parent.joinpath(f"*{p.suffix}"))
|
|
281
|
+
else:
|
|
282
|
+
upload_path_patterns.add(p)
|
|
283
|
+
|
|
284
|
+
for path in upload_path_patterns:
|
|
263
285
|
session.file.put(
|
|
264
286
|
str(path),
|
|
265
287
|
payload_stage_path.joinpath(path.parent.relative_to(source_path)).as_posix(),
|
|
@@ -275,6 +297,27 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp
|
|
|
275
297
|
)
|
|
276
298
|
|
|
277
299
|
|
|
300
|
+
def upload_system_resources(session: snowpark.Session, stage_path: PurePath) -> None:
|
|
301
|
+
resource_ref = importlib.resources.files(__package__).joinpath("scripts")
|
|
302
|
+
|
|
303
|
+
def upload_dir(ref: Traversable, relative_path: str = "") -> None:
|
|
304
|
+
for item in ref.iterdir():
|
|
305
|
+
current_path = Path(relative_path) / item.name if relative_path else Path(item.name)
|
|
306
|
+
if item.is_dir():
|
|
307
|
+
# Recursively process subdirectories
|
|
308
|
+
upload_dir(item, str(current_path))
|
|
309
|
+
elif item.is_file():
|
|
310
|
+
content = item.read_bytes()
|
|
311
|
+
session.file.put_stream(
|
|
312
|
+
io.BytesIO(content),
|
|
313
|
+
stage_path.joinpath(current_path).as_posix(),
|
|
314
|
+
auto_compress=False,
|
|
315
|
+
overwrite=True,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
upload_dir(resource_ref)
|
|
319
|
+
|
|
320
|
+
|
|
278
321
|
def resolve_source(
|
|
279
322
|
source: Union[types.PayloadPath, Callable[..., Any]]
|
|
280
323
|
) -> Union[types.PayloadPath, Callable[..., Any]]:
|
|
@@ -454,8 +497,6 @@ class JobPayload:
|
|
|
454
497
|
overwrite=True,
|
|
455
498
|
)
|
|
456
499
|
source = Path(entrypoint.file_path.parent)
|
|
457
|
-
if not any(r.startswith("cloudpickle") for r in pip_requirements):
|
|
458
|
-
pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
|
|
459
500
|
|
|
460
501
|
elif isinstance(source, stage_utils.StagePath):
|
|
461
502
|
# copy payload to stage
|
|
@@ -470,19 +511,20 @@ class JobPayload:
|
|
|
470
511
|
|
|
471
512
|
upload_payloads(session, app_stage_path, *additional_payload_specs)
|
|
472
513
|
|
|
473
|
-
|
|
474
|
-
|
|
514
|
+
if not any(r.startswith("cloudpickle") for r in pip_requirements):
|
|
515
|
+
pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
|
|
516
|
+
|
|
517
|
+
# Upload system scripts and requirements.txt generated by pip_requirements to system/ directory
|
|
518
|
+
system_stage_path = stage_path.joinpath(constants.SYSTEM_STAGE_SUBPATH)
|
|
475
519
|
if pip_requirements:
|
|
476
520
|
# Upload requirements.txt to stage
|
|
477
521
|
session.file.put_stream(
|
|
478
522
|
io.BytesIO("\n".join(pip_requirements).encode()),
|
|
479
|
-
stage_location=
|
|
523
|
+
stage_location=system_stage_path.joinpath("requirements.txt").as_posix(),
|
|
480
524
|
auto_compress=False,
|
|
481
525
|
overwrite=True,
|
|
482
526
|
)
|
|
483
527
|
|
|
484
|
-
# Upload startup script to system/ directory within payload
|
|
485
|
-
system_stage_path = stage_path.joinpath(constants.SYSTEM_STAGE_SUBPATH)
|
|
486
528
|
# TODO: Make sure payload does not include file with same name
|
|
487
529
|
session.file.put_stream(
|
|
488
530
|
io.BytesIO(_STARTUP_SCRIPT_CODE.encode()),
|
|
@@ -491,15 +533,7 @@ class JobPayload:
|
|
|
491
533
|
overwrite=False, # FIXME
|
|
492
534
|
)
|
|
493
535
|
|
|
494
|
-
|
|
495
|
-
for script_file in scripts_dir.glob("*"):
|
|
496
|
-
if script_file.is_file():
|
|
497
|
-
session.file.put(
|
|
498
|
-
script_file.as_posix(),
|
|
499
|
-
system_stage_path.as_posix(),
|
|
500
|
-
overwrite=True,
|
|
501
|
-
auto_compress=False,
|
|
502
|
-
)
|
|
536
|
+
upload_system_resources(session, system_stage_path)
|
|
503
537
|
python_entrypoint: list[Union[str, PurePath]] = [
|
|
504
538
|
PurePath(f"{constants.SYSTEM_MOUNT_PATH}/mljob_launcher.py"),
|
|
505
539
|
PurePath(f"{constants.APP_MOUNT_PATH}/{entrypoint.file_path.relative_to(source).as_posix()}"),
|
|
@@ -4,6 +4,7 @@ from snowflake import snowpark
|
|
|
4
4
|
from snowflake.snowpark import Row
|
|
5
5
|
from snowflake.snowpark._internal import utils
|
|
6
6
|
from snowflake.snowpark._internal.analyzer import snowflake_plan
|
|
7
|
+
from snowflake.snowpark._internal.utils import is_in_stored_procedure
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
def result_set_to_rows(session: snowpark.Session, result: dict[str, Any]) -> list[Row]:
|
|
@@ -14,7 +15,10 @@ def result_set_to_rows(session: snowpark.Session, result: dict[str, Any]) -> lis
|
|
|
14
15
|
|
|
15
16
|
@snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
|
|
16
17
|
def run_query(session: snowpark.Session, query_text: str, params: Optional[Sequence[Any]] = None) -> list[Row]:
|
|
17
|
-
|
|
18
|
+
kwargs: dict[str, Any] = {"query": query_text, "params": params}
|
|
19
|
+
if not is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
|
20
|
+
kwargs["_force_qmark_paramstyle"] = True
|
|
21
|
+
result = session._conn.run_query(**kwargs)
|
|
18
22
|
if not isinstance(result, dict) or "data" not in result:
|
|
19
23
|
raise ValueError(f"Unprocessable result: {result}")
|
|
20
24
|
return result_set_to_rows(session, result)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
|
+
|
|
3
|
+
from packaging.version import Version
|
|
4
|
+
from pydantic import BaseModel, Field, RootModel, field_validator
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SpcsContainerRuntime(BaseModel):
|
|
8
|
+
python_version: Version = Field(alias="pythonVersion")
|
|
9
|
+
hardware_type: str = Field(alias="hardwareType")
|
|
10
|
+
runtime_container_image: str = Field(alias="runtimeContainerImage")
|
|
11
|
+
|
|
12
|
+
@field_validator("python_version", mode="before")
|
|
13
|
+
@classmethod
|
|
14
|
+
def validate_python_version(cls, v: Union[str, Version]) -> Version:
|
|
15
|
+
if isinstance(v, Version):
|
|
16
|
+
return v
|
|
17
|
+
try:
|
|
18
|
+
return Version(v)
|
|
19
|
+
except Exception:
|
|
20
|
+
raise ValueError(f"Invalid Python version format: {v}")
|
|
21
|
+
|
|
22
|
+
class Config:
|
|
23
|
+
frozen = True
|
|
24
|
+
extra = "allow"
|
|
25
|
+
arbitrary_types_allowed = True
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class RuntimeEnvironmentEntry(BaseModel):
|
|
29
|
+
spcs_container_runtime: Optional[SpcsContainerRuntime] = Field(alias="spcsContainerRuntime", default=None)
|
|
30
|
+
|
|
31
|
+
class Config:
|
|
32
|
+
extra = "allow"
|
|
33
|
+
frozen = True
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class RuntimeEnvironmentsDict(RootModel[dict[str, RuntimeEnvironmentEntry]]):
|
|
37
|
+
@field_validator("root", mode="before")
|
|
38
|
+
@classmethod
|
|
39
|
+
def _filter_to_dict_entries(cls, data: Any) -> dict[str, dict[str, Any]]:
|
|
40
|
+
"""
|
|
41
|
+
Pre-validation hook: keep only those items at the root level
|
|
42
|
+
whose values are dicts. Non-dict values will be dropped.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
data: The input data to filter, expected to be a dictionary.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
A dictionary containing only the key-value pairs where values are dictionaries.
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
ValueError: If input data is not a dictionary.
|
|
52
|
+
"""
|
|
53
|
+
# If the entire root is not a dict, raise error immediately
|
|
54
|
+
if not isinstance(data, dict):
|
|
55
|
+
raise ValueError(f"Expected dictionary data, but got {type(data).__name__}: {data}")
|
|
56
|
+
|
|
57
|
+
# Filter out any key whose value is not a dict
|
|
58
|
+
return {key: value for key, value in data.items() if isinstance(value, dict)}
|
|
59
|
+
|
|
60
|
+
def get_spcs_container_runtimes(self) -> list[SpcsContainerRuntime]:
|
|
61
|
+
return [
|
|
62
|
+
entry.spcs_container_runtime for entry in self.root.values() if entry.spcs_container_runtime is not None
|
|
63
|
+
]
|
|
@@ -47,8 +47,8 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str, str]]:
|
|
|
47
47
|
if not result:
|
|
48
48
|
return None
|
|
49
49
|
|
|
50
|
-
# Sort by start_time first, then by instance_id
|
|
51
|
-
sorted_instances = sorted(result, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
|
50
|
+
# Sort by start_time first, then by instance_id. If start_time is null/empty, it will be sorted to the end.
|
|
51
|
+
sorted_instances = sorted(result, key=lambda x: (not bool(x["start_time"]), x["start_time"], int(x["instance_id"])))
|
|
52
52
|
head_instance = sorted_instances[0]
|
|
53
53
|
if not head_instance["instance_id"] or not head_instance["ip_address"]:
|
|
54
54
|
return None
|
|
@@ -173,10 +173,10 @@ def wait_for_instances(
|
|
|
173
173
|
|
|
174
174
|
start_time = time.time()
|
|
175
175
|
current_interval = max(min(1, check_interval), 0.1) # Default 1s, minimum 0.1s
|
|
176
|
-
logger.
|
|
176
|
+
logger.info(
|
|
177
177
|
"Waiting for instances to be ready "
|
|
178
|
-
"(min_instances={}, target_instances={}, timeout={}s, max_check_interval={}s)".format(
|
|
179
|
-
min_instances, target_instances, timeout, check_interval
|
|
178
|
+
"(min_instances={}, target_instances={}, min_wait_time={}s, timeout={}s, max_check_interval={}s)".format(
|
|
179
|
+
min_instances, target_instances, min_wait_time, timeout, check_interval
|
|
180
180
|
)
|
|
181
181
|
)
|
|
182
182
|
|
|
@@ -191,7 +191,7 @@ def wait_for_instances(
|
|
|
191
191
|
logger.info(f"Minimum instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
|
|
192
192
|
return
|
|
193
193
|
|
|
194
|
-
logger.
|
|
194
|
+
logger.info(
|
|
195
195
|
f"Waiting for instances: current_instances={total_nodes}, min_instances={min_instances}, "
|
|
196
196
|
f"target_instances={target_instances}, elapsed={elapsed:.1f}s, next check in {current_interval:.1f}s"
|
|
197
197
|
)
|
|
@@ -199,7 +199,7 @@ def wait_for_instances(
|
|
|
199
199
|
current_interval = min(current_interval * 2, check_interval) # Exponential backoff
|
|
200
200
|
|
|
201
201
|
raise TimeoutError(
|
|
202
|
-
f"Timed out after {
|
|
202
|
+
f"Timed out after {elapsed}s waiting for {min_instances} instances, only " f"{total_nodes} available"
|
|
203
203
|
)
|
|
204
204
|
|
|
205
205
|
|
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
+
import sys
|
|
3
4
|
from math import ceil
|
|
4
5
|
from pathlib import PurePath
|
|
5
|
-
from typing import Any, Optional, Union
|
|
6
|
+
from typing import Any, Literal, Optional, Union
|
|
6
7
|
|
|
7
8
|
from snowflake import snowpark
|
|
8
9
|
from snowflake.ml._internal.utils import snowflake_env
|
|
9
10
|
from snowflake.ml.jobs._utils import constants, query_helper, types
|
|
11
|
+
from snowflake.ml.jobs._utils.runtime_env_utils import RuntimeEnvironmentsDict
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
|
|
@@ -28,22 +30,53 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
|
|
|
28
30
|
)
|
|
29
31
|
|
|
30
32
|
|
|
33
|
+
def _get_runtime_image(session: snowpark.Session, target_hardware: Literal["CPU", "GPU"]) -> Optional[str]:
|
|
34
|
+
rows = query_helper.run_query(session, "CALL SYSTEM$NOTEBOOKS_FIND_LABELED_RUNTIMES()")
|
|
35
|
+
if not rows:
|
|
36
|
+
return None
|
|
37
|
+
try:
|
|
38
|
+
runtime_envs = RuntimeEnvironmentsDict.model_validate_json(rows[0][0])
|
|
39
|
+
spcs_container_runtimes = runtime_envs.get_spcs_container_runtimes()
|
|
40
|
+
except Exception as e:
|
|
41
|
+
logging.warning(f"Failed to parse runtime image name from {rows[0][0]}, error: {e}")
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
selected_runtime = next(
|
|
45
|
+
(
|
|
46
|
+
runtime
|
|
47
|
+
for runtime in spcs_container_runtimes
|
|
48
|
+
if (
|
|
49
|
+
runtime.hardware_type.lower() == target_hardware.lower()
|
|
50
|
+
and runtime.python_version.major == sys.version_info.major
|
|
51
|
+
and runtime.python_version.minor == sys.version_info.minor
|
|
52
|
+
)
|
|
53
|
+
),
|
|
54
|
+
None,
|
|
55
|
+
)
|
|
56
|
+
return selected_runtime.runtime_container_image if selected_runtime else None
|
|
57
|
+
|
|
58
|
+
|
|
31
59
|
def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
|
|
32
60
|
# Retrieve compute pool node resources
|
|
33
61
|
resources = _get_node_resources(session, compute_pool=compute_pool)
|
|
34
62
|
|
|
35
63
|
# Use MLRuntime image
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
64
|
+
hardware = "GPU" if resources.gpu > 0 else "CPU"
|
|
65
|
+
container_image = None
|
|
66
|
+
if os.environ.get(constants.ENABLE_IMAGE_VERSION_ENV_VAR, "").lower() == "true":
|
|
67
|
+
container_image = _get_runtime_image(session, hardware) # type: ignore[arg-type]
|
|
68
|
+
|
|
69
|
+
if not container_image:
|
|
70
|
+
image_repo = constants.DEFAULT_IMAGE_REPO
|
|
71
|
+
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
|
72
|
+
image_tag = _get_runtime_image_tag()
|
|
73
|
+
container_image = f"{image_repo}/{image_name}:{image_tag}"
|
|
39
74
|
|
|
40
75
|
# TODO: Should each instance consume the entire pod?
|
|
41
76
|
return types.ImageSpec(
|
|
42
|
-
repo=image_repo,
|
|
43
|
-
image_name=image_name,
|
|
44
|
-
image_tag=image_tag,
|
|
45
77
|
resource_requests=resources,
|
|
46
78
|
resource_limits=resources,
|
|
79
|
+
container_image=container_image,
|
|
47
80
|
)
|
|
48
81
|
|
|
49
82
|
|
|
@@ -220,7 +253,7 @@ def generate_service_spec(
|
|
|
220
253
|
"containers": [
|
|
221
254
|
{
|
|
222
255
|
"name": constants.DEFAULT_CONTAINER_NAME,
|
|
223
|
-
"image": image_spec.
|
|
256
|
+
"image": image_spec.container_image,
|
|
224
257
|
"command": ["/usr/local/bin/_entrypoint.sh"],
|
|
225
258
|
"args": [
|
|
226
259
|
(stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
|