snowflake-ml-python 1.9.1__py3-none-any.whl → 1.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. snowflake/ml/_internal/utils/mixins.py +6 -4
  2. snowflake/ml/_internal/utils/service_logger.py +101 -1
  3. snowflake/ml/data/_internal/arrow_ingestor.py +4 -1
  4. snowflake/ml/data/data_connector.py +4 -34
  5. snowflake/ml/dataset/dataset.py +1 -1
  6. snowflake/ml/dataset/dataset_reader.py +2 -8
  7. snowflake/ml/experiment/__init__.py +3 -0
  8. snowflake/ml/experiment/callback.py +121 -0
  9. snowflake/ml/jobs/_utils/constants.py +15 -4
  10. snowflake/ml/jobs/_utils/payload_utils.py +150 -49
  11. snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
  12. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +125 -22
  13. snowflake/ml/jobs/_utils/spec_utils.py +1 -1
  14. snowflake/ml/jobs/_utils/stage_utils.py +30 -14
  15. snowflake/ml/jobs/_utils/types.py +64 -4
  16. snowflake/ml/jobs/job.py +22 -6
  17. snowflake/ml/jobs/manager.py +5 -3
  18. snowflake/ml/model/_client/ops/service_ops.py +17 -2
  19. snowflake/ml/model/_client/sql/service.py +1 -38
  20. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
  21. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -0
  22. snowflake/ml/model/_signatures/pandas_handler.py +3 -0
  23. snowflake/ml/model/_signatures/utils.py +4 -0
  24. snowflake/ml/model/model_signature.py +2 -0
  25. snowflake/ml/version.py +1 -1
  26. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +42 -4
  27. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +30 -28
  28. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
  29. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
  30. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,8 @@ import functools
2
2
  import inspect
3
3
  import io
4
4
  import itertools
5
+ import keyword
6
+ import logging
5
7
  import pickle
6
8
  import sys
7
9
  import textwrap
@@ -22,8 +24,11 @@ from snowflake.ml.jobs._utils import (
22
24
  from snowflake.snowpark import exceptions as sp_exceptions
23
25
  from snowflake.snowpark._internal import code_generation
24
26
 
27
+ logger = logging.getLogger(__name__)
28
+
25
29
  cp.register_pickle_by_value(function_payload_utils)
26
30
 
31
+
27
32
  _SUPPORTED_ARG_TYPES = {str, int, float}
28
33
  _SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
29
34
  _ENTRYPOINT_FUNC_NAME = "func"
@@ -32,6 +37,9 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
32
37
  f"""
33
38
  #!/bin/bash
34
39
 
40
+ ##### Get system scripts directory #####
41
+ SYSTEM_DIR=$(cd "$(dirname "$0")" && pwd)
42
+
35
43
  ##### Perform common set up steps #####
36
44
  set -e # exit if a command fails
37
45
 
@@ -75,12 +83,14 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
75
83
 
76
84
  # Check if the local get_instance_ip.py script exists
77
85
  HELPER_EXISTS=$(
78
- [ -f "get_instance_ip.py" ] && echo "true" || echo "false"
86
+ [ -f "${{SYSTEM_DIR}}/get_instance_ip.py" ] && echo "true" || echo "false"
79
87
  )
80
88
 
89
+
81
90
  # Configure IP address and logging directory
82
91
  if [ "$HELPER_EXISTS" = "true" ]; then
83
- eth0Ip=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
92
+ eth0Ip=$(python3 "${{SYSTEM_DIR}}/get_instance_ip.py" \
93
+ "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
84
94
  else
85
95
  eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
86
96
  fi
@@ -103,7 +113,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
103
113
 
104
114
  # Determine if it should be a worker or a head node for batch jobs
105
115
  if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 && "$HELPER_EXISTS" = "true" ]]; then
106
- head_info=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --head)
116
+ head_info=$(python3 "${{SYSTEM_DIR}}/get_instance_ip.py" "$SNOWFLAKE_SERVICE_NAME" --head)
107
117
  if [ $? -eq 0 ]; then
108
118
  # Parse the output using read
109
119
  read head_index head_ip head_status<<< "$head_info"
@@ -185,7 +195,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
185
195
 
186
196
  # Start the worker shutdown listener in the background
187
197
  echo "Starting worker shutdown listener..."
188
- python worker_shutdown_listener.py
198
+ python "${{SYSTEM_DIR}}/worker_shutdown_listener.py"
189
199
  WORKER_EXIT_CODE=$?
190
200
 
191
201
  echo "Worker shutdown listener exited with code $WORKER_EXIT_CODE"
@@ -218,19 +228,59 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
218
228
 
219
229
  # After the user's job completes, signal workers to shut down
220
230
  echo "User job completed. Signaling workers to shut down..."
221
- python signal_workers.py --wait-time 15
231
+ python "${{SYSTEM_DIR}}/signal_workers.py" --wait-time 15
222
232
  echo "Head node job completed. Exiting."
223
233
  fi
224
234
  """
225
235
  ).strip()
