snowflake-ml-python 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
- snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
- snowflake/ml/jobs/__init__.py +4 -0
- snowflake/ml/jobs/_interop/__init__.py +0 -0
- snowflake/ml/jobs/_interop/data_utils.py +124 -0
- snowflake/ml/jobs/_interop/dto_schema.py +95 -0
- snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
- snowflake/ml/jobs/_interop/legacy.py +225 -0
- snowflake/ml/jobs/_interop/protocols.py +471 -0
- snowflake/ml/jobs/_interop/results.py +51 -0
- snowflake/ml/jobs/_interop/utils.py +144 -0
- snowflake/ml/jobs/_utils/constants.py +4 -1
- snowflake/ml/jobs/_utils/feature_flags.py +37 -5
- snowflake/ml/jobs/_utils/payload_utils.py +1 -1
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
- snowflake/ml/jobs/_utils/spec_utils.py +2 -1
- snowflake/ml/jobs/_utils/types.py +10 -0
- snowflake/ml/jobs/job.py +168 -36
- snowflake/ml/jobs/manager.py +36 -38
- snowflake/ml/model/_client/model/model_version_impl.py +39 -7
- snowflake/ml/model/_client/ops/model_ops.py +4 -0
- snowflake/ml/model/_client/sql/model_version.py +3 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +7 -2
- snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
- snowflake/ml/model/_packager/model_env/model_env.py +22 -5
- snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/METADATA +26 -4
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/RECORD +35 -27
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
aerial
|
|
1
2
|
afraid
|
|
2
3
|
ancient
|
|
3
4
|
angry
|
|
@@ -26,7 +27,6 @@ dull
|
|
|
26
27
|
empty
|
|
27
28
|
evil
|
|
28
29
|
fast
|
|
29
|
-
fat
|
|
30
30
|
fluffy
|
|
31
31
|
foolish
|
|
32
32
|
fresh
|
|
@@ -57,10 +57,10 @@ lovely
|
|
|
57
57
|
lucky
|
|
58
58
|
massive
|
|
59
59
|
mean
|
|
60
|
+
metallic
|
|
60
61
|
mighty
|
|
61
62
|
modern
|
|
62
63
|
moody
|
|
63
|
-
nasty
|
|
64
64
|
neat
|
|
65
65
|
nervous
|
|
66
66
|
new
|
|
@@ -85,7 +85,6 @@ rotten
|
|
|
85
85
|
rude
|
|
86
86
|
selfish
|
|
87
87
|
serious
|
|
88
|
-
shaggy
|
|
89
88
|
sharp
|
|
90
89
|
short
|
|
91
90
|
shy
|
|
@@ -96,14 +95,15 @@ slippery
|
|
|
96
95
|
smart
|
|
97
96
|
smooth
|
|
98
97
|
soft
|
|
98
|
+
solid
|
|
99
99
|
sour
|
|
100
100
|
spicy
|
|
101
101
|
splendid
|
|
102
102
|
spotty
|
|
103
|
+
squishy
|
|
103
104
|
stale
|
|
104
105
|
strange
|
|
105
106
|
strong
|
|
106
|
-
stupid
|
|
107
107
|
sweet
|
|
108
108
|
swift
|
|
109
109
|
tall
|
|
@@ -116,7 +116,6 @@ tidy
|
|
|
116
116
|
tiny
|
|
117
117
|
tough
|
|
118
118
|
tricky
|
|
119
|
-
ugly
|
|
120
119
|
warm
|
|
121
120
|
weak
|
|
122
121
|
wet
|
|
@@ -124,5 +123,6 @@ wicked
|
|
|
124
123
|
wise
|
|
125
124
|
witty
|
|
126
125
|
wonderful
|
|
126
|
+
wooden
|
|
127
127
|
yellow
|
|
128
128
|
young
|
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
anaconda
|
|
2
2
|
ant
|
|
3
|
-
ape
|
|
4
|
-
baboon
|
|
5
3
|
badger
|
|
6
4
|
bat
|
|
7
5
|
bear
|
|
6
|
+
beetle
|
|
8
7
|
bird
|
|
9
8
|
bobcat
|
|
10
9
|
bulldog
|
|
@@ -73,7 +72,6 @@ lobster
|
|
|
73
72
|
mayfly
|
|
74
73
|
mamba
|
|
75
74
|
mole
|
|
76
|
-
monkey
|
|
77
75
|
moose
|
|
78
76
|
moth
|
|
79
77
|
mouse
|
|
@@ -114,6 +112,7 @@ swan
|
|
|
114
112
|
termite
|
|
115
113
|
tiger
|
|
116
114
|
treefrog
|
|
115
|
+
tuna
|
|
117
116
|
turkey
|
|
118
117
|
turtle
|
|
119
118
|
vampirebat
|
|
@@ -126,3 +125,4 @@ worm
|
|
|
126
125
|
yak
|
|
127
126
|
yeti
|
|
128
127
|
zebra
|
|
128
|
+
zebrafish
|
snowflake/ml/jobs/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from snowflake.ml.jobs._interop.exception_utils import install_exception_display_hooks
|
|
1
2
|
from snowflake.ml.jobs._utils.types import JOB_STATUS
|
|
2
3
|
from snowflake.ml.jobs.decorators import remote
|
|
3
4
|
from snowflake.ml.jobs.job import MLJob
|
|
@@ -10,6 +11,9 @@ from snowflake.ml.jobs.manager import (
|
|
|
10
11
|
submit_from_stage,
|
|
11
12
|
)
|
|
12
13
|
|
|
14
|
+
# Initialize exception display hooks for remote job error handling
|
|
15
|
+
install_exception_display_hooks()
|
|
16
|
+
|
|
13
17
|
__all__ = [
|
|
14
18
|
"remote",
|
|
15
19
|
"submit_file",
|
|
File without changes
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import json
|
|
3
|
+
from typing import Any, Literal, Optional, Protocol, Union, cast, overload
|
|
4
|
+
|
|
5
|
+
from snowflake import snowpark
|
|
6
|
+
from snowflake.ml.jobs._interop import dto_schema
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class StageFileWriter(io.IOBase):
|
|
10
|
+
"""
|
|
11
|
+
A context manager IOBase implementation that proxies writes to an internal BytesIO
|
|
12
|
+
and uploads to Snowflake stage on close.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, session: snowpark.Session, path: str) -> None:
|
|
16
|
+
self._session = session
|
|
17
|
+
self._path = path
|
|
18
|
+
self._buffer = io.BytesIO()
|
|
19
|
+
self._closed = False
|
|
20
|
+
self._exception_occurred = False
|
|
21
|
+
|
|
22
|
+
def write(self, data: Union[bytes, bytearray]) -> int:
|
|
23
|
+
"""Write data to the internal buffer."""
|
|
24
|
+
if self._closed:
|
|
25
|
+
raise ValueError("I/O operation on closed file")
|
|
26
|
+
return self._buffer.write(data)
|
|
27
|
+
|
|
28
|
+
def close(self, write_contents: bool = True) -> None:
|
|
29
|
+
"""Close the file and upload the buffer contents to the stage."""
|
|
30
|
+
if not self._closed:
|
|
31
|
+
# Only upload if buffer has content and no exception occurred
|
|
32
|
+
if write_contents and self._buffer.tell() > 0:
|
|
33
|
+
self._buffer.seek(0)
|
|
34
|
+
self._session.file.put_stream(self._buffer, self._path)
|
|
35
|
+
self._buffer.close()
|
|
36
|
+
self._closed = True
|
|
37
|
+
|
|
38
|
+
def __enter__(self) -> "StageFileWriter":
|
|
39
|
+
return self
|
|
40
|
+
|
|
41
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
42
|
+
exception_occurred = exc_type is not None
|
|
43
|
+
self.close(write_contents=not exception_occurred)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def closed(self) -> bool:
|
|
47
|
+
return self._closed
|
|
48
|
+
|
|
49
|
+
def writable(self) -> bool:
|
|
50
|
+
return not self._closed
|
|
51
|
+
|
|
52
|
+
def readable(self) -> bool:
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
def seekable(self) -> bool:
|
|
56
|
+
return not self._closed
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _is_stage_path(path: str) -> bool:
|
|
60
|
+
return path.startswith("@") or path.startswith("snow://")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def open_stream(path: str, mode: str = "rb", session: Optional[snowpark.Session] = None) -> io.IOBase:
|
|
64
|
+
if _is_stage_path(path):
|
|
65
|
+
if session is None:
|
|
66
|
+
raise ValueError("Session is required when opening a stage path")
|
|
67
|
+
if "r" in mode:
|
|
68
|
+
stream: io.IOBase = session.file.get_stream(path) # type: ignore[assignment]
|
|
69
|
+
return stream
|
|
70
|
+
elif "w" in mode:
|
|
71
|
+
return StageFileWriter(session, path)
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError(f"Unsupported mode '{mode}' for stage path")
|
|
74
|
+
else:
|
|
75
|
+
result: io.IOBase = open(path, mode) # type: ignore[assignment]
|
|
76
|
+
return result
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class DtoCodec(Protocol):
|
|
80
|
+
@overload
|
|
81
|
+
@staticmethod
|
|
82
|
+
def decode(stream: io.IOBase, as_dict: Literal[True]) -> dict[str, Any]:
|
|
83
|
+
...
|
|
84
|
+
|
|
85
|
+
@overload
|
|
86
|
+
@staticmethod
|
|
87
|
+
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO:
|
|
88
|
+
...
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def encode(dto: dto_schema.ResultDTO) -> bytes:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class JsonDtoCodec(DtoCodec):
|
|
100
|
+
@overload
|
|
101
|
+
@staticmethod
|
|
102
|
+
def decode(stream: io.IOBase, as_dict: Literal[True]) -> dict[str, Any]:
|
|
103
|
+
...
|
|
104
|
+
|
|
105
|
+
@overload
|
|
106
|
+
@staticmethod
|
|
107
|
+
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO:
|
|
108
|
+
...
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]:
|
|
112
|
+
data = cast(dict[str, Any], json.load(stream))
|
|
113
|
+
if as_dict:
|
|
114
|
+
return data
|
|
115
|
+
return dto_schema.ResultDTO.model_validate(data)
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def encode(dto: dto_schema.ResultDTO) -> bytes:
|
|
119
|
+
# Temporarily extract the value to avoid accidentally applying model_dump() on it
|
|
120
|
+
result_value = dto.value
|
|
121
|
+
dto.value = None # Clear value to avoid serializing it in the model_dump
|
|
122
|
+
result_dict = dto.model_dump()
|
|
123
|
+
result_dict["value"] = result_value # Put back the value
|
|
124
|
+
return json.dumps(result_dict).encode("utf-8")
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, model_validator
|
|
4
|
+
from typing_extensions import NotRequired, TypedDict
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BinaryManifest(TypedDict):
|
|
8
|
+
"""
|
|
9
|
+
Binary data manifest schema.
|
|
10
|
+
Contains one of: path, bytes, or base64 for the serialized data.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
path: NotRequired[str] # Path to file
|
|
14
|
+
bytes: NotRequired[bytes] # In-line byte string (not supported with JSON codec)
|
|
15
|
+
base64: NotRequired[str] # Base64 encoded string
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ParquetManifest(TypedDict):
|
|
19
|
+
"""Protocol manifest schema for parquet files."""
|
|
20
|
+
|
|
21
|
+
paths: list[str] # File paths
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Union type for all manifest types, including catch-all dict[str, Any] for backward compatibility
|
|
25
|
+
PayloadManifest = Union[BinaryManifest, ParquetManifest, dict[str, Any]]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ProtocolInfo(BaseModel):
|
|
29
|
+
"""
|
|
30
|
+
The protocol used to serialize the result and the manifest of the result.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
name: str
|
|
34
|
+
version: Optional[str] = None
|
|
35
|
+
metadata: Optional[dict[str, str]] = None
|
|
36
|
+
manifest: Optional[PayloadManifest] = None
|
|
37
|
+
|
|
38
|
+
def __str__(self) -> str:
|
|
39
|
+
result = self.name
|
|
40
|
+
if self.version:
|
|
41
|
+
result += f"-{self.version}"
|
|
42
|
+
return result
|
|
43
|
+
|
|
44
|
+
def with_manifest(self, manifest: PayloadManifest) -> "ProtocolInfo":
|
|
45
|
+
"""
|
|
46
|
+
Return a new ProtocolInfo object with the manifest.
|
|
47
|
+
"""
|
|
48
|
+
return ProtocolInfo(
|
|
49
|
+
name=self.name,
|
|
50
|
+
version=self.version,
|
|
51
|
+
metadata=self.metadata,
|
|
52
|
+
manifest=manifest,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ResultMetadata(BaseModel):
|
|
57
|
+
"""
|
|
58
|
+
The metadata of a result.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
type: str
|
|
62
|
+
repr: str
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ExceptionMetadata(ResultMetadata):
|
|
66
|
+
message: str
|
|
67
|
+
traceback: str
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ResultDTO(BaseModel):
|
|
71
|
+
"""
|
|
72
|
+
A JSON representation of an execution result.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
success: Whether the execution was successful.
|
|
76
|
+
value: The value of the execution or the exception if the execution failed.
|
|
77
|
+
protocol: The protocol used to serialize the result.
|
|
78
|
+
metadata: The metadata of the result.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
success: bool
|
|
82
|
+
value: Optional[Any] = None
|
|
83
|
+
protocol: Optional[ProtocolInfo] = None
|
|
84
|
+
metadata: Optional[Union[ResultMetadata, ExceptionMetadata]] = None
|
|
85
|
+
serialize_error: Optional[str] = None
|
|
86
|
+
|
|
87
|
+
@model_validator(mode="before")
|
|
88
|
+
@classmethod
|
|
89
|
+
def validate_fields(cls, data: Any) -> Any:
|
|
90
|
+
"""Ensure at least one of value, protocol, or metadata keys is specified."""
|
|
91
|
+
if isinstance(data, dict):
|
|
92
|
+
required_fields = {"value", "protocol", "metadata"}
|
|
93
|
+
if not any(field in data for field in required_fields):
|
|
94
|
+
raise ValueError("At least one of 'value', 'protocol', or 'metadata' must be specified")
|
|
95
|
+
return data
|
|
@@ -1,19 +1,12 @@
|
|
|
1
1
|
import builtins
|
|
2
2
|
import functools
|
|
3
3
|
import importlib
|
|
4
|
-
import json
|
|
5
|
-
import os
|
|
6
|
-
import pickle
|
|
7
4
|
import re
|
|
8
5
|
import sys
|
|
9
6
|
import traceback
|
|
10
7
|
from collections import namedtuple
|
|
11
|
-
from dataclasses import dataclass
|
|
12
8
|
from types import TracebackType
|
|
13
|
-
from typing import Any, Callable, Optional,
|
|
14
|
-
|
|
15
|
-
from snowflake import snowpark
|
|
16
|
-
from snowflake.snowpark import exceptions as sp_exceptions
|
|
9
|
+
from typing import Any, Callable, Optional, cast
|
|
17
10
|
|
|
18
11
|
_TRACEBACK_ENTRY_PATTERN = re.compile(
|
|
19
12
|
r'File "(?P<filename>[^"]+)", line (?P<lineno>\d+), in (?P<name>[^\n]+)(?:\n(?!^\s*File)^\s*(?P<line>[^\n]+))?\n',
|
|
@@ -21,175 +14,46 @@ _TRACEBACK_ENTRY_PATTERN = re.compile(
|
|
|
21
14
|
)
|
|
22
15
|
_REMOTE_ERROR_ATTR_NAME = "_remote_error"
|
|
23
16
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
@dataclass(frozen=True)
|
|
28
|
-
class ExecutionResult:
|
|
29
|
-
result: Any = None
|
|
30
|
-
exception: Optional[BaseException] = None
|
|
31
|
-
|
|
32
|
-
@property
|
|
33
|
-
def success(self) -> bool:
|
|
34
|
-
return self.exception is None
|
|
35
|
-
|
|
36
|
-
def to_dict(self) -> dict[str, Any]:
|
|
37
|
-
"""Return the serializable dictionary."""
|
|
38
|
-
if isinstance(self.exception, BaseException):
|
|
39
|
-
exc_type = type(self.exception)
|
|
40
|
-
return {
|
|
41
|
-
"success": False,
|
|
42
|
-
"exc_type": f"{exc_type.__module__}.{exc_type.__name__}",
|
|
43
|
-
"exc_value": self.exception,
|
|
44
|
-
"exc_tb": "".join(traceback.format_tb(self.exception.__traceback__)),
|
|
45
|
-
}
|
|
46
|
-
return {
|
|
47
|
-
"success": True,
|
|
48
|
-
"result_type": type(self.result).__qualname__,
|
|
49
|
-
"result": self.result,
|
|
50
|
-
}
|
|
51
|
-
|
|
52
|
-
@classmethod
|
|
53
|
-
def from_dict(cls, result_dict: dict[str, Any]) -> "ExecutionResult":
|
|
54
|
-
if not isinstance(result_dict.get("success"), bool):
|
|
55
|
-
raise ValueError("Invalid result dictionary")
|
|
56
|
-
|
|
57
|
-
if result_dict["success"]:
|
|
58
|
-
# Load successful result
|
|
59
|
-
return cls(result=result_dict.get("result"))
|
|
60
|
-
|
|
61
|
-
# Load exception
|
|
62
|
-
exc_type = result_dict.get("exc_type", "RuntimeError")
|
|
63
|
-
exc_value = result_dict.get("exc_value", "Unknown error")
|
|
64
|
-
exc_tb = result_dict.get("exc_tb", "")
|
|
65
|
-
return cls(exception=load_exception(exc_type, exc_value, exc_tb))
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult:
|
|
69
|
-
"""
|
|
70
|
-
Fetch the serialized result from the specified path.
|
|
17
|
+
RemoteErrorInfo = namedtuple("RemoteErrorInfo", ["exc_type", "exc_msg", "exc_tb"])
|
|
71
18
|
|
|
72
|
-
Args:
|
|
73
|
-
session: Snowpark Session to use for file operations.
|
|
74
|
-
result_path: The path to the serialized result file.
|
|
75
19
|
|
|
76
|
-
|
|
77
|
-
|
|
20
|
+
class RemoteError(RuntimeError):
|
|
21
|
+
"""Base exception for errors from remote execution environment which could not be reconstructed locally."""
|
|
78
22
|
|
|
79
|
-
Raises:
|
|
80
|
-
RuntimeError: If both pickle and JSON result retrieval fail.
|
|
81
|
-
"""
|
|
82
|
-
try:
|
|
83
|
-
# TODO: Check if file exists
|
|
84
|
-
with session.file.get_stream(result_path) as result_stream:
|
|
85
|
-
return ExecutionResult.from_dict(pickle.load(result_stream))
|
|
86
|
-
except (
|
|
87
|
-
sp_exceptions.SnowparkSQLException,
|
|
88
|
-
pickle.UnpicklingError,
|
|
89
|
-
TypeError,
|
|
90
|
-
ImportError,
|
|
91
|
-
AttributeError,
|
|
92
|
-
MemoryError,
|
|
93
|
-
) as pickle_error:
|
|
94
|
-
# Fall back to JSON result if loading pickled result fails for any reason
|
|
95
|
-
try:
|
|
96
|
-
result_json_path = os.path.splitext(result_path)[0] + ".json"
|
|
97
|
-
with session.file.get_stream(result_json_path) as result_stream:
|
|
98
|
-
return ExecutionResult.from_dict(json.load(result_stream))
|
|
99
|
-
except Exception as json_error:
|
|
100
|
-
# Both pickle and JSON failed - provide helpful error message
|
|
101
|
-
raise RuntimeError(_fetch_result_error_message(pickle_error, result_path, json_error)) from pickle_error
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def _fetch_result_error_message(error: Exception, result_path: str, json_error: Optional[Exception] = None) -> str:
|
|
105
|
-
"""Create helpful error messages for common result retrieval failures."""
|
|
106
|
-
|
|
107
|
-
# Package import issues
|
|
108
|
-
if isinstance(error, ImportError):
|
|
109
|
-
return f"Failed to retrieve job result: Package not installed in your local environment. Error: {str(error)}"
|
|
110
|
-
|
|
111
|
-
# Package versions differ between runtime and local environment
|
|
112
|
-
if isinstance(error, AttributeError):
|
|
113
|
-
return f"Failed to retrieve job result: Package version mismatch. Error: {str(error)}"
|
|
114
|
-
|
|
115
|
-
# Serialization issues
|
|
116
|
-
if isinstance(error, TypeError):
|
|
117
|
-
return f"Failed to retrieve job result: Non-serializable objects were returned. Error: {str(error)}"
|
|
118
|
-
|
|
119
|
-
# Python version pickling incompatibility
|
|
120
|
-
if isinstance(error, pickle.UnpicklingError) and "protocol" in str(error).lower():
|
|
121
|
-
# TODO: Update this once we support different Python versions
|
|
122
|
-
client_version = f"Python {sys.version_info.major}.{sys.version_info.minor}"
|
|
123
|
-
runtime_version = "Python 3.10"
|
|
124
|
-
return (
|
|
125
|
-
f"Failed to retrieve job result: Python version mismatch - job ran on {runtime_version}, "
|
|
126
|
-
f"local environment using Python {client_version}. Error: {str(error)}"
|
|
127
|
-
)
|
|
128
23
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
return f"Failed to retrieve job result: Result too large for memory. Error: {str(error)}"
|
|
141
|
-
|
|
142
|
-
# Generic fallback
|
|
143
|
-
base_message = f"Failed to retrieve job result: {str(error)}"
|
|
144
|
-
if json_error:
|
|
145
|
-
base_message += f" (JSON fallback also failed: {str(json_error)})"
|
|
146
|
-
return base_message
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
def load_exception(exc_type_name: str, exc_value: Union[Exception, str], exc_tb: str) -> Exception:
|
|
150
|
-
"""
|
|
151
|
-
Create an exception with a string-formatted traceback.
|
|
152
|
-
|
|
153
|
-
When this exception is raised and not caught, it will display the original traceback.
|
|
154
|
-
When caught, it behaves like a regular exception without showing the traceback.
|
|
155
|
-
|
|
156
|
-
Args:
|
|
157
|
-
exc_type_name: Name of the exception type (e.g., 'ValueError', 'RuntimeError')
|
|
158
|
-
exc_value: The deserialized exception value or exception string (i.e. message)
|
|
159
|
-
exc_tb: String representation of the traceback
|
|
24
|
+
def build_exception(type_str: str, message: str, traceback: str, original_repr: Optional[str] = None) -> BaseException:
|
|
25
|
+
"""Build an exception from metadata, attaching remote error info."""
|
|
26
|
+
if not original_repr:
|
|
27
|
+
original_repr = f"{type_str}('{message}')"
|
|
28
|
+
try:
|
|
29
|
+
ex = reconstruct_exception(type_str=type_str, message=message)
|
|
30
|
+
except Exception as e:
|
|
31
|
+
# Fallback to a generic error type if reconstruction fails
|
|
32
|
+
ex = RemoteError(original_repr)
|
|
33
|
+
ex.__cause__ = e
|
|
34
|
+
return attach_remote_error_info(ex, type_str, message, traceback)
|
|
160
35
|
|
|
161
|
-
Returns:
|
|
162
|
-
An exception object with the original traceback information
|
|
163
36
|
|
|
164
|
-
|
|
165
|
-
"""
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
exc_type = getattr(module, class_name)
|
|
178
|
-
if exc_type is None or not issubclass(exc_type, Exception):
|
|
179
|
-
raise TypeError(f"{exc_type_name} is not a known exception type")
|
|
180
|
-
# Create the exception instance
|
|
181
|
-
exception = exc_type(exc_value)
|
|
182
|
-
except (ImportError, AttributeError, TypeError):
|
|
183
|
-
# Fall back to a generic exception
|
|
184
|
-
exception = RuntimeError(
|
|
185
|
-
f"Exception deserialization failed, original exception: {exc_type_name}: {exc_value}"
|
|
186
|
-
)
|
|
37
|
+
def reconstruct_exception(type_str: str, message: str) -> BaseException:
|
|
38
|
+
"""Best effort reconstruction of an exception from metadata."""
|
|
39
|
+
try:
|
|
40
|
+
type_split = type_str.rsplit(".", 1)
|
|
41
|
+
if len(type_split) == 1:
|
|
42
|
+
module = builtins
|
|
43
|
+
else:
|
|
44
|
+
module = importlib.import_module(type_split[0])
|
|
45
|
+
exc_type = getattr(module, type_split[-1])
|
|
46
|
+
except (ImportError, AttributeError):
|
|
47
|
+
raise ModuleNotFoundError(
|
|
48
|
+
f"Unrecognized exception type '{type_str}', likely due to a missing or unavailable package"
|
|
49
|
+
) from None
|
|
187
50
|
|
|
188
|
-
|
|
189
|
-
|
|
51
|
+
if not issubclass(exc_type, BaseException):
|
|
52
|
+
raise TypeError(f"Imported type {type_str} is not a known exception type, possibly due to a name conflict")
|
|
53
|
+
return cast(BaseException, exc_type(message))
|
|
190
54
|
|
|
191
55
|
|
|
192
|
-
def
|
|
56
|
+
def attach_remote_error_info(ex: BaseException, exc_type: str, exc_msg: str, traceback_str: str) -> BaseException:
|
|
193
57
|
"""
|
|
194
58
|
Attach a string-formatted traceback to an exception.
|
|
195
59
|
|
|
@@ -207,11 +71,11 @@ def _attach_remote_error_info(ex: Exception, exc_type: str, exc_msg: str, traceb
|
|
|
207
71
|
"""
|
|
208
72
|
# Store the traceback information
|
|
209
73
|
exc_type = exc_type.rsplit(".", 1)[-1] # Remove module path
|
|
210
|
-
setattr(ex, _REMOTE_ERROR_ATTR_NAME,
|
|
74
|
+
setattr(ex, _REMOTE_ERROR_ATTR_NAME, RemoteErrorInfo(exc_type=exc_type, exc_msg=exc_msg, exc_tb=traceback_str))
|
|
211
75
|
return ex
|
|
212
76
|
|
|
213
77
|
|
|
214
|
-
def
|
|
78
|
+
def retrieve_remote_error_info(ex: Optional[BaseException]) -> Optional[RemoteErrorInfo]:
|
|
215
79
|
"""
|
|
216
80
|
Retrieve the string-formatted traceback from an exception if it exists.
|
|
217
81
|
|
|
@@ -285,7 +149,7 @@ def _install_sys_excepthook() -> None:
|
|
|
285
149
|
sys.excepthook is the global hook that Python calls when an unhandled exception occurs.
|
|
286
150
|
By default it prints the exception type, message and traceback to stderr.
|
|
287
151
|
|
|
288
|
-
We override sys.excepthook to intercept exceptions that contain our special
|
|
152
|
+
We override sys.excepthook to intercept exceptions that contain our special RemoteErrorInfo
|
|
289
153
|
attribute. These exceptions come from deserialized remote execution results and contain
|
|
290
154
|
the original traceback information from where they occurred.
|
|
291
155
|
|
|
@@ -327,7 +191,7 @@ def _install_sys_excepthook() -> None:
|
|
|
327
191
|
"\nDuring handling of the above exception, another exception occurred:\n", file=sys.stderr
|
|
328
192
|
)
|
|
329
193
|
|
|
330
|
-
if (remote_err :=
|
|
194
|
+
if (remote_err := retrieve_remote_error_info(exc_value)) and isinstance(remote_err, RemoteErrorInfo):
|
|
331
195
|
# Display stored traceback for deserialized exceptions
|
|
332
196
|
print("Traceback (from remote execution):", file=sys.stderr) # noqa: T201
|
|
333
197
|
print(remote_err.exc_tb, end="", file=sys.stderr) # noqa: T201
|
|
@@ -408,7 +272,7 @@ def _install_ipython_hook() -> bool:
|
|
|
408
272
|
tb_offset: Optional[int],
|
|
409
273
|
**kwargs: Any,
|
|
410
274
|
) -> list[list[str]]:
|
|
411
|
-
if (remote_err :=
|
|
275
|
+
if (remote_err := retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteErrorInfo):
|
|
412
276
|
# Implementation forked from IPython.core.ultratb.VerboseTB.format_exception_as_a_whole
|
|
413
277
|
head = self.prepare_header(remote_err.exc_type, long_version=False).replace(
|
|
414
278
|
"(most recent call last)",
|
|
@@ -448,7 +312,7 @@ def _install_ipython_hook() -> bool:
|
|
|
448
312
|
tb_offset: Optional[int] = None,
|
|
449
313
|
**kwargs: Any,
|
|
450
314
|
) -> list[str]:
|
|
451
|
-
if (remote_err :=
|
|
315
|
+
if (remote_err := retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteErrorInfo):
|
|
452
316
|
tb_list = [
|
|
453
317
|
(m.group("filename"), m.group("lineno"), m.group("name"), m.group("line"))
|
|
454
318
|
for m in re.finditer(_TRACEBACK_ENTRY_PATTERN, remote_err.exc_tb or "")
|
|
@@ -493,9 +357,16 @@ def _uninstall_ipython_hook() -> None:
|
|
|
493
357
|
|
|
494
358
|
|
|
495
359
|
def install_exception_display_hooks() -> None:
|
|
496
|
-
|
|
497
|
-
_install_sys_excepthook()
|
|
360
|
+
"""Install custom exception display hooks for improved remote error reporting.
|
|
498
361
|
|
|
362
|
+
This function should be called once during package initialization to set up
|
|
363
|
+
enhanced error handling for remote job execution errors. The hooks will:
|
|
499
364
|
|
|
500
|
-
|
|
501
|
-
|
|
365
|
+
- Display original remote tracebacks instead of local deserialization traces
|
|
366
|
+
- Work in both standard Python and IPython/Jupyter environments
|
|
367
|
+
- Safely fall back to original behavior if errors occur
|
|
368
|
+
|
|
369
|
+
Note: This function is idempotent and safe to call multiple times.
|
|
370
|
+
"""
|
|
371
|
+
if not _install_ipython_hook():
|
|
372
|
+
_install_sys_excepthook()
|