snowflake-ml-python 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
  2. snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
  3. snowflake/ml/jobs/__init__.py +4 -0
  4. snowflake/ml/jobs/_interop/__init__.py +0 -0
  5. snowflake/ml/jobs/_interop/data_utils.py +124 -0
  6. snowflake/ml/jobs/_interop/dto_schema.py +95 -0
  7. snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
  8. snowflake/ml/jobs/_interop/legacy.py +225 -0
  9. snowflake/ml/jobs/_interop/protocols.py +471 -0
  10. snowflake/ml/jobs/_interop/results.py +51 -0
  11. snowflake/ml/jobs/_interop/utils.py +144 -0
  12. snowflake/ml/jobs/_utils/constants.py +4 -1
  13. snowflake/ml/jobs/_utils/feature_flags.py +37 -5
  14. snowflake/ml/jobs/_utils/payload_utils.py +1 -1
  15. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
  16. snowflake/ml/jobs/_utils/spec_utils.py +2 -1
  17. snowflake/ml/jobs/_utils/types.py +10 -0
  18. snowflake/ml/jobs/job.py +168 -36
  19. snowflake/ml/jobs/manager.py +36 -38
  20. snowflake/ml/model/_client/model/model_version_impl.py +39 -7
  21. snowflake/ml/model/_client/ops/model_ops.py +4 -0
  22. snowflake/ml/model/_client/sql/model_version.py +3 -1
  23. snowflake/ml/model/_model_composer/model_method/model_method.py +7 -2
  24. snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
  25. snowflake/ml/model/_packager/model_env/model_env.py +22 -5
  26. snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
  27. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  28. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  29. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
  30. snowflake/ml/version.py +1 -1
  31. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/METADATA +26 -4
  32. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/RECORD +35 -27
  33. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/WHEEL +0 -0
  34. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/licenses/LICENSE.txt +0 -0
  35. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,51 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Optional
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class ExecutionResult:
7
+ """
8
+ A result of a job execution.
9
+
10
+ Args:
11
+ success: Whether the execution was successful.
12
+ value: The value of the execution.
13
+ """
14
+
15
+ success: bool
16
+ value: Any
17
+
18
+ def get_value(self, wrap_exceptions: bool = True) -> Any:
19
+ if not self.success:
20
+ assert isinstance(self.value, BaseException), "Unexpected non-exception value for failed result"
21
+ self._raise_exception(self.value, wrap_exceptions)
22
+ return self.value
23
+
24
+ def _raise_exception(self, exception: BaseException, wrap_exceptions: bool) -> None:
25
+ if wrap_exceptions:
26
+ raise RuntimeError(f"Job execution failed with error: {exception!r}") from exception
27
+ else:
28
+ raise exception
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class LoadedExecutionResult(ExecutionResult):
33
+ """
34
+ A result of a job execution that has been loaded from a file.
35
+ """
36
+
37
+ load_error: Optional[Exception] = None
38
+ result_metadata: Optional[dict[str, Any]] = None
39
+
40
+ def get_value(self, wrap_exceptions: bool = True) -> Any:
41
+ if not self.success:
42
+ # Raise the original exception if available, otherwise raise the load error
43
+ ex = self.value
44
+ if not isinstance(ex, BaseException):
45
+ ex = RuntimeError(f"Unknown error {ex or ''}")
46
+ ex.__cause__ = self.load_error
47
+ self._raise_exception(ex, wrap_exceptions)
48
+ else:
49
+ if self.load_error:
50
+ raise ValueError("Job execution succeeded but result retrieval failed") from self.load_error
51
+ return self.value
@@ -0,0 +1,144 @@
1
+ import logging
2
+ import os
3
+ import traceback
4
+ from pathlib import PurePath
5
+ from typing import Any, Callable, Optional
6
+
7
+ import pydantic
8
+
9
+ from snowflake import snowpark
10
+ from snowflake.ml.jobs._interop import data_utils, exception_utils, legacy, protocols
11
+ from snowflake.ml.jobs._interop.dto_schema import (
12
+ ExceptionMetadata,
13
+ ResultDTO,
14
+ ResultMetadata,
15
+ )
16
+ from snowflake.ml.jobs._interop.results import ExecutionResult, LoadedExecutionResult
17
+ from snowflake.snowpark import exceptions as sp_exceptions
18
+
19
+ DEFAULT_CODEC = data_utils.JsonDtoCodec
20
+ DEFAULT_PROTOCOL = protocols.AutoProtocol()
21
+ DEFAULT_PROTOCOL.try_register_protocol(protocols.CloudPickleProtocol)
22
+ DEFAULT_PROTOCOL.try_register_protocol(protocols.ArrowTableProtocol)
23
+ DEFAULT_PROTOCOL.try_register_protocol(protocols.PandasDataFrameProtocol)
24
+ DEFAULT_PROTOCOL.try_register_protocol(protocols.NumpyArrayProtocol)
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def save_result(result: ExecutionResult, path: str, session: Optional[snowpark.Session] = None) -> None:
31
+ """
32
+ Save the result to a file.
33
+ """
34
+ result_dto = ResultDTO(
35
+ success=result.success,
36
+ value=result.value,
37
+ )
38
+
39
+ try:
40
+ # Try to encode result directly
41
+ payload = DEFAULT_CODEC.encode(result_dto)
42
+ except TypeError:
43
+ result_dto.value = None # Remove raw value to avoid serialization error
44
+ result_dto.metadata = _get_metadata(result.value) # Add metadata for client fallback on protocol mismatch
45
+ try:
46
+ path_dir = PurePath(path).parent.as_posix()
47
+ protocol_info = DEFAULT_PROTOCOL.save(result.value, path_dir, session=session)
48
+ result_dto.protocol = protocol_info
49
+
50
+ except Exception as e:
51
+ logger.warning(f"Error dumping result value: {repr(e)}")
52
+ result_dto.serialize_error = repr(e)
53
+
54
+ # Encode the modified result DTO
55
+ payload = DEFAULT_CODEC.encode(result_dto)
56
+
57
+ with data_utils.open_stream(path, "wb", session=session) as stream:
58
+ stream.write(payload)
59
+
60
+
61
+ def load_result(
62
+ path: str, session: Optional[snowpark.Session] = None, path_transform: Optional[Callable[[str], str]] = None
63
+ ) -> ExecutionResult:
64
+ """Load the result from a file on a Snowflake stage."""
65
+ try:
66
+ with data_utils.open_stream(path, "r", session=session) as stream:
67
+ # Load the DTO as a dict for easy fallback to legacy loading if necessary
68
+ dto_dict = DEFAULT_CODEC.decode(stream, as_dict=True)
69
+ except UnicodeDecodeError:
70
+ # Path may be a legacy result file (cloudpickle)
71
+ # TODO: Re-use the stream
72
+ assert session is not None
73
+ return legacy.load_legacy_result(session, path)
74
+
75
+ try:
76
+ dto = ResultDTO.model_validate(dto_dict)
77
+ except pydantic.ValidationError as e:
78
+ if "success" in dto_dict:
79
+ assert session is not None
80
+ if path.endswith(".json"):
81
+ path = os.path.splitext(path)[0] + ".pkl"
82
+ return legacy.load_legacy_result(session, path, result_json=dto_dict)
83
+ raise ValueError("Invalid result schema") from e
84
+
85
+ # Try loading data from file using the protocol info
86
+ result_value = None
87
+ data_load_error = None
88
+ if dto.protocol is not None:
89
+ try:
90
+ logger.debug(f"Loading result value with protocol {dto.protocol}")
91
+ result_value = DEFAULT_PROTOCOL.load(dto.protocol, session=session, path_transform=path_transform)
92
+ except sp_exceptions.SnowparkSQLException:
93
+ raise # Data retrieval errors should be bubbled up
94
+ except Exception as e:
95
+ logger.debug(f"Error loading result value with protocol {dto.protocol}: {repr(e)}")
96
+ data_load_error = e
97
+
98
+ # Wrap serialize_error in a TypeError
99
+ if dto.serialize_error:
100
+ serialize_error = TypeError("Original result serialization failed with error: " + dto.serialize_error)
101
+ if data_load_error:
102
+ data_load_error.__context__ = serialize_error
103
+ else:
104
+ data_load_error = serialize_error
105
+
106
+ # Prepare to assemble the final result
107
+ result_value = result_value if result_value is not None else dto.value
108
+ if not dto.success and result_value is None:
109
+ # Try to reconstruct exception from metadata if available
110
+ if isinstance(dto.metadata, ExceptionMetadata):
111
+ logger.debug(f"Reconstructing exception from metadata {dto.metadata}")
112
+ result_value = exception_utils.build_exception(
113
+ type_str=dto.metadata.type,
114
+ message=dto.metadata.message,
115
+ traceback=dto.metadata.traceback,
116
+ original_repr=dto.metadata.repr,
117
+ )
118
+
119
+ # Generate a generic error if we still don't have a value,
120
+ # attaching the data load error if any
121
+ if result_value is None:
122
+ result_value = exception_utils.RemoteError("Unknown remote error")
123
+ result_value.__cause__ = data_load_error
124
+
125
+ return LoadedExecutionResult(
126
+ success=dto.success,
127
+ value=result_value,
128
+ load_error=data_load_error,
129
+ )
130
+
131
+
132
+ def _get_metadata(value: Any) -> ResultMetadata:
133
+ type_name = f"{type(value).__module__}.{type(value).__name__}"
134
+ if isinstance(value, BaseException):
135
+ return ExceptionMetadata(
136
+ type=type_name,
137
+ repr=repr(value),
138
+ message=str(value),
139
+ traceback="".join(traceback.format_tb(value.__traceback__)),
140
+ )
141
+ return ResultMetadata(
142
+ type=type_name,
143
+ repr=repr(value),
144
+ )
@@ -12,6 +12,9 @@ PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
12
12
  RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
