snowflake-ml-python 1.9.1__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. snowflake/ml/_internal/utils/mixins.py +6 -4
  2. snowflake/ml/_internal/utils/service_logger.py +118 -4
  3. snowflake/ml/data/_internal/arrow_ingestor.py +4 -1
  4. snowflake/ml/data/data_connector.py +4 -34
  5. snowflake/ml/dataset/dataset.py +1 -1
  6. snowflake/ml/dataset/dataset_reader.py +2 -8
  7. snowflake/ml/experiment/__init__.py +3 -0
  8. snowflake/ml/experiment/callback/lightgbm.py +55 -0
  9. snowflake/ml/experiment/callback/xgboost.py +63 -0
  10. snowflake/ml/experiment/utils.py +14 -0
  11. snowflake/ml/jobs/_utils/constants.py +15 -4
  12. snowflake/ml/jobs/_utils/payload_utils.py +159 -52
  13. snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
  14. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +126 -23
  15. snowflake/ml/jobs/_utils/spec_utils.py +1 -1
  16. snowflake/ml/jobs/_utils/stage_utils.py +30 -14
  17. snowflake/ml/jobs/_utils/types.py +64 -4
  18. snowflake/ml/jobs/job.py +22 -6
  19. snowflake/ml/jobs/manager.py +5 -3
  20. snowflake/ml/model/_client/model/model_version_impl.py +56 -48
  21. snowflake/ml/model/_client/ops/service_ops.py +194 -14
  22. snowflake/ml/model/_client/sql/service.py +1 -38
  23. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
  24. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -0
  25. snowflake/ml/model/_signatures/pandas_handler.py +3 -0
  26. snowflake/ml/model/_signatures/utils.py +4 -0
  27. snowflake/ml/model/event_handler.py +87 -18
  28. snowflake/ml/model/model_signature.py +2 -0
  29. snowflake/ml/model/models/huggingface_pipeline.py +71 -49
  30. snowflake/ml/model/type_hints.py +26 -1
  31. snowflake/ml/registry/_manager/model_manager.py +30 -35
  32. snowflake/ml/registry/_manager/model_parameter_reconciler.py +105 -0
  33. snowflake/ml/registry/registry.py +0 -19
  34. snowflake/ml/version.py +1 -1
  35. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/METADATA +542 -491
  36. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/RECORD +39 -34
  37. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/WHEEL +0 -0
  38. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/licenses/LICENSE.txt +0 -0
  39. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,8 @@ import functools
2
2
  import inspect
3
3
  import io
4
4
  import itertools
5
+ import keyword
6
+ import logging
5
7
  import pickle
6
8
  import sys
7
9
  import textwrap