226
236
 
227
237
 
238
+ def resolve_path(path: str) -> types.PayloadPath:
239
+ try:
240
+ stage_path = stage_utils.StagePath(path)
241
+ except ValueError:
242
+ return Path(path)
243
+ return stage_path
244
+
245
+
246
+ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_specs: types.PayloadSpec) -> None:
247
+ for source_path, remote_relative_path in payload_specs:
248
+ payload_stage_path = stage_path.joinpath(remote_relative_path) if remote_relative_path else stage_path
249
+ if isinstance(source_path, stage_utils.StagePath):
250
+ # only copy files into one stage directory from another stage directory, not from stage file
251
+ # due to incomplete of StagePath functionality
252
+ session.sql(f"copy files into {payload_stage_path.as_posix()}/ from {source_path.as_posix()}/").collect()
253
+ elif isinstance(source_path, Path):
254
+ if source_path.is_dir():
255
+ # Manually traverse the directory and upload each file, since Snowflake PUT
256
+ # can't handle directories. Reduce the number of PUT operations by using
257
+ # wildcard patterns to batch upload files with the same extension.
258
+ for path in {
259
+ p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
260
+ for p in source_path.resolve().rglob("*")
261
+ if p.is_file()
262
+ }:
263
+ session.file.put(
264
+ str(path),
265
+ payload_stage_path.joinpath(path.parent.relative_to(source_path)).as_posix(),
266
+ overwrite=True,
267
+ auto_compress=False,
268
+ )
269
+ else:
270
+ session.file.put(
271
+ str(source_path.resolve()),
272
+ payload_stage_path.as_posix(),
273
+ overwrite=True,
274
+ auto_compress=False,
275
+ )
276
+
277
+
228
278
  def resolve_source(
229
- source: Union[Path, stage_utils.StagePath, Callable[..., Any]]
230
- ) -> Union[Path, stage_utils.StagePath, Callable[..., Any]]:
279
+ source: Union[types.PayloadPath, Callable[..., Any]]
280
+ ) -> Union[types.PayloadPath, Callable[..., Any]]:
231
281
  if callable(source):
232
282
  return source
233
- elif isinstance(source, (Path, stage_utils.StagePath)):
283
+ elif isinstance(source, types.PayloadPath):
234
284
  if not source.exists():
235
285
  raise FileNotFoundError(f"{source} does not exist")
236
286
  return source.absolute()