13
13
  MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
14
14
  TARGET_INSTANCES_ENV_VAR = "SNOWFLAKE_JOBS_COUNT"
15
+ INSTANCES_MIN_WAIT_ENV_VAR = "MLRS_INSTANCES_MIN_WAIT"
16
+ INSTANCES_TIMEOUT_ENV_VAR = "MLRS_INSTANCES_TIMEOUT"
17
+ INSTANCES_CHECK_INTERVAL_ENV_VAR = "MLRS_INSTANCES_CHECK_INTERVAL"
15
18
  RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG"
16
19
 
17
20
  # Stage mount paths
@@ -19,7 +22,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/job_stage"
19
22
  APP_STAGE_SUBPATH = "app"
20
23
  SYSTEM_STAGE_SUBPATH = "system"
21
24
  OUTPUT_STAGE_SUBPATH = "output"
22
- RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_STAGE_SUBPATH}/mljob_result.pkl"
25
+ RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_STAGE_SUBPATH}/mljob_result"
23
26
 
24
27
  # Default container image information
25
28
  DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
@@ -1,16 +1,48 @@
1
1
  import os
2
2
  from enum import Enum
3
+ from typing import Optional
4
+
5
+
6
+ def parse_bool_env_value(value: Optional[str], default: bool = False) -> bool:
7
+ """Parse a boolean value from an environment variable string.
8
+
9
+ Args:
10
+ value: The environment variable value to parse (may be None).
11
+ default: The default value to return if the value is None or unrecognized.
12
+
13
+ Returns:
14
+ True if the value is a truthy string (true, 1, yes, on - case insensitive),
15
+ False if the value is a falsy string (false, 0, no, off - case insensitive),
16
+ or the default value if the value is None or unrecognized.
17
+ """
18
+ if value is None:
19
+ return default
20
+
21
+ normalized_value = value.strip().lower()
22
+ if normalized_value in ("true", "1", "yes", "on"):
23
+ return True
24
+ elif normalized_value in ("false", "0", "no", "off"):
25
+ return False
26
+ else:
27
+ # For unrecognized values, return the default
28
+ return default
3
29
 
