snowflake-ml-python 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/mixins.py +54 -42
  5. snowflake/ml/_internal/utils/service_logger.py +105 -3
  6. snowflake/ml/data/_internal/arrow_ingestor.py +15 -2
  7. snowflake/ml/data/data_connector.py +13 -2
  8. snowflake/ml/data/data_ingestor.py +8 -0
  9. snowflake/ml/data/torch_utils.py +1 -1
  10. snowflake/ml/dataset/dataset.py +2 -1
  11. snowflake/ml/dataset/dataset_reader.py +14 -4
  12. snowflake/ml/experiment/__init__.py +3 -0
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/callback.py +121 -0
  20. snowflake/ml/experiment/experiment_tracking.py +319 -0
  21. snowflake/ml/jobs/_utils/constants.py +15 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +156 -54
  23. snowflake/ml/jobs/_utils/query_helper.py +16 -5
  24. snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
  25. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +130 -23
  26. snowflake/ml/jobs/_utils/spec_utils.py +23 -8
  27. snowflake/ml/jobs/_utils/stage_utils.py +30 -14
  28. snowflake/ml/jobs/_utils/types.py +64 -4
  29. snowflake/ml/jobs/job.py +70 -75
  30. snowflake/ml/jobs/manager.py +59 -31
  31. snowflake/ml/lineage/lineage_node.py +2 -2
  32. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  33. snowflake/ml/model/_client/ops/service_ops.py +336 -137
  34. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  35. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -1
  36. snowflake/ml/model/_client/sql/service.py +1 -38
  37. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  38. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +17 -3
  39. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  40. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  41. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/pandas_handler.py +3 -0
  45. snowflake/ml/model/_signatures/utils.py +4 -0
  46. snowflake/ml/model/event_handler.py +117 -0
  47. snowflake/ml/model/model_signature.py +11 -9
  48. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  49. snowflake/ml/modeling/framework/base.py +1 -1
  50. snowflake/ml/modeling/metrics/classification.py +14 -14
  51. snowflake/ml/modeling/metrics/correlation.py +19 -8
  52. snowflake/ml/modeling/metrics/ranking.py +6 -6
  53. snowflake/ml/modeling/metrics/regression.py +9 -9
  54. snowflake/ml/monitoring/explain_visualize.py +12 -5
  55. snowflake/ml/registry/_manager/model_manager.py +32 -15
  56. snowflake/ml/registry/registry.py +48 -80
  57. snowflake/ml/version.py +1 -1
  58. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +107 -5
  59. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +62 -52
  60. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
  61. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
  62. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,8 @@ import functools
2
2
  import inspect
3
3
  import io
4
4
  import itertools
5
+ import keyword
6
+ import logging
5
7
  import pickle
6
8
  import sys
7
9
  import textwrap
@@ -12,17 +14,21 @@ import cloudpickle as cp
12
14
  from packaging import version
13
15
 
14
16
  from snowflake import snowpark
15
- from snowflake.connector import errors
16
17
  from snowflake.ml.jobs._utils import (
17
18
  constants,
18
19
  function_payload_utils,
20
+ query_helper,
19
21
  stage_utils,
20
22
  types,
21
23
  )
24
+ from snowflake.snowpark import exceptions as sp_exceptions
22
25
  from snowflake.snowpark._internal import code_generation
23
26
 
27
+ logger = logging.getLogger(__name__)
28
+
24
29
  cp.register_pickle_by_value(function_payload_utils)
25
30
 
31
+
26
32
  _SUPPORTED_ARG_TYPES = {str, int, float}
27
33
  _SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
28
34
  _ENTRYPOINT_FUNC_NAME = "func"