@@ -239,8 +289,8 @@ def resolve_source(
239
289
 
240
290
 
241
291
  def resolve_entrypoint(
242
- source: Union[Path, stage_utils.StagePath, Callable[..., Any]],
243
- entrypoint: Optional[Union[stage_utils.StagePath, Path]],
292
+ source: Union[types.PayloadPath, Callable[..., Any]],
293
+ entrypoint: Optional[types.PayloadPath],
244
294
  ) -> types.PayloadEntrypoint:
245
295
  if callable(source):
246
296
  # Entrypoint is generated for callable payloads
@@ -289,6 +339,73 @@ def resolve_entrypoint(
289
339
  )
290
340
 
291
341
 
342
+ def resolve_additional_payloads(
343
+ additional_payloads: Optional[list[Union[str, tuple[str, str]]]]
344
+ ) -> list[types.PayloadSpec]:
345
+ """
346
+ Determine how to stage local packages so that imports continue to work.
347
+
348
+ Args:
349
+ additional_payloads: A list of directory paths, each optionally paired with a dot-separated
350
+ import path
351
+ e.g. [("proj/src/utils", "src.utils"), "proj/src/helper"]
352
+ if there is no import path, the last part of path will be considered as import path
353
+ e.g. the import path of "proj/src/helper" is "helper"
354
+
355
+ Returns:
356
+ A list of payloadSpec for additional payloads.
357
+
358
+ Raises:
359
+ FileNotFoundError: If any specified package path does not exist.
360
+ ValueError: If the format of local_packages is invalid.
361
+
362
+ """
363
+ if not additional_payloads:
364
+ return []
365
+
366
+ logger.warning(
367
+ "When providing a stage path as an additional payload, "
368
+ "please ensure it points to a directory. "
369
+ "Files are not currently supported."
370
+ )
371
+
372
+ additional_payloads_paths = []
373
+ for pkg in additional_payloads:
374
+ if isinstance(pkg, str):
375
+ source_path = resolve_path(pkg).absolute()
376
+ module_path = source_path.name
377
+ elif isinstance(pkg, tuple):
378
+ try:
379
+ source_path_str, module_path = pkg
380
+ except ValueError:
381
+ raise ValueError(
382
+ f"Invalid format in `additional_payloads`. "
383
+ f"Expected a tuple of (source_path, module_path). Got {pkg}"
384
+ )
385
+ source_path = resolve_path(source_path_str).absolute()
386
+ else:
387
+ raise ValueError("the format of additional payload is not correct")
388
+
389
+ if not source_path.exists():
390
+ raise FileNotFoundError(f"{source_path} does not exist")
391
+
392
+ if isinstance(source_path, Path):
393
+ if source_path.is_file():
394
+ raise ValueError(f"file is not supported for additional payloads: {source_path}")
395
+
396
+ module_parts = module_path.split(".")
397
+ for part in module_parts:
398
+ if not part.isidentifier() or keyword.iskeyword(part):
399
+ raise ValueError(
400
+ f"Invalid module import path '{module_path}'. "
401
+ f"'{part}' is not a valid Python identifier or is a keyword."
402
+ )
403
+
404
+ dest_path = PurePath(*module_parts)
405
+ additional_payloads_paths.append(types.PayloadSpec(source_path, dest_path))
406
+ return additional_payloads_paths
407
+
408
+
292
409
  class JobPayload:
293
410
  def __init__(
294
411
  self,
@@ -296,11 +413,13 @@ class JobPayload:
296
413
  entrypoint: Optional[Union[str, Path]] = None,
297
414
  *,
298
415
  pip_requirements: Optional[list[str]] = None,
416
+ additional_payloads: Optional[list[Union[str, tuple[str, str]]]] = None,
299
417
  ) -> None:
300
418
  # for stage path like snow://domain....., Path(path) will remove duplicate /, it will become snow:/ domain...
301
- self.source = stage_utils.identify_stage_path(source) if isinstance(source, str) else source
302
- self.entrypoint = stage_utils.identify_stage_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
419
+ self.source = resolve_path(source) if isinstance(source, str) else source
420
+ self.entrypoint = resolve_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
303
421
  self.pip_requirements = pip_requirements
422
+ self.additional_payloads = additional_payloads
304
423
 
305
424
  def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
306
425
  # Prepare local variables
@@ -308,6 +427,7 @@ class JobPayload:
308
427
  source = resolve_source(self.source)
309
428
  entrypoint = resolve_entrypoint(source, self.entrypoint)
310
429
  pip_requirements = self.pip_requirements or []
430
+ additional_payload_specs = resolve_additional_payloads(self.additional_payloads)
311
431
 
312
432
  # Create stage if necessary
313
433
  stage_name = stage_path.parts[0].lstrip("@")
@@ -323,12 +443,13 @@ class JobPayload:
323
443
  params=[stage_name],
324
444
  )
325
445
 
326
- # Upload payload to stage
327
- if not isinstance(source, (Path, stage_utils.StagePath)):
446
+ # Upload payload to stage - organize into app/ subdirectory
447
+ app_stage_path = stage_path.joinpath(constants.APP_STAGE_SUBPATH)
448
+ if not isinstance(source, types.PayloadPath):
328
449
  source_code = generate_python_code(source, source_code_display=True)
