snowflake-ml-python 1.8.4__py3-none-any.whl → 1.8.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +42 -16
- snowflake/ml/_internal/utils/connection_params.py +196 -0
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +12 -2
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/interop_utils.py +1 -1
- snowflake/ml/jobs/_utils/payload_utils.py +95 -39
- snowflake/ml/jobs/_utils/scripts/constants.py +22 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +67 -2
- snowflake/ml/jobs/_utils/spec_utils.py +30 -6
- snowflake/ml/jobs/_utils/stage_utils.py +119 -0
- snowflake/ml/jobs/_utils/types.py +5 -1
- snowflake/ml/jobs/decorators.py +10 -7
- snowflake/ml/jobs/job.py +176 -28
- snowflake/ml/jobs/manager.py +119 -26
- snowflake/ml/model/_client/model/model_impl.py +58 -0
- snowflake/ml/model/_client/model/model_version_impl.py +90 -0
- snowflake/ml/model/_client/ops/model_ops.py +6 -3
- snowflake/ml/model/_client/ops/service_ops.py +24 -7
- snowflake/ml/model/_client/service/model_deployment_spec.py +11 -0
- snowflake/ml/model/_client/sql/model_version.py +1 -1
- snowflake/ml/model/_client/sql/service.py +73 -28
- snowflake/ml/model/_client/sql/stage.py +5 -2
- snowflake/ml/model/_model_composer/model_composer.py +3 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -2
- snowflake/ml/model/_signatures/core.py +24 -0
- snowflake/ml/monitoring/explain_visualize.py +160 -22
- snowflake/ml/monitoring/model_monitor.py +0 -4
- snowflake/ml/registry/registry.py +34 -14
- snowflake/ml/utils/connection_params.py +9 -3
- snowflake/ml/utils/html_utils.py +263 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/METADATA +40 -13
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/RECORD +40 -37
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/model_monitor_version.py +0 -1
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/top_level.txt +0 -0
@@ -12,10 +12,17 @@ import cloudpickle as cp
|
|
12
12
|
from packaging import version
|
13
13
|
|
14
14
|
from snowflake import snowpark
|
15
|
-
from snowflake.ml.jobs._utils import
|
15
|
+
from snowflake.ml.jobs._utils import (
|
16
|
+
constants,
|
17
|
+
function_payload_utils,
|
18
|
+
stage_utils,
|
19
|
+
types,
|
20
|
+
)
|
16
21
|
from snowflake.snowpark import exceptions as sp_exceptions
|
17
22
|
from snowflake.snowpark._internal import code_generation
|
18
23
|
|
24
|
+
cp.register_pickle_by_value(function_payload_utils)
|
25
|
+
|
19
26
|
_SUPPORTED_ARG_TYPES = {str, int, float}
|
20
27
|
_SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
|
21
28
|
_ENTRYPOINT_FUNC_NAME = "func"
|
@@ -100,6 +107,11 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
100
107
|
# Parse the output using read
|
101
108
|
read head_index head_ip head_status<<< "$head_info"
|
102
109
|
|
110
|
+
if [ "$SNOWFLAKE_JOB_INDEX" -ne "$head_index" ]; then
|
111
|
+
NODE_TYPE="worker"
|
112
|
+
echo "{constants.LOG_START_MSG}"
|
113
|
+
fi
|
114
|
+
|
103
115
|
# Use the parsed variables
|
104
116
|
echo "Head Instance Index: $head_index"
|
105
117
|
echo "Head Instance IP: $head_ip"
|
@@ -117,9 +129,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
117
129
|
exit 1
|
118
130
|
fi
|
119
131
|
|
120
|
-
|
121
|
-
NODE_TYPE="worker"
|
122
|
-
fi
|
132
|
+
|
123
133
|
fi
|
124
134
|
|
125
135
|
# Common parameters for both head and worker nodes
|
@@ -168,6 +178,10 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
168
178
|
# Start Ray on a worker node - run in background
|
169
179
|
ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block &
|
170
180
|
|
181
|
+
echo "Worker node started on address $eth0Ip. See more logs in the head node."
|
182
|
+
|
183
|
+
echo "{constants.LOG_END_MSG}"
|
184
|
+
|
171
185
|
# Start the worker shutdown listener in the background
|
172
186
|
echo "Starting worker shutdown listener..."
|
173
187
|
python worker_shutdown_listener.py
|
@@ -189,15 +203,16 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
189
203
|
|
190
204
|
# Start Ray on the head node
|
191
205
|
ray start "${{common_params[@]}}" "${{head_params[@]}}" -v
|
206
|
+
|
192
207
|
##### End Ray configuration #####
|
193
208
|
|
194
209
|
# TODO: Monitor MLRS and handle process crashes
|
195
210
|
python -m web.ml_runtime_grpc_server &
|
196
211
|
|
197
212
|
# TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
|
213
|
+
echo Running command: python "$@"
|
198
214
|
|
199
215
|
# Run user's Python entrypoint
|
200
|
-
echo Running command: python "$@"
|
201
216
|
python "$@"
|
202
217
|
|
203
218
|
# After the user's job completes, signal workers to shut down
|
@@ -209,20 +224,23 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
209
224
|
).strip()
|
210
225
|
|
211
226
|
|
212
|
-
def resolve_source(
|
227
|
+
def resolve_source(
|
228
|
+
source: Union[Path, stage_utils.StagePath, Callable[..., Any]]
|
229
|
+
) -> Union[Path, stage_utils.StagePath, Callable[..., Any]]:
|
213
230
|
if callable(source):
|
214
231
|
return source
|
215
|
-
elif isinstance(source, Path):
|
216
|
-
# Validate source
|
217
|
-
source = source
|
232
|
+
elif isinstance(source, (Path, stage_utils.StagePath)):
|
218
233
|
if not source.exists():
|
219
234
|
raise FileNotFoundError(f"{source} does not exist")
|
220
235
|
return source.absolute()
|
221
236
|
else:
|
222
|
-
raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
|
237
|
+
raise ValueError("Unsupported source type. Source must be a stage, file, directory, or callable.")
|
223
238
|
|
224
239
|
|
225
|
-
def resolve_entrypoint(
|
240
|
+
def resolve_entrypoint(
|
241
|
+
source: Union[Path, stage_utils.StagePath, Callable[..., Any]],
|
242
|
+
entrypoint: Optional[Union[stage_utils.StagePath, Path]],
|
243
|
+
) -> types.PayloadEntrypoint:
|
226
244
|
if callable(source):
|
227
245
|
# Entrypoint is generated for callable payloads
|
228
246
|
return types.PayloadEntrypoint(
|
@@ -237,11 +255,11 @@ def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Opti
|
|
237
255
|
# Infer entrypoint from source
|
238
256
|
entrypoint = parent
|
239
257
|
else:
|
240
|
-
raise ValueError("
|
258
|
+
raise ValueError("Entrypoint must be provided when source is a directory")
|
241
259
|
elif entrypoint.is_absolute():
|
242
260
|
# Absolute path - validate it's a subpath of source dir
|
243
261
|
if not entrypoint.is_relative_to(parent):
|
244
|
-
raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint}
|
262
|
+
raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint}")
|
245
263
|
else:
|
246
264
|
# Relative path
|
247
265
|
if (abs_entrypoint := entrypoint.absolute()).is_relative_to(parent) and abs_entrypoint.is_file():
|
@@ -257,6 +275,7 @@ def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Opti
|
|
257
275
|
"Entrypoint not found. Ensure the entrypoint is a valid file and is under"
|
258
276
|
f" the source directory (source={parent}, entrypoint={entrypoint})"
|
259
277
|
)
|
278
|
+
|
260
279
|
if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
|
261
280
|
raise ValueError(
|
262
281
|
"Unsupported entrypoint type:"
|
@@ -277,8 +296,9 @@ class JobPayload:
|
|
277
296
|
*,
|
278
297
|
pip_requirements: Optional[list[str]] = None,
|
279
298
|
) -> None:
|
280
|
-
|
281
|
-
self.
|
299
|
+
# for stage path like snow://domain....., Path(path) will remove duplicate /, it will become snow:/ domain...
|
300
|
+
self.source = stage_utils.identify_stage_path(source) if isinstance(source, str) else source
|
301
|
+
self.entrypoint = stage_utils.identify_stage_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
|
282
302
|
self.pip_requirements = pip_requirements
|
283
303
|
|
284
304
|
def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
|
@@ -302,7 +322,7 @@ class JobPayload:
|
|
302
322
|
).collect()
|
303
323
|
|
304
324
|
# Upload payload to stage
|
305
|
-
if not isinstance(source, Path):
|
325
|
+
if not isinstance(source, (Path, stage_utils.StagePath)):
|
306
326
|
source_code = generate_python_code(source, source_code_display=True)
|
307
327
|
_ = session.file.put_stream(
|
308
328
|
io.BytesIO(source_code.encode()),
|
@@ -313,27 +333,38 @@ class JobPayload:
|
|
313
333
|
source = Path(entrypoint.file_path.parent)
|
314
334
|
if not any(r.startswith("cloudpickle") for r in pip_requirements):
|
315
335
|
pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
|
316
|
-
|
317
|
-
|
318
|
-
#
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
}
|
336
|
+
|
337
|
+
elif isinstance(source, stage_utils.StagePath):
|
338
|
+
# copy payload to stage
|
339
|
+
if source == entrypoint.file_path:
|
340
|
+
source = source.parent
|
341
|
+
source_path = source.as_posix() + "/"
|
342
|
+
session.sql(f"copy files into {stage_path}/ from {source_path}").collect()
|
343
|
+
|
344
|
+
elif isinstance(source, Path):
|
345
|
+
if source.is_dir():
|
346
|
+
# Manually traverse the directory and upload each file, since Snowflake PUT
|
347
|
+
# can't handle directories. Reduce the number of PUT operations by using
|
348
|
+
# wildcard patterns to batch upload files with the same extension.
|
349
|
+
for path in {
|
350
|
+
p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
|
351
|
+
for p in source.resolve().rglob("*")
|
352
|
+
if p.is_file()
|
353
|
+
}:
|
354
|
+
session.file.put(
|
355
|
+
str(path),
|
356
|
+
stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
|
357
|
+
overwrite=True,
|
358
|
+
auto_compress=False,
|
359
|
+
)
|
360
|
+
else:
|
323
361
|
session.file.put(
|
324
|
-
str(
|
325
|
-
stage_path.
|
362
|
+
str(source.resolve()),
|
363
|
+
stage_path.as_posix(),
|
326
364
|
overwrite=True,
|
327
365
|
auto_compress=False,
|
328
366
|
)
|
329
|
-
|
330
|
-
session.file.put(
|
331
|
-
str(source.resolve()),
|
332
|
-
stage_path.as_posix(),
|
333
|
-
overwrite=True,
|
334
|
-
auto_compress=False,
|
335
|
-
)
|
336
|
-
source = source.parent
|
367
|
+
source = source.parent
|
337
368
|
|
338
369
|
# Upload requirements
|
339
370
|
# TODO: Check if payload includes both a requirements.txt file and pip_requirements
|
@@ -494,9 +525,15 @@ def _generate_param_handler_code(signature: inspect.Signature, output_name: str
|
|
494
525
|
return param_code
|
495
526
|
|
496
527
|
|
497
|
-
def generate_python_code(
|
528
|
+
def generate_python_code(payload: Callable[..., Any], source_code_display: bool = False) -> str:
|
498
529
|
"""Generate an entrypoint script from a Python function."""
|
499
|
-
|
530
|
+
|
531
|
+
if isinstance(payload, function_payload_utils.FunctionPayload):
|
532
|
+
function = payload.function
|
533
|
+
else:
|
534
|
+
function = payload
|
535
|
+
|
536
|
+
signature = inspect.signature(function)
|
500
537
|
if any(
|
501
538
|
p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
502
539
|
for p in signature.parameters.values()
|
@@ -505,21 +542,20 @@ def generate_python_code(func: Callable[..., Any], source_code_display: bool = F
|
|
505
542
|
|
506
543
|
# Mirrored from Snowpark generate_python_code() function
|
507
544
|
# https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
|
508
|
-
source_code_comment = _generate_source_code_comment(
|
545
|
+
source_code_comment = _generate_source_code_comment(function) if source_code_display else ""
|
509
546
|
|
510
547
|
arg_dict_name = "kwargs"
|
511
|
-
if
|
548
|
+
if isinstance(payload, function_payload_utils.FunctionPayload):
|
512
549
|
param_code = f"{arg_dict_name} = {{}}"
|
513
550
|
else:
|
514
551
|
param_code = _generate_param_handler_code(signature, arg_dict_name)
|
515
|
-
|
516
552
|
return f"""
|
517
553
|
import sys
|
518
554
|
import pickle
|
519
555
|
|
520
556
|
try:
|
521
557
|
{textwrap.indent(source_code_comment, ' ')}
|
522
|
-
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(
|
558
|
+
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(payload).hex()}'))
|
523
559
|
except (TypeError, pickle.PickleError):
|
524
560
|
if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
|
525
561
|
raise RuntimeError(
|
@@ -543,3 +579,23 @@ if __name__ == '__main__':
|
|
543
579
|
|
544
580
|
__return__ = {_ENTRYPOINT_FUNC_NAME}(**{arg_dict_name})
|
545
581
|
"""
|
582
|
+
|
583
|
+
|
584
|
+
def create_function_payload(
|
585
|
+
func: Callable[..., Any], *args: Any, **kwargs: Any
|
586
|
+
) -> function_payload_utils.FunctionPayload:
|
587
|
+
signature = inspect.signature(func)
|
588
|
+
bound = signature.bind(*args, **kwargs)
|
589
|
+
bound.apply_defaults()
|
590
|
+
session_argument = ""
|
591
|
+
session = None
|
592
|
+
for name, val in list(bound.arguments.items()):
|
593
|
+
if isinstance(val, snowpark.Session):
|
594
|
+
if session:
|
595
|
+
raise TypeError(f"Expected only one Session-type argument, but got both {session_argument} and {name}.")
|
596
|
+
session = val
|
597
|
+
session_argument = name
|
598
|
+
del bound.arguments[name]
|
599
|
+
payload = function_payload_utils.FunctionPayload(func, session, session_argument, *bound.args, **bound.kwargs)
|
600
|
+
|
601
|
+
return payload
|
@@ -1,4 +1,26 @@
|
|
1
|
+
from snowflake.ml.jobs._utils import constants as mljob_constants
|
2
|
+
|
1
3
|
# Constants defining the shutdown signal actor configuration.
|
2
4
|
SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
|
3
5
|
SHUTDOWN_ACTOR_NAMESPACE = "default"
|
4
6
|
SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0
|
7
|
+
|
8
|
+
|
9
|
+
# The followings are Inherited from snowflake.ml.jobs._utils.constants
|
10
|
+
# We need to copy them here since snowml package on the server side does
|
11
|
+
# not have the latest version of the code
|
12
|
+
|
13
|
+
# Log start and end messages
|
14
|
+
LOG_START_MSG = getattr(
|
15
|
+
mljob_constants,
|
16
|
+
"LOG_START_MSG",
|
17
|
+
"--------------------------------\nML job started\n--------------------------------",
|
18
|
+
)
|
19
|
+
LOG_END_MSG = getattr(
|
20
|
+
mljob_constants,
|
21
|
+
"LOG_END_MSG",
|
22
|
+
"--------------------------------\nML job finished\n--------------------------------",
|
23
|
+
)
|
24
|
+
|
25
|
+
# min_instances environment variable name
|
26
|
+
MIN_INSTANCES_ENV_VAR = getattr(mljob_constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
|
@@ -2,25 +2,35 @@ import argparse
|
|
2
2
|
import copy
|
3
3
|
import importlib.util
|
4
4
|
import json
|
5
|
+
import logging
|
5
6
|
import os
|
6
7
|
import runpy
|
7
8
|
import sys
|
9
|
+
import time
|
8
10
|
import traceback
|
9
11
|
import warnings
|
10
12
|
from pathlib import Path
|
11
13
|
from typing import Any, Optional
|
12
14
|
|
13
15
|
import cloudpickle
|
16
|
+
from constants import LOG_END_MSG, LOG_START_MSG, MIN_INSTANCES_ENV_VAR
|
14
17
|
|
15
18
|
from snowflake.ml.jobs._utils import constants
|
16
19
|
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
17
20
|
from snowflake.snowpark import Session
|
18
21
|
|
22
|
+
# Configure logging
|
23
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
19
26
|
# Fallbacks in case of SnowML version mismatch
|
20
27
|
RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
|
21
|
-
|
22
28
|
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "mljob_result.pkl")
|
23
29
|
|
30
|
+
# Constants for the wait_for_min_instances function
|
31
|
+
CHECK_INTERVAL = 10 # seconds
|
32
|
+
TIMEOUT = 720 # seconds
|
33
|
+
|
24
34
|
|
25
35
|
try:
|
26
36
|
from snowflake.ml.jobs._utils.interop_utils import ExecutionResult
|
@@ -62,6 +72,48 @@ class SimpleJSONEncoder(json.JSONEncoder):
|
|
62
72
|
return f"Unserializable object: {repr(obj)}"
|
63
73
|
|
64
74
|
|
75
|
+
def wait_for_min_instances(min_instances: int) -> None:
|
76
|
+
"""
|
77
|
+
Wait until the specified minimum number of instances are available in the Ray cluster.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
min_instances: Minimum number of instances required
|
81
|
+
|
82
|
+
Raises:
|
83
|
+
TimeoutError: If failed to connect to Ray or if minimum instances are not available within timeout
|
84
|
+
"""
|
85
|
+
if min_instances <= 1:
|
86
|
+
logger.debug("Minimum instances is 1 or less, no need to wait for additional instances")
|
87
|
+
return
|
88
|
+
|
89
|
+
# mljob_launcher runs inside the CR where mlruntime libraries are available, so we can import common_util directly
|
90
|
+
from common_utils import common_util as mlrs_util
|
91
|
+
|
92
|
+
start_time = time.time()
|
93
|
+
timeout = os.getenv("JOB_MIN_INSTANCES_TIMEOUT", TIMEOUT)
|
94
|
+
check_interval = os.getenv("JOB_MIN_INSTANCES_CHECK_INTERVAL", CHECK_INTERVAL)
|
95
|
+
logger.debug(f"Waiting for at least {min_instances} instances to be ready (timeout: {timeout}s)")
|
96
|
+
|
97
|
+
while time.time() - start_time < timeout:
|
98
|
+
total_nodes = mlrs_util.get_num_ray_nodes()
|
99
|
+
|
100
|
+
if total_nodes >= min_instances:
|
101
|
+
elapsed = time.time() - start_time
|
102
|
+
logger.info(f"Minimum instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
|
103
|
+
return
|
104
|
+
|
105
|
+
logger.debug(
|
106
|
+
f"Waiting for instances: {total_nodes}/{min_instances} available "
|
107
|
+
f"(elapsed: {time.time() - start_time:.1f}s)"
|
108
|
+
)
|
109
|
+
time.sleep(check_interval)
|
110
|
+
|
111
|
+
raise TimeoutError(
|
112
|
+
f"Timed out after {timeout}s waiting for {min_instances} instances, only "
|
113
|
+
f"{mlrs_util.get_num_ray_nodes()} available"
|
114
|
+
)
|
115
|
+
|
116
|
+
|
65
117
|
def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = None) -> Any:
|
66
118
|
"""
|
67
119
|
Execute a Python script and return its result.
|
@@ -86,6 +138,7 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
86
138
|
session = Session.builder.configs(SnowflakeLoginOptions()).create() # noqa: F841
|
87
139
|
|
88
140
|
try:
|
141
|
+
|
89
142
|
if main_func:
|
90
143
|
# Use importlib for scripts with a main function defined
|
91
144
|
module_name = Path(script_path).stem
|
@@ -126,9 +179,21 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
126
179
|
Raises:
|
127
180
|
Exception: Re-raises any exception caught during script execution.
|
128
181
|
"""
|
129
|
-
# Run the script with the specified arguments
|
130
182
|
try:
|
183
|
+
# Wait for minimum required instances if specified
|
184
|
+
min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
|
185
|
+
if min_instances_str and int(min_instances_str) > 1:
|
186
|
+
wait_for_min_instances(int(min_instances_str))
|
187
|
+
|
188
|
+
# Log start marker for user script execution
|
189
|
+
print(LOG_START_MSG) # noqa: T201
|
190
|
+
|
191
|
+
# Run the script with the specified arguments
|
131
192
|
result = run_script(script_path, *script_args, main_func=script_main_func)
|
193
|
+
|
194
|
+
# Log end marker for user script execution
|
195
|
+
print(LOG_END_MSG) # noqa: T201
|
196
|
+
|
132
197
|
result_obj = ExecutionResult(result=result)
|
133
198
|
return result_obj
|
134
199
|
except Exception as e:
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import logging
|
2
|
+
import os
|
2
3
|
from math import ceil
|
3
4
|
from pathlib import PurePath
|
4
5
|
from typing import Any, Optional, Union
|
@@ -30,7 +31,7 @@ def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.Image
|
|
30
31
|
# Use MLRuntime image
|
31
32
|
image_repo = constants.DEFAULT_IMAGE_REPO
|
32
33
|
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
33
|
-
image_tag =
|
34
|
+
image_tag = _get_runtime_image_tag()
|
34
35
|
|
35
36
|
# TODO: Should each instance consume the entire pod?
|
36
37
|
return types.ImageSpec(
|
@@ -85,7 +86,8 @@ def generate_service_spec(
|
|
85
86
|
compute_pool: str,
|
86
87
|
payload: types.UploadedPayload,
|
87
88
|
args: Optional[list[str]] = None,
|
88
|
-
|
89
|
+
target_instances: int = 1,
|
90
|
+
min_instances: int = 1,
|
89
91
|
enable_metrics: bool = False,
|
90
92
|
) -> dict[str, Any]:
|
91
93
|
"""
|
@@ -96,13 +98,13 @@ def generate_service_spec(
|
|
96
98
|
compute_pool: Compute pool for job execution
|
97
99
|
payload: Uploaded job payload
|
98
100
|
args: Arguments to pass to entrypoint script
|
99
|
-
|
101
|
+
target_instances: Number of instances for multi-node job
|
100
102
|
enable_metrics: Enable platform metrics for the job
|
103
|
+
min_instances: Minimum number of instances required to start the job
|
101
104
|
|
102
105
|
Returns:
|
103
106
|
Job service specification
|
104
107
|
"""
|
105
|
-
is_multi_node = num_instances is not None and num_instances > 1
|
106
108
|
image_spec = _get_image_spec(session, compute_pool)
|
107
109
|
|
108
110
|
# Set resource requests/limits, including nvidia.com/gpu quantity if applicable
|
@@ -180,10 +182,11 @@ def generate_service_spec(
|
|
180
182
|
}
|
181
183
|
endpoints = []
|
182
184
|
|
183
|
-
if
|
185
|
+
if target_instances > 1:
|
184
186
|
# Update environment variables for multi-node job
|
185
187
|
env_vars.update(constants.RAY_PORTS)
|
186
|
-
env_vars[
|
188
|
+
env_vars[constants.ENABLE_HEALTH_CHECKS_ENV_VAR] = constants.ENABLE_HEALTH_CHECKS
|
189
|
+
env_vars[constants.MIN_INSTANCES_ENV_VAR] = str(min_instances)
|
187
190
|
|
188
191
|
# Define Ray endpoints for intra-service instance communication
|
189
192
|
ray_endpoints = [
|
@@ -344,3 +347,24 @@ def _merge_lists_of_dicts(
|
|
344
347
|
result[key] = d
|
345
348
|
|
346
349
|
return list(result.values())
|
350
|
+
|
351
|
+
|
352
|
+
def _get_runtime_image_tag() -> str:
|
353
|
+
"""
|
354
|
+
Detect runtime image tag from container environment.
|
355
|
+
|
356
|
+
Checks in order:
|
357
|
+
1. Environment variable MLRS_CONTAINER_IMAGE_TAG
|
358
|
+
2. Falls back to hardcoded default
|
359
|
+
|
360
|
+
Returns:
|
361
|
+
str: The runtime image tag to use for job containers
|
362
|
+
"""
|
363
|
+
env_tag = os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR)
|
364
|
+
if env_tag:
|
365
|
+
logging.debug(f"Using runtime image tag from environment: {env_tag}")
|
366
|
+
return env_tag
|
367
|
+
|
368
|
+
# Fall back to default
|
369
|
+
logging.debug(f"Using default runtime image tag: {constants.DEFAULT_IMAGE_TAG}")
|
370
|
+
return constants.DEFAULT_IMAGE_TAG
|
@@ -0,0 +1,119 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
from os import PathLike
|
4
|
+
from pathlib import Path, PurePath
|
5
|
+
from typing import Union
|
6
|
+
|
7
|
+
from snowflake.ml._internal.utils import identifier
|
8
|
+
|
9
|
+
PROTOCOL_NAME = "snow"
|
10
|
+
_SNOWURL_PATH_RE = re.compile(
|
11
|
+
rf"^(?:(?:{PROTOCOL_NAME}://)?"
|
12
|
+
r"(?<!@)(?P<domain>\w+)/"
|
13
|
+
rf"(?P<name>(?:{identifier._SF_IDENTIFIER}\.){{,2}}{identifier._SF_IDENTIFIER})/)?"
|
14
|
+
r"(?P<path>versions(?:/(?P<version>[^/]+)(?:/(?P<relpath>.*))?)?)$"
|
15
|
+
)
|
16
|
+
|
17
|
+
_STAGEF_PATH_RE = re.compile(r"^@(?P<stage>~|%?\w+)(?:/(?P<relpath>[\w\-./]*))?$")
|
18
|
+
|
19
|
+
|
20
|
+
class StagePath:
|
21
|
+
def __init__(self, path: str) -> None:
|
22
|
+
stage_match = _SNOWURL_PATH_RE.fullmatch(path) or _STAGEF_PATH_RE.fullmatch(path)
|
23
|
+
if not stage_match:
|
24
|
+
raise ValueError(f"{path} is not a valid stage path")
|
25
|
+
path = path.strip()
|
26
|
+
self._raw_path = path
|
27
|
+
relpath = stage_match.group("relpath")
|
28
|
+
start, _ = stage_match.span("relpath")
|
29
|
+
self._root = self._raw_path[0:start].rstrip("/") if relpath else self._raw_path.rstrip("/")
|
30
|
+
self._path = Path(relpath or "")
|
31
|
+
|
32
|
+
@property
|
33
|
+
def parent(self) -> "StagePath":
|
34
|
+
if self._path.parent == Path(""):
|
35
|
+
return StagePath(self._root)
|
36
|
+
else:
|
37
|
+
return StagePath(f"{self._root}/{self._path.parent}")
|
38
|
+
|
39
|
+
@property
|
40
|
+
def root(self) -> str:
|
41
|
+
return self._root
|
42
|
+
|
43
|
+
@property
|
44
|
+
def suffix(self) -> str:
|
45
|
+
return self._path.suffix
|
46
|
+
|
47
|
+
def _compose_path(self, path: Path) -> str:
|
48
|
+
# in pathlib, Path("") = "."
|
49
|
+
if path == Path(""):
|
50
|
+
return self.root
|
51
|
+
else:
|
52
|
+
return f"{self.root}/{path}"
|
53
|
+
|
54
|
+
def is_relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> bool:
|
55
|
+
stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
|
56
|
+
if stage_path.root == self.root:
|
57
|
+
return self._path.is_relative_to(stage_path._path)
|
58
|
+
else:
|
59
|
+
return False
|
60
|
+
|
61
|
+
def relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> PurePath:
|
62
|
+
stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
|
63
|
+
if self.root == stage_path.root:
|
64
|
+
return self._path.relative_to(stage_path._path)
|
65
|
+
raise ValueError(f"{self._raw_path} does not start with {stage_path._raw_path}")
|
66
|
+
|
67
|
+
def absolute(self) -> "StagePath":
|
68
|
+
return self
|
69
|
+
|
70
|
+
def as_posix(self) -> str:
|
71
|
+
return self._compose_path(self._path)
|
72
|
+
|
73
|
+
# TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
|
74
|
+
def exists(self) -> bool:
|
75
|
+
return True
|
76
|
+
|
77
|
+
# TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
|
78
|
+
def is_file(self) -> bool:
|
79
|
+
return True
|
80
|
+
|
81
|
+
# TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
|
82
|
+
def is_dir(self) -> bool:
|
83
|
+
return True
|
84
|
+
|
85
|
+
def is_absolute(self) -> bool:
|
86
|
+
return True
|
87
|
+
|
88
|
+
def __str__(self) -> str:
|
89
|
+
return self.as_posix()
|
90
|
+
|
91
|
+
def __eq__(self, other: object) -> bool:
|
92
|
+
if not isinstance(other, StagePath):
|
93
|
+
raise NotImplementedError
|
94
|
+
return bool(self.root == other.root and self._path == other._path)
|
95
|
+
|
96
|
+
def __fspath__(self) -> str:
|
97
|
+
return self._compose_path(self._path)
|
98
|
+
|
99
|
+
def joinpath(self, *args: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
|
100
|
+
path = self
|
101
|
+
for arg in args:
|
102
|
+
path = path._make_child(arg)
|
103
|
+
return path
|
104
|
+
|
105
|
+
def _make_child(self, path: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
|
106
|
+
stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
|
107
|
+
if self.root == stage_path.root:
|
108
|
+
child_path = self._path.joinpath(stage_path._path)
|
109
|
+
return StagePath(self._compose_path(child_path))
|
110
|
+
else:
|
111
|
+
return stage_path
|
112
|
+
|
113
|
+
|
114
|
+
def identify_stage_path(path: str) -> Union[StagePath, Path]:
|
115
|
+
try:
|
116
|
+
stage_path = StagePath(path)
|
117
|
+
except ValueError:
|
118
|
+
return Path(path)
|
119
|
+
return stage_path
|
@@ -2,18 +2,22 @@ from dataclasses import dataclass
|
|
2
2
|
from pathlib import PurePath
|
3
3
|
from typing import Literal, Optional, Union
|
4
4
|
|
5
|
+
from snowflake.ml.jobs._utils import stage_utils
|
6
|
+
|
5
7
|
JOB_STATUS = Literal[
|
6
8
|
"PENDING",
|
7
9
|
"RUNNING",
|
8
10
|
"FAILED",
|
9
11
|
"DONE",
|
12
|
+
"CANCELLING",
|
13
|
+
"CANCELLED",
|
10
14
|
"INTERNAL_ERROR",
|
11
15
|
]
|
12
16
|
|
13
17
|
|
14
18
|
@dataclass(frozen=True)
|
15
19
|
class PayloadEntrypoint:
|
16
|
-
file_path: PurePath
|
20
|
+
file_path: Union[PurePath, stage_utils.StagePath]
|
17
21
|
main_func: Optional[str]
|
18
22
|
|
19
23
|
|
snowflake/ml/jobs/decorators.py
CHANGED
@@ -7,7 +7,7 @@ from typing_extensions import ParamSpec
|
|
7
7
|
from snowflake import snowpark
|
8
8
|
from snowflake.ml._internal import telemetry
|
9
9
|
from snowflake.ml.jobs import job as jb, manager as jm
|
10
|
-
from snowflake.ml.jobs._utils import
|
10
|
+
from snowflake.ml.jobs._utils import payload_utils
|
11
11
|
|
12
12
|
_PROJECT = "MLJob"
|
13
13
|
|
@@ -24,7 +24,8 @@ def remote(
|
|
24
24
|
external_access_integrations: Optional[list[str]] = None,
|
25
25
|
query_warehouse: Optional[str] = None,
|
26
26
|
env_vars: Optional[dict[str, str]] = None,
|
27
|
-
|
27
|
+
target_instances: int = 1,
|
28
|
+
min_instances: Optional[int] = None,
|
28
29
|
enable_metrics: bool = False,
|
29
30
|
database: Optional[str] = None,
|
30
31
|
schema: Optional[str] = None,
|
@@ -40,7 +41,9 @@ def remote(
|
|
40
41
|
external_access_integrations: A list of external access integrations.
|
41
42
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
42
43
|
env_vars: Environment variables to set in container
|
43
|
-
|
44
|
+
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
45
|
+
min_instances: The minimum number of nodes required to start the job. If none specified,
|
46
|
+
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
44
47
|
enable_metrics: Whether to enable metrics publishing for the job.
|
45
48
|
database: The database to use for the job.
|
46
49
|
schema: The schema to use for the job.
|
@@ -59,8 +62,7 @@ def remote(
|
|
59
62
|
|
60
63
|
@functools.wraps(func)
|
61
64
|
def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
|
62
|
-
payload =
|
63
|
-
setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
|
65
|
+
payload = payload_utils.create_function_payload(func, *args, **kwargs)
|
64
66
|
job = jm._submit_job(
|
65
67
|
source=payload,
|
66
68
|
stage_name=stage_name,
|
@@ -69,11 +71,12 @@ def remote(
|
|
69
71
|
external_access_integrations=external_access_integrations,
|
70
72
|
query_warehouse=query_warehouse,
|
71
73
|
env_vars=env_vars,
|
72
|
-
|
74
|
+
target_instances=target_instances,
|
75
|
+
min_instances=min_instances,
|
73
76
|
enable_metrics=enable_metrics,
|
74
77
|
database=database,
|
75
78
|
schema=schema,
|
76
|
-
session=session,
|
79
|
+
session=payload.session or session,
|
77
80
|
)
|
78
81
|
assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
|
79
82
|
return job
|