4
30
 
5
31
  class FeatureFlags(Enum):
6
32
  USE_SUBMIT_JOB_V2 = "MLRS_USE_SUBMIT_JOB_V2"
7
- ENABLE_IMAGE_VERSION_ENV_VAR = "MLRS_ENABLE_RUNTIME_VERSIONS"
33
+ ENABLE_RUNTIME_VERSIONS = "MLRS_ENABLE_RUNTIME_VERSIONS"
34
+
35
+ def is_enabled(self, default: bool = False) -> bool:
36
+ """Check if the feature flag is enabled.
8
37
 
9
- def is_enabled(self) -> bool:
10
- return os.getenv(self.value, "false").lower() == "true"
38
+ Args:
39
+ default: The default value to return if the environment variable is not set.
11
40
 
12
- def is_disabled(self) -> bool:
13
- return not self.is_enabled()
41
+ Returns:
42
+ True if the environment variable is set to a truthy value,
43
+ False if set to a falsy value, or the default value if not set.
44
+ """
45
+ return parse_bool_env_value(os.getenv(self.value), default)
14
46
 
15
47
  def __str__(self) -> str:
16
48
  return self.value
@@ -268,7 +268,7 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp
268
268
  # can't handle directories. Reduce the number of PUT operations by using
269
269
  # wildcard patterns to batch upload files with the same extension.