329
450
  _ = session.file.put_stream(
330
451
  io.BytesIO(source_code.encode()),
331
- stage_location=stage_path.joinpath(entrypoint.file_path).as_posix(),
452
+ stage_location=app_stage_path.joinpath(entrypoint.file_path).as_posix(),
332
453
  auto_compress=False,
333
454
  overwrite=True,
334
455
  )
@@ -340,68 +461,48 @@ class JobPayload:
340
461
  # copy payload to stage
341
462
  if source == entrypoint.file_path:
342
463
  source = source.parent
343
- source_path = source.as_posix() + "/"
344
- session.sql(f"copy files into {stage_path}/ from {source_path}").collect()
464
+ upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
345
465
 
346
466
  elif isinstance(source, Path):
347
- if source.is_dir():
348
- # Manually traverse the directory and upload each file, since Snowflake PUT
349
- # can't handle directories. Reduce the number of PUT operations by using
350
- # wildcard patterns to batch upload files with the same extension.
351
- for path in {
352
- p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
353
- for p in source.resolve().rglob("*")
354
- if p.is_file()
355
- }:
356
- session.file.put(
357
- str(path),
358
- stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
359
- overwrite=True,
360
- auto_compress=False,
361
- )
362
- else:
363
- session.file.put(
364
- str(source.resolve()),
365
- stage_path.as_posix(),
366
- overwrite=True,
367
- auto_compress=False,
368
- )
467
+ upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
468
+ if source.is_file():
369
469
  source = source.parent
370
470
 
371
- # Upload requirements
471
+ upload_payloads(session, app_stage_path, *additional_payload_specs)
472
+
473
+ # Upload requirements to app/ directory
372
474
  # TODO: Check if payload includes both a requirements.txt file and pip_requirements
373
475
  if pip_requirements:
374
476
  # Upload requirements.txt to stage
375
477
  session.file.put_stream(
376
478
  io.BytesIO("\n".join(pip_requirements).encode()),
377
- stage_location=stage_path.joinpath("requirements.txt").as_posix(),
479
+ stage_location=app_stage_path.joinpath("requirements.txt").as_posix(),
378
480
  auto_compress=False,
379
481
  overwrite=True,
380
482
  )
381
483
 
382
- # Upload startup script
484
+ # Upload startup script to system/ directory within payload
485
+ system_stage_path = stage_path.joinpath(constants.SYSTEM_STAGE_SUBPATH)
383
486
  # TODO: Make sure payload does not include file with same name
384
487
  session.file.put_stream(
385
488
  io.BytesIO(_STARTUP_SCRIPT_CODE.encode()),
386
- stage_location=stage_path.joinpath(_STARTUP_SCRIPT_PATH).as_posix(),
489
+ stage_location=system_stage_path.joinpath(_STARTUP_SCRIPT_PATH).as_posix(),
387
490
  auto_compress=False,
388
491
  overwrite=False, # FIXME
389
492
  )
390
493
 
391
- # Upload system scripts
392
494
  scripts_dir = Path(__file__).parent.joinpath("scripts")
393
495
  for script_file in scripts_dir.glob("*"):
394
496
  if script_file.is_file():
395
497
  session.file.put(
396
498
  script_file.as_posix(),
397
- stage_path.as_posix(),
499
+ system_stage_path.as_posix(),
398
500
  overwrite=True,
399
501
  auto_compress=False,
400
502
  )
401
-
402
503
  python_entrypoint: list[Union[str, PurePath]] = [
403
- PurePath("mljob_launcher.py"),
404
- entrypoint.file_path.relative_to(source),
504
+ PurePath(f"{constants.SYSTEM_MOUNT_PATH}/mljob_launcher.py"),
505
+ PurePath(f"{constants.APP_MOUNT_PATH}/{entrypoint.file_path.relative_to(source).as_posix()}"),
405
506
  ]
406
507
  if entrypoint.main_func:
407
508
  python_entrypoint += ["--script_main_func", entrypoint.main_func]
@@ -410,7 +511,7 @@ class JobPayload:
410
511
  stage_path=stage_path,
411
512
  entrypoint=[
412
513
  "bash",
413
- _STARTUP_SCRIPT_PATH,
514
+ f"{constants.SYSTEM_MOUNT_PATH}/{_STARTUP_SCRIPT_PATH}",
414
515
  *python_entrypoint,
415
516
  ],
