snowflake-ml-python 1.9.1__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/mixins.py +6 -4
- snowflake/ml/_internal/utils/service_logger.py +118 -4
- snowflake/ml/data/_internal/arrow_ingestor.py +4 -1
- snowflake/ml/data/data_connector.py +4 -34
- snowflake/ml/dataset/dataset.py +1 -1
- snowflake/ml/dataset/dataset_reader.py +2 -8
- snowflake/ml/experiment/__init__.py +3 -0
- snowflake/ml/experiment/callback/lightgbm.py +55 -0
- snowflake/ml/experiment/callback/xgboost.py +63 -0
- snowflake/ml/experiment/utils.py +14 -0
- snowflake/ml/jobs/_utils/constants.py +15 -4
- snowflake/ml/jobs/_utils/payload_utils.py +159 -52
- snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +126 -23
- snowflake/ml/jobs/_utils/spec_utils.py +1 -1
- snowflake/ml/jobs/_utils/stage_utils.py +30 -14
- snowflake/ml/jobs/_utils/types.py +64 -4
- snowflake/ml/jobs/job.py +22 -6
- snowflake/ml/jobs/manager.py +5 -3
- snowflake/ml/model/_client/model/model_version_impl.py +56 -48
- snowflake/ml/model/_client/ops/service_ops.py +194 -14
- snowflake/ml/model/_client/sql/service.py +1 -38
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -0
- snowflake/ml/model/_signatures/pandas_handler.py +3 -0
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/event_handler.py +87 -18
- snowflake/ml/model/model_signature.py +2 -0
- snowflake/ml/model/models/huggingface_pipeline.py +71 -49
- snowflake/ml/model/type_hints.py +26 -1
- snowflake/ml/registry/_manager/model_manager.py +30 -35
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +105 -0
- snowflake/ml/registry/registry.py +0 -19
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/METADATA +542 -491
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/RECORD +39 -34
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,8 @@ import functools
|
|
|
2
2
|
import inspect
|
|
3
3
|
import io
|
|
4
4
|
import itertools
|
|
5
|
+
import keyword
|
|
6
|
+
import logging
|
|
5
7
|
import pickle
|
|
6
8
|
import sys
|
|
7
9
|
import textwrap
|
|
@@ -22,8 +24,11 @@ from snowflake.ml.jobs._utils import (
|
|
|
22
24
|
from snowflake.snowpark import exceptions as sp_exceptions
|
|
23
25
|
from snowflake.snowpark._internal import code_generation
|
|
24
26
|
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
25
29
|
cp.register_pickle_by_value(function_payload_utils)
|
|
26
30
|
|
|
31
|
+
|
|
27
32
|
_SUPPORTED_ARG_TYPES = {str, int, float}
|
|
28
33
|
_SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
|
|
29
34
|
_ENTRYPOINT_FUNC_NAME = "func"
|
|
@@ -32,6 +37,9 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
32
37
|
f"""
|
|
33
38
|
#!/bin/bash
|
|
34
39
|
|
|
40
|
+
##### Get system scripts directory #####
|
|
41
|
+
SYSTEM_DIR=$(cd "$(dirname "$0")" && pwd)
|
|
42
|
+
|
|
35
43
|
##### Perform common set up steps #####
|
|
36
44
|
set -e # exit if a command fails
|
|
37
45
|
|
|
@@ -55,6 +63,13 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
55
63
|
|
|
56
64
|
##### Set up Python environment #####
|
|
57
65
|
export PYTHONPATH=/opt/env/site-packages/
|
|
66
|
+
MLRS_SYSTEM_REQUIREMENTS_FILE=${{MLRS_SYSTEM_REQUIREMENTS_FILE:-"${{SYSTEM_DIR}}/requirements.txt"}}
|
|
67
|
+
|
|
68
|
+
if [ -f "${{MLRS_SYSTEM_REQUIREMENTS_FILE}}" ]; then
|
|
69
|
+
echo "Installing packages from $MLRS_SYSTEM_REQUIREMENTS_FILE"
|
|
70
|
+
pip install -r $MLRS_SYSTEM_REQUIREMENTS_FILE
|
|
71
|
+
fi
|
|
72
|
+
|
|
58
73
|
MLRS_REQUIREMENTS_FILE=${{MLRS_REQUIREMENTS_FILE:-"requirements.txt"}}
|
|
59
74
|
if [ -f "${{MLRS_REQUIREMENTS_FILE}}" ]; then
|
|
60
75
|
# TODO: Prevent collisions with MLRS packages using virtualenvs
|
|
@@ -75,12 +90,14 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
75
90
|
|
|
76
91
|
# Check if the local get_instance_ip.py script exists
|
|
77
92
|
HELPER_EXISTS=$(
|
|
78
|
-
[ -f "get_instance_ip.py" ] && echo "true" || echo "false"
|
|
93
|
+
[ -f "${{SYSTEM_DIR}}/get_instance_ip.py" ] && echo "true" || echo "false"
|
|
79
94
|
)
|
|
80
95
|
|
|
96
|
+
|
|
81
97
|
# Configure IP address and logging directory
|
|
82
98
|
if [ "$HELPER_EXISTS" = "true" ]; then
|
|
83
|
-
eth0Ip=$(python3 get_instance_ip.py
|
|
99
|
+
eth0Ip=$(python3 "${{SYSTEM_DIR}}/get_instance_ip.py" \
|
|
100
|
+
"$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
|
|
84
101
|
else
|
|
85
102
|
eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
|
|
86
103
|
fi
|
|
@@ -103,7 +120,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
103
120
|
|
|
104
121
|
# Determine if it should be a worker or a head node for batch jobs
|
|
105
122
|
if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 && "$HELPER_EXISTS" = "true" ]]; then
|
|
106
|
-
head_info=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --head)
|
|
123
|
+
head_info=$(python3 "${{SYSTEM_DIR}}/get_instance_ip.py" "$SNOWFLAKE_SERVICE_NAME" --head)
|
|
107
124
|
if [ $? -eq 0 ]; then
|
|
108
125
|
# Parse the output using read
|
|
109
126
|
read head_index head_ip head_status<<< "$head_info"
|
|
@@ -185,7 +202,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
185
202
|
|
|
186
203
|
# Start the worker shutdown listener in the background
|
|
187
204
|
echo "Starting worker shutdown listener..."
|
|
188
|
-
python worker_shutdown_listener.py
|
|
205
|
+
python "${{SYSTEM_DIR}}/worker_shutdown_listener.py"
|
|
189
206
|
WORKER_EXIT_CODE=$?
|
|
190
207
|
|
|
191
208
|
echo "Worker shutdown listener exited with code $WORKER_EXIT_CODE"
|
|
@@ -218,19 +235,59 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
218
235
|
|
|
219
236
|
# After the user's job completes, signal workers to shut down
|
|
220
237
|
echo "User job completed. Signaling workers to shut down..."
|
|
221
|
-
python signal_workers.py --wait-time 15
|
|
238
|
+
python "${{SYSTEM_DIR}}/signal_workers.py" --wait-time 15
|
|
222
239
|
echo "Head node job completed. Exiting."
|
|
223
240
|
fi
|
|
224
241
|
"""
|
|
225
242
|
).strip()
|
|
226
243
|
|
|
227
244
|
|
|
245
|
+
def resolve_path(path: str) -> types.PayloadPath:
|
|
246
|
+
try:
|
|
247
|
+
stage_path = stage_utils.StagePath(path)
|
|
248
|
+
except ValueError:
|
|
249
|
+
return Path(path)
|
|
250
|
+
return stage_path
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_specs: types.PayloadSpec) -> None:
|
|
254
|
+
for source_path, remote_relative_path in payload_specs:
|
|
255
|
+
payload_stage_path = stage_path.joinpath(remote_relative_path) if remote_relative_path else stage_path
|
|
256
|
+
if isinstance(source_path, stage_utils.StagePath):
|
|
257
|
+
# only copy files into one stage directory from another stage directory, not from stage file
|
|
258
|
+
# due to incomplete of StagePath functionality
|
|
259
|
+
session.sql(f"copy files into {payload_stage_path.as_posix()}/ from {source_path.as_posix()}/").collect()
|
|
260
|
+
elif isinstance(source_path, Path):
|
|
261
|
+
if source_path.is_dir():
|
|
262
|
+
# Manually traverse the directory and upload each file, since Snowflake PUT
|
|
263
|
+
# can't handle directories. Reduce the number of PUT operations by using
|
|
264
|
+
# wildcard patterns to batch upload files with the same extension.
|
|
265
|
+
for path in {
|
|
266
|
+
p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
|
|
267
|
+
for p in source_path.resolve().rglob("*")
|
|
268
|
+
if p.is_file()
|
|
269
|
+
}:
|
|
270
|
+
session.file.put(
|
|
271
|
+
str(path),
|
|
272
|
+
payload_stage_path.joinpath(path.parent.relative_to(source_path)).as_posix(),
|
|
273
|
+
overwrite=True,
|
|
274
|
+
auto_compress=False,
|
|
275
|
+
)
|
|
276
|
+
else:
|
|
277
|
+
session.file.put(
|
|
278
|
+
str(source_path.resolve()),
|
|
279
|
+
payload_stage_path.as_posix(),
|
|
280
|
+
overwrite=True,
|
|
281
|
+
auto_compress=False,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
|
|
228
285
|
def resolve_source(
|
|
229
|
-
source: Union[
|
|
230
|
-
) -> Union[
|
|
286
|
+
source: Union[types.PayloadPath, Callable[..., Any]]
|
|
287
|
+
) -> Union[types.PayloadPath, Callable[..., Any]]:
|
|
231
288
|
if callable(source):
|
|
232
289
|
return source
|
|
233
|
-
elif isinstance(source,
|
|
290
|
+
elif isinstance(source, types.PayloadPath):
|
|
234
291
|
if not source.exists():
|
|
235
292
|
raise FileNotFoundError(f"{source} does not exist")
|
|
236
293
|
return source.absolute()
|
|
@@ -239,8 +296,8 @@ def resolve_source(
|
|
|
239
296
|
|
|
240
297
|
|
|
241
298
|
def resolve_entrypoint(
|
|
242
|
-
source: Union[
|
|
243
|
-
entrypoint: Optional[
|
|
299
|
+
source: Union[types.PayloadPath, Callable[..., Any]],
|
|
300
|
+
entrypoint: Optional[types.PayloadPath],
|
|
244
301
|
) -> types.PayloadEntrypoint:
|
|
245
302
|
if callable(source):
|
|
246
303
|
# Entrypoint is generated for callable payloads
|
|
@@ -289,6 +346,73 @@ def resolve_entrypoint(
|
|
|
289
346
|
)
|
|
290
347
|
|
|
291
348
|
|
|
349
|
+
def resolve_additional_payloads(
|
|
350
|
+
additional_payloads: Optional[list[Union[str, tuple[str, str]]]]
|
|
351
|
+
) -> list[types.PayloadSpec]:
|
|
352
|
+
"""
|
|
353
|
+
Determine how to stage local packages so that imports continue to work.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
additional_payloads: A list of directory paths, each optionally paired with a dot-separated
|
|
357
|
+
import path
|
|
358
|
+
e.g. [("proj/src/utils", "src.utils"), "proj/src/helper"]
|
|
359
|
+
if there is no import path, the last part of path will be considered as import path
|
|
360
|
+
e.g. the import path of "proj/src/helper" is "helper"
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
A list of payloadSpec for additional payloads.
|
|
364
|
+
|
|
365
|
+
Raises:
|
|
366
|
+
FileNotFoundError: If any specified package path does not exist.
|
|
367
|
+
ValueError: If the format of local_packages is invalid.
|
|
368
|
+
|
|
369
|
+
"""
|
|
370
|
+
if not additional_payloads:
|
|
371
|
+
return []
|
|
372
|
+
|
|
373
|
+
logger.warning(
|
|
374
|
+
"When providing a stage path as an additional payload, "
|
|
375
|
+
"please ensure it points to a directory. "
|
|
376
|
+
"Files are not currently supported."
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
additional_payloads_paths = []
|
|
380
|
+
for pkg in additional_payloads:
|
|
381
|
+
if isinstance(pkg, str):
|
|
382
|
+
source_path = resolve_path(pkg).absolute()
|
|
383
|
+
module_path = source_path.name
|
|
384
|
+
elif isinstance(pkg, tuple):
|
|
385
|
+
try:
|
|
386
|
+
source_path_str, module_path = pkg
|
|
387
|
+
except ValueError:
|
|
388
|
+
raise ValueError(
|
|
389
|
+
f"Invalid format in `additional_payloads`. "
|
|
390
|
+
f"Expected a tuple of (source_path, module_path). Got {pkg}"
|
|
391
|
+
)
|
|
392
|
+
source_path = resolve_path(source_path_str).absolute()
|
|
393
|
+
else:
|
|
394
|
+
raise ValueError("the format of additional payload is not correct")
|
|
395
|
+
|
|
396
|
+
if not source_path.exists():
|
|
397
|
+
raise FileNotFoundError(f"{source_path} does not exist")
|
|
398
|
+
|
|
399
|
+
if isinstance(source_path, Path):
|
|
400
|
+
if source_path.is_file():
|
|
401
|
+
raise ValueError(f"file is not supported for additional payloads: {source_path}")
|
|
402
|
+
|
|
403
|
+
module_parts = module_path.split(".")
|
|
404
|
+
for part in module_parts:
|
|
405
|
+
if not part.isidentifier() or keyword.iskeyword(part):
|
|
406
|
+
raise ValueError(
|
|
407
|
+
f"Invalid module import path '{module_path}'. "
|
|
408
|
+
f"'{part}' is not a valid Python identifier or is a keyword."
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
dest_path = PurePath(*module_parts)
|
|
412
|
+
additional_payloads_paths.append(types.PayloadSpec(source_path, dest_path))
|
|
413
|
+
return additional_payloads_paths
|
|
414
|
+
|
|
415
|
+
|
|
292
416
|
class JobPayload:
|
|
293
417
|
def __init__(
|
|
294
418
|
self,
|
|
@@ -296,11 +420,13 @@ class JobPayload:
|
|
|
296
420
|
entrypoint: Optional[Union[str, Path]] = None,
|
|
297
421
|
*,
|
|
298
422
|
pip_requirements: Optional[list[str]] = None,
|
|
423
|
+
additional_payloads: Optional[list[Union[str, tuple[str, str]]]] = None,
|
|
299
424
|
) -> None:
|
|
300
425
|
# for stage path like snow://domain....., Path(path) will remove duplicate /, it will become snow:/ domain...
|
|
301
|
-
self.source =
|
|
302
|
-
self.entrypoint =
|
|
426
|
+
self.source = resolve_path(source) if isinstance(source, str) else source
|
|
427
|
+
self.entrypoint = resolve_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
|
|
303
428
|
self.pip_requirements = pip_requirements
|
|
429
|
+
self.additional_payloads = additional_payloads
|
|
304
430
|
|
|
305
431
|
def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
|
|
306
432
|
# Prepare local variables
|
|
@@ -308,6 +434,7 @@ class JobPayload:
|
|
|
308
434
|
source = resolve_source(self.source)
|
|
309
435
|
entrypoint = resolve_entrypoint(source, self.entrypoint)
|
|
310
436
|
pip_requirements = self.pip_requirements or []
|
|
437
|
+
additional_payload_specs = resolve_additional_payloads(self.additional_payloads)
|
|
311
438
|
|
|
312
439
|
# Create stage if necessary
|
|
313
440
|
stage_name = stage_path.parts[0].lstrip("@")
|
|
@@ -323,85 +450,65 @@ class JobPayload:
|
|
|
323
450
|
params=[stage_name],
|
|
324
451
|
)
|
|
325
452
|
|
|
326
|
-
# Upload payload to stage
|
|
327
|
-
|
|
453
|
+
# Upload payload to stage - organize into app/ subdirectory
|
|
454
|
+
app_stage_path = stage_path.joinpath(constants.APP_STAGE_SUBPATH)
|
|
455
|
+
if not isinstance(source, types.PayloadPath):
|
|
328
456
|
source_code = generate_python_code(source, source_code_display=True)
|
|
329
457
|
_ = session.file.put_stream(
|
|
330
458
|
io.BytesIO(source_code.encode()),
|
|
331
|
-
stage_location=
|
|
459
|
+
stage_location=app_stage_path.joinpath(entrypoint.file_path).as_posix(),
|
|
332
460
|
auto_compress=False,
|
|
333
461
|
overwrite=True,
|
|
334
462
|
)
|
|
335
463
|
source = Path(entrypoint.file_path.parent)
|
|
336
|
-
if not any(r.startswith("cloudpickle") for r in pip_requirements):
|
|
337
|
-
pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
|
|
338
464
|
|
|
339
465
|
elif isinstance(source, stage_utils.StagePath):
|
|
340
466
|
# copy payload to stage
|
|
341
467
|
if source == entrypoint.file_path:
|
|
342
468
|
source = source.parent
|
|
343
|
-
|
|
344
|
-
session.sql(f"copy files into {stage_path}/ from {source_path}").collect()
|
|
469
|
+
upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
|
|
345
470
|
|
|
346
471
|
elif isinstance(source, Path):
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
# can't handle directories. Reduce the number of PUT operations by using
|
|
350
|
-
# wildcard patterns to batch upload files with the same extension.
|
|
351
|
-
for path in {
|
|
352
|
-
p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
|
|
353
|
-
for p in source.resolve().rglob("*")
|
|
354
|
-
if p.is_file()
|
|
355
|
-
}:
|
|
356
|
-
session.file.put(
|
|
357
|
-
str(path),
|
|
358
|
-
stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
|
|
359
|
-
overwrite=True,
|
|
360
|
-
auto_compress=False,
|
|
361
|
-
)
|
|
362
|
-
else:
|
|
363
|
-
session.file.put(
|
|
364
|
-
str(source.resolve()),
|
|
365
|
-
stage_path.as_posix(),
|
|
366
|
-
overwrite=True,
|
|
367
|
-
auto_compress=False,
|
|
368
|
-
)
|
|
472
|
+
upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
|
|
473
|
+
if source.is_file():
|
|
369
474
|
source = source.parent
|
|
370
475
|
|
|
371
|
-
|
|
372
|
-
|
|
476
|
+
upload_payloads(session, app_stage_path, *additional_payload_specs)
|
|
477
|
+
|
|
478
|
+
if not any(r.startswith("cloudpickle") for r in pip_requirements):
|
|
479
|
+
pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
|
|
480
|
+
|
|
481
|
+
# Upload system scripts and requirements.txt generated by pip_requirements to system/ directory
|
|
482
|
+
system_stage_path = stage_path.joinpath(constants.SYSTEM_STAGE_SUBPATH)
|
|
373
483
|
if pip_requirements:
|
|
374
484
|
# Upload requirements.txt to stage
|
|
375
485
|
session.file.put_stream(
|
|
376
486
|
io.BytesIO("\n".join(pip_requirements).encode()),
|
|
377
|
-
stage_location=
|
|
487
|
+
stage_location=system_stage_path.joinpath("requirements.txt").as_posix(),
|
|
378
488
|
auto_compress=False,
|
|
379
489
|
overwrite=True,
|
|
380
490
|
)
|
|
381
491
|
|
|
382
|
-
# Upload startup script
|
|
383
492
|
# TODO: Make sure payload does not include file with same name
|
|
384
493
|
session.file.put_stream(
|
|
385
494
|
io.BytesIO(_STARTUP_SCRIPT_CODE.encode()),
|
|
386
|
-
stage_location=
|
|
495
|
+
stage_location=system_stage_path.joinpath(_STARTUP_SCRIPT_PATH).as_posix(),
|
|
387
496
|
auto_compress=False,
|
|
388
497
|
overwrite=False, # FIXME
|
|
389
498
|
)
|
|
390
499
|
|
|
391
|
-
# Upload system scripts
|
|
392
500
|
scripts_dir = Path(__file__).parent.joinpath("scripts")
|
|
393
501
|
for script_file in scripts_dir.glob("*"):
|
|
394
502
|
if script_file.is_file():
|
|
395
503
|
session.file.put(
|
|
396
504
|
script_file.as_posix(),
|
|
397
|
-
|
|
505
|
+
system_stage_path.as_posix(),
|
|
398
506
|
overwrite=True,
|
|
399
507
|
auto_compress=False,
|
|
400
508
|
)
|
|
401
|
-
|
|
402
509
|
python_entrypoint: list[Union[str, PurePath]] = [
|
|
403
|
-
PurePath("mljob_launcher.py"),
|
|
404
|
-
entrypoint.file_path.relative_to(source),
|
|
510
|
+
PurePath(f"{constants.SYSTEM_MOUNT_PATH}/mljob_launcher.py"),
|
|
511
|
+
PurePath(f"{constants.APP_MOUNT_PATH}/{entrypoint.file_path.relative_to(source).as_posix()}"),
|
|
405
512
|
]
|
|
406
513
|
if entrypoint.main_func:
|
|
407
514
|
python_entrypoint += ["--script_main_func", entrypoint.main_func]
|
|
@@ -410,7 +517,7 @@ class JobPayload:
|
|
|
410
517
|
stage_path=stage_path,
|
|
411
518
|
entrypoint=[
|
|
412
519
|
"bash",
|
|
413
|
-
_STARTUP_SCRIPT_PATH,
|
|
520
|
+
f"{constants.SYSTEM_MOUNT_PATH}/{_STARTUP_SCRIPT_PATH}",
|
|
414
521
|
*python_entrypoint,
|
|
415
522
|
],
|
|
416
523
|
)
|
|
@@ -1,26 +1,4 @@
|
|
|
1
|
-
from snowflake.ml.jobs._utils import constants as mljob_constants
|
|
2
|
-
|
|
3
1
|
# Constants defining the shutdown signal actor configuration.
|
|
4
2
|
SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
|
|
5
3
|
SHUTDOWN_ACTOR_NAMESPACE = "default"
|
|
6
4
|
SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
# The followings are Inherited from snowflake.ml.jobs._utils.constants
|
|
10
|
-
# We need to copy them here since snowml package on the server side does
|
|
11
|
-
# not have the latest version of the code
|
|
12
|
-
|
|
13
|
-
# Log start and end messages
|
|
14
|
-
LOG_START_MSG = getattr(
|
|
15
|
-
mljob_constants,
|
|
16
|
-
"LOG_START_MSG",
|
|
17
|
-
"--------------------------------\nML job started\n--------------------------------",
|
|
18
|
-
)
|
|
19
|
-
LOG_END_MSG = getattr(
|
|
20
|
-
mljob_constants,
|
|
21
|
-
"LOG_END_MSG",
|
|
22
|
-
"--------------------------------\nML job finished\n--------------------------------",
|
|
23
|
-
)
|
|
24
|
-
|
|
25
|
-
# min_instances environment variable name
|
|
26
|
-
MIN_INSTANCES_ENV_VAR = getattr(mljob_constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
|
|
@@ -3,6 +3,7 @@ import copy
|
|
|
3
3
|
import importlib.util
|
|
4
4
|
import json
|
|
5
5
|
import logging
|
|
6
|
+
import math
|
|
6
7
|
import os
|
|
7
8
|
import runpy
|
|
8
9
|
import sys
|
|
@@ -13,7 +14,6 @@ from pathlib import Path
|
|
|
13
14
|
from typing import Any, Optional
|
|
14
15
|
|
|
15
16
|
import cloudpickle
|
|
16
|
-
from constants import LOG_END_MSG, LOG_START_MSG, MIN_INSTANCES_ENV_VAR
|
|
17
17
|
|
|
18
18
|
from snowflake.ml.jobs._utils import constants
|
|
19
19
|
from snowflake.snowpark import Session
|
|
@@ -27,13 +27,35 @@ except ImportError:
|
|
|
27
27
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
28
28
|
logger = logging.getLogger(__name__)
|
|
29
29
|
|
|
30
|
+
|
|
31
|
+
# The followings are Inherited from snowflake.ml.jobs._utils.constants
|
|
32
|
+
# We need to copy them here since snowml package on the server side does
|
|
33
|
+
# not have the latest version of the code
|
|
34
|
+
# Log start and end messages
|
|
35
|
+
LOG_START_MSG = getattr(
|
|
36
|
+
constants,
|
|
37
|
+
"LOG_START_MSG",
|
|
38
|
+
"--------------------------------\nML job started\n--------------------------------",
|
|
39
|
+
)
|
|
40
|
+
LOG_END_MSG = getattr(
|
|
41
|
+
constants,
|
|
42
|
+
"LOG_END_MSG",
|
|
43
|
+
"--------------------------------\nML job finished\n--------------------------------",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# min_instances environment variable name
|
|
47
|
+
MIN_INSTANCES_ENV_VAR = getattr(constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
|
|
48
|
+
TARGET_INSTANCES_ENV_VAR = getattr(constants, "TARGET_INSTANCES_ENV_VAR", "SNOWFLAKE_JOBS_COUNT")
|
|
49
|
+
|
|
30
50
|
# Fallbacks in case of SnowML version mismatch
|
|
31
51
|
RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
|
|
32
|
-
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "mljob_result.pkl")
|
|
52
|
+
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "/mnt/job_stage/output/mljob_result.pkl")
|
|
53
|
+
PAYLOAD_DIR_ENV_VAR = getattr(constants, "PAYLOAD_DIR_ENV_VAR", "MLRS_PAYLOAD_DIR")
|
|
33
54
|
|
|
34
|
-
# Constants for the
|
|
35
|
-
|
|
36
|
-
TIMEOUT = 720 # seconds
|
|
55
|
+
# Constants for the wait_for_instances function
|
|
56
|
+
MIN_WAIT_TIME = float(os.getenv("MLRS_INSTANCES_MIN_WAIT") or -1) # seconds
|
|
57
|
+
TIMEOUT = float(os.getenv("MLRS_INSTANCES_TIMEOUT") or 720) # seconds
|
|
58
|
+
CHECK_INTERVAL = float(os.getenv("MLRS_INSTANCES_CHECK_INTERVAL") or 10) # seconds
|
|
37
59
|
|
|
38
60
|
|
|
39
61
|
try:
|
|
@@ -76,45 +98,108 @@ class SimpleJSONEncoder(json.JSONEncoder):
|
|
|
76
98
|
return f"Unserializable object: {repr(obj)}"
|
|
77
99
|
|
|
78
100
|
|
|
79
|
-
def
|
|
101
|
+
def wait_for_instances(
|
|
102
|
+
min_instances: int,
|
|
103
|
+
target_instances: int,
|
|
104
|
+
*,
|
|
105
|
+
min_wait_time: float = -1, # seconds
|
|
106
|
+
timeout: float = 720, # seconds
|
|
107
|
+
check_interval: float = 10, # seconds
|
|
108
|
+
) -> None:
|
|
80
109
|
"""
|
|
81
110
|
Wait until the specified minimum number of instances are available in the Ray cluster.
|
|
82
111
|
|
|
83
112
|
Args:
|
|
84
113
|
min_instances: Minimum number of instances required
|
|
114
|
+
target_instances: Target number of instances to wait for
|
|
115
|
+
min_wait_time: Minimum time to wait for target_instances to be available.
|
|
116
|
+
If less than 0, automatically set based on target_instances.
|
|
117
|
+
timeout: Maximum time to wait for min_instances to be available before raising a TimeoutError.
|
|
118
|
+
check_interval: Maximum time to wait between checks (uses exponential backoff).
|
|
119
|
+
|
|
120
|
+
Examples:
|
|
121
|
+
Scenario 1 - Ideal case (target met quickly):
|
|
122
|
+
wait_for_instances(min_instances=2, target_instances=4, min_wait_time=5, timeout=60)
|
|
123
|
+
If 4 instances are available after 1 second, the function returns without further waiting (target met).
|
|
124
|
+
|
|
125
|
+
Scenario 2 - Min instances met, target not reached:
|
|
126
|
+
wait_for_instances(min_instances=2, target_instances=4, min_wait_time=10, timeout=60)
|
|
127
|
+
If only 3 instances are available after 10 seconds, the function returns (min requirement satisfied).
|
|
128
|
+
|
|
129
|
+
Scenario 3 - Min instances met early, but min_wait_time not elapsed:
|
|
130
|
+
wait_for_instances(min_instances=2, target_instances=4, min_wait_time=30, timeout=60)
|
|
131
|
+
If 2 instances are available after 5 seconds, function continues waiting for target_instances
|
|
132
|
+
until either 4 instances are found or 30 seconds have elapsed.
|
|
133
|
+
|
|
134
|
+
Scenario 4 - Timeout scenario:
|
|
135
|
+
wait_for_instances(min_instances=3, target_instances=5, min_wait_time=10, timeout=30)
|
|
136
|
+
If only 2 instances are available after 30 seconds, TimeoutError is raised.
|
|
137
|
+
|
|
138
|
+
Scenario 5 - Single instance job (early return):
|
|
139
|
+
wait_for_instances(min_instances=1, target_instances=1, min_wait_time=5, timeout=60)
|
|
140
|
+
The function returns without waiting because target_instances <= 1.
|
|
85
141
|
|
|
86
142
|
Raises:
|
|
143
|
+
ValueError: If arguments are invalid
|
|
87
144
|
TimeoutError: If failed to connect to Ray or if minimum instances are not available within timeout
|
|
88
145
|
"""
|
|
89
|
-
if min_instances
|
|
90
|
-
|
|
146
|
+
if min_instances > target_instances:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f"Minimum instances ({min_instances}) cannot be greater than target instances ({target_instances})"
|
|
149
|
+
)
|
|
150
|
+
if timeout < 0:
|
|
151
|
+
raise ValueError("Timeout must be greater than 0")
|
|
152
|
+
if check_interval < 0:
|
|
153
|
+
raise ValueError("Check interval must be greater than 0")
|
|
154
|
+
|
|
155
|
+
if target_instances <= 1:
|
|
156
|
+
logger.debug("Target instances is 1 or less, no need to wait for additional instances")
|
|
91
157
|
return
|
|
92
158
|
|
|
159
|
+
if min_wait_time < 0:
|
|
160
|
+
# Automatically set min_wait_time based on the number of target instances
|
|
161
|
+
# Using min_wait_time = 3 * log2(target_instances) as a starting point:
|
|
162
|
+
# target_instances = 1 => min_wait_time = 0
|
|
163
|
+
# target_instances = 2 => min_wait_time = 3
|
|
164
|
+
# target_instances = 4 => min_wait_time = 6
|
|
165
|
+
# target_instances = 8 => min_wait_time = 9
|
|
166
|
+
# target_instances = 32 => min_wait_time = 15
|
|
167
|
+
# target_instances = 50 => min_wait_time = 16.9
|
|
168
|
+
# target_instances = 100 => min_wait_time = 19.9
|
|
169
|
+
min_wait_time = min(3 * math.log2(target_instances), timeout / 10) # Clamp to timeout / 10
|
|
170
|
+
|
|
93
171
|
# mljob_launcher runs inside the CR where mlruntime libraries are available, so we can import common_util directly
|
|
94
172
|
from common_utils import common_util as mlrs_util
|
|
95
173
|
|
|
96
174
|
start_time = time.time()
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
175
|
+
current_interval = max(min(1, check_interval), 0.1) # Default 1s, minimum 0.1s
|
|
176
|
+
logger.debug(
|
|
177
|
+
"Waiting for instances to be ready "
|
|
178
|
+
"(min_instances={}, target_instances={}, timeout={}s, max_check_interval={}s)".format(
|
|
179
|
+
min_instances, target_instances, timeout, check_interval
|
|
180
|
+
)
|
|
181
|
+
)
|
|
100
182
|
|
|
101
|
-
while time.time() - start_time < timeout:
|
|
183
|
+
while (elapsed := time.time() - start_time) < timeout:
|
|
102
184
|
total_nodes = mlrs_util.get_num_ray_nodes()
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
185
|
+
if total_nodes >= target_instances:
|
|
186
|
+
# Best case scenario: target_instances are already available
|
|
187
|
+
logger.info(f"Target instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
|
|
188
|
+
return
|
|
189
|
+
elif total_nodes >= min_instances and elapsed >= min_wait_time:
|
|
190
|
+
# Second best case scenario: target_instances not met within min_wait_time, but min_instances met
|
|
106
191
|
logger.info(f"Minimum instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
|
|
107
192
|
return
|
|
108
193
|
|
|
109
|
-
logger.
|
|
110
|
-
f"Waiting for instances: {total_nodes}
|
|
111
|
-
f"
|
|
194
|
+
logger.info(
|
|
195
|
+
f"Waiting for instances: current_instances={total_nodes}, min_instances={min_instances}, "
|
|
196
|
+
f"target_instances={target_instances}, elapsed={elapsed:.1f}s, next check in {current_interval:.1f}s"
|
|
112
197
|
)
|
|
113
|
-
time.sleep(
|
|
198
|
+
time.sleep(current_interval)
|
|
199
|
+
current_interval = min(current_interval * 2, check_interval) # Exponential backoff
|
|
114
200
|
|
|
115
201
|
raise TimeoutError(
|
|
116
|
-
f"Timed out after {
|
|
117
|
-
f"{mlrs_util.get_num_ray_nodes()} available"
|
|
202
|
+
f"Timed out after {elapsed}s waiting for {min_instances} instances, only " f"{total_nodes} available"
|
|
118
203
|
)
|
|
119
204
|
|
|
120
205
|
|
|
@@ -137,6 +222,13 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
137
222
|
original_argv = sys.argv
|
|
138
223
|
sys.argv = [script_path, *script_args]
|
|
139
224
|
|
|
225
|
+
# Ensure payload directory is in sys.path for module imports
|
|
226
|
+
# This is needed because mljob_launcher.py is now in /mnt/job_stage/system
|
|
227
|
+
# but user scripts are in the payload directory and may import from each other
|
|
228
|
+
payload_dir = os.environ.get(PAYLOAD_DIR_ENV_VAR)
|
|
229
|
+
if payload_dir and payload_dir not in sys.path:
|
|
230
|
+
sys.path.insert(0, payload_dir)
|
|
231
|
+
|
|
140
232
|
# Create a Snowpark session before running the script
|
|
141
233
|
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
|
|
142
234
|
session = Session.builder.configs(SnowflakeLoginOptions()).create() # noqa: F841
|
|
@@ -183,11 +275,22 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
183
275
|
Raises:
|
|
184
276
|
Exception: Re-raises any exception caught during script execution.
|
|
185
277
|
"""
|
|
278
|
+
# Ensure the output directory exists before trying to write result files.
|
|
279
|
+
output_dir = os.path.dirname(JOB_RESULT_PATH)
|
|
280
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
281
|
+
|
|
186
282
|
try:
|
|
187
283
|
# Wait for minimum required instances if specified
|
|
188
284
|
min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
|
|
189
|
-
|
|
190
|
-
|
|
285
|
+
target_instances_str = os.environ.get(TARGET_INSTANCES_ENV_VAR) or min_instances_str
|
|
286
|
+
if target_instances_str and int(target_instances_str) > 1:
|
|
287
|
+
wait_for_instances(
|
|
288
|
+
int(min_instances_str),
|
|
289
|
+
int(target_instances_str),
|
|
290
|
+
min_wait_time=MIN_WAIT_TIME,
|
|
291
|
+
timeout=TIMEOUT,
|
|
292
|
+
check_interval=CHECK_INTERVAL,
|
|
293
|
+
)
|
|
191
294
|
|
|
192
295
|
# Log start marker for user script execution
|
|
193
296
|
print(LOG_START_MSG) # noqa: T201
|
|
@@ -181,7 +181,7 @@ def generate_service_spec(
|
|
|
181
181
|
# TODO: Add hooks for endpoints for integration with TensorBoard etc
|
|
182
182
|
|
|
183
183
|
env_vars = {
|
|
184
|
-
constants.PAYLOAD_DIR_ENV_VAR:
|
|
184
|
+
constants.PAYLOAD_DIR_ENV_VAR: constants.APP_MOUNT_PATH,
|
|
185
185
|
constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
|
|
186
186
|
}
|
|
187
187
|
endpoints: list[dict[str, Any]] = []
|