270
270
  upload_path_patterns = set()
271
- for p in source_path.resolve().rglob("*"):
271
+ for p in source_path.rglob("*"):
272
272
  if p.is_dir():
273
273
  continue
274
274
  if p.name.startswith("."):
@@ -9,19 +9,23 @@ import runpy
9
9
  import sys
10
10
  import time
11
11
  import traceback
12
- import warnings
13
- from pathlib import Path
14
12
  from typing import Any, Optional
15
13
 
16
- import cloudpickle
17
-
18
- from snowflake.ml.jobs._utils import constants
19
- from snowflake.snowpark import Session
20
-
21
- try:
22
- from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
23
- except ImportError:
24
- from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
14
+ # Ensure payload directory is in sys.path for module imports before importing other modules
15
+ # This is needed to support relative imports in user scripts and to allow overriding
16
+ # modules using modules in the payload directory
17
+ # TODO: Inject the environment variable names at job submission time
18
+ STAGE_MOUNT_PATH = os.environ.get("MLRS_STAGE_MOUNT_PATH", "/mnt/job_stage")
19
+ JOB_RESULT_PATH = os.environ.get("MLRS_RESULT_PATH", "output/mljob_result.pkl")
20
+ PAYLOAD_PATH = os.environ.get("MLRS_PAYLOAD_DIR")
21
+ if PAYLOAD_PATH and not os.path.isabs(PAYLOAD_PATH):
22
+ PAYLOAD_PATH = os.path.join(STAGE_MOUNT_PATH, PAYLOAD_PATH)
23
+ if PAYLOAD_PATH and PAYLOAD_PATH not in sys.path:
24
+ sys.path.insert(0, PAYLOAD_PATH)
25
+
26
+ # Imports below must come after sys.path modification to support module overrides
27
+ import snowflake.ml.jobs._utils.constants # noqa: E402
28
+ import snowflake.snowpark # noqa: E402
25
29
 
26
30
  # Configure logging
27
31
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -33,48 +37,74 @@ logger = logging.getLogger(__name__)
33
37
  # not have the latest version of the code
34
38
  # Log start and end messages
35
39
  LOG_START_MSG = getattr(
36
- constants,
40
+ snowflake.ml.jobs._utils.constants,
37
41
  "LOG_START_MSG",
38
42
  "--------------------------------\nML job started\n--------------------------------",
39
43
  )
40
44
  LOG_END_MSG = getattr(
41
- constants,
45
+ snowflake.ml.jobs._utils.constants,
42
46
  "LOG_END_MSG",
43
47
  "--------------------------------\nML job finished\n--------------------------------",
44
48
  )