416
517
  )
@@ -1,26 +1,4 @@
1
- from snowflake.ml.jobs._utils import constants as mljob_constants
2
-
3
1
  # Constants defining the shutdown signal actor configuration.
4
2
  SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
5
3
  SHUTDOWN_ACTOR_NAMESPACE = "default"
6
4
  SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0
7
-
8
-
9
- # The followings are Inherited from snowflake.ml.jobs._utils.constants
10
- # We need to copy them here since snowml package on the server side does
11
- # not have the latest version of the code
12
-
13
- # Log start and end messages
14
- LOG_START_MSG = getattr(
15
- mljob_constants,
16
- "LOG_START_MSG",
17
- "--------------------------------\nML job started\n--------------------------------",
18
- )
19
- LOG_END_MSG = getattr(
20
- mljob_constants,
21
- "LOG_END_MSG",
22
- "--------------------------------\nML job finished\n--------------------------------",
23
- )
24
-
25
- # min_instances environment variable name
26
- MIN_INSTANCES_ENV_VAR = getattr(mljob_constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
@@ -3,6 +3,7 @@ import copy
3
3
  import importlib.util
4
4
  import json
5
5
  import logging
6
+ import math
6
7
  import os
7
8
  import runpy
8
9
  import sys
@@ -13,7 +14,6 @@ from pathlib import Path
13
14
  from typing import Any, Optional
14
15
 
15
16
  import cloudpickle
16
- from constants import LOG_END_MSG, LOG_START_MSG, MIN_INSTANCES_ENV_VAR
17
17
 
18
18
  from snowflake.ml.jobs._utils import constants
19
19
  from snowflake.snowpark import Session
@@ -27,13 +27,35 @@ except ImportError:
27
27
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
28
28
  logger = logging.getLogger(__name__)
29
29
 
30
+
31
+ # The followings are Inherited from snowflake.ml.jobs._utils.constants
32
+ # We need to copy them here since snowml package on the server side does
33
+ # not have the latest version of the code
34
+ # Log start and end messages
35
+ LOG_START_MSG = getattr(
36
+ constants,
37
+ "LOG_START_MSG",
38
+ "--------------------------------\nML job started\n--------------------------------",
39
+ )
40
+ LOG_END_MSG = getattr(
41
+ constants,
42
+ "LOG_END_MSG",
43
+ "--------------------------------\nML job finished\n--------------------------------",
44
+ )
45
+
46
+ # min_instances environment variable name
47
+ MIN_INSTANCES_ENV_VAR = getattr(constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
48
+ TARGET_INSTANCES_ENV_VAR = getattr(constants, "TARGET_INSTANCES_ENV_VAR", "SNOWFLAKE_JOBS_COUNT")
49
+
30
50
  # Fallbacks in case of SnowML version mismatch
31
51
  RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
32
- JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "mljob_result.pkl")
52
+ JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "/mnt/job_stage/output/mljob_result.pkl")
53
+ PAYLOAD_DIR_ENV_VAR = getattr(constants, "PAYLOAD_DIR_ENV_VAR", "MLRS_PAYLOAD_DIR")
33
54
 
34
- # Constants for the wait_for_min_instances function
35
- CHECK_INTERVAL = 10 # seconds
36
- TIMEOUT = 720 # seconds
55
+ # Constants for the wait_for_instances function
56
+ MIN_WAIT_TIME = float(os.getenv("MLRS_INSTANCES_MIN_WAIT") or -1) # seconds
57
+ TIMEOUT = float(os.getenv("MLRS_INSTANCES_TIMEOUT") or 720) # seconds
58
+ CHECK_INTERVAL = float(os.getenv("MLRS_INSTANCES_CHECK_INTERVAL") or 10) # seconds
37
59
 
38
60
 
39
61
  try:
@@ -76,45 +98,108 @@ class SimpleJSONEncoder(json.JSONEncoder):
76
98
  return f"Unserializable object: {repr(obj)}"
77
99
 
78
100
 
