snowflake-ml-python 1.9.1__py3-none-any.whl → 1.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/mixins.py +6 -4
- snowflake/ml/_internal/utils/service_logger.py +101 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +4 -1
- snowflake/ml/data/data_connector.py +4 -34
- snowflake/ml/dataset/dataset.py +1 -1
- snowflake/ml/dataset/dataset_reader.py +2 -8
- snowflake/ml/experiment/__init__.py +3 -0
- snowflake/ml/experiment/callback.py +121 -0
- snowflake/ml/jobs/_utils/constants.py +15 -4
- snowflake/ml/jobs/_utils/payload_utils.py +150 -49
- snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +125 -22
- snowflake/ml/jobs/_utils/spec_utils.py +1 -1
- snowflake/ml/jobs/_utils/stage_utils.py +30 -14
- snowflake/ml/jobs/_utils/types.py +64 -4
- snowflake/ml/jobs/job.py +22 -6
- snowflake/ml/jobs/manager.py +5 -3
- snowflake/ml/model/_client/ops/service_ops.py +17 -2
- snowflake/ml/model/_client/sql/service.py +1 -38
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -0
- snowflake/ml/model/_signatures/pandas_handler.py +3 -0
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/model_signature.py +2 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +42 -4
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +30 -28
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,8 @@ import functools
|
|
|
2
2
|
import inspect
|
|
3
3
|
import io
|
|
4
4
|
import itertools
|
|
5
|
+
import keyword
|
|
6
|
+
import logging
|
|
5
7
|
import pickle
|
|
6
8
|
import sys
|
|
7
9
|
import textwrap
|
|
@@ -22,8 +24,11 @@ from snowflake.ml.jobs._utils import (
|
|
|
22
24
|
from snowflake.snowpark import exceptions as sp_exceptions
|
|
23
25
|
from snowflake.snowpark._internal import code_generation
|
|
24
26
|
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
25
29
|
cp.register_pickle_by_value(function_payload_utils)
|
|
26
30
|
|
|
31
|
+
|
|
27
32
|
_SUPPORTED_ARG_TYPES = {str, int, float}
|
|
28
33
|
_SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
|
|
29
34
|
_ENTRYPOINT_FUNC_NAME = "func"
|
|
@@ -32,6 +37,9 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
32
37
|
f"""
|
|
33
38
|
#!/bin/bash
|
|
34
39
|
|
|
40
|
+
##### Get system scripts directory #####
|
|
41
|
+
SYSTEM_DIR=$(cd "$(dirname "$0")" && pwd)
|
|
42
|
+
|
|
35
43
|
##### Perform common set up steps #####
|
|
36
44
|
set -e # exit if a command fails
|
|
37
45
|
|
|
@@ -75,12 +83,14 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
75
83
|
|
|
76
84
|
# Check if the local get_instance_ip.py script exists
|
|
77
85
|
HELPER_EXISTS=$(
|
|
78
|
-
[ -f "get_instance_ip.py" ] && echo "true" || echo "false"
|
|
86
|
+
[ -f "${{SYSTEM_DIR}}/get_instance_ip.py" ] && echo "true" || echo "false"
|
|
79
87
|
)
|
|
80
88
|
|
|
89
|
+
|
|
81
90
|
# Configure IP address and logging directory
|
|
82
91
|
if [ "$HELPER_EXISTS" = "true" ]; then
|
|
83
|
-
eth0Ip=$(python3 get_instance_ip.py
|
|
92
|
+
eth0Ip=$(python3 "${{SYSTEM_DIR}}/get_instance_ip.py" \
|
|
93
|
+
"$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
|
|
84
94
|
else
|
|
85
95
|
eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
|
|
86
96
|
fi
|
|
@@ -103,7 +113,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
103
113
|
|
|
104
114
|
# Determine if it should be a worker or a head node for batch jobs
|
|
105
115
|
if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 && "$HELPER_EXISTS" = "true" ]]; then
|
|
106
|
-
head_info=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --head)
|
|
116
|
+
head_info=$(python3 "${{SYSTEM_DIR}}/get_instance_ip.py" "$SNOWFLAKE_SERVICE_NAME" --head)
|
|
107
117
|
if [ $? -eq 0 ]; then
|
|
108
118
|
# Parse the output using read
|
|
109
119
|
read head_index head_ip head_status<<< "$head_info"
|
|
@@ -185,7 +195,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
185
195
|
|
|
186
196
|
# Start the worker shutdown listener in the background
|
|
187
197
|
echo "Starting worker shutdown listener..."
|
|
188
|
-
python worker_shutdown_listener.py
|
|
198
|
+
python "${{SYSTEM_DIR}}/worker_shutdown_listener.py"
|
|
189
199
|
WORKER_EXIT_CODE=$?
|
|
190
200
|
|
|
191
201
|
echo "Worker shutdown listener exited with code $WORKER_EXIT_CODE"
|
|
@@ -218,19 +228,59 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
218
228
|
|
|
219
229
|
# After the user's job completes, signal workers to shut down
|
|
220
230
|
echo "User job completed. Signaling workers to shut down..."
|
|
221
|
-
python signal_workers.py --wait-time 15
|
|
231
|
+
python "${{SYSTEM_DIR}}/signal_workers.py" --wait-time 15
|
|
222
232
|
echo "Head node job completed. Exiting."
|
|
223
233
|
fi
|
|
224
234
|
"""
|
|
225
235
|
).strip()
|
|
226
236
|
|
|
227
237
|
|
|
238
|
+
def resolve_path(path: str) -> types.PayloadPath:
|
|
239
|
+
try:
|
|
240
|
+
stage_path = stage_utils.StagePath(path)
|
|
241
|
+
except ValueError:
|
|
242
|
+
return Path(path)
|
|
243
|
+
return stage_path
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_specs: types.PayloadSpec) -> None:
|
|
247
|
+
for source_path, remote_relative_path in payload_specs:
|
|
248
|
+
payload_stage_path = stage_path.joinpath(remote_relative_path) if remote_relative_path else stage_path
|
|
249
|
+
if isinstance(source_path, stage_utils.StagePath):
|
|
250
|
+
# only copy files into one stage directory from another stage directory, not from stage file
|
|
251
|
+
# due to incomplete of StagePath functionality
|
|
252
|
+
session.sql(f"copy files into {payload_stage_path.as_posix()}/ from {source_path.as_posix()}/").collect()
|
|
253
|
+
elif isinstance(source_path, Path):
|
|
254
|
+
if source_path.is_dir():
|
|
255
|
+
# Manually traverse the directory and upload each file, since Snowflake PUT
|
|
256
|
+
# can't handle directories. Reduce the number of PUT operations by using
|
|
257
|
+
# wildcard patterns to batch upload files with the same extension.
|
|
258
|
+
for path in {
|
|
259
|
+
p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
|
|
260
|
+
for p in source_path.resolve().rglob("*")
|
|
261
|
+
if p.is_file()
|
|
262
|
+
}:
|
|
263
|
+
session.file.put(
|
|
264
|
+
str(path),
|
|
265
|
+
payload_stage_path.joinpath(path.parent.relative_to(source_path)).as_posix(),
|
|
266
|
+
overwrite=True,
|
|
267
|
+
auto_compress=False,
|
|
268
|
+
)
|
|
269
|
+
else:
|
|
270
|
+
session.file.put(
|
|
271
|
+
str(source_path.resolve()),
|
|
272
|
+
payload_stage_path.as_posix(),
|
|
273
|
+
overwrite=True,
|
|
274
|
+
auto_compress=False,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
|
|
228
278
|
def resolve_source(
|
|
229
|
-
source: Union[
|
|
230
|
-
) -> Union[
|
|
279
|
+
source: Union[types.PayloadPath, Callable[..., Any]]
|
|
280
|
+
) -> Union[types.PayloadPath, Callable[..., Any]]:
|
|
231
281
|
if callable(source):
|
|
232
282
|
return source
|
|
233
|
-
elif isinstance(source,
|
|
283
|
+
elif isinstance(source, types.PayloadPath):
|
|
234
284
|
if not source.exists():
|
|
235
285
|
raise FileNotFoundError(f"{source} does not exist")
|
|
236
286
|
return source.absolute()
|
|
@@ -239,8 +289,8 @@ def resolve_source(
|
|
|
239
289
|
|
|
240
290
|
|
|
241
291
|
def resolve_entrypoint(
|
|
242
|
-
source: Union[
|
|
243
|
-
entrypoint: Optional[
|
|
292
|
+
source: Union[types.PayloadPath, Callable[..., Any]],
|
|
293
|
+
entrypoint: Optional[types.PayloadPath],
|
|
244
294
|
) -> types.PayloadEntrypoint:
|
|
245
295
|
if callable(source):
|
|
246
296
|
# Entrypoint is generated for callable payloads
|
|
@@ -289,6 +339,73 @@ def resolve_entrypoint(
|
|
|
289
339
|
)
|
|
290
340
|
|
|
291
341
|
|
|
342
|
+
def resolve_additional_payloads(
|
|
343
|
+
additional_payloads: Optional[list[Union[str, tuple[str, str]]]]
|
|
344
|
+
) -> list[types.PayloadSpec]:
|
|
345
|
+
"""
|
|
346
|
+
Determine how to stage local packages so that imports continue to work.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
additional_payloads: A list of directory paths, each optionally paired with a dot-separated
|
|
350
|
+
import path
|
|
351
|
+
e.g. [("proj/src/utils", "src.utils"), "proj/src/helper"]
|
|
352
|
+
if there is no import path, the last part of path will be considered as import path
|
|
353
|
+
e.g. the import path of "proj/src/helper" is "helper"
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
A list of payloadSpec for additional payloads.
|
|
357
|
+
|
|
358
|
+
Raises:
|
|
359
|
+
FileNotFoundError: If any specified package path does not exist.
|
|
360
|
+
ValueError: If the format of local_packages is invalid.
|
|
361
|
+
|
|
362
|
+
"""
|
|
363
|
+
if not additional_payloads:
|
|
364
|
+
return []
|
|
365
|
+
|
|
366
|
+
logger.warning(
|
|
367
|
+
"When providing a stage path as an additional payload, "
|
|
368
|
+
"please ensure it points to a directory. "
|
|
369
|
+
"Files are not currently supported."
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
additional_payloads_paths = []
|
|
373
|
+
for pkg in additional_payloads:
|
|
374
|
+
if isinstance(pkg, str):
|
|
375
|
+
source_path = resolve_path(pkg).absolute()
|
|
376
|
+
module_path = source_path.name
|
|
377
|
+
elif isinstance(pkg, tuple):
|
|
378
|
+
try:
|
|
379
|
+
source_path_str, module_path = pkg
|
|
380
|
+
except ValueError:
|
|
381
|
+
raise ValueError(
|
|
382
|
+
f"Invalid format in `additional_payloads`. "
|
|
383
|
+
f"Expected a tuple of (source_path, module_path). Got {pkg}"
|
|
384
|
+
)
|
|
385
|
+
source_path = resolve_path(source_path_str).absolute()
|
|
386
|
+
else:
|
|
387
|
+
raise ValueError("the format of additional payload is not correct")
|
|
388
|
+
|
|
389
|
+
if not source_path.exists():
|
|
390
|
+
raise FileNotFoundError(f"{source_path} does not exist")
|
|
391
|
+
|
|
392
|
+
if isinstance(source_path, Path):
|
|
393
|
+
if source_path.is_file():
|
|
394
|
+
raise ValueError(f"file is not supported for additional payloads: {source_path}")
|
|
395
|
+
|
|
396
|
+
module_parts = module_path.split(".")
|
|
397
|
+
for part in module_parts:
|
|
398
|
+
if not part.isidentifier() or keyword.iskeyword(part):
|
|
399
|
+
raise ValueError(
|
|
400
|
+
f"Invalid module import path '{module_path}'. "
|
|
401
|
+
f"'{part}' is not a valid Python identifier or is a keyword."
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
dest_path = PurePath(*module_parts)
|
|
405
|
+
additional_payloads_paths.append(types.PayloadSpec(source_path, dest_path))
|
|
406
|
+
return additional_payloads_paths
|
|
407
|
+
|
|
408
|
+
|
|
292
409
|
class JobPayload:
|
|
293
410
|
def __init__(
|
|
294
411
|
self,
|
|
@@ -296,11 +413,13 @@ class JobPayload:
|
|
|
296
413
|
entrypoint: Optional[Union[str, Path]] = None,
|
|
297
414
|
*,
|
|
298
415
|
pip_requirements: Optional[list[str]] = None,
|
|
416
|
+
additional_payloads: Optional[list[Union[str, tuple[str, str]]]] = None,
|
|
299
417
|
) -> None:
|
|
300
418
|
# for stage path like snow://domain....., Path(path) will remove duplicate /, it will become snow:/ domain...
|
|
301
|
-
self.source =
|
|
302
|
-
self.entrypoint =
|
|
419
|
+
self.source = resolve_path(source) if isinstance(source, str) else source
|
|
420
|
+
self.entrypoint = resolve_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
|
|
303
421
|
self.pip_requirements = pip_requirements
|
|
422
|
+
self.additional_payloads = additional_payloads
|
|
304
423
|
|
|
305
424
|
def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
|
|
306
425
|
# Prepare local variables
|
|
@@ -308,6 +427,7 @@ class JobPayload:
|
|
|
308
427
|
source = resolve_source(self.source)
|
|
309
428
|
entrypoint = resolve_entrypoint(source, self.entrypoint)
|
|
310
429
|
pip_requirements = self.pip_requirements or []
|
|
430
|
+
additional_payload_specs = resolve_additional_payloads(self.additional_payloads)
|
|
311
431
|
|
|
312
432
|
# Create stage if necessary
|
|
313
433
|
stage_name = stage_path.parts[0].lstrip("@")
|
|
@@ -323,12 +443,13 @@ class JobPayload:
|
|
|
323
443
|
params=[stage_name],
|
|
324
444
|
)
|
|
325
445
|
|
|
326
|
-
# Upload payload to stage
|
|
327
|
-
|
|
446
|
+
# Upload payload to stage - organize into app/ subdirectory
|
|
447
|
+
app_stage_path = stage_path.joinpath(constants.APP_STAGE_SUBPATH)
|
|
448
|
+
if not isinstance(source, types.PayloadPath):
|
|
328
449
|
source_code = generate_python_code(source, source_code_display=True)
|
|
329
450
|
_ = session.file.put_stream(
|
|
330
451
|
io.BytesIO(source_code.encode()),
|
|
331
|
-
stage_location=
|
|
452
|
+
stage_location=app_stage_path.joinpath(entrypoint.file_path).as_posix(),
|
|
332
453
|
auto_compress=False,
|
|
333
454
|
overwrite=True,
|
|
334
455
|
)
|
|
@@ -340,68 +461,48 @@ class JobPayload:
|
|
|
340
461
|
# copy payload to stage
|
|
341
462
|
if source == entrypoint.file_path:
|
|
342
463
|
source = source.parent
|
|
343
|
-
|
|
344
|
-
session.sql(f"copy files into {stage_path}/ from {source_path}").collect()
|
|
464
|
+
upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
|
|
345
465
|
|
|
346
466
|
elif isinstance(source, Path):
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
# can't handle directories. Reduce the number of PUT operations by using
|
|
350
|
-
# wildcard patterns to batch upload files with the same extension.
|
|
351
|
-
for path in {
|
|
352
|
-
p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
|
|
353
|
-
for p in source.resolve().rglob("*")
|
|
354
|
-
if p.is_file()
|
|
355
|
-
}:
|
|
356
|
-
session.file.put(
|
|
357
|
-
str(path),
|
|
358
|
-
stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
|
|
359
|
-
overwrite=True,
|
|
360
|
-
auto_compress=False,
|
|
361
|
-
)
|
|
362
|
-
else:
|
|
363
|
-
session.file.put(
|
|
364
|
-
str(source.resolve()),
|
|
365
|
-
stage_path.as_posix(),
|
|
366
|
-
overwrite=True,
|
|
367
|
-
auto_compress=False,
|
|
368
|
-
)
|
|
467
|
+
upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
|
|
468
|
+
if source.is_file():
|
|
369
469
|
source = source.parent
|
|
370
470
|
|
|
371
|
-
|
|
471
|
+
upload_payloads(session, app_stage_path, *additional_payload_specs)
|
|
472
|
+
|
|
473
|
+
# Upload requirements to app/ directory
|
|
372
474
|
# TODO: Check if payload includes both a requirements.txt file and pip_requirements
|
|
373
475
|
if pip_requirements:
|
|
374
476
|
# Upload requirements.txt to stage
|
|
375
477
|
session.file.put_stream(
|
|
376
478
|
io.BytesIO("\n".join(pip_requirements).encode()),
|
|
377
|
-
stage_location=
|
|
479
|
+
stage_location=app_stage_path.joinpath("requirements.txt").as_posix(),
|
|
378
480
|
auto_compress=False,
|
|
379
481
|
overwrite=True,
|
|
380
482
|
)
|
|
381
483
|
|
|
382
|
-
# Upload startup script
|
|
484
|
+
# Upload startup script to system/ directory within payload
|
|
485
|
+
system_stage_path = stage_path.joinpath(constants.SYSTEM_STAGE_SUBPATH)
|
|
383
486
|
# TODO: Make sure payload does not include file with same name
|
|
384
487
|
session.file.put_stream(
|
|
385
488
|
io.BytesIO(_STARTUP_SCRIPT_CODE.encode()),
|
|
386
|
-
stage_location=
|
|
489
|
+
stage_location=system_stage_path.joinpath(_STARTUP_SCRIPT_PATH).as_posix(),
|
|
387
490
|
auto_compress=False,
|
|
388
491
|
overwrite=False, # FIXME
|
|
389
492
|
)
|
|
390
493
|
|
|
391
|
-
# Upload system scripts
|
|
392
494
|
scripts_dir = Path(__file__).parent.joinpath("scripts")
|
|
393
495
|
for script_file in scripts_dir.glob("*"):
|
|
394
496
|
if script_file.is_file():
|
|
395
497
|
session.file.put(
|
|
396
498
|
script_file.as_posix(),
|
|
397
|
-
|
|
499
|
+
system_stage_path.as_posix(),
|
|
398
500
|
overwrite=True,
|
|
399
501
|
auto_compress=False,
|
|
400
502
|
)
|
|
401
|
-
|
|
402
503
|
python_entrypoint: list[Union[str, PurePath]] = [
|
|
403
|
-
PurePath("mljob_launcher.py"),
|
|
404
|
-
entrypoint.file_path.relative_to(source),
|
|
504
|
+
PurePath(f"{constants.SYSTEM_MOUNT_PATH}/mljob_launcher.py"),
|
|
505
|
+
PurePath(f"{constants.APP_MOUNT_PATH}/{entrypoint.file_path.relative_to(source).as_posix()}"),
|
|
405
506
|
]
|
|
406
507
|
if entrypoint.main_func:
|
|
407
508
|
python_entrypoint += ["--script_main_func", entrypoint.main_func]
|
|
@@ -410,7 +511,7 @@ class JobPayload:
|
|
|
410
511
|
stage_path=stage_path,
|
|
411
512
|
entrypoint=[
|
|
412
513
|
"bash",
|
|
413
|
-
_STARTUP_SCRIPT_PATH,
|
|
514
|
+
f"{constants.SYSTEM_MOUNT_PATH}/{_STARTUP_SCRIPT_PATH}",
|
|
414
515
|
*python_entrypoint,
|
|
415
516
|
],
|
|
416
517
|
)
|
|
@@ -1,26 +1,4 @@
|
|
|
1
|
-
from snowflake.ml.jobs._utils import constants as mljob_constants
|
|
2
|
-
|
|
3
1
|
# Constants defining the shutdown signal actor configuration.
|
|
4
2
|
SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
|
|
5
3
|
SHUTDOWN_ACTOR_NAMESPACE = "default"
|
|
6
4
|
SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
# The followings are Inherited from snowflake.ml.jobs._utils.constants
|
|
10
|
-
# We need to copy them here since snowml package on the server side does
|
|
11
|
-
# not have the latest version of the code
|
|
12
|
-
|
|
13
|
-
# Log start and end messages
|
|
14
|
-
LOG_START_MSG = getattr(
|
|
15
|
-
mljob_constants,
|
|
16
|
-
"LOG_START_MSG",
|
|
17
|
-
"--------------------------------\nML job started\n--------------------------------",
|
|
18
|
-
)
|
|
19
|
-
LOG_END_MSG = getattr(
|
|
20
|
-
mljob_constants,
|
|
21
|
-
"LOG_END_MSG",
|
|
22
|
-
"--------------------------------\nML job finished\n--------------------------------",
|
|
23
|
-
)
|
|
24
|
-
|
|
25
|
-
# min_instances environment variable name
|
|
26
|
-
MIN_INSTANCES_ENV_VAR = getattr(mljob_constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
|
|
@@ -3,6 +3,7 @@ import copy
|
|
|
3
3
|
import importlib.util
|
|
4
4
|
import json
|
|
5
5
|
import logging
|
|
6
|
+
import math
|
|
6
7
|
import os
|
|
7
8
|
import runpy
|
|
8
9
|
import sys
|
|
@@ -13,7 +14,6 @@ from pathlib import Path
|
|
|
13
14
|
from typing import Any, Optional
|
|
14
15
|
|
|
15
16
|
import cloudpickle
|
|
16
|
-
from constants import LOG_END_MSG, LOG_START_MSG, MIN_INSTANCES_ENV_VAR
|
|
17
17
|
|
|
18
18
|
from snowflake.ml.jobs._utils import constants
|
|
19
19
|
from snowflake.snowpark import Session
|
|
@@ -27,13 +27,35 @@ except ImportError:
|
|
|
27
27
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
28
28
|
logger = logging.getLogger(__name__)
|
|
29
29
|
|
|
30
|
+
|
|
31
|
+
# The followings are Inherited from snowflake.ml.jobs._utils.constants
|
|
32
|
+
# We need to copy them here since snowml package on the server side does
|
|
33
|
+
# not have the latest version of the code
|
|
34
|
+
# Log start and end messages
|
|
35
|
+
LOG_START_MSG = getattr(
|
|
36
|
+
constants,
|
|
37
|
+
"LOG_START_MSG",
|
|
38
|
+
"--------------------------------\nML job started\n--------------------------------",
|
|
39
|
+
)
|
|
40
|
+
LOG_END_MSG = getattr(
|
|
41
|
+
constants,
|
|
42
|
+
"LOG_END_MSG",
|
|
43
|
+
"--------------------------------\nML job finished\n--------------------------------",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# min_instances environment variable name
|
|
47
|
+
MIN_INSTANCES_ENV_VAR = getattr(constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
|
|
48
|
+
TARGET_INSTANCES_ENV_VAR = getattr(constants, "TARGET_INSTANCES_ENV_VAR", "SNOWFLAKE_JOBS_COUNT")
|
|
49
|
+
|
|
30
50
|
# Fallbacks in case of SnowML version mismatch
|
|
31
51
|
RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
|
|
32
|
-
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "mljob_result.pkl")
|
|
52
|
+
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "/mnt/job_stage/output/mljob_result.pkl")
|
|
53
|
+
PAYLOAD_DIR_ENV_VAR = getattr(constants, "PAYLOAD_DIR_ENV_VAR", "MLRS_PAYLOAD_DIR")
|
|
33
54
|
|
|
34
|
-
# Constants for the
|
|
35
|
-
|
|
36
|
-
TIMEOUT = 720 # seconds
|
|
55
|
+
# Constants for the wait_for_instances function
|
|
56
|
+
MIN_WAIT_TIME = float(os.getenv("MLRS_INSTANCES_MIN_WAIT") or -1) # seconds
|
|
57
|
+
TIMEOUT = float(os.getenv("MLRS_INSTANCES_TIMEOUT") or 720) # seconds
|
|
58
|
+
CHECK_INTERVAL = float(os.getenv("MLRS_INSTANCES_CHECK_INTERVAL") or 10) # seconds
|
|
37
59
|
|
|
38
60
|
|
|
39
61
|
try:
|
|
@@ -76,45 +98,108 @@ class SimpleJSONEncoder(json.JSONEncoder):
|
|
|
76
98
|
return f"Unserializable object: {repr(obj)}"
|
|
77
99
|
|
|
78
100
|
|
|
79
|
-
def
|
|
101
|
+
def wait_for_instances(
|
|
102
|
+
min_instances: int,
|
|
103
|
+
target_instances: int,
|
|
104
|
+
*,
|
|
105
|
+
min_wait_time: float = -1, # seconds
|
|
106
|
+
timeout: float = 720, # seconds
|
|
107
|
+
check_interval: float = 10, # seconds
|
|
108
|
+
) -> None:
|
|
80
109
|
"""
|
|
81
110
|
Wait until the specified minimum number of instances are available in the Ray cluster.
|
|
82
111
|
|
|
83
112
|
Args:
|
|
84
113
|
min_instances: Minimum number of instances required
|
|
114
|
+
target_instances: Target number of instances to wait for
|
|
115
|
+
min_wait_time: Minimum time to wait for target_instances to be available.
|
|
116
|
+
If less than 0, automatically set based on target_instances.
|
|
117
|
+
timeout: Maximum time to wait for min_instances to be available before raising a TimeoutError.
|
|
118
|
+
check_interval: Maximum time to wait between checks (uses exponential backoff).
|
|
119
|
+
|
|
120
|
+
Examples:
|
|
121
|
+
Scenario 1 - Ideal case (target met quickly):
|
|
122
|
+
wait_for_instances(min_instances=2, target_instances=4, min_wait_time=5, timeout=60)
|
|
123
|
+
If 4 instances are available after 1 second, the function returns without further waiting (target met).
|
|
124
|
+
|
|
125
|
+
Scenario 2 - Min instances met, target not reached:
|
|
126
|
+
wait_for_instances(min_instances=2, target_instances=4, min_wait_time=10, timeout=60)
|
|
127
|
+
If only 3 instances are available after 10 seconds, the function returns (min requirement satisfied).
|
|
128
|
+
|
|
129
|
+
Scenario 3 - Min instances met early, but min_wait_time not elapsed:
|
|
130
|
+
wait_for_instances(min_instances=2, target_instances=4, min_wait_time=30, timeout=60)
|
|
131
|
+
If 2 instances are available after 5 seconds, function continues waiting for target_instances
|
|
132
|
+
until either 4 instances are found or 30 seconds have elapsed.
|
|
133
|
+
|
|
134
|
+
Scenario 4 - Timeout scenario:
|
|
135
|
+
wait_for_instances(min_instances=3, target_instances=5, min_wait_time=10, timeout=30)
|
|
136
|
+
If only 2 instances are available after 30 seconds, TimeoutError is raised.
|
|
137
|
+
|
|
138
|
+
Scenario 5 - Single instance job (early return):
|
|
139
|
+
wait_for_instances(min_instances=1, target_instances=1, min_wait_time=5, timeout=60)
|
|
140
|
+
The function returns without waiting because target_instances <= 1.
|
|
85
141
|
|
|
86
142
|
Raises:
|
|
143
|
+
ValueError: If arguments are invalid
|
|
87
144
|
TimeoutError: If failed to connect to Ray or if minimum instances are not available within timeout
|
|
88
145
|
"""
|
|
89
|
-
if min_instances
|
|
90
|
-
|
|
146
|
+
if min_instances > target_instances:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f"Minimum instances ({min_instances}) cannot be greater than target instances ({target_instances})"
|
|
149
|
+
)
|
|
150
|
+
if timeout < 0:
|
|
151
|
+
raise ValueError("Timeout must be greater than 0")
|
|
152
|
+
if check_interval < 0:
|
|
153
|
+
raise ValueError("Check interval must be greater than 0")
|
|
154
|
+
|
|
155
|
+
if target_instances <= 1:
|
|
156
|
+
logger.debug("Target instances is 1 or less, no need to wait for additional instances")
|
|
91
157
|
return
|
|
92
158
|
|
|
159
|
+
if min_wait_time < 0:
|
|
160
|
+
# Automatically set min_wait_time based on the number of target instances
|
|
161
|
+
# Using min_wait_time = 3 * log2(target_instances) as a starting point:
|
|
162
|
+
# target_instances = 1 => min_wait_time = 0
|
|
163
|
+
# target_instances = 2 => min_wait_time = 3
|
|
164
|
+
# target_instances = 4 => min_wait_time = 6
|
|
165
|
+
# target_instances = 8 => min_wait_time = 9
|
|
166
|
+
# target_instances = 32 => min_wait_time = 15
|
|
167
|
+
# target_instances = 50 => min_wait_time = 16.9
|
|
168
|
+
# target_instances = 100 => min_wait_time = 19.9
|
|
169
|
+
min_wait_time = min(3 * math.log2(target_instances), timeout / 10) # Clamp to timeout / 10
|
|
170
|
+
|
|
93
171
|
# mljob_launcher runs inside the CR where mlruntime libraries are available, so we can import common_util directly
|
|
94
172
|
from common_utils import common_util as mlrs_util
|
|
95
173
|
|
|
96
174
|
start_time = time.time()
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
175
|
+
current_interval = max(min(1, check_interval), 0.1) # Default 1s, minimum 0.1s
|
|
176
|
+
logger.debug(
|
|
177
|
+
"Waiting for instances to be ready "
|
|
178
|
+
"(min_instances={}, target_instances={}, timeout={}s, max_check_interval={}s)".format(
|
|
179
|
+
min_instances, target_instances, timeout, check_interval
|
|
180
|
+
)
|
|
181
|
+
)
|
|
100
182
|
|
|
101
|
-
while time.time() - start_time < timeout:
|
|
183
|
+
while (elapsed := time.time() - start_time) < timeout:
|
|
102
184
|
total_nodes = mlrs_util.get_num_ray_nodes()
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
185
|
+
if total_nodes >= target_instances:
|
|
186
|
+
# Best case scenario: target_instances are already available
|
|
187
|
+
logger.info(f"Target instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
|
|
188
|
+
return
|
|
189
|
+
elif total_nodes >= min_instances and elapsed >= min_wait_time:
|
|
190
|
+
# Second best case scenario: target_instances not met within min_wait_time, but min_instances met
|
|
106
191
|
logger.info(f"Minimum instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
|
|
107
192
|
return
|
|
108
193
|
|
|
109
194
|
logger.debug(
|
|
110
|
-
f"Waiting for instances: {total_nodes}
|
|
111
|
-
f"
|
|
195
|
+
f"Waiting for instances: current_instances={total_nodes}, min_instances={min_instances}, "
|
|
196
|
+
f"target_instances={target_instances}, elapsed={elapsed:.1f}s, next check in {current_interval:.1f}s"
|
|
112
197
|
)
|
|
113
|
-
time.sleep(
|
|
198
|
+
time.sleep(current_interval)
|
|
199
|
+
current_interval = min(current_interval * 2, check_interval) # Exponential backoff
|
|
114
200
|
|
|
115
201
|
raise TimeoutError(
|
|
116
|
-
f"Timed out after {timeout}s waiting for {min_instances} instances, only "
|
|
117
|
-
f"{mlrs_util.get_num_ray_nodes()} available"
|
|
202
|
+
f"Timed out after {timeout}s waiting for {min_instances} instances, only " f"{total_nodes} available"
|
|
118
203
|
)
|
|
119
204
|
|
|
120
205
|
|
|
@@ -137,6 +222,13 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
137
222
|
original_argv = sys.argv
|
|
138
223
|
sys.argv = [script_path, *script_args]
|
|
139
224
|
|
|
225
|
+
# Ensure payload directory is in sys.path for module imports
|
|
226
|
+
# This is needed because mljob_launcher.py is now in /mnt/job_stage/system
|
|
227
|
+
# but user scripts are in the payload directory and may import from each other
|
|
228
|
+
payload_dir = os.environ.get(PAYLOAD_DIR_ENV_VAR)
|
|
229
|
+
if payload_dir and payload_dir not in sys.path:
|
|
230
|
+
sys.path.insert(0, payload_dir)
|
|
231
|
+
|
|
140
232
|
# Create a Snowpark session before running the script
|
|
141
233
|
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
|
|
142
234
|
session = Session.builder.configs(SnowflakeLoginOptions()).create() # noqa: F841
|
|
@@ -183,11 +275,22 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
183
275
|
Raises:
|
|
184
276
|
Exception: Re-raises any exception caught during script execution.
|
|
185
277
|
"""
|
|
278
|
+
# Ensure the output directory exists before trying to write result files.
|
|
279
|
+
output_dir = os.path.dirname(JOB_RESULT_PATH)
|
|
280
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
281
|
+
|
|
186
282
|
try:
|
|
187
283
|
# Wait for minimum required instances if specified
|
|
188
284
|
min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
|
|
189
|
-
|
|
190
|
-
|
|
285
|
+
target_instances_str = os.environ.get(TARGET_INSTANCES_ENV_VAR) or min_instances_str
|
|
286
|
+
if target_instances_str and int(target_instances_str) > 1:
|
|
287
|
+
wait_for_instances(
|
|
288
|
+
int(min_instances_str),
|
|
289
|
+
int(target_instances_str),
|
|
290
|
+
min_wait_time=MIN_WAIT_TIME,
|
|
291
|
+
timeout=TIMEOUT,
|
|
292
|
+
check_interval=CHECK_INTERVAL,
|
|
293
|
+
)
|
|
191
294
|
|
|
192
295
|
# Log start marker for user script execution
|
|
193
296
|
print(LOG_START_MSG) # noqa: T201
|
|
@@ -181,7 +181,7 @@ def generate_service_spec(
|
|
|
181
181
|
# TODO: Add hooks for endpoints for integration with TensorBoard etc
|
|
182
182
|
|
|
183
183
|
env_vars = {
|
|
184
|
-
constants.PAYLOAD_DIR_ENV_VAR:
|
|
184
|
+
constants.PAYLOAD_DIR_ENV_VAR: constants.APP_MOUNT_PATH,
|
|
185
185
|
constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
|
|
186
186
|
}
|
|
187
187
|
endpoints: list[dict[str, Any]] = []
|