49
+ MIN_INSTANCES_ENV_VAR = getattr(
50
+ snowflake.ml.jobs._utils.constants,
51
+ "MIN_INSTANCES_ENV_VAR",
52
+ "MLRS_MIN_INSTANCES",
53
+ )
54
+ TARGET_INSTANCES_ENV_VAR = getattr(
55
+ snowflake.ml.jobs._utils.constants,
56
+ "TARGET_INSTANCES_ENV_VAR",
57
+ "SNOWFLAKE_JOBS_COUNT",
58
+ )
59
+ INSTANCES_MIN_WAIT_ENV_VAR = getattr(
60
+ snowflake.ml.jobs._utils.constants,
61
+ "INSTANCES_MIN_WAIT_ENV_VAR",
62
+ "MLRS_INSTANCES_MIN_WAIT",
63
+ )
64
+ INSTANCES_TIMEOUT_ENV_VAR = getattr(
65
+ snowflake.ml.jobs._utils.constants,
66
+ "INSTANCES_TIMEOUT_ENV_VAR",
67
+ "MLRS_INSTANCES_TIMEOUT",
68
+ )
69
+ INSTANCES_CHECK_INTERVAL_ENV_VAR = getattr(
70
+ snowflake.ml.jobs._utils.constants,
71
+ "INSTANCES_CHECK_INTERVAL_ENV_VAR",
72
+ "MLRS_INSTANCES_CHECK_INTERVAL",
73
+ )
45
74
 
46
- # min_instances environment variable name
47
- MIN_INSTANCES_ENV_VAR = getattr(constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
48
- TARGET_INSTANCES_ENV_VAR = getattr(constants, "TARGET_INSTANCES_ENV_VAR", "SNOWFLAKE_JOBS_COUNT")
49
-
50
- # Fallbacks in case of SnowML version mismatch
51
- STAGE_MOUNT_PATH_ENV_VAR = getattr(constants, "STAGE_MOUNT_PATH_ENV_VAR", "MLRS_STAGE_MOUNT_PATH")
52
- RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
53
- PAYLOAD_DIR_ENV_VAR = getattr(constants, "PAYLOAD_DIR_ENV_VAR", "MLRS_PAYLOAD_DIR")
54
75
 
55
76
  # Constants for the wait_for_instances function
56
- MIN_WAIT_TIME = float(os.getenv("MLRS_INSTANCES_MIN_WAIT") or -1) # seconds
57
- TIMEOUT = float(os.getenv("MLRS_INSTANCES_TIMEOUT") or 720) # seconds
58
- CHECK_INTERVAL = float(os.getenv("MLRS_INSTANCES_CHECK_INTERVAL") or 10) # seconds
77
+ MIN_INSTANCES = int(os.environ.get(MIN_INSTANCES_ENV_VAR) or "1")
78
+ TARGET_INSTANCES = int(os.environ.get(TARGET_INSTANCES_ENV_VAR) or MIN_INSTANCES)
79
+ MIN_WAIT_TIME = float(os.getenv(INSTANCES_MIN_WAIT_ENV_VAR) or -1) # seconds
80
+ TIMEOUT = float(os.getenv(INSTANCES_TIMEOUT_ENV_VAR) or 720) # seconds
81
+ CHECK_INTERVAL = float(os.getenv(INSTANCES_CHECK_INTERVAL_ENV_VAR) or 10) # seconds
59
82
 
60
- STAGE_MOUNT_PATH = os.environ.get(STAGE_MOUNT_PATH_ENV_VAR, "/mnt/job_stage")
61
- JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "output/mljob_result.pkl")
62
83
 
84
+ def save_mljob_result_v2(value: Any, is_error: bool, path: str) -> None:
85
+ from snowflake.ml.jobs._interop import (
86
+ results as interop_result,
87
+ utils as interop_utils,
88
+ )
89
+
90
+ result_obj = interop_result.ExecutionResult(success=not is_error, value=value)
91
+ interop_utils.save_result(result_obj, path)
63
92
 
64
- try:
65
- from snowflake.ml.jobs._utils.interop_utils import ExecutionResult
66
- except ImportError:
93
+
94
+ def save_mljob_result_v1(value: Any, is_error: bool, path: str) -> None:
67
95
  from dataclasses import dataclass
68
96
 
