snowflake-ml-python 1.25.0__py3-none-any.whl → 1.25.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/mixins.py +1 -26
- snowflake/ml/jobs/_interop/data_utils.py +8 -8
- snowflake/ml/jobs/_interop/dto_schema.py +7 -52
- snowflake/ml/jobs/_interop/protocols.py +7 -124
- snowflake/ml/jobs/_interop/utils.py +33 -92
- snowflake/ml/jobs/_utils/constants.py +0 -4
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/payload_utils.py +40 -6
- snowflake/ml/jobs/_utils/runtime_env_utils.py +111 -12
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +27 -204
- snowflake/ml/jobs/_utils/spec_utils.py +22 -0
- snowflake/ml/jobs/decorators.py +22 -17
- snowflake/ml/jobs/job.py +10 -25
- snowflake/ml/jobs/job_definition.py +4 -90
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.25.0.dist-info → snowflake_ml_python-1.25.1.dist-info}/METADATA +7 -1
- {snowflake_ml_python-1.25.0.dist-info → snowflake_ml_python-1.25.1.dist-info}/RECORD +20 -19
- snowflake/ml/jobs/_utils/arg_protocol.py +0 -7
- {snowflake_ml_python-1.25.0.dist-info → snowflake_ml_python-1.25.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.25.0.dist-info → snowflake_ml_python-1.25.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.25.0.dist-info → snowflake_ml_python-1.25.1.dist-info}/top_level.txt +0 -0
|
@@ -9,7 +9,6 @@ _SESSION_ACCOUNT_KEY = "session$account"
|
|
|
9
9
|
_SESSION_ROLE_KEY = "session$role"
|
|
10
10
|
_SESSION_DATABASE_KEY = "session$database"
|
|
11
11
|
_SESSION_SCHEMA_KEY = "session$schema"
|
|
12
|
-
_SESSION_STATE_ATTR = "_session_state"
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
def _identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
|
|
@@ -62,7 +61,7 @@ class SerializableSessionMixin:
|
|
|
62
61
|
else:
|
|
63
62
|
self.__dict__.update(state)
|
|
64
63
|
|
|
65
|
-
|
|
64
|
+
self._set_session(session_state)
|
|
66
65
|
|
|
67
66
|
def _set_session(self, session_state: _SessionState) -> None:
|
|
68
67
|
|
|
@@ -87,27 +86,3 @@ class SerializableSessionMixin:
|
|
|
87
86
|
),
|
|
88
87
|
),
|
|
89
88
|
)
|
|
90
|
-
|
|
91
|
-
@property
|
|
92
|
-
def session(self) -> Optional[snowpark_session.Session]:
|
|
93
|
-
if _SESSION_KEY not in self.__dict__:
|
|
94
|
-
session_state = getattr(self, _SESSION_STATE_ATTR, None)
|
|
95
|
-
if session_state is not None:
|
|
96
|
-
self._set_session(session_state)
|
|
97
|
-
return self.__dict__.get(_SESSION_KEY)
|
|
98
|
-
|
|
99
|
-
@session.setter
|
|
100
|
-
def session(self, value: Optional[snowpark_session.Session]) -> None:
|
|
101
|
-
self.__dict__[_SESSION_KEY] = value
|
|
102
|
-
|
|
103
|
-
# _getattr__ is only called when an attribute is NOT found through normal lookup.
|
|
104
|
-
# 1. Data descriptors (like @property with setter) from the class hierarchy
|
|
105
|
-
# 2. Instance __dict__ (e.g., self.x = 10)
|
|
106
|
-
# 3. Non-data descriptors (methods, `@property without setter) from the class hierarchy
|
|
107
|
-
# __getattr__ — only called if steps 1-3 all fail
|
|
108
|
-
def __getattr__(self, name: str) -> Any:
|
|
109
|
-
if name == _SESSION_KEY:
|
|
110
|
-
return self.session
|
|
111
|
-
if hasattr(super(), "__getattr__"):
|
|
112
|
-
return super().__getattr__(name) # type: ignore[misc]
|
|
113
|
-
raise AttributeError(f"{type(self).__name__!s} object has no attribute {name!r}")
|
|
@@ -31,7 +31,7 @@ class StageFileWriter(io.IOBase):
|
|
|
31
31
|
# Only upload if buffer has content and no exception occurred
|
|
32
32
|
if write_contents and self._buffer.tell() > 0:
|
|
33
33
|
self._buffer.seek(0)
|
|
34
|
-
self._session.file.put_stream(self._buffer, self._path
|
|
34
|
+
self._session.file.put_stream(self._buffer, self._path)
|
|
35
35
|
self._buffer.close()
|
|
36
36
|
self._closed = True
|
|
37
37
|
|
|
@@ -84,15 +84,15 @@ class DtoCodec(Protocol):
|
|
|
84
84
|
|
|
85
85
|
@overload
|
|
86
86
|
@staticmethod
|
|
87
|
-
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.
|
|
87
|
+
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO:
|
|
88
88
|
...
|
|
89
89
|
|
|
90
90
|
@staticmethod
|
|
91
|
-
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.
|
|
91
|
+
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]:
|
|
92
92
|
pass
|
|
93
93
|
|
|
94
94
|
@staticmethod
|
|
95
|
-
def encode(dto: dto_schema.
|
|
95
|
+
def encode(dto: dto_schema.ResultDTO) -> bytes:
|
|
96
96
|
pass
|
|
97
97
|
|
|
98
98
|
|
|
@@ -104,18 +104,18 @@ class JsonDtoCodec(DtoCodec):
|
|
|
104
104
|
|
|
105
105
|
@overload
|
|
106
106
|
@staticmethod
|
|
107
|
-
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.
|
|
107
|
+
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO:
|
|
108
108
|
...
|
|
109
109
|
|
|
110
110
|
@staticmethod
|
|
111
|
-
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.
|
|
111
|
+
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]:
|
|
112
112
|
data = cast(dict[str, Any], json.load(stream))
|
|
113
113
|
if as_dict:
|
|
114
114
|
return data
|
|
115
|
-
return dto_schema.
|
|
115
|
+
return dto_schema.ResultDTO.model_validate(data)
|
|
116
116
|
|
|
117
117
|
@staticmethod
|
|
118
|
-
def encode(dto: dto_schema.
|
|
118
|
+
def encode(dto: dto_schema.ResultDTO) -> bytes:
|
|
119
119
|
# Temporarily extract the value to avoid accidentally applying model_dump() on it
|
|
120
120
|
result_value = dto.value
|
|
121
121
|
dto.value = None # Clear value to avoid serializing it in the model_dump
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from typing import Any,
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
2
|
|
|
3
|
-
from pydantic import BaseModel,
|
|
4
|
-
from typing_extensions import
|
|
3
|
+
from pydantic import BaseModel, model_validator
|
|
4
|
+
from typing_extensions import NotRequired, TypedDict
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class BinaryManifest(TypedDict):
|
|
@@ -67,47 +67,22 @@ class ExceptionMetadata(ResultMetadata):
|
|
|
67
67
|
traceback: str
|
|
68
68
|
|
|
69
69
|
|
|
70
|
-
class
|
|
71
|
-
"""
|
|
72
|
-
Base class for serializable payloads.
|
|
73
|
-
|
|
74
|
-
Args:
|
|
75
|
-
kind: Discriminator field for DTO type dispatch.
|
|
76
|
-
value: The payload value (if JSON-serializable).
|
|
77
|
-
protocol: The protocol used to serialize the payload (if not JSON-serializable).
|
|
78
|
-
"""
|
|
79
|
-
|
|
80
|
-
kind: Literal["base"] = "base"
|
|
81
|
-
value: Optional[Any] = None
|
|
82
|
-
protocol: Optional[ProtocolInfo] = None
|
|
83
|
-
serialize_error: Optional[str] = None
|
|
84
|
-
|
|
85
|
-
@model_validator(mode="before")
|
|
86
|
-
@classmethod
|
|
87
|
-
def validate_fields(cls, data: Any) -> Any:
|
|
88
|
-
"""Ensure at least one of value or protocol keys is specified."""
|
|
89
|
-
if cls is PayloadDTO and isinstance(data, dict):
|
|
90
|
-
required_fields = {"value", "protocol"}
|
|
91
|
-
if not any(field in data for field in required_fields):
|
|
92
|
-
raise ValueError("At least one of 'value' or 'protocol' must be specified")
|
|
93
|
-
return data
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
class ResultDTO(PayloadDTO):
|
|
70
|
+
class ResultDTO(BaseModel):
|
|
97
71
|
"""
|
|
98
72
|
A JSON representation of an execution result.
|
|
99
73
|
|
|
100
74
|
Args:
|
|
101
|
-
kind: Discriminator field for DTO type dispatch.
|
|
102
75
|
success: Whether the execution was successful.
|
|
103
76
|
value: The value of the execution or the exception if the execution failed.
|
|
104
77
|
protocol: The protocol used to serialize the result.
|
|
105
78
|
metadata: The metadata of the result.
|
|
106
79
|
"""
|
|
107
80
|
|
|
108
|
-
kind: Literal["result"] = "result" # type: ignore[assignment]
|
|
109
81
|
success: bool
|
|
82
|
+
value: Optional[Any] = None
|
|
83
|
+
protocol: Optional[ProtocolInfo] = None
|
|
110
84
|
metadata: Optional[Union[ResultMetadata, ExceptionMetadata]] = None
|
|
85
|
+
serialize_error: Optional[str] = None
|
|
111
86
|
|
|
112
87
|
@model_validator(mode="before")
|
|
113
88
|
@classmethod
|
|
@@ -118,23 +93,3 @@ class ResultDTO(PayloadDTO):
|
|
|
118
93
|
if not any(field in data for field in required_fields):
|
|
119
94
|
raise ValueError("At least one of 'value', 'protocol', or 'metadata' must be specified")
|
|
120
95
|
return data
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def _get_dto_kind(data: Any) -> str:
|
|
124
|
-
"""Extract the 'kind' discriminator from input, defaulting to 'result' for backward compatibility."""
|
|
125
|
-
if isinstance(data, dict):
|
|
126
|
-
kind = data.get("kind", "result")
|
|
127
|
-
else:
|
|
128
|
-
kind = getattr(data, "kind", "result")
|
|
129
|
-
return str(kind)
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
AnyResultDTO = Annotated[
|
|
133
|
-
Union[
|
|
134
|
-
Annotated[ResultDTO, Tag("result")],
|
|
135
|
-
Annotated[PayloadDTO, Tag("base")],
|
|
136
|
-
],
|
|
137
|
-
Discriminator(_get_dto_kind),
|
|
138
|
-
]
|
|
139
|
-
|
|
140
|
-
ResultDTOAdapter: TypeAdapter[AnyResultDTO] = TypeAdapter(AnyResultDTO)
|
|
@@ -17,8 +17,6 @@ Condition = Union[type, tuple[type, ...], Callable[[Any], bool], None]
|
|
|
17
17
|
|
|
18
18
|
logger = logging.getLogger(__name__)
|
|
19
19
|
|
|
20
|
-
SESSION_KEY_PREFIX = "session@"
|
|
21
|
-
|
|
22
20
|
|
|
23
21
|
class SerializationError(TypeError):
|
|
24
22
|
"""Exception raised when a serialization protocol fails."""
|
|
@@ -138,10 +136,9 @@ class CloudPickleProtocol(SerializationProtocol):
|
|
|
138
136
|
|
|
139
137
|
def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
|
|
140
138
|
"""Save the object to the destination directory."""
|
|
141
|
-
replaced_obj = self._pack_obj(obj)
|
|
142
139
|
result_path = posixpath.join(dest_dir, self.DEFAULT_PATH)
|
|
143
140
|
with data_utils.open_stream(result_path, "wb", session=session) as f:
|
|
144
|
-
self._backend.dump(
|
|
141
|
+
self._backend.dump(obj, f)
|
|
145
142
|
manifest: BinaryManifest = {"path": result_path}
|
|
146
143
|
return self.protocol_info.with_manifest(manifest)
|
|
147
144
|
|
|
@@ -160,15 +157,12 @@ class CloudPickleProtocol(SerializationProtocol):
|
|
|
160
157
|
payload_manifest = cast(BinaryManifest, payload_info.manifest)
|
|
161
158
|
try:
|
|
162
159
|
if payload_bytes := payload_manifest.get("bytes"):
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
else
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
result = self._backend.load(f)
|
|
170
|
-
|
|
171
|
-
return self._unpack_obj(result, session=session)
|
|
160
|
+
return self._backend.loads(payload_bytes)
|
|
161
|
+
if payload_b64 := payload_manifest.get("base64"):
|
|
162
|
+
return self._backend.loads(base64.b64decode(payload_b64))
|
|
163
|
+
result_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"]
|
|
164
|
+
with data_utils.open_stream(result_path, "rb", session=session) as f:
|
|
165
|
+
return self._backend.load(f)
|
|
172
166
|
except (
|
|
173
167
|
pickle.UnpicklingError,
|
|
174
168
|
TypeError,
|
|
@@ -179,117 +173,6 @@ class CloudPickleProtocol(SerializationProtocol):
|
|
|
179
173
|
raise error from pickle_error
|
|
180
174
|
raise
|
|
181
175
|
|
|
182
|
-
def _pack_obj(self, obj: Any) -> Any:
|
|
183
|
-
"""Pack objects into JSON-safe dicts using reserved marker keys.
|
|
184
|
-
|
|
185
|
-
Markers:
|
|
186
|
-
- "type@": container type for list/tuple (list or tuple)
|
|
187
|
-
- "#<i>": positional element for list/tuple at index i
|
|
188
|
-
- "session@": placeholder for snowpark.Session values
|
|
189
|
-
- "session@#<i>" for list/tuple entries
|
|
190
|
-
- "session@<key>" for dict entries
|
|
191
|
-
- {"session@": None} for a bare Session object
|
|
192
|
-
|
|
193
|
-
Example:
|
|
194
|
-
obj = {"x": [1, session], "s": session}
|
|
195
|
-
packed = {
|
|
196
|
-
"x": {"type@": list, "#0": 1, "session@#1": None},
|
|
197
|
-
"session@s": None,
|
|
198
|
-
}
|
|
199
|
-
_unpack_obj(packed, session) == obj
|
|
200
|
-
|
|
201
|
-
Args:
|
|
202
|
-
obj: Object to pack into JSON-safe marker dictionaries.
|
|
203
|
-
|
|
204
|
-
Returns:
|
|
205
|
-
Packed representation with markers for session references.
|
|
206
|
-
"""
|
|
207
|
-
arguments: dict[str, Any] = {}
|
|
208
|
-
if isinstance(obj, tuple) or isinstance(obj, list):
|
|
209
|
-
arguments = {"type@": type(obj)}
|
|
210
|
-
for i, arg in enumerate(obj):
|
|
211
|
-
if isinstance(arg, snowpark.Session):
|
|
212
|
-
arguments[f"{SESSION_KEY_PREFIX}#{i}"] = None
|
|
213
|
-
else:
|
|
214
|
-
arguments[f"#{i}"] = self._pack_obj(arg)
|
|
215
|
-
return arguments
|
|
216
|
-
elif isinstance(obj, dict):
|
|
217
|
-
for k, v in obj.items():
|
|
218
|
-
if isinstance(v, snowpark.Session):
|
|
219
|
-
arguments[f"{SESSION_KEY_PREFIX}{k}"] = None
|
|
220
|
-
else:
|
|
221
|
-
arguments[k] = self._pack_obj(v)
|
|
222
|
-
return arguments
|
|
223
|
-
elif isinstance(obj, snowpark.Session):
|
|
224
|
-
# Box session into a dict marker so we can distinguish it from other plain objects.
|
|
225
|
-
arguments[f"{SESSION_KEY_PREFIX}"] = None
|
|
226
|
-
return arguments
|
|
227
|
-
else:
|
|
228
|
-
return obj
|
|
229
|
-
|
|
230
|
-
def _unpack_obj(self, obj: Any, session: Optional[snowpark.Session] = None) -> Any:
|
|
231
|
-
"""Unpack dict markers back into containers and Session references.
|
|
232
|
-
|
|
233
|
-
Markers:
|
|
234
|
-
- "type@": container type for list/tuple (list or tuple)
|
|
235
|
-
- "#<i>": positional element for list/tuple at index i
|
|
236
|
-
- "session@": placeholder for snowpark.Session values
|
|
237
|
-
- "session@#<i>" for list/tuple entries
|
|
238
|
-
- "session@<key>" for dict entries
|
|
239
|
-
- {"session@": None} for a bare Session object
|
|
240
|
-
|
|
241
|
-
Example:
|
|
242
|
-
packed = {
|
|
243
|
-
"x": {"type@": list, "#0": 1, "session@#1": None},
|
|
244
|
-
"session@s": None,
|
|
245
|
-
}
|
|
246
|
-
obj = _unpack_obj(packed, session)
|
|
247
|
-
# obj == {"x": [1, session], "s": session}
|
|
248
|
-
|
|
249
|
-
Args:
|
|
250
|
-
obj: Packed object with marker dictionaries.
|
|
251
|
-
session: Session to inject for session markers.
|
|
252
|
-
|
|
253
|
-
Returns:
|
|
254
|
-
Unpacked object with session references restored.
|
|
255
|
-
"""
|
|
256
|
-
if not isinstance(obj, dict):
|
|
257
|
-
return obj
|
|
258
|
-
elif len(obj) == 1 and SESSION_KEY_PREFIX in obj:
|
|
259
|
-
return session
|
|
260
|
-
else:
|
|
261
|
-
type = obj.get("type@", None)
|
|
262
|
-
# If type is None, we are unpacking a dict
|
|
263
|
-
if type is None:
|
|
264
|
-
result_dict = {}
|
|
265
|
-
for k, v in obj.items():
|
|
266
|
-
if k.startswith(SESSION_KEY_PREFIX):
|
|
267
|
-
result_key = k[len(SESSION_KEY_PREFIX) :]
|
|
268
|
-
result_dict[result_key] = session
|
|
269
|
-
else:
|
|
270
|
-
result_dict[k] = self._unpack_obj(v, session)
|
|
271
|
-
return result_dict
|
|
272
|
-
# If type is not None, we are unpacking a tuple or list
|
|
273
|
-
else:
|
|
274
|
-
indexes = []
|
|
275
|
-
for k, _ in obj.items():
|
|
276
|
-
if "#" in k:
|
|
277
|
-
indexes.append(int(k.split("#")[-1]))
|
|
278
|
-
|
|
279
|
-
if not indexes:
|
|
280
|
-
return tuple() if type is tuple else []
|
|
281
|
-
result_list: list[Any] = [None] * (max(indexes) + 1)
|
|
282
|
-
|
|
283
|
-
for k, v in obj.items():
|
|
284
|
-
if k == "type@":
|
|
285
|
-
continue
|
|
286
|
-
idx = int(k.split("#")[-1])
|
|
287
|
-
if k.startswith(SESSION_KEY_PREFIX):
|
|
288
|
-
result_list[idx] = session
|
|
289
|
-
else:
|
|
290
|
-
result_list[idx] = self._unpack_obj(v, session)
|
|
291
|
-
return tuple(result_list) if type is tuple else result_list
|
|
292
|
-
|
|
293
176
|
|
|
294
177
|
class ArrowTableProtocol(SerializationProtocol):
|
|
295
178
|
"""
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import io
|
|
2
1
|
import logging
|
|
3
2
|
import os
|
|
4
3
|
import traceback
|
|
@@ -11,9 +10,7 @@ from snowflake import snowpark
|
|
|
11
10
|
from snowflake.ml.jobs._interop import data_utils, exception_utils, legacy, protocols
|
|
12
11
|
from snowflake.ml.jobs._interop.dto_schema import (
|
|
13
12
|
ExceptionMetadata,
|
|
14
|
-
PayloadDTO,
|
|
15
13
|
ResultDTO,
|
|
16
|
-
ResultDTOAdapter,
|
|
17
14
|
ResultMetadata,
|
|
18
15
|
)
|
|
19
16
|
from snowflake.ml.jobs._interop.results import ExecutionResult, LoadedExecutionResult
|
|
@@ -26,137 +23,79 @@ DEFAULT_PROTOCOL.try_register_protocol(protocols.ArrowTableProtocol)
|
|
|
26
23
|
DEFAULT_PROTOCOL.try_register_protocol(protocols.PandasDataFrameProtocol)
|
|
27
24
|
DEFAULT_PROTOCOL.try_register_protocol(protocols.NumpyArrayProtocol)
|
|
28
25
|
|
|
29
|
-
# Constants for argument encoding
|
|
30
|
-
_MAX_INLINE_SIZE = 1024 * 1024 # 1MB - https://docs.snowflake.com/en/user-guide/query-size-limits
|
|
31
26
|
|
|
32
27
|
logger = logging.getLogger(__name__)
|
|
33
28
|
|
|
34
29
|
|
|
35
|
-
def
|
|
36
|
-
value: Any,
|
|
37
|
-
path: str,
|
|
38
|
-
session: Optional[snowpark.Session] = None,
|
|
39
|
-
max_inline_size: int = 0,
|
|
40
|
-
) -> Optional[bytes]:
|
|
30
|
+
def save_result(result: ExecutionResult, path: str, session: Optional[snowpark.Session] = None) -> None:
|
|
41
31
|
"""
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
Args:
|
|
45
|
-
value: The value to serialize. If ExecutionResult, creates ResultDTO with success flag.
|
|
46
|
-
path: Full file path for writing the DTO (if needed). Protocol data saved to path's parent.
|
|
47
|
-
session: Snowpark session for stage operations.
|
|
48
|
-
max_inline_size: Max bytes for inline return. 0 = always write to file.
|
|
49
|
-
|
|
50
|
-
Returns:
|
|
51
|
-
Encoded bytes if <= max_inline_size, else None (written to file).
|
|
52
|
-
|
|
53
|
-
Raises:
|
|
54
|
-
Exception: If session validation fails during serialization.
|
|
32
|
+
Save the result to a file.
|
|
55
33
|
"""
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
dto = PayloadDTO(value=value)
|
|
61
|
-
raw_value = value
|
|
34
|
+
result_dto = ResultDTO(
|
|
35
|
+
success=result.success,
|
|
36
|
+
value=result.value,
|
|
37
|
+
)
|
|
62
38
|
|
|
63
39
|
try:
|
|
64
|
-
|
|
40
|
+
# Try to encode result directly
|
|
41
|
+
payload = DEFAULT_CODEC.encode(result_dto)
|
|
65
42
|
except TypeError:
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
# Metadata enables client fallback display when result can't be deserialized (protocol mismatch)..
|
|
69
|
-
dto.metadata = _get_metadata(raw_value)
|
|
43
|
+
result_dto.value = None # Remove raw value to avoid serialization error
|
|
44
|
+
result_dto.metadata = _get_metadata(result.value) # Add metadata for client fallback on protocol mismatch
|
|
70
45
|
try:
|
|
71
46
|
path_dir = PurePath(path).parent.as_posix()
|
|
72
|
-
protocol_info = DEFAULT_PROTOCOL.save(
|
|
73
|
-
|
|
47
|
+
protocol_info = DEFAULT_PROTOCOL.save(result.value, path_dir, session=session)
|
|
48
|
+
result_dto.protocol = protocol_info
|
|
74
49
|
|
|
75
50
|
except Exception as e:
|
|
76
51
|
logger.warning(f"Error dumping result value: {repr(e)}")
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
# 2. Input Arguments: Trigger a "hard-fail."
|
|
82
|
-
# If arguments cannot be saved, the job script cannot run. We raise
|
|
83
|
-
# an immediate exception to prevent execution with invalid state.
|
|
84
|
-
if not isinstance(dto, ResultDTO):
|
|
85
|
-
raise
|
|
86
|
-
dto.serialize_error = repr(e)
|
|
87
|
-
|
|
88
|
-
# Encode the modified DTO
|
|
89
|
-
payload = DEFAULT_CODEC.encode(dto)
|
|
90
|
-
|
|
91
|
-
if not isinstance(dto, ResultDTO) and len(payload) <= max_inline_size:
|
|
92
|
-
return payload
|
|
52
|
+
result_dto.serialize_error = repr(e)
|
|
53
|
+
|
|
54
|
+
# Encode the modified result DTO
|
|
55
|
+
payload = DEFAULT_CODEC.encode(result_dto)
|
|
93
56
|
|
|
94
57
|
with data_utils.open_stream(path, "wb", session=session) as stream:
|
|
95
58
|
stream.write(payload)
|
|
96
|
-
return None
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
save_result = save # Backwards compatibility
|
|
100
59
|
|
|
101
60
|
|
|
102
|
-
def
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
) -> Any:
|
|
107
|
-
"""Load data from a file path or inline string."""
|
|
108
|
-
|
|
61
|
+
def load_result(
|
|
62
|
+
path: str, session: Optional[snowpark.Session] = None, path_transform: Optional[Callable[[str], str]] = None
|
|
63
|
+
) -> ExecutionResult:
|
|
64
|
+
"""Load the result from a file on a Snowflake stage."""
|
|
109
65
|
try:
|
|
110
|
-
with data_utils.open_stream(
|
|
66
|
+
with data_utils.open_stream(path, "r", session=session) as stream:
|
|
111
67
|
# Load the DTO as a dict for easy fallback to legacy loading if necessary
|
|
112
68
|
dto_dict = DEFAULT_CODEC.decode(stream, as_dict=True)
|
|
113
|
-
# the exception could be OSError or BlockingIOError(the file name is too long)
|
|
114
|
-
except OSError as e:
|
|
115
|
-
# path_or_data might be inline data
|
|
116
|
-
try:
|
|
117
|
-
dto_dict = DEFAULT_CODEC.decode(io.StringIO(path_or_data), as_dict=True)
|
|
118
|
-
except Exception:
|
|
119
|
-
raise e
|
|
120
69
|
except UnicodeDecodeError:
|
|
121
70
|
# Path may be a legacy result file (cloudpickle)
|
|
71
|
+
# TODO: Re-use the stream
|
|
122
72
|
assert session is not None
|
|
123
|
-
return legacy.load_legacy_result(session,
|
|
73
|
+
return legacy.load_legacy_result(session, path)
|
|
124
74
|
|
|
125
75
|
try:
|
|
126
|
-
dto =
|
|
76
|
+
dto = ResultDTO.model_validate(dto_dict)
|
|
127
77
|
except pydantic.ValidationError as e:
|
|
128
78
|
if "success" in dto_dict:
|
|
129
79
|
assert session is not None
|
|
130
|
-
if
|
|
131
|
-
|
|
132
|
-
return legacy.load_legacy_result(session,
|
|
80
|
+
if path.endswith(".json"):
|
|
81
|
+
path = os.path.splitext(path)[0] + ".pkl"
|
|
82
|
+
return legacy.load_legacy_result(session, path, result_json=dto_dict)
|
|
133
83
|
raise ValueError("Invalid result schema") from e
|
|
134
84
|
|
|
135
85
|
# Try loading data from file using the protocol info
|
|
136
|
-
|
|
86
|
+
result_value = None
|
|
137
87
|
data_load_error = None
|
|
138
88
|
if dto.protocol is not None:
|
|
139
89
|
try:
|
|
140
90
|
logger.debug(f"Loading result value with protocol {dto.protocol}")
|
|
141
|
-
|
|
91
|
+
result_value = DEFAULT_PROTOCOL.load(dto.protocol, session=session, path_transform=path_transform)
|
|
142
92
|
except sp_exceptions.SnowparkSQLException:
|
|
143
93
|
raise # Data retrieval errors should be bubbled up
|
|
144
94
|
except Exception as e:
|
|
145
95
|
logger.debug(f"Error loading result value with protocol {dto.protocol}: {repr(e)}")
|
|
146
|
-
# Error handling strategy depends on the DTO type:
|
|
147
|
-
# 1. ResultDTO: Soft-fail. The job has already finished.
|
|
148
|
-
# We package the load error into the result so the client can
|
|
149
|
-
# debug or adjust their protocol version to retrieve the output.
|
|
150
|
-
# 2. PayloadDTO : Raise a hard error. If arguments cannot be
|
|
151
|
-
# loaded, the job cannot run. We abort early to prevent execution.
|
|
152
|
-
if not isinstance(dto, ResultDTO):
|
|
153
|
-
raise
|
|
154
96
|
data_load_error = e
|
|
155
97
|
|
|
156
|
-
#
|
|
157
|
-
if not isinstance(dto, ResultDTO):
|
|
158
|
-
return payload_value
|
|
159
|
-
|
|
98
|
+
# Wrap serialize_error in a TypeError
|
|
160
99
|
if dto.serialize_error:
|
|
161
100
|
serialize_error = TypeError("Original result serialization failed with error: " + dto.serialize_error)
|
|
162
101
|
if data_load_error:
|
|
@@ -164,7 +103,8 @@ def load(
|
|
|
164
103
|
else:
|
|
165
104
|
data_load_error = serialize_error
|
|
166
105
|
|
|
167
|
-
|
|
106
|
+
# Prepare to assemble the final result
|
|
107
|
+
result_value = result_value if result_value is not None else dto.value
|
|
168
108
|
if not dto.success and result_value is None:
|
|
169
109
|
# Try to reconstruct exception from metadata if available
|
|
170
110
|
if isinstance(dto.metadata, ExceptionMetadata):
|
|
@@ -175,6 +115,7 @@ def load(
|
|
|
175
115
|
traceback=dto.metadata.traceback,
|
|
176
116
|
original_repr=dto.metadata.repr,
|
|
177
117
|
)
|
|
118
|
+
|
|
178
119
|
# Generate a generic error if we still don't have a value,
|
|
179
120
|
# attaching the data load error if any
|
|
180
121
|
if result_value is None:
|
|
@@ -5,7 +5,6 @@ from snowflake.ml.jobs._utils.types import ComputeResources
|
|
|
5
5
|
DEFAULT_CONTAINER_NAME = "main"
|
|
6
6
|
MEMORY_VOLUME_NAME = "dshm"
|
|
7
7
|
STAGE_VOLUME_NAME = "stage-volume"
|
|
8
|
-
RESULT_VOLUME_NAME = "result-volume"
|
|
9
8
|
DEFAULT_PYTHON_VERSION = "3.10"
|
|
10
9
|
|
|
11
10
|
# Environment variables
|
|
@@ -110,6 +109,3 @@ CLOUD_INSTANCE_FAMILIES = {
|
|
|
110
109
|
SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES,
|
|
111
110
|
SnowflakeCloudType.GCP: GCP_INSTANCE_FAMILIES,
|
|
112
111
|
}
|
|
113
|
-
|
|
114
|
-
# Magic attributes
|
|
115
|
-
IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from typing import Any, Callable, Optional
|
|
3
|
+
|
|
4
|
+
from snowflake import snowpark
|
|
5
|
+
from snowflake.snowpark import context as sp_context
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FunctionPayload:
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
func: Callable[..., Any],
|
|
12
|
+
session: Optional[snowpark.Session] = None,
|
|
13
|
+
session_argument: str = "",
|
|
14
|
+
*args: Any,
|
|
15
|
+
**kwargs: Any
|
|
16
|
+
) -> None:
|
|
17
|
+
self.function = func
|
|
18
|
+
self.args = args
|
|
19
|
+
self.kwargs = kwargs
|
|
20
|
+
self._session = session
|
|
21
|
+
self._session_argument = session_argument
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def session(self) -> Optional[snowpark.Session]:
|
|
25
|
+
return self._session
|
|
26
|
+
|
|
27
|
+
def __getstate__(self) -> dict[str, Any]:
|
|
28
|
+
"""Customize pickling to exclude session."""
|
|
29
|
+
state = self.__dict__.copy()
|
|
30
|
+
state["_session"] = None
|
|
31
|
+
return state
|
|
32
|
+
|
|
33
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
34
|
+
"""Restore session from context during unpickling."""
|
|
35
|
+
self.__dict__.update(state)
|
|
36
|
+
self._session = sp_context.get_active_session()
|
|
37
|
+
|
|
38
|
+
def __call__(self) -> Any:
|
|
39
|
+
sig = inspect.signature(self.function)
|
|
40
|
+
bound = sig.bind_partial(*self.args, **self.kwargs)
|
|
41
|
+
bound.arguments[self._session_argument] = self._session
|
|
42
|
+
|
|
43
|
+
return self.function(*bound.args, **bound.kwargs)
|