@@ -22,8 +24,11 @@ from snowflake.ml.jobs._utils import (
22
24
  from snowflake.snowpark import exceptions as sp_exceptions
23
25
  from snowflake.snowpark._internal import code_generation
24
26
 
27
+ logger = logging.getLogger(__name__)
28
+
25
29
  cp.register_pickle_by_value(function_payload_utils)
26
30
 
31
+
27
32
  _SUPPORTED_ARG_TYPES = {str, int, float}
28
33
  _SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
29
34
  _ENTRYPOINT_FUNC_NAME = "func"
@@ -32,6 +37,9 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
32
37
  f"""
33
38
  #!/bin/bash
34
39
 
40
+ ##### Get system scripts directory #####
41
+ SYSTEM_DIR=$(cd "$(dirname "$0")" && pwd)
42
+
35
43
  ##### Perform common set up steps #####
36
44
  set -e # exit if a command fails
37
45
 
@@ -55,6 +63,13 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
55
63
 
56
64
  ##### Set up Python environment #####
57
65
  export PYTHONPATH=/opt/env/site-packages/
66
+ MLRS_SYSTEM_REQUIREMENTS_FILE=${{MLRS_SYSTEM_REQUIREMENTS_FILE:-"${{SYSTEM_DIR}}/requirements.txt"}}
67
+
68
+ if [ -f "${{MLRS_SYSTEM_REQUIREMENTS_FILE}}" ]; then
69
+ echo "Installing packages from $MLRS_SYSTEM_REQUIREMENTS_FILE"
70
+ pip install -r $MLRS_SYSTEM_REQUIREMENTS_FILE
71
+ fi
72
+
58
73
  MLRS_REQUIREMENTS_FILE=${{MLRS_REQUIREMENTS_FILE:-"requirements.txt"}}
59
74
  if [ -f "${{MLRS_REQUIREMENTS_FILE}}" ]; then
60
75
  # TODO: Prevent collisions with MLRS packages using virtualenvs
@@ -75,12 +90,14 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
75
90
 
76
91
  # Check if the local get_instance_ip.py script exists
77
92
  HELPER_EXISTS=$(
78
- [ -f "get_instance_ip.py" ] && echo "true" || echo "false"
93
+ [ -f "${{SYSTEM_DIR}}/get_instance_ip.py" ] && echo "true" || echo "false"
79
94
  )
80
95
 
96
+
81
97
  # Configure IP address and logging directory
82
98
  if [ "$HELPER_EXISTS" = "true" ]; then
83
- eth0Ip=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
99
+ eth0Ip=$(python3 "${{SYSTEM_DIR}}/get_instance_ip.py" \
100
+ "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
84
101
  else
85
102
  eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
86
103
  fi
@@ -103,7 +120,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
103
120
 
104
121
  # Determine if it should be a worker or a head node for batch jobs
105
122
  if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 && "$HELPER_EXISTS" = "true" ]]; then
106
- head_info=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --head)
123
+ head_info=$(python3 "${{SYSTEM_DIR}}/get_instance_ip.py" "$SNOWFLAKE_SERVICE_NAME" --head)
107
124
  if [ $? -eq 0 ]; then
108
125
  # Parse the output using read
109
126
  read head_index head_ip head_status<<< "$head_info"
@@ -185,7 +202,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
185
202
 
186
203
  # Start the worker shutdown listener in the background
187
204
  echo "Starting worker shutdown listener..."
188
- python worker_shutdown_listener.py
205
+ python "${{SYSTEM_DIR}}/worker_shutdown_listener.py"
189
206
  WORKER_EXIT_CODE=$?
190
207
 
191
208
  echo "Worker shutdown listener exited with code $WORKER_EXIT_CODE"
@@ -218,19 +235,59 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
218
235
 
219
236
  # After the user's job completes, signal workers to shut down
220
237
  echo "User job completed. Signaling workers to shut down..."
221
- python signal_workers.py --wait-time 15
238
+ python "${{SYSTEM_DIR}}/signal_workers.py" --wait-time 15
222
239
  echo "Head node job completed. Exiting."
223
240
  fi
224
241
  """
225
242
  ).strip()
226
243
 
227
244
 
245
+ def resolve_path(path: str) -> types.PayloadPath:
246
+ try:
247
+ stage_path = stage_utils.StagePath(path)
248
+ except ValueError:
249
+ return Path(path)
250
+ return stage_path
251
+
252
+
253
+ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_specs: types.PayloadSpec) -> None:
254
+ for source_path, remote_relative_path in payload_specs:
255
+ payload_stage_path = stage_path.joinpath(remote_relative_path) if remote_relative_path else stage_path
256
+ if isinstance(source_path, stage_utils.StagePath):
257
+ # only copy files into one stage directory from another stage directory, not from stage file
258
+ # due to incomplete of StagePath functionality
259
+ session.sql(f"copy files into {payload_stage_path.as_posix()}/ from {source_path.as_posix()}/").collect()
260
+ elif isinstance(source_path, Path):
261
+ if source_path.is_dir():
262
+ # Manually traverse the directory and upload each file, since Snowflake PUT
263
+ # can't handle directories. Reduce the number of PUT operations by using
264
+ # wildcard patterns to batch upload files with the same extension.
265
+ for path in {
266
+ p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
267
+ for p in source_path.resolve().rglob("*")
268
+ if p.is_file()
269
+ }:
270
+ session.file.put(
271
+ str(path),
272
+ payload_stage_path.joinpath(path.parent.relative_to(source_path)).as_posix(),
273
+ overwrite=True,
274
+ auto_compress=False,
275
+ )
276
+ else:
277
+ session.file.put(
278
+ str(source_path.resolve()),
279
+ payload_stage_path.as_posix(),
280
+ overwrite=True,
281
+ auto_compress=False,
282
+ )
283
+
284
+
228
285
  def resolve_source(
229
- source: Union[Path, stage_utils.StagePath, Callable[..., Any]]
230
- ) -> Union[Path, stage_utils.StagePath, Callable[..., Any]]:
286
+ source: Union[types.PayloadPath, Callable[..., Any]]
287
+ ) -> Union[types.PayloadPath, Callable[..., Any]]:
231
288
  if callable(source):