97
+ import cloudpickle
98
+
99
+ # Directly in-line the ExecutionResult class since the legacy type
100
+ # instead of attempting to import the to-be-deprecated
101
+ # snowflake.ml.jobs._utils.interop module
102
+ # Eventually, this entire function will be removed in favor of v2
69
103
  @dataclass(frozen=True)
70
- class ExecutionResult: # type: ignore[no-redef]
104
+ class ExecutionResult:
71
105
  result: Optional[Any] = None
72
106
  exception: Optional[BaseException] = None
73
107
 
74
- @property
75
- def success(self) -> bool:
76
- return self.exception is None
77
-
78
108
  def to_dict(self) -> dict[str, Any]:
79
109
  """Return the serializable dictionary."""
80
110
  if isinstance(self.exception, BaseException):
@@ -91,14 +121,45 @@ except ImportError:
91
121
  "result": self.result,
92
122
  }
93
123
 
124
+ # Create a custom JSON encoder that converts non-serializable types to strings
125
+ class SimpleJSONEncoder(json.JSONEncoder):
126
+ def default(self, obj: Any) -> Any:
127
+ try:
128
+ return super().default(obj)
129
+ except TypeError:
130
+ return f"Unserializable object: {repr(obj)}"
131
+
132
+ result_obj = ExecutionResult(result=None if is_error else value, exception=value if is_error else None)
133
+ result_dict = result_obj.to_dict()
134
+ try:
135
+ # Serialize result using cloudpickle
136
+ result_pickle_path = path
137
+ with open(result_pickle_path, "wb") as f:
138
+ cloudpickle.dump(result_dict, f) # Pickle dictionary form for compatibility
139
+ except Exception as pkl_exc:
140
+ logger.warning(f"Failed to pickle result to {result_pickle_path}: {pkl_exc}")
94
141
 
95
- # Create a custom JSON encoder that converts non-serializable types to strings
96
- class SimpleJSONEncoder(json.JSONEncoder):
97
- def default(self, obj: Any) -> Any:
98
- try:
99
- return super().default(obj)
100
- except TypeError:
101
- return f"Unserializable object: {repr(obj)}"
142
+ try:
143
+ # Serialize result to JSON as fallback path in case of cross version incompatibility
144
+ result_json_path = os.path.splitext(path)[0] + ".json"
145
+ with open(result_json_path, "w") as f:
146
+ json.dump(result_dict, f, indent=2, cls=SimpleJSONEncoder)
147
+ except Exception as json_exc:
148
+ logger.warning(f"Failed to serialize JSON result to {result_json_path}: {json_exc}")
149
+
150
+
151
+ def save_mljob_result(result_obj: Any, is_error: bool, path: str) -> None:
152
+ """Saves the result or error message to a file in the stage mount path.
153
+
154
+ Args:
155
+ result_obj: The result object to save, either the return value or the exception.
156
+ is_error: Whether the result_obj is a raised exception.
157
+ path: The file path to save the result to.
158
+ """
159
+ try:
160
+ save_mljob_result_v2(result_obj, is_error, path)
161
+ except ImportError:
162
+ save_mljob_result_v1(result_obj, is_error, path)
102
163
 
103
164
 