@@ -31,6 +37,9 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
31
37
  f"""
32
38
  #!/bin/bash
33
39
 
40
+ ##### Get system scripts directory #####
41
+ SYSTEM_DIR=$(cd "$(dirname "$0")" && pwd)
42
+
34
43
  ##### Perform common set up steps #####
35
44
  set -e # exit if a command fails
36
45
 
@@ -74,12 +83,14 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
74
83
 
75
84
  # Check if the local get_instance_ip.py script exists
76
85
  HELPER_EXISTS=$(
77
- [ -f "get_instance_ip.py" ] && echo "true" || echo "false"
86
+ [ -f "${{SYSTEM_DIR}}/get_instance_ip.py" ] && echo "true" || echo "false"
78
87
  )
79
88
 
89
+
80
90
  # Configure IP address and logging directory
81
91
  if [ "$HELPER_EXISTS" = "true" ]; then
82
- eth0Ip=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
92
+ eth0Ip=$(python3 "${{SYSTEM_DIR}}/get_instance_ip.py" \
93
+ "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
83
94
  else
84
95
  eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
85
96
  fi
@@ -102,7 +113,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
102
113
 
103
114
  # Determine if it should be a worker or a head node for batch jobs
104
115
  if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 && "$HELPER_EXISTS" = "true" ]]; then
105
- head_info=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --head)
116
+ head_info=$(python3 "${{SYSTEM_DIR}}/get_instance_ip.py" "$SNOWFLAKE_SERVICE_NAME" --head)
106
117
  if [ $? -eq 0 ]; then
107
118
  # Parse the output using read
108
119
  read head_index head_ip head_status<<< "$head_info"
@@ -184,7 +195,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
184
195
 
185
196
  # Start the worker shutdown listener in the background
186
197
  echo "Starting worker shutdown listener..."
187
- python worker_shutdown_listener.py
198
+ python "${{SYSTEM_DIR}}/worker_shutdown_listener.py"
188
199
  WORKER_EXIT_CODE=$?
189
200
 
190
201
  echo "Worker shutdown listener exited with code $WORKER_EXIT_CODE"
@@ -217,19 +228,59 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
217
228
 
218
229
  # After the user's job completes, signal workers to shut down
219
230
  echo "User job completed. Signaling workers to shut down..."
220
- python signal_workers.py --wait-time 15
231
+ python "${{SYSTEM_DIR}}/signal_workers.py" --wait-time 15
221
232
  echo "Head node job completed. Exiting."
222
233
  fi
223
234
  """
224
235
  ).strip()
225
236
 
226
237
 
238
+ def resolve_path(path: str) -> types.PayloadPath:
239
+ try:
240
+ stage_path = stage_utils.StagePath(path)
241
+ except ValueError:
242
+ return Path(path)
243
+ return stage_path
244
+
245
+
246
+ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_specs: types.PayloadSpec) -> None:
247
+ for source_path, remote_relative_path in payload_specs:
248
+ payload_stage_path = stage_path.joinpath(remote_relative_path) if remote_relative_path else stage_path
249
+ if isinstance(source_path, stage_utils.StagePath):
250
+ # only copy files into one stage directory from another stage directory, not from stage file
251
+ # due to incomplete of StagePath functionality
252
+ session.sql(f"copy files into {payload_stage_path.as_posix()}/ from {source_path.as_posix()}/").collect()
253
+ elif isinstance(source_path, Path):
254
+ if source_path.is_dir():
255
+ # Manually traverse the directory and upload each file, since Snowflake PUT
256
+ # can't handle directories. Reduce the number of PUT operations by using
257
+ # wildcard patterns to batch upload files with the same extension.
258
+ for path in {
259
+ p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
260
+ for p in source_path.resolve().rglob("*")
261
+ if p.is_file()
262
+ }:
263
+ session.file.put(
264
+ str(path),
265
+ payload_stage_path.joinpath(path.parent.relative_to(source_path)).as_posix(),
266
+ overwrite=True,
267
+ auto_compress=False,
268
+ )
269
+ else:
270
+ session.file.put(
271
+ str(source_path.resolve()),
272
+ payload_stage_path.as_posix(),
273
+ overwrite=True,
274
+ auto_compress=False,
275
+ )
276
+
277
+
227
278
  def resolve_source(
228
- source: Union[Path, stage_utils.StagePath, Callable[..., Any]]
229
- ) -> Union[Path, stage_utils.StagePath, Callable[..., Any]]:
279
+ source: Union[types.PayloadPath, Callable[..., Any]]
280
+ ) -> Union[types.PayloadPath, Callable[..., Any]]:
230
281
  if callable(source):