232
289
  return source
233
- elif isinstance(source, (Path, stage_utils.StagePath)):
290
+ elif isinstance(source, types.PayloadPath):
234
291
  if not source.exists():
235
292
  raise FileNotFoundError(f"{source} does not exist")
236
293
  return source.absolute()
@@ -239,8 +296,8 @@ def resolve_source(
239
296
 
240
297
 
241
298
  def resolve_entrypoint(
242
- source: Union[Path, stage_utils.StagePath, Callable[..., Any]],
243
- entrypoint: Optional[Union[stage_utils.StagePath, Path]],
299
+ source: Union[types.PayloadPath, Callable[..., Any]],
300
+ entrypoint: Optional[types.PayloadPath],
244
301
  ) -> types.PayloadEntrypoint:
245
302
  if callable(source):
246
303
  # Entrypoint is generated for callable payloads
@@ -289,6 +346,73 @@ def resolve_entrypoint(
289
346
  )
290
347
 
291
348
 
349
+ def resolve_additional_payloads(
350
+ additional_payloads: Optional[list[Union[str, tuple[str, str]]]]
351
+ ) -> list[types.PayloadSpec]:
352
+ """
353
+ Determine how to stage local packages so that imports continue to work.
354
+
355
+ Args:
356
+ additional_payloads: A list of directory paths, each optionally paired with a dot-separated
357
+ import path
358
+ e.g. [("proj/src/utils", "src.utils"), "proj/src/helper"]
359
+ if there is no import path, the last part of path will be considered as import path
360
+ e.g. the import path of "proj/src/helper" is "helper"
361
+
362
+ Returns:
363
+ A list of payloadSpec for additional payloads.
364
+
365
+ Raises:
366
+ FileNotFoundError: If any specified package path does not exist.
367
+ ValueError: If the format of local_packages is invalid.
368
+
369
+ """
370
+ if not additional_payloads:
371
+ return []
372
+
373
+ logger.warning(
374
+ "When providing a stage path as an additional payload, "
375
+ "please ensure it points to a directory. "
376
+ "Files are not currently supported."
377
+ )
378
+
379
+ additional_payloads_paths = []
380
+ for pkg in additional_payloads:
381
+ if isinstance(pkg, str):
382
+ source_path = resolve_path(pkg).absolute()
383
+ module_path = source_path.name
384
+ elif isinstance(pkg, tuple):
385
+ try:
386
+ source_path_str, module_path = pkg
387
+ except ValueError:
388
+ raise ValueError(
389
+ f"Invalid format in `additional_payloads`. "
390
+ f"Expected a tuple of (source_path, module_path). Got {pkg}"
391
+ )
392
+ source_path = resolve_path(source_path_str).absolute()
393
+ else:
394
+ raise ValueError("the format of additional payload is not correct")
395
+
396
+ if not source_path.exists():
397
+ raise FileNotFoundError(f"{source_path} does not exist")
398
+
399
+ if isinstance(source_path, Path):
400
+ if source_path.is_file():
401
+ raise ValueError(f"file is not supported for additional payloads: {source_path}")
402
+
403
+ module_parts = module_path.split(".")
404
+ for part in module_parts:
405
+ if not part.isidentifier() or keyword.iskeyword(part):
406
+ raise ValueError(
407
+ f"Invalid module import path '{module_path}'. "
408
+ f"'{part}' is not a valid Python identifier or is a keyword."
409
+ )
410
+
411
+ dest_path = PurePath(*module_parts)
412
+ additional_payloads_paths.append(types.PayloadSpec(source_path, dest_path))
413
+ return additional_payloads_paths
414
+
415
+
292
416
  class JobPayload:
293
417
  def __init__(
294
418
  self,
@@ -296,11 +420,13 @@ class JobPayload:
296
420
  entrypoint: Optional[Union[str, Path]] = None,
297
421
  *,
298
422
  pip_requirements: Optional[list[str]] = None,
423
+ additional_payloads: Optional[list[Union[str, tuple[str, str]]]] = None,
299
424
  ) -> None:
300
425
  # for stage path like snow://domain....., Path(path) will remove duplicate /, it will become snow:/ domain...
301
- self.source = stage_utils.identify_stage_path(source) if isinstance(source, str) else source
302
- self.entrypoint = stage_utils.identify_stage_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
426
+ self.source = resolve_path(source) if isinstance(source, str) else source
427
+ self.entrypoint = resolve_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
303
428
  self.pip_requirements = pip_requirements
429
+ self.additional_payloads = additional_payloads
304
430
 
305
431
  def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
306
432
  # Prepare local variables
@@ -308,6 +434,7 @@ class JobPayload:
308
434
  source = resolve_source(self.source)
309
435
  entrypoint = resolve_entrypoint(source, self.entrypoint)
310
436
  pip_requirements = self.pip_requirements or []
437
+ additional_payload_specs = resolve_additional_payloads(self.additional_payloads)
311
438
 
312
439
  # Create stage if necessary
313
440
  stage_name = stage_path.parts[0].lstrip("@")
@@ -323,85 +450,65 @@ class JobPayload:
323
450
  params=[stage_name],
324
451
  )