79
- def wait_for_min_instances(min_instances: int) -> None:
101
+ def wait_for_instances(
102
+ min_instances: int,
103
+ target_instances: int,
104
+ *,
105
+ min_wait_time: float = -1, # seconds
106
+ timeout: float = 720, # seconds
107
+ check_interval: float = 10, # seconds
108
+ ) -> None:
80
109
  """
81
110
  Wait until the specified minimum number of instances are available in the Ray cluster.
82
111
 
83
112
  Args:
84
113
  min_instances: Minimum number of instances required
114
+ target_instances: Target number of instances to wait for
115
+ min_wait_time: Minimum time to wait for target_instances to be available.
116
+ If less than 0, automatically set based on target_instances.
117
+ timeout: Maximum time to wait for min_instances to be available before raising a TimeoutError.
118
+ check_interval: Maximum time to wait between checks (uses exponential backoff).
119
+
120
+ Examples:
121
+ Scenario 1 - Ideal case (target met quickly):
122
+ wait_for_instances(min_instances=2, target_instances=4, min_wait_time=5, timeout=60)
123
+ If 4 instances are available after 1 second, the function returns without further waiting (target met).
124
+
125
+ Scenario 2 - Min instances met, target not reached:
126
+ wait_for_instances(min_instances=2, target_instances=4, min_wait_time=10, timeout=60)
127
+ If only 3 instances are available after 10 seconds, the function returns (min requirement satisfied).
128
+
129
+ Scenario 3 - Min instances met early, but min_wait_time not elapsed:
130
+ wait_for_instances(min_instances=2, target_instances=4, min_wait_time=30, timeout=60)
131
+ If 2 instances are available after 5 seconds, function continues waiting for target_instances
132
+ until either 4 instances are found or 30 seconds have elapsed.
133
+
134
+ Scenario 4 - Timeout scenario:
135
+ wait_for_instances(min_instances=3, target_instances=5, min_wait_time=10, timeout=30)
136
+ If only 2 instances are available after 30 seconds, TimeoutError is raised.
137
+
138
+ Scenario 5 - Single instance job (early return):
139
+ wait_for_instances(min_instances=1, target_instances=1, min_wait_time=5, timeout=60)
140
+ The function returns without waiting because target_instances <= 1.
85
141
 
86
142
  Raises:
143
+ ValueError: If arguments are invalid
87
144
  TimeoutError: If failed to connect to Ray or if minimum instances are not available within timeout
88
145
  """
89
- if min_instances <= 1:
90
- logger.debug("Minimum instances is 1 or less, no need to wait for additional instances")
146
+ if min_instances > target_instances:
147
+ raise ValueError(
148
+ f"Minimum instances ({min_instances}) cannot be greater than target instances ({target_instances})"
149
+ )
150
+ if timeout < 0:
151
+ raise ValueError("Timeout must be greater than 0")
152
+ if check_interval < 0:
153
+ raise ValueError("Check interval must be greater than 0")
154
+
155
+ if target_instances <= 1:
156
+ logger.debug("Target instances is 1 or less, no need to wait for additional instances")
91
157
  return
92
158
 
159
+ if min_wait_time < 0:
160
+ # Automatically set min_wait_time based on the number of target instances
161
+ # Using min_wait_time = 3 * log2(target_instances) as a starting point:
162
+ # target_instances = 1 => min_wait_time = 0
163
+ # target_instances = 2 => min_wait_time = 3
164
+ # target_instances = 4 => min_wait_time = 6
165
+ # target_instances = 8 => min_wait_time = 9
166
+ # target_instances = 32 => min_wait_time = 15
167
+ # target_instances = 50 => min_wait_time = 16.9
168
+ # target_instances = 100 => min_wait_time = 19.9
169
+ min_wait_time = min(3 * math.log2(target_instances), timeout / 10) # Clamp to timeout / 10
170
+
93
171
  # mljob_launcher runs inside the CR where mlruntime libraries are available, so we can import common_util directly
94
172
  from common_utils import common_util as mlrs_util
95
173
 
96
174
  start_time = time.time()
97
- timeout = os.getenv("JOB_MIN_INSTANCES_TIMEOUT", TIMEOUT)
98
- check_interval = os.getenv("JOB_MIN_INSTANCES_CHECK_INTERVAL", CHECK_INTERVAL)
99
- logger.debug(f"Waiting for at least {min_instances} instances to be ready (timeout: {timeout}s)")
175
+ current_interval = max(min(1, check_interval), 0.1) # Default 1s, minimum 0.1s
176
+ logger.debug(
177
+ "Waiting for instances to be ready "
178
+ "(min_instances={}, target_instances={}, timeout={}s, max_check_interval={}s)".format(
179
+ min_instances, target_instances, timeout, check_interval
180
+ )
181
+ )
100
182
 