231
282
  return source
232
- elif isinstance(source, (Path, stage_utils.StagePath)):
283
+ elif isinstance(source, types.PayloadPath):
233
284
  if not source.exists():
234
285
  raise FileNotFoundError(f"{source} does not exist")
235
286
  return source.absolute()
@@ -238,8 +289,8 @@ def resolve_source(
238
289
 
239
290
 
240
291
  def resolve_entrypoint(
241
- source: Union[Path, stage_utils.StagePath, Callable[..., Any]],
242
- entrypoint: Optional[Union[stage_utils.StagePath, Path]],
292
+ source: Union[types.PayloadPath, Callable[..., Any]],
293
+ entrypoint: Optional[types.PayloadPath],
243
294
  ) -> types.PayloadEntrypoint:
244
295
  if callable(source):
245
296
  # Entrypoint is generated for callable payloads
@@ -288,6 +339,73 @@ def resolve_entrypoint(
288
339
  )
289
340
 
290
341
 
342
+ def resolve_additional_payloads(
343
+ additional_payloads: Optional[list[Union[str, tuple[str, str]]]]
344
+ ) -> list[types.PayloadSpec]:
345
+ """
346
+ Determine how to stage local packages so that imports continue to work.
347
+
348
+ Args:
349
+ additional_payloads: A list of directory paths, each optionally paired with a dot-separated
350
+ import path
351
+ e.g. [("proj/src/utils", "src.utils"), "proj/src/helper"]
352
+ if there is no import path, the last part of path will be considered as import path
353
+ e.g. the import path of "proj/src/helper" is "helper"
354
+
355
+ Returns:
356
+ A list of payloadSpec for additional payloads.
357
+
358
+ Raises:
359
+ FileNotFoundError: If any specified package path does not exist.
360
+ ValueError: If the format of local_packages is invalid.
361
+
362
+ """
363
+ if not additional_payloads:
364
+ return []
365
+
366
+ logger.warning(
367
+ "When providing a stage path as an additional payload, "
368
+ "please ensure it points to a directory. "
369
+ "Files are not currently supported."
370
+ )
371
+
372
+ additional_payloads_paths = []
373
+ for pkg in additional_payloads:
374
+ if isinstance(pkg, str):
375
+ source_path = resolve_path(pkg).absolute()
376
+ module_path = source_path.name
377
+ elif isinstance(pkg, tuple):
378
+ try:
379
+ source_path_str, module_path = pkg
380
+ except ValueError:
381
+ raise ValueError(
382
+ f"Invalid format in `additional_payloads`. "
383
+ f"Expected a tuple of (source_path, module_path). Got {pkg}"
384
+ )
385
+ source_path = resolve_path(source_path_str).absolute()
386
+ else:
387
+ raise ValueError("the format of additional payload is not correct")
388
+
389
+ if not source_path.exists():
390
+ raise FileNotFoundError(f"{source_path} does not exist")
391
+
392
+ if isinstance(source_path, Path):
393
+ if source_path.is_file():
394
+ raise ValueError(f"file is not supported for additional payloads: {source_path}")
395
+
396
+ module_parts = module_path.split(".")
397
+ for part in module_parts:
398
+ if not part.isidentifier() or keyword.iskeyword(part):
399
+ raise ValueError(
400
+ f"Invalid module import path '{module_path}'. "
401
+ f"'{part}' is not a valid Python identifier or is a keyword."
402
+ )
403
+
404
+ dest_path = PurePath(*module_parts)
405
+ additional_payloads_paths.append(types.PayloadSpec(source_path, dest_path))
406
+ return additional_payloads_paths
407
+
408
+
291
409
  class JobPayload:
292
410
  def __init__(
293
411
  self,
@@ -295,11 +413,13 @@ class JobPayload:
295
413
  entrypoint: Optional[Union[str, Path]] = None,
296
414
  *,
297
415
  pip_requirements: Optional[list[str]] = None,
416
+ additional_payloads: Optional[list[Union[str, tuple[str, str]]]] = None,
298
417
  ) -> None:
299
418
  # for stage path like snow://domain....., Path(path) will remove duplicate /, it will become snow:/ domain...
300
- self.source = stage_utils.identify_stage_path(source) if isinstance(source, str) else source
301
- self.entrypoint = stage_utils.identify_stage_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
419
+ self.source = resolve_path(source) if isinstance(source, str) else source
420
+ self.entrypoint = resolve_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
302
421
  self.pip_requirements = pip_requirements
422
+ self.additional_payloads = additional_payloads
303
423
 
304
424
  def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
305
425
  # Prepare local variables
@@ -307,27 +427,29 @@ class JobPayload:
307
427
  source = resolve_source(self.source)
308
428
  entrypoint = resolve_entrypoint(source, self.entrypoint)
309
429
  pip_requirements = self.pip_requirements or []
430
+ additional_payload_specs = resolve_additional_payloads(self.additional_payloads)
310
431
 
311
432
  # Create stage if necessary
312
433
  stage_name = stage_path.parts[0].lstrip("@")
313
434
  # Explicitly check if stage exists first since we may not have CREATE STAGE privilege
314
435
  try:
315
- session._conn.run_query("describe stage identifier(?)", params=[stage_name], _force_qmark_paramstyle=True)
316
- except errors.ProgrammingError:
317
- session._conn.run_query(
436
+ query_helper.run_query(session, "describe stage identifier(?)", params=[stage_name])
437
+ except sp_exceptions.SnowparkSQLException:
438
+ query_helper.run_query(
439
+ session,
318
440
  "create stage if not exists identifier(?)"
319
441
  " encryption = ( type = 'SNOWFLAKE_SSE' )"
320
442
  " comment = 'Created by snowflake.ml.jobs Python API'",
321
443
  params=[stage_name],
322
- _force_qmark_paramstyle=True,
323
444
  )
324
445
 
325
- # Upload payload to stage
326
- if not isinstance(source, (Path, stage_utils.StagePath)):
446
+ # Upload payload to stage - organize into app/ subdirectory
447
+ app_stage_path = stage_path.joinpath(constants.APP_STAGE_SUBPATH)
448
+ if not isinstance(source, types.PayloadPath):
327
449
  source_code = generate_python_code(source, source_code_display=True)
328
450
  _ = session.file.put_stream(
329
451
  io.BytesIO(source_code.encode()),
330
- stage_location=stage_path.joinpath(entrypoint.file_path).as_posix(),
452
+ stage_location=app_stage_path.joinpath(entrypoint.file_path).as_posix(),
331
453
  auto_compress=False,
332
454
  overwrite=True,
333
455
  )
@@ -339,68 +461,48 @@ class JobPayload:
339
461
  # copy payload to stage
340
462
  if source == entrypoint.file_path:
341
463
  source = source.parent
342
- source_path = source.as_posix() + "/"
343
- session.sql(f"copy files into {stage_path}/ from {source_path}").collect()
464
+ upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
344
465
 
345
466
  elif isinstance(source, Path):
346
- if source.is_dir():
347
- # Manually traverse the directory and upload each file, since Snowflake PUT
348
- # can't handle directories. Reduce the number of PUT operations by using
349
- # wildcard patterns to batch upload files with the same extension.
350
- for path in {
351
- p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
352
- for p in source.resolve().rglob("*")
353
- if p.is_file()
354
- }:
355
- session.file.put(
356
- str(path),
357
- stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
358
- overwrite=True,
359
- auto_compress=False,
360
- )
361
- else:
362
- session.file.put(
363
- str(source.resolve()),
364
- stage_path.as_posix(),
365
- overwrite=True,
366
- auto_compress=False,
367
- )
467
+ upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
468
+ if source.is_file():
368
469
  source = source.parent
369
470
 
370
- # Upload requirements
471
+ upload_payloads(session, app_stage_path, *additional_payload_specs)
472
+
473
+ # Upload requirements to app/ directory
371
474
  # TODO: Check if payload includes both a requirements.txt file and pip_requirements
372
475
  if pip_requirements:
373
476
  # Upload requirements.txt to stage
374
477
  session.file.put_stream(
375
478
  io.BytesIO("\n".join(pip_requirements).encode()),
376
- stage_location=stage_path.joinpath("requirements.txt").as_posix(),
479
+ stage_location=app_stage_path.joinpath("requirements.txt").as_posix(),
377
480
  auto_compress=False,
378
481
  overwrite=True,
379
482
  )
380
483
 
381
- # Upload startup script
484
+ # Upload startup script to system/ directory within payload
485
+ system_stage_path = stage_path.joinpath(constants.SYSTEM_STAGE_SUBPATH)
382
486
  # TODO: Make sure payload does not include file with same name
383
487
  session.file.put_stream(
384
488
  io.BytesIO(_STARTUP_SCRIPT_CODE.encode()),
385
- stage_location=stage_path.joinpath(_STARTUP_SCRIPT_PATH).as_posix(),
489
+ stage_location=system_stage_path.joinpath(_STARTUP_SCRIPT_PATH).as_posix(),
386
490
  auto_compress=False,
387
491
  overwrite=False, # FIXME
388
492
  )
389
493
 
390
- # Upload system scripts
391
494
  scripts_dir = Path(__file__).parent.joinpath("scripts")
392
495
  for script_file in scripts_dir.glob("*"):
393
496
  if script_file.is_file():
394
497
  session.file.put(
395
498
  script_file.as_posix(),
396
- stage_path.as_posix(),
499
+ system_stage_path.as_posix(),
397
500
  overwrite=True,
398
501
  auto_compress=False,
399
502
  )
400
-
401
503
  python_entrypoint: list[Union[str, PurePath]] = [
402
- PurePath("mljob_launcher.py"),
403
- entrypoint.file_path.relative_to(source),
504
+ PurePath(f"{constants.SYSTEM_MOUNT_PATH}/mljob_launcher.py"),
505
+ PurePath(f"{constants.APP_MOUNT_PATH}/{entrypoint.file_path.relative_to(source).as_posix()}"),
404
506
  ]
405
507
  if entrypoint.main_func:
406
508
  python_entrypoint += ["--script_main_func", entrypoint.main_func]
@@ -409,7 +511,7 @@ class JobPayload:
409
511
  stage_path=stage_path,
410
512
  entrypoint=[
411
513
  "bash",
412
- _STARTUP_SCRIPT_PATH,
514
+ f"{constants.SYSTEM_MOUNT_PATH}/{_STARTUP_SCRIPT_PATH}",
413
515
  *python_entrypoint,
414
516
  ],
415
517
  )
@@ -1,9 +1,20 @@
1
+ from typing import Any, Optional, Sequence
2
+
1
3
  from snowflake import snowpark
4
+ from snowflake.snowpark import Row
5
+ from snowflake.snowpark._internal import utils
6
+ from snowflake.snowpark._internal.analyzer import snowflake_plan
2
7
 
3
8
 
4
- def get_attribute_map(session: snowpark.Session, requested_attributes: dict[str, int]) -> dict[str, int]:
9
+ def result_set_to_rows(session: snowpark.Session, result: dict[str, Any]) -> list[Row]:
5
10
  metadata = session._conn._cursor.description
6
- for index in range(len(metadata)):
7
- if metadata[index].name in requested_attributes.keys():
8
- requested_attributes[metadata[index].name] = index
9
- return requested_attributes
11
+ result_set = result["data"]
12
+ return utils.result_set_to_rows(result_set, metadata)
13
+
14
+
15
+ @snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
16
+ def run_query(session: snowpark.Session, query_text: str, params: Optional[Sequence[Any]] = None) -> list[Row]:
17
+ result = session._conn.run_query(query=query_text, params=params, _force_qmark_paramstyle=True)
18
+ if not isinstance(result, dict) or "data" not in result:
19
+ raise ValueError(f"Unprocessable result: {result}")
20
+ return result_set_to_rows(session, result)
@@ -1,26 +1,4 @@
1
- from snowflake.ml.jobs._utils import constants as mljob_constants
2
-
3
1
  # Constants defining the shutdown signal actor configuration.
4
2
  SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
5
3
  SHUTDOWN_ACTOR_NAMESPACE = "default"
6
4
  SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0
7
-
8
-
9
- # The followings are Inherited from snowflake.ml.jobs._utils.constants
10
- # We need to copy them here since snowml package on the server side does
11
- # not have the latest version of the code
12
-
13
- # Log start and end messages
14
- LOG_START_MSG = getattr(
15
- mljob_constants,
16
- "LOG_START_MSG",
17
- "--------------------------------\nML job started\n--------------------------------",
18
- )
19
- LOG_END_MSG = getattr(
20
- mljob_constants,
21
- "LOG_END_MSG",
22
- "--------------------------------\nML job finished\n--------------------------------",
23
- )
24
-
25
- # min_instances environment variable name
26
- MIN_INSTANCES_ENV_VAR = getattr(mljob_constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
@@ -3,6 +3,7 @@ import copy
3
3
  import importlib.util
4
4
  import json
5
5
  import logging
6
+ import math
6
7
  import os
7
8
  import runpy
8
9
  import sys
@@ -13,23 +14,48 @@ from pathlib import Path
13
14
  from typing import Any, Optional
14
15
 
15
16
  import cloudpickle
16
- from constants import LOG_END_MSG, LOG_START_MSG, MIN_INSTANCES_ENV_VAR
17
17
 
18
18
  from snowflake.ml.jobs._utils import constants
19
- from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
20
19
  from snowflake.snowpark import Session
21
20
 
21
+ try:
22
+ from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
23
+ except ImportError:
24
+ from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
25
+
22
26
  # Configure logging
23
27
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
24
28
  logger = logging.getLogger(__name__)
25
29
 
30
+
31
+ # The followings are Inherited from snowflake.ml.jobs._utils.constants
32
+ # We need to copy them here since snowml package on the server side does
33
+ # not have the latest version of the code
34
+ # Log start and end messages
35
+ LOG_START_MSG = getattr(
36
+ constants,
37
+ "LOG_START_MSG",
38
+ "--------------------------------\nML job started\n--------------------------------",
39
+ )
40
+ LOG_END_MSG = getattr(
41
+ constants,
42
+ "LOG_END_MSG",
43
+ "--------------------------------\nML job finished\n--------------------------------",
44
+ )
45
+
46
+ # min_instances environment variable name
47
+ MIN_INSTANCES_ENV_VAR = getattr(constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
48
+ TARGET_INSTANCES_ENV_VAR = getattr(constants, "TARGET_INSTANCES_ENV_VAR", "SNOWFLAKE_JOBS_COUNT")
49
+
26
50
  # Fallbacks in case of SnowML version mismatch
27
51
  RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
28
- JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "mljob_result.pkl")
52
+ JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "/mnt/job_stage/output/mljob_result.pkl")
53
+ PAYLOAD_DIR_ENV_VAR = getattr(constants, "PAYLOAD_DIR_ENV_VAR", "MLRS_PAYLOAD_DIR")
29
54
 
30
- # Constants for the wait_for_min_instances function
31
- CHECK_INTERVAL = 10 # seconds
32
- TIMEOUT = 720 # seconds
55
+ # Constants for the wait_for_instances function
56
+ MIN_WAIT_TIME = float(os.getenv("MLRS_INSTANCES_MIN_WAIT") or -1) # seconds
57
+ TIMEOUT = float(os.getenv("MLRS_INSTANCES_TIMEOUT") or 720) # seconds
58
+ CHECK_INTERVAL = float(os.getenv("MLRS_INSTANCES_CHECK_INTERVAL") or 10) # seconds
33
59
 
34
60
 
35
61
  try:
@@ -72,45 +98,108 @@ class SimpleJSONEncoder(json.JSONEncoder):
72
98
  return f"Unserializable object: {repr(obj)}"
73
99
 
74
100
 
75
- def wait_for_min_instances(min_instances: int) -> None:
101
+ def wait_for_instances(
102
+ min_instances: int,
103
+ target_instances: int,
104
+ *,
105
+ min_wait_time: float = -1, # seconds
106
+ timeout: float = 720, # seconds
107
+ check_interval: float = 10, # seconds
108
+ ) -> None:
76
109
  """
77
110
  Wait until the specified minimum number of instances are available in the Ray cluster.
78
111
 
79
112
  Args:
80
113
  min_instances: Minimum number of instances required
114
+ target_instances: Target number of instances to wait for
115
+ min_wait_time: Minimum time to wait for target_instances to be available.
116
+ If less than 0, automatically set based on target_instances.
117
+ timeout: Maximum time to wait for min_instances to be available before raising a TimeoutError.
118
+ check_interval: Maximum time to wait between checks (uses exponential backoff).
119
+
120
+ Examples:
121
+ Scenario 1 - Ideal case (target met quickly):
122
+ wait_for_instances(min_instances=2, target_instances=4, min_wait_time=5, timeout=60)
123
+ If 4 instances are available after 1 second, the function returns without further waiting (target met).
124
+
125
+ Scenario 2 - Min instances met, target not reached:
126
+ wait_for_instances(min_instances=2, target_instances=4, min_wait_time=10, timeout=60)
127
+ If only 3 instances are available after 10 seconds, the function returns (min requirement satisfied).
128
+
129
+ Scenario 3 - Min instances met early, but min_wait_time not elapsed:
130
+ wait_for_instances(min_instances=2, target_instances=4, min_wait_time=30, timeout=60)
131
+ If 2 instances are available after 5 seconds, function continues waiting for target_instances
132
+ until either 4 instances are found or 30 seconds have elapsed.
133
+
134
+ Scenario 4 - Timeout scenario:
135
+ wait_for_instances(min_instances=3, target_instances=5, min_wait_time=10, timeout=30)
136
+ If only 2 instances are available after 30 seconds, TimeoutError is raised.
137
+
138
+ Scenario 5 - Single instance job (early return):
139
+ wait_for_instances(min_instances=1, target_instances=1, min_wait_time=5, timeout=60)
140
+ The function returns without waiting because target_instances <= 1.
81
141
 
82
142
  Raises:
143
+ ValueError: If arguments are invalid
83
144
  TimeoutError: If failed to connect to Ray or if minimum instances are not available within timeout
84
145
  """
85
- if min_instances <= 1:
86
- logger.debug("Minimum instances is 1 or less, no need to wait for additional instances")
146
+ if min_instances > target_instances:
147
+ raise ValueError(
148
+ f"Minimum instances ({min_instances}) cannot be greater than target instances ({target_instances})"
149
+ )
150
+ if timeout < 0:
151
+ raise ValueError("Timeout must be greater than 0")
152
+ if check_interval < 0:
153
+ raise ValueError("Check interval must be greater than 0")
154
+
155
+ if target_instances <= 1:
156
+ logger.debug("Target instances is 1 or less, no need to wait for additional instances")
87
157
  return
88
158
 
159
+ if min_wait_time < 0:
160
+ # Automatically set min_wait_time based on the number of target instances
161
+ # Using min_wait_time = 3 * log2(target_instances) as a starting point:
162
+ # target_instances = 1 => min_wait_time = 0
163
+ # target_instances = 2 => min_wait_time = 3
164
+ # target_instances = 4 => min_wait_time = 6
165
+ # target_instances = 8 => min_wait_time = 9
166
+ # target_instances = 32 => min_wait_time = 15
167
+ # target_instances = 50 => min_wait_time = 16.9
168
+ # target_instances = 100 => min_wait_time = 19.9
169
+ min_wait_time = min(3 * math.log2(target_instances), timeout / 10) # Clamp to timeout / 10
170
+
89
171
  # mljob_launcher runs inside the CR where mlruntime libraries are available, so we can import common_util directly
90
172
  from common_utils import common_util as mlrs_util
91
173
 
92
174
  start_time = time.time()
93
- timeout = os.getenv("JOB_MIN_INSTANCES_TIMEOUT", TIMEOUT)
94
- check_interval = os.getenv("JOB_MIN_INSTANCES_CHECK_INTERVAL", CHECK_INTERVAL)
95
- logger.debug(f"Waiting for at least {min_instances} instances to be ready (timeout: {timeout}s)")
175
+ current_interval = max(min(1, check_interval), 0.1) # Default 1s, minimum 0.1s
176
+ logger.debug(
177
+ "Waiting for instances to be ready "
178
+ "(min_instances={}, target_instances={}, timeout={}s, max_check_interval={}s)".format(
179
+ min_instances, target_instances, timeout, check_interval
180
+ )
181
+ )
96
182
 
97
- while time.time() - start_time < timeout:
183
+ while (elapsed := time.time() - start_time) < timeout:
98
184
  total_nodes = mlrs_util.get_num_ray_nodes()
99
-
100
- if total_nodes >= min_instances:
101
- elapsed = time.time() - start_time
185
+ if total_nodes >= target_instances:
186
+ # Best case scenario: target_instances are already available
187
+ logger.info(f"Target instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
188
+ return
189
+ elif total_nodes >= min_instances and elapsed >= min_wait_time:
190
+ # Second best case scenario: target_instances not met within min_wait_time, but min_instances met
102
191
  logger.info(f"Minimum instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
103
192
  return
104
193
 
105
194
  logger.debug(
106
- f"Waiting for instances: {total_nodes}/{min_instances} available "
107
- f"(elapsed: {time.time() - start_time:.1f}s)"
195
+ f"Waiting for instances: current_instances={total_nodes}, min_instances={min_instances}, "
196
+ f"target_instances={target_instances}, elapsed={elapsed:.1f}s, next check in {current_interval:.1f}s"
108
197
  )
109
- time.sleep(check_interval)
198
+ time.sleep(current_interval)
199
+ current_interval = min(current_interval * 2, check_interval) # Exponential backoff
110
200
 
111
201
  raise TimeoutError(
112
- f"Timed out after {timeout}s waiting for {min_instances} instances, only "
113
- f"{mlrs_util.get_num_ray_nodes()} available"
202
+ f"Timed out after {timeout}s waiting for {min_instances} instances, only " f"{total_nodes} available"
114
203
  )
115
204
 
116
205
 
@@ -133,6 +222,13 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
133
222
  original_argv = sys.argv
134
223
  sys.argv = [script_path, *script_args]
135
224
 
225
+ # Ensure payload directory is in sys.path for module imports
226
+ # This is needed because mljob_launcher.py is now in /mnt/job_stage/system
227
+ # but user scripts are in the payload directory and may import from each other
228
+ payload_dir = os.environ.get(PAYLOAD_DIR_ENV_VAR)
229
+ if payload_dir and payload_dir not in sys.path:
230
+ sys.path.insert(0, payload_dir)
231
+
136
232
  # Create a Snowpark session before running the script
137
233
  # Session can be retrieved from using snowflake.snowpark.context.get_active_session()
138
234
  session = Session.builder.configs(SnowflakeLoginOptions()).create() # noqa: F841
@@ -179,11 +275,22 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
179
275
  Raises:
180
276
  Exception: Re-raises any exception caught during script execution.
181
277
  """
278
+ # Ensure the output directory exists before trying to write result files.
279
+ output_dir = os.path.dirname(JOB_RESULT_PATH)
280
+ os.makedirs(output_dir, exist_ok=True)
281
+
182
282
  try:
183
283
  # Wait for minimum required instances if specified
184
284
  min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
185
- if min_instances_str and int(min_instances_str) > 1:
186
- wait_for_min_instances(int(min_instances_str))
285
+ target_instances_str = os.environ.get(TARGET_INSTANCES_ENV_VAR) or min_instances_str
286
+ if target_instances_str and int(target_instances_str) > 1:
287
+ wait_for_instances(
288
+ int(min_instances_str),
289
+ int(target_instances_str),
290
+ min_wait_time=MIN_WAIT_TIME,
291
+ timeout=TIMEOUT,
292
+ check_interval=CHECK_INTERVAL,
293
+ )
187
294
 
188
295
  # Log start marker for user script execution
189
296
  print(LOG_START_MSG) # noqa: T201