104
165
  def wait_for_instances(
@@ -225,20 +286,10 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
225
286
  original_argv = sys.argv
226
287
  sys.argv = [script_path, *script_args]
227
288
 
228
- # Ensure payload directory is in sys.path for module imports
229
- # This is needed because mljob_launcher.py is now in /mnt/job_stage/system
230
- # but user scripts are in the payload directory and may import from each other
231
- payload_dir = os.environ.get(PAYLOAD_DIR_ENV_VAR)
232
- if payload_dir and not os.path.isabs(payload_dir):
233
- payload_dir = os.path.join(STAGE_MOUNT_PATH, payload_dir)
234
- if payload_dir and payload_dir not in sys.path:
235
- sys.path.insert(0, payload_dir)
236
-
237
289
  try:
238
-
239
290
  if main_func:
240
291
  # Use importlib for scripts with a main function defined
241
- module_name = Path(script_path).stem
292
+ module_name = os.path.splitext(os.path.basename(script_path))[0]
242
293
  spec = importlib.util.spec_from_file_location(module_name, script_path)
243
294
  assert spec is not None
244
295
  assert spec.loader is not None
@@ -262,7 +313,7 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
262
313
  sys.argv = original_argv
263
314
 
264
315
 
265
- def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> ExecutionResult:
316
+ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> Any:
266
317
  """Executes a Python script and serializes the result to JOB_RESULT_PATH.
267
318
 
268
319
  Args:
@@ -271,55 +322,53 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
271
322
  script_main_func (str, optional): The name of the function to call in the script (if any).
272
323
 
273
324
  Returns:
274
- ExecutionResult: Object containing execution results.
325
+ Any: The result of the script execution.
275
326
 
276
327
  Raises:
277
328
  Exception: Re-raises any exception caught during script execution.
278
329
  """
279
- # Ensure the output directory exists before trying to write result files.
280
- result_abs_path = (
281
- JOB_RESULT_PATH if os.path.isabs(JOB_RESULT_PATH) else os.path.join(STAGE_MOUNT_PATH, JOB_RESULT_PATH)
282
- )
283
- output_dir = os.path.dirname(result_abs_path)
284
- os.makedirs(output_dir, exist_ok=True)
330
+ try:
331
+ from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
332
+ except ImportError:
333
+ from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
285
334
 
335
+ # Initialize Ray if available
286
336
  try:
287
337
  import ray
288
338
 
289
339
  ray.init(address="auto")
290
340
  except ModuleNotFoundError:
291
- warnings.warn("Ray is not installed, skipping Ray initialization", ImportWarning, stacklevel=1)
341
+ logger.debug("Ray is not installed, skipping Ray initialization")
292
342
 
293
343
  # Create a Snowpark session before starting
294
344
  # Session can be retrieved from using snowflake.snowpark.context.get_active_session()
295
345
  config = SnowflakeLoginOptions()
296
346
  config["client_session_keep_alive"] = "True"
297
- session = Session.builder.configs(config).create() # noqa: F841
347
+ session = snowflake.snowpark.Session.builder.configs(config).create() # noqa: F841
298
348
 
349
+ execution_result_is_error = False
350
+ execution_result_value = None
299
351
  try:
300
- # Wait for minimum required instances if specified
301
- min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
302
- target_instances_str = os.environ.get(TARGET_INSTANCES_ENV_VAR) or min_instances_str
303
- if target_instances_str and int(target_instances_str) > 1:
304
- wait_for_instances(
305
- int(min_instances_str),
306
- int(target_instances_str),
307
- min_wait_time=MIN_WAIT_TIME,
308
- timeout=TIMEOUT,
309
- check_interval=CHECK_INTERVAL,
310
- )
311
-
312
- # Log start marker for user script execution
352
+ # Wait for minimum required instances before starting user script execution
353
+ wait_for_instances(
354
+ MIN_INSTANCES,
355
+ TARGET_INSTANCES,
356
+ min_wait_time=MIN_WAIT_TIME,
357
+ timeout=TIMEOUT,
358
+ check_interval=CHECK_INTERVAL,
359
+ )
360
+
361
+ # Log start marker before starting user script execution
313
362
  print(LOG_START_MSG) # noqa: T201
314
363
 
315
- # Run the script with the specified arguments
316
- result = run_script(script_path, *script_args, main_func=script_main_func)
364
+ # Run the user script
365
+ execution_result_value = run_script(script_path, *script_args, main_func=script_main_func)
317
366
 
318
367
  # Log end marker for user script execution
319
368
  print(LOG_END_MSG) # noqa: T201
320
369
 
321
- result_obj = ExecutionResult(result=result)
322
- return result_obj
370
+ return execution_result_value
371
+
323
372
  except Exception as e:
324
373
  tb = e.__traceback__
325
374
  skip_files = {__file__, runpy.__file__}
@@ -328,35 +377,23 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
328
377
  tb = tb.tb_next
329
378
  cleaned_ex = copy.copy(e) # Need to create a mutable copy of exception to set __traceback__
330
379
  cleaned_ex = cleaned_ex.with_traceback(tb)
331
- result_obj = ExecutionResult(exception=cleaned_ex)
380
+ execution_result_value = cleaned_ex
381
+ execution_result_is_error = True
332
382
  raise
333
383
  finally:
334
- result_dict = result_obj.to_dict()
335
- try:
336
- # Serialize result using cloudpickle
337
- result_pickle_path = result_abs_path
338
- with open(result_pickle_path, "wb") as f:
339
- cloudpickle.dump(result_dict, f) # Pickle dictionary form for compatibility
340
- except Exception as pkl_exc:
341
- warnings.warn(f"Failed to pickle result to {result_pickle_path}: {pkl_exc}", RuntimeWarning, stacklevel=1)
342
-
343
- try:
344
- # Serialize result to JSON as fallback path in case of cross version incompatibility
345
- # TODO: Manually convert non-serializable types to strings
346
- result_json_path = os.path.splitext(result_abs_path)[0] + ".json"
347
- with open(result_json_path, "w") as f:
348
- json.dump(result_dict, f, indent=2, cls=SimpleJSONEncoder)
349
- except Exception as json_exc:
350
- warnings.warn(
351
- f"Failed to serialize JSON result to {result_json_path}: {json_exc}", RuntimeWarning, stacklevel=1
352
- )
353
-
354
- # Close the session after serializing the result
384
+ # Ensure the output directory exists before trying to write result files.
385
+ result_abs_path = (
386
+ JOB_RESULT_PATH if os.path.isabs(JOB_RESULT_PATH) else os.path.join(STAGE_MOUNT_PATH, JOB_RESULT_PATH)
387
+ )
388
+ output_dir = os.path.dirname(result_abs_path)
389
+ os.makedirs(output_dir, exist_ok=True)
390
+
391
+ # Save the result before closing the session
392
+ save_mljob_result(execution_result_value, execution_result_is_error, result_abs_path)
355
393
  session.close()
356
394
 
357
395
 
358
396
  if __name__ == "__main__":
359
- # Parse command line arguments
360
397
  parser = argparse.ArgumentParser(description="Launch a Python script and save the result")
361
398
  parser.add_argument("script_path", help="Path to the Python script to execute")
362
399
  parser.add_argument("script_args", nargs="*", help="Arguments to pass to the script")
@@ -104,7 +104,7 @@ def _get_image_spec(
104
104
  image_tag = runtime_environment
105
105
  else:
106
106
  container_image = runtime_environment
107
- elif feature_flags.FeatureFlags.ENABLE_IMAGE_VERSION_ENV_VAR.is_enabled():
107
+ elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
108
108
  container_image = _get_runtime_image(session, hardware) # type: ignore[arg-type]
109
109
 
110
110
  container_image = container_image or f"{image_repo}/{image_name}:{image_tag}"
@@ -266,6 +266,7 @@ def generate_service_spec(
266
266
  {"name": "ray-client-server-endpoint", "port": 10001, "protocol": "TCP"},
267
267
  {"name": "ray-gcs-endpoint", "port": 12001, "protocol": "TCP"},
268
268
  {"name": "ray-dashboard-grpc-endpoint", "port": 12002, "protocol": "TCP"},
269
+ {"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"},
269
270
  {"name": "ray-object-manager-endpoint", "port": 12011, "protocol": "TCP"},
270
271
  {"name": "ray-node-manager-endpoint", "port": 12012, "protocol": "TCP"},
271
272
  {"name": "ray-runtime-agent-endpoint", "port": 12013, "protocol": "TCP"},
@@ -11,6 +11,7 @@ JOB_STATUS = Literal[
11
11
  "CANCELLING",
12
12
  "CANCELLED",
13
13
  "INTERNAL_ERROR",
14
+ "DELETED",
14
15
  ]
15
16
 
16
17
 
@@ -106,3 +107,12 @@ class ImageSpec:
106
107
  resource_requests: ComputeResources
107
108
  resource_limits: ComputeResources
108
109
  container_image: str
110
+
111
+
112
+ @dataclass(frozen=True)
113
+ class ServiceInfo:
114
+ database_name: str
115
+ schema_name: str
116
+ status: str
117
+ compute_pool: str
118
+ target_instances: int