325
452
 
326
- # Upload payload to stage
327
- if not isinstance(source, (Path, stage_utils.StagePath)):
453
+ # Upload payload to stage - organize into app/ subdirectory
454
+ app_stage_path = stage_path.joinpath(constants.APP_STAGE_SUBPATH)
455
+ if not isinstance(source, types.PayloadPath):
328
456
  source_code = generate_python_code(source, source_code_display=True)
329
457
  _ = session.file.put_stream(
330
458
  io.BytesIO(source_code.encode()),
331
- stage_location=stage_path.joinpath(entrypoint.file_path).as_posix(),
459
+ stage_location=app_stage_path.joinpath(entrypoint.file_path).as_posix(),
332
460
  auto_compress=False,
333
461
  overwrite=True,
334
462
  )
335
463
  source = Path(entrypoint.file_path.parent)
336
- if not any(r.startswith("cloudpickle") for r in pip_requirements):
337
- pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
338
464
 
339
465
  elif isinstance(source, stage_utils.StagePath):
340
466
  # copy payload to stage
341
467
  if source == entrypoint.file_path:
342
468
  source = source.parent
343
- source_path = source.as_posix() + "/"
344
- session.sql(f"copy files into {stage_path}/ from {source_path}").collect()
469
+ upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
345
470
 
346
471
  elif isinstance(source, Path):
347
- if source.is_dir():
348
- # Manually traverse the directory and upload each file, since Snowflake PUT
349
- # can't handle directories. Reduce the number of PUT operations by using
350
- # wildcard patterns to batch upload files with the same extension.
351
- for path in {
352
- p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
353
- for p in source.resolve().rglob("*")
354
- if p.is_file()
355
- }:
356
- session.file.put(
357
- str(path),
358
- stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
359
- overwrite=True,
360
- auto_compress=False,
361
- )
362
- else:
363
- session.file.put(
364
- str(source.resolve()),
365
- stage_path.as_posix(),
366
- overwrite=True,
367
- auto_compress=False,
368
- )
472
+ upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
473
+ if source.is_file():
369
474
  source = source.parent