101
- while time.time() - start_time < timeout:
183
+ while (elapsed := time.time() - start_time) < timeout:
102
184
  total_nodes = mlrs_util.get_num_ray_nodes()
103
-
104
- if total_nodes >= min_instances:
105
- elapsed = time.time() - start_time
185
+ if total_nodes >= target_instances:
186
+ # Best case scenario: target_instances are already available
187
+ logger.info(f"Target instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
188
+ return
189
+ elif total_nodes >= min_instances and elapsed >= min_wait_time:
190
+ # Second best case scenario: target_instances not met within min_wait_time, but min_instances met
106
191
  logger.info(f"Minimum instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
107
192
  return
108
193
 
109
194
  logger.debug(
110
- f"Waiting for instances: {total_nodes}/{min_instances} available "
111
- f"(elapsed: {time.time() - start_time:.1f}s)"
195
+ f"Waiting for instances: current_instances={total_nodes}, min_instances={min_instances}, "
196
+ f"target_instances={target_instances}, elapsed={elapsed:.1f}s, next check in {current_interval:.1f}s"
112
197
  )
113
- time.sleep(check_interval)
198
+ time.sleep(current_interval)
199
+ current_interval = min(current_interval * 2, check_interval) # Exponential backoff
114
200
 
115
201
  raise TimeoutError(
116
- f"Timed out after {timeout}s waiting for {min_instances} instances, only "
117
- f"{mlrs_util.get_num_ray_nodes()} available"
202
+ f"Timed out after {timeout}s waiting for {min_instances} instances, only " f"{total_nodes} available"
118
203
  )
119
204
 
120
205
 
@@ -137,6 +222,13 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
137
222
  original_argv = sys.argv
138
223
  sys.argv = [script_path, *script_args]
139
224
 
225
+ # Ensure payload directory is in sys.path for module imports
226
+ # This is needed because mljob_launcher.py is now in /mnt/job_stage/system
227
+ # but user scripts are in the payload directory and may import from each other
228
+ payload_dir = os.environ.get(PAYLOAD_DIR_ENV_VAR)
229
+ if payload_dir and payload_dir not in sys.path:
230
+ sys.path.insert(0, payload_dir)
231
+
140
232
  # Create a Snowpark session before running the script
141
233
  # Session can be retrieved from using snowflake.snowpark.context.get_active_session()
142
234
  session = Session.builder.configs(SnowflakeLoginOptions()).create() # noqa: F841
@@ -183,11 +275,22 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
183
275
  Raises:
184
276
  Exception: Re-raises any exception caught during script execution.
185
277
  """
278
+ # Ensure the output directory exists before trying to write result files.
279
+ output_dir = os.path.dirname(JOB_RESULT_PATH)
280
+ os.makedirs(output_dir, exist_ok=True)
281
+
186
282
  try:
187
283
  # Wait for minimum required instances if specified
188
284
  min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
189
- if min_instances_str and int(min_instances_str) > 1:
190
- wait_for_min_instances(int(min_instances_str))
285
+ target_instances_str = os.environ.get(TARGET_INSTANCES_ENV_VAR) or min_instances_str
286
+ if target_instances_str and int(target_instances_str) > 1:
287
+ wait_for_instances(
288
+ int(min_instances_str),
289
+ int(target_instances_str),
290
+ min_wait_time=MIN_WAIT_TIME,
291
+ timeout=TIMEOUT,
292
+ check_interval=CHECK_INTERVAL,
293
+ )
191
294
 
192
295
  # Log start marker for user script execution
193
296
  print(LOG_START_MSG) # noqa: T201
@@ -181,7 +181,7 @@ def generate_service_spec(
181
181
  # TODO: Add hooks for endpoints for integration with TensorBoard etc
182
182
 
183
183
  env_vars = {
184
- constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix(),
184
+ constants.PAYLOAD_DIR_ENV_VAR: constants.APP_MOUNT_PATH,
185
185
  constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
186
186
  }
187
187
  endpoints: list[dict[str, Any]] = []