370
475
 
371
- # Upload requirements
372
- # TODO: Check if payload includes both a requirements.txt file and pip_requirements
476
+ upload_payloads(session, app_stage_path, *additional_payload_specs)
477
+
478
+ if not any(r.startswith("cloudpickle") for r in pip_requirements):
479
+ pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
480
+
481
+ # Upload system scripts and requirements.txt generated by pip_requirements to system/ directory
482
+ system_stage_path = stage_path.joinpath(constants.SYSTEM_STAGE_SUBPATH)
373
483
  if pip_requirements:
374
484
  # Upload requirements.txt to stage
375
485
  session.file.put_stream(
376
486
  io.BytesIO("\n".join(pip_requirements).encode()),
377
- stage_location=stage_path.joinpath("requirements.txt").as_posix(),
487
+ stage_location=system_stage_path.joinpath("requirements.txt").as_posix(),
378
488
  auto_compress=False,
379
489
  overwrite=True,
380
490
  )
381
491
 
382
- # Upload startup script
383
492
  # TODO: Make sure payload does not include file with same name
384
493
  session.file.put_stream(
385
494
  io.BytesIO(_STARTUP_SCRIPT_CODE.encode()),
386
- stage_location=stage_path.joinpath(_STARTUP_SCRIPT_PATH).as_posix(),
495
+ stage_location=system_stage_path.joinpath(_STARTUP_SCRIPT_PATH).as_posix(),
387
496
  auto_compress=False,
388
497
  overwrite=False, # FIXME
389
498
  )
390
499
 
391
- # Upload system scripts
392
500
  scripts_dir = Path(__file__).parent.joinpath("scripts")
393
501
  for script_file in scripts_dir.glob("*"):
394
502
  if script_file.is_file():
395
503
  session.file.put(
396
504
  script_file.as_posix(),
397
- stage_path.as_posix(),
505
+ system_stage_path.as_posix(),
398
506
  overwrite=True,
399
507
  auto_compress=False,
400
508
  )
401
-
402
509
  python_entrypoint: list[Union[str, PurePath]] = [
403
- PurePath("mljob_launcher.py"),
404
- entrypoint.file_path.relative_to(source),
510
+ PurePath(f"{constants.SYSTEM_MOUNT_PATH}/mljob_launcher.py"),
511
+ PurePath(f"{constants.APP_MOUNT_PATH}/{entrypoint.file_path.relative_to(source).as_posix()}"),
405
512
  ]
406
513
  if entrypoint.main_func:
407
514
  python_entrypoint += ["--script_main_func", entrypoint.main_func]
@@ -410,7 +517,7 @@ class JobPayload:
410
517
  stage_path=stage_path,
411
518
  entrypoint=[
412
519
  "bash",
413
- _STARTUP_SCRIPT_PATH,
520
+ f"{constants.SYSTEM_MOUNT_PATH}/{_STARTUP_SCRIPT_PATH}",
414
521
  *python_entrypoint,
415
522
  ],
416
523
  )
@@ -1,26 +1,4 @@
1
- from snowflake.ml.jobs._utils import constants as mljob_constants
2
-
3
1
  # Constants defining the shutdown signal actor configuration.
4
2
  SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
5
3
  SHUTDOWN_ACTOR_NAMESPACE = "default"
6
4
  SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0
7
-
8
-
9
- # The followings are Inherited from snowflake.ml.jobs._utils.constants
10
- # We need to copy them here since snowml package on the server side does
11
- # not have the latest version of the code
12
-
13
- # Log start and end messages
14
- LOG_START_MSG = getattr(
15
- mljob_constants,
16
- "LOG_START_MSG",
17
- "--------------------------------\nML job started\n--------------------------------",
18
- )
19
- LOG_END_MSG = getattr(
20
- mljob_constants,
21
- "LOG_END_MSG",
22
- "--------------------------------\nML job finished\n--------------------------------",
23
- )
24
-
25
- # min_instances environment variable name
26
- MIN_INSTANCES_ENV_VAR = getattr(mljob_constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
@@ -3,6 +3,7 @@ import copy
3
3
  import importlib.util
4
4
  import json
5
5
  import logging
6
+ import math
6
7
  import os
7
8
  import runpy
8
9
  import sys
@@ -13,7 +14,6 @@ from pathlib import Path
13
14
  from typing import Any, Optional
14
15
 
15
16
  import cloudpickle
16
- from constants import LOG_END_MSG, LOG_START_MSG, MIN_INSTANCES_ENV_VAR
17
17
 
18
18
  from snowflake.ml.jobs._utils import constants
19
19
  from snowflake.snowpark import Session
@@ -27,13 +27,35 @@ except ImportError:
27
27
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
28
28
  logger = logging.getLogger(__name__)
29
29
 
30
+
31
+ # The followings are Inherited from snowflake.ml.jobs._utils.constants
32
+ # We need to copy them here since snowml package on the server side does
33
+ # not have the latest version of the code
34
+ # Log start and end messages
35
+ LOG_START_MSG = getattr(
36
+ constants,
37
+ "LOG_START_MSG",
38
+ "--------------------------------\nML job started\n--------------------------------",
39
+ )
40
+ LOG_END_MSG = getattr(
41
+ constants,
42
+ "LOG_END_MSG",
43
+ "--------------------------------\nML job finished\n--------------------------------",
44
+ )
45
+
46
+ # min_instances environment variable name
47
+ MIN_INSTANCES_ENV_VAR = getattr(constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
48
+ TARGET_INSTANCES_ENV_VAR = getattr(constants, "TARGET_INSTANCES_ENV_VAR", "SNOWFLAKE_JOBS_COUNT")
49
+
30
50
  # Fallbacks in case of SnowML version mismatch
31
51
  RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
32
- JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "mljob_result.pkl")
52
+ JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "/mnt/job_stage/output/mljob_result.pkl")
53
+ PAYLOAD_DIR_ENV_VAR = getattr(constants, "PAYLOAD_DIR_ENV_VAR", "MLRS_PAYLOAD_DIR")
33
54
 
34
- # Constants for the wait_for_min_instances function
35
- CHECK_INTERVAL = 10 # seconds
36
- TIMEOUT = 720 # seconds
55
+ # Constants for the wait_for_instances function
56
+ MIN_WAIT_TIME = float(os.getenv("MLRS_INSTANCES_MIN_WAIT") or -1) # seconds
57
+ TIMEOUT = float(os.getenv("MLRS_INSTANCES_TIMEOUT") or 720) # seconds
58
+ CHECK_INTERVAL = float(os.getenv("MLRS_INSTANCES_CHECK_INTERVAL") or 10) # seconds
37
59
 
38
60
 
39
61
  try:
@@ -76,45 +98,108 @@ class SimpleJSONEncoder(json.JSONEncoder):
76
98
  return f"Unserializable object: {repr(obj)}"
77
99
 
78
100
 
79
- def wait_for_min_instances(min_instances: int) -> None:
101
+ def wait_for_instances(
102
+ min_instances: int,
103
+ target_instances: int,
104
+ *,
105
+ min_wait_time: float = -1, # seconds
106
+ timeout: float = 720, # seconds
107
+ check_interval: float = 10, # seconds
108
+ ) -> None:
80
109
  """
81
110
  Wait until the specified minimum number of instances are available in the Ray cluster.
82
111
 
83
112
  Args:
84
113
  min_instances: Minimum number of instances required
114
+ target_instances: Target number of instances to wait for
115
+ min_wait_time: Minimum time to wait for target_instances to be available.
116
+ If less than 0, automatically set based on target_instances.
117
+ timeout: Maximum time to wait for min_instances to be available before raising a TimeoutError.
118
+ check_interval: Maximum time to wait between checks (uses exponential backoff).
119
+
120
+ Examples:
121
+ Scenario 1 - Ideal case (target met quickly):
122
+ wait_for_instances(min_instances=2, target_instances=4, min_wait_time=5, timeout=60)
123
+ If 4 instances are available after 1 second, the function returns without further waiting (target met).
124
+
125
+ Scenario 2 - Min instances met, target not reached:
126
+ wait_for_instances(min_instances=2, target_instances=4, min_wait_time=10, timeout=60)
127
+ If only 3 instances are available after 10 seconds, the function returns (min requirement satisfied).
128
+
129
+ Scenario 3 - Min instances met early, but min_wait_time not elapsed:
130
+ wait_for_instances(min_instances=2, target_instances=4, min_wait_time=30, timeout=60)
131
+ If 2 instances are available after 5 seconds, function continues waiting for target_instances
132
+ until either 4 instances are found or 30 seconds have elapsed.
133
+
134
+ Scenario 4 - Timeout scenario:
135
+ wait_for_instances(min_instances=3, target_instances=5, min_wait_time=10, timeout=30)
136
+ If only 2 instances are available after 30 seconds, TimeoutError is raised.
137
+
138
+ Scenario 5 - Single instance job (early return):
139
+ wait_for_instances(min_instances=1, target_instances=1, min_wait_time=5, timeout=60)
140
+ The function returns without waiting because target_instances <= 1.
85
141
 
86
142
  Raises:
143
+ ValueError: If arguments are invalid
87
144
  TimeoutError: If failed to connect to Ray or if minimum instances are not available within timeout
88
145
  """
89
- if min_instances <= 1:
90
- logger.debug("Minimum instances is 1 or less, no need to wait for additional instances")
146
+ if min_instances > target_instances:
147
+ raise ValueError(
148
+ f"Minimum instances ({min_instances}) cannot be greater than target instances ({target_instances})"
149
+ )
150
+ if timeout < 0:
151
+ raise ValueError("Timeout must be greater than 0")
152
+ if check_interval < 0:
153
+ raise ValueError("Check interval must be greater than 0")
154
+
155
+ if target_instances <= 1:
156
+ logger.debug("Target instances is 1 or less, no need to wait for additional instances")
91
157
  return
92
158
 
159
+ if min_wait_time < 0:
160
+ # Automatically set min_wait_time based on the number of target instances
161
+ # Using min_wait_time = 3 * log2(target_instances) as a starting point:
162
+ # target_instances = 1 => min_wait_time = 0
163
+ # target_instances = 2 => min_wait_time = 3
164
+ # target_instances = 4 => min_wait_time = 6
165
+ # target_instances = 8 => min_wait_time = 9
166
+ # target_instances = 32 => min_wait_time = 15
167
+ # target_instances = 50 => min_wait_time = 16.9
168
+ # target_instances = 100 => min_wait_time = 19.9
169
+ min_wait_time = min(3 * math.log2(target_instances), timeout / 10) # Clamp to timeout / 10
170
+
93
171
  # mljob_launcher runs inside the CR where mlruntime libraries are available, so we can import common_util directly
94
172
  from common_utils import common_util as mlrs_util
95
173
 
96
174
  start_time = time.time()
97
- timeout = os.getenv("JOB_MIN_INSTANCES_TIMEOUT", TIMEOUT)
98
- check_interval = os.getenv("JOB_MIN_INSTANCES_CHECK_INTERVAL", CHECK_INTERVAL)
99
- logger.debug(f"Waiting for at least {min_instances} instances to be ready (timeout: {timeout}s)")
175
+ current_interval = max(min(1, check_interval), 0.1) # Default 1s, minimum 0.1s
176
+ logger.debug(
177
+ "Waiting for instances to be ready "
178
+ "(min_instances={}, target_instances={}, timeout={}s, max_check_interval={}s)".format(
179
+ min_instances, target_instances, timeout, check_interval
180
+ )
181
+ )
100
182
 
101
- while time.time() - start_time < timeout:
183
+ while (elapsed := time.time() - start_time) < timeout:
102
184
  total_nodes = mlrs_util.get_num_ray_nodes()
103
-
104
- if total_nodes >= min_instances:
105
- elapsed = time.time() - start_time
185
+ if total_nodes >= target_instances:
186
+ # Best case scenario: target_instances are already available
187
+ logger.info(f"Target instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
188
+ return
189
+ elif total_nodes >= min_instances and elapsed >= min_wait_time:
190
+ # Second best case scenario: target_instances not met within min_wait_time, but min_instances met
106
191
  logger.info(f"Minimum instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
107
192
  return
108
193
 
109
- logger.debug(
110
- f"Waiting for instances: {total_nodes}/{min_instances} available "
111
- f"(elapsed: {time.time() - start_time:.1f}s)"
194
+ logger.info(
195
+ f"Waiting for instances: current_instances={total_nodes}, min_instances={min_instances}, "
196
+ f"target_instances={target_instances}, elapsed={elapsed:.1f}s, next check in {current_interval:.1f}s"
112
197
  )
113
- time.sleep(check_interval)
198
+ time.sleep(current_interval)
199
+ current_interval = min(current_interval * 2, check_interval) # Exponential backoff
114
200
 
115
201
  raise TimeoutError(
116
- f"Timed out after {timeout}s waiting for {min_instances} instances, only "
117
- f"{mlrs_util.get_num_ray_nodes()} available"
202
+ f"Timed out after {elapsed}s waiting for {min_instances} instances, only " f"{total_nodes} available"
118
203
  )
119
204
 
120
205
 
@@ -137,6 +222,13 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
137
222
  original_argv = sys.argv
138
223
  sys.argv = [script_path, *script_args]
139
224
 
225
+ # Ensure payload directory is in sys.path for module imports
226
+ # This is needed because mljob_launcher.py is now in /mnt/job_stage/system
227
+ # but user scripts are in the payload directory and may import from each other
228
+ payload_dir = os.environ.get(PAYLOAD_DIR_ENV_VAR)
229
+ if payload_dir and payload_dir not in sys.path:
230
+ sys.path.insert(0, payload_dir)
231
+
140
232
  # Create a Snowpark session before running the script
141
233
  # Session can be retrieved from using snowflake.snowpark.context.get_active_session()
142
234
  session = Session.builder.configs(SnowflakeLoginOptions()).create() # noqa: F841
@@ -183,11 +275,22 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
183
275
  Raises:
184
276
  Exception: Re-raises any exception caught during script execution.
185
277
  """
278
+ # Ensure the output directory exists before trying to write result files.
279
+ output_dir = os.path.dirname(JOB_RESULT_PATH)
280
+ os.makedirs(output_dir, exist_ok=True)
281
+
186
282
  try:
187
283
  # Wait for minimum required instances if specified
188
284
  min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
189
- if min_instances_str and int(min_instances_str) > 1:
190
- wait_for_min_instances(int(min_instances_str))
285
+ target_instances_str = os.environ.get(TARGET_INSTANCES_ENV_VAR) or min_instances_str
286
+ if target_instances_str and int(target_instances_str) > 1:
287
+ wait_for_instances(
288
+ int(min_instances_str),
289
+ int(target_instances_str),
290
+ min_wait_time=MIN_WAIT_TIME,
291
+ timeout=TIMEOUT,
292
+ check_interval=CHECK_INTERVAL,
293
+ )
191
294
 
192
295
  # Log start marker for user script execution
193
296
  print(LOG_START_MSG) # noqa: T201
@@ -181,7 +181,7 @@ def generate_service_spec(
181
181
  # TODO: Add hooks for endpoints for integration with TensorBoard etc
182
182
 
183
183
  env_vars = {
184
- constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix(),
184
+ constants.PAYLOAD_DIR_ENV_VAR: constants.APP_MOUNT_PATH,
185
185
  constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
186
186
  }
187
187
  endpoints: list[dict[str, Any]] = []