snowflake-ml-python 1.24.0__py3-none-any.whl → 1.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/mixins.py +26 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +5 -1
- snowflake/ml/data/data_connector.py +2 -2
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/experiment/_experiment_info.py +3 -3
- snowflake/ml/jobs/_interop/data_utils.py +8 -8
- snowflake/ml/jobs/_interop/dto_schema.py +52 -7
- snowflake/ml/jobs/_interop/protocols.py +124 -7
- snowflake/ml/jobs/_interop/utils.py +92 -33
- snowflake/ml/jobs/_utils/arg_protocol.py +7 -0
- snowflake/ml/jobs/_utils/constants.py +4 -0
- snowflake/ml/jobs/_utils/feature_flags.py +97 -13
- snowflake/ml/jobs/_utils/payload_utils.py +6 -40
- snowflake/ml/jobs/_utils/runtime_env_utils.py +12 -111
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +204 -27
- snowflake/ml/jobs/decorators.py +17 -22
- snowflake/ml/jobs/job.py +25 -10
- snowflake/ml/jobs/job_definition.py +100 -8
- snowflake/ml/model/_client/model/model_version_impl.py +25 -14
- snowflake/ml/model/_client/ops/service_ops.py +6 -6
- snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/models/huggingface_pipeline.py +3 -0
- snowflake/ml/model/openai_signatures.py +154 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/METADATA +41 -2
- {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/RECORD +31 -32
- {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/WHEEL +1 -1
- snowflake/ml/jobs/_utils/function_payload_utils.py +0 -43
- snowflake/ml/jobs/_utils/spec_utils.py +0 -22
- {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/top_level.txt +0 -0
|
@@ -9,6 +9,7 @@ _SESSION_ACCOUNT_KEY = "session$account"
|
|
|
9
9
|
_SESSION_ROLE_KEY = "session$role"
|
|
10
10
|
_SESSION_DATABASE_KEY = "session$database"
|
|
11
11
|
_SESSION_SCHEMA_KEY = "session$schema"
|
|
12
|
+
_SESSION_STATE_ATTR = "_session_state"
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def _identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
|
|
@@ -61,7 +62,7 @@ class SerializableSessionMixin:
|
|
|
61
62
|
else:
|
|
62
63
|
self.__dict__.update(state)
|
|
63
64
|
|
|
64
|
-
self
|
|
65
|
+
setattr(self, _SESSION_STATE_ATTR, session_state)
|
|
65
66
|
|
|
66
67
|
def _set_session(self, session_state: _SessionState) -> None:
|
|
67
68
|
|
|
@@ -86,3 +87,27 @@ class SerializableSessionMixin:
|
|
|
86
87
|
),
|
|
87
88
|
),
|
|
88
89
|
)
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def session(self) -> Optional[snowpark_session.Session]:
|
|
93
|
+
if _SESSION_KEY not in self.__dict__:
|
|
94
|
+
session_state = getattr(self, _SESSION_STATE_ATTR, None)
|
|
95
|
+
if session_state is not None:
|
|
96
|
+
self._set_session(session_state)
|
|
97
|
+
return self.__dict__.get(_SESSION_KEY)
|
|
98
|
+
|
|
99
|
+
@session.setter
|
|
100
|
+
def session(self, value: Optional[snowpark_session.Session]) -> None:
|
|
101
|
+
self.__dict__[_SESSION_KEY] = value
|
|
102
|
+
|
|
103
|
+
# _getattr__ is only called when an attribute is NOT found through normal lookup.
|
|
104
|
+
# 1. Data descriptors (like @property with setter) from the class hierarchy
|
|
105
|
+
# 2. Instance __dict__ (e.g., self.x = 10)
|
|
106
|
+
# 3. Non-data descriptors (methods, `@property without setter) from the class hierarchy
|
|
107
|
+
# __getattr__ — only called if steps 1-3 all fail
|
|
108
|
+
def __getattr__(self, name: str) -> Any:
|
|
109
|
+
if name == _SESSION_KEY:
|
|
110
|
+
return self.session
|
|
111
|
+
if hasattr(super(), "__getattr__"):
|
|
112
|
+
return super().__getattr__(name) # type: ignore[misc]
|
|
113
|
+
raise AttributeError(f"{type(self).__name__!s} object has no attribute {name!r}")
|
|
@@ -73,15 +73,19 @@ class ArrowIngestor(data_ingestor.DataIngestor, mixins.SerializableSessionMixin)
|
|
|
73
73
|
self._schema: Optional[pa.Schema] = None
|
|
74
74
|
|
|
75
75
|
@classmethod
|
|
76
|
-
def from_sources(
|
|
76
|
+
def from_sources(
|
|
77
|
+
cls, session: snowpark.Session, sources: Sequence[data_source.DataSource], **kwargs: Any
|
|
78
|
+
) -> "ArrowIngestor":
|
|
77
79
|
if session is None:
|
|
78
80
|
raise ValueError("Session is required")
|
|
81
|
+
# Skipping kwargs until needed to avoid impact other workflows.
|
|
79
82
|
return cls(session, sources)
|
|
80
83
|
|
|
81
84
|
@classmethod
|
|
82
85
|
def from_ray_dataset(
|
|
83
86
|
cls,
|
|
84
87
|
ray_ds: "ray.data.Dataset",
|
|
88
|
+
**kwargs: Any,
|
|
85
89
|
) -> "ArrowIngestor":
|
|
86
90
|
raise NotImplementedError
|
|
87
91
|
|
|
@@ -94,7 +94,7 @@ class DataConnector:
|
|
|
94
94
|
**kwargs: Any,
|
|
95
95
|
) -> DataConnectorType:
|
|
96
96
|
ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
|
|
97
|
-
ray_ingestor = ingestor_class.from_ray_dataset(ray_ds=ray_ds)
|
|
97
|
+
ray_ingestor = ingestor_class.from_ray_dataset(ray_ds=ray_ds, **kwargs)
|
|
98
98
|
return cls(ray_ingestor, **kwargs)
|
|
99
99
|
|
|
100
100
|
@classmethod
|
|
@@ -111,7 +111,7 @@ class DataConnector:
|
|
|
111
111
|
**kwargs: Any,
|
|
112
112
|
) -> DataConnectorType:
|
|
113
113
|
ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
|
|
114
|
-
ingestor = ingestor_class.from_sources(session, sources)
|
|
114
|
+
ingestor = ingestor_class.from_sources(session, sources, **kwargs)
|
|
115
115
|
return cls(ingestor, **kwargs)
|
|
116
116
|
|
|
117
117
|
@property
|
|
@@ -16,7 +16,7 @@ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
|
|
|
16
16
|
class DataIngestor(Protocol):
|
|
17
17
|
@classmethod
|
|
18
18
|
def from_sources(
|
|
19
|
-
cls: type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource]
|
|
19
|
+
cls: type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource], **kwargs: Any
|
|
20
20
|
) -> DataIngestorType:
|
|
21
21
|
raise NotImplementedError
|
|
22
22
|
|
|
@@ -24,6 +24,7 @@ class DataIngestor(Protocol):
|
|
|
24
24
|
def from_ray_dataset(
|
|
25
25
|
cls: type[DataIngestorType],
|
|
26
26
|
ray_ds: "ray.data.Dataset",
|
|
27
|
+
**kwargs: Any,
|
|
27
28
|
) -> DataIngestorType:
|
|
28
29
|
raise NotImplementedError
|
|
29
30
|
|
|
@@ -3,7 +3,7 @@ import functools
|
|
|
3
3
|
import types
|
|
4
4
|
from typing import Callable, Optional
|
|
5
5
|
|
|
6
|
-
from snowflake.ml import
|
|
6
|
+
from snowflake.ml.model._client.model import model_version_impl
|
|
7
7
|
from snowflake.ml.registry._manager import model_manager
|
|
8
8
|
|
|
9
9
|
|
|
@@ -23,7 +23,7 @@ class ExperimentInfoPatcher:
|
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
25
|
# Store original method at class definition time to avoid recursive patching
|
|
26
|
-
_original_log_model: Callable[...,
|
|
26
|
+
_original_log_model: Callable[..., model_version_impl.ModelVersion] = model_manager.ModelManager.log_model
|
|
27
27
|
|
|
28
28
|
# Stack of active experiment_info contexts for nested experiment support
|
|
29
29
|
_experiment_info_stack: list[ExperimentInfo] = []
|
|
@@ -36,7 +36,7 @@ class ExperimentInfoPatcher:
|
|
|
36
36
|
if not ExperimentInfoPatcher._experiment_info_stack:
|
|
37
37
|
|
|
38
38
|
@functools.wraps(ExperimentInfoPatcher._original_log_model)
|
|
39
|
-
def patched(*args, **kwargs) ->
|
|
39
|
+
def patched(*args, **kwargs) -> model_version_impl.ModelVersion: # type: ignore[no-untyped-def]
|
|
40
40
|
# Use the most recent (top of stack) experiment_info for nested contexts
|
|
41
41
|
current_experiment_info = ExperimentInfoPatcher._experiment_info_stack[-1]
|
|
42
42
|
return ExperimentInfoPatcher._original_log_model(
|
|
@@ -31,7 +31,7 @@ class StageFileWriter(io.IOBase):
|
|
|
31
31
|
# Only upload if buffer has content and no exception occurred
|
|
32
32
|
if write_contents and self._buffer.tell() > 0:
|
|
33
33
|
self._buffer.seek(0)
|
|
34
|
-
self._session.file.put_stream(self._buffer, self._path)
|
|
34
|
+
self._session.file.put_stream(self._buffer, self._path, auto_compress=False)
|
|
35
35
|
self._buffer.close()
|
|
36
36
|
self._closed = True
|
|
37
37
|
|
|
@@ -84,15 +84,15 @@ class DtoCodec(Protocol):
|
|
|
84
84
|
|
|
85
85
|
@overload
|
|
86
86
|
@staticmethod
|
|
87
|
-
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.
|
|
87
|
+
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.PayloadDTO:
|
|
88
88
|
...
|
|
89
89
|
|
|
90
90
|
@staticmethod
|
|
91
|
-
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.
|
|
91
|
+
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.PayloadDTO, dict[str, Any]]:
|
|
92
92
|
pass
|
|
93
93
|
|
|
94
94
|
@staticmethod
|
|
95
|
-
def encode(dto: dto_schema.
|
|
95
|
+
def encode(dto: dto_schema.PayloadDTO) -> bytes:
|
|
96
96
|
pass
|
|
97
97
|
|
|
98
98
|
|
|
@@ -104,18 +104,18 @@ class JsonDtoCodec(DtoCodec):
|
|
|
104
104
|
|
|
105
105
|
@overload
|
|
106
106
|
@staticmethod
|
|
107
|
-
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.
|
|
107
|
+
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.PayloadDTO:
|
|
108
108
|
...
|
|
109
109
|
|
|
110
110
|
@staticmethod
|
|
111
|
-
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.
|
|
111
|
+
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.PayloadDTO, dict[str, Any]]:
|
|
112
112
|
data = cast(dict[str, Any], json.load(stream))
|
|
113
113
|
if as_dict:
|
|
114
114
|
return data
|
|
115
|
-
return dto_schema.
|
|
115
|
+
return dto_schema.ResultDTOAdapter.validate_python(data)
|
|
116
116
|
|
|
117
117
|
@staticmethod
|
|
118
|
-
def encode(dto: dto_schema.
|
|
118
|
+
def encode(dto: dto_schema.PayloadDTO) -> bytes:
|
|
119
119
|
# Temporarily extract the value to avoid accidentally applying model_dump() on it
|
|
120
120
|
result_value = dto.value
|
|
121
121
|
dto.value = None # Clear value to avoid serializing it in the model_dump
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from typing import Any, Optional, Union
|
|
1
|
+
from typing import Any, Literal, Optional, Union
|
|
2
2
|
|
|
3
|
-
from pydantic import BaseModel, model_validator
|
|
4
|
-
from typing_extensions import NotRequired, TypedDict
|
|
3
|
+
from pydantic import BaseModel, Discriminator, Tag, TypeAdapter, model_validator
|
|
4
|
+
from typing_extensions import Annotated, NotRequired, TypedDict
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class BinaryManifest(TypedDict):
|
|
@@ -67,22 +67,47 @@ class ExceptionMetadata(ResultMetadata):
|
|
|
67
67
|
traceback: str
|
|
68
68
|
|
|
69
69
|
|
|
70
|
-
class
|
|
70
|
+
class PayloadDTO(BaseModel):
|
|
71
|
+
"""
|
|
72
|
+
Base class for serializable payloads.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
kind: Discriminator field for DTO type dispatch.
|
|
76
|
+
value: The payload value (if JSON-serializable).
|
|
77
|
+
protocol: The protocol used to serialize the payload (if not JSON-serializable).
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
kind: Literal["base"] = "base"
|
|
81
|
+
value: Optional[Any] = None
|
|
82
|
+
protocol: Optional[ProtocolInfo] = None
|
|
83
|
+
serialize_error: Optional[str] = None
|
|
84
|
+
|
|
85
|
+
@model_validator(mode="before")
|
|
86
|
+
@classmethod
|
|
87
|
+
def validate_fields(cls, data: Any) -> Any:
|
|
88
|
+
"""Ensure at least one of value or protocol keys is specified."""
|
|
89
|
+
if cls is PayloadDTO and isinstance(data, dict):
|
|
90
|
+
required_fields = {"value", "protocol"}
|
|
91
|
+
if not any(field in data for field in required_fields):
|
|
92
|
+
raise ValueError("At least one of 'value' or 'protocol' must be specified")
|
|
93
|
+
return data
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class ResultDTO(PayloadDTO):
|
|
71
97
|
"""
|
|
72
98
|
A JSON representation of an execution result.
|
|
73
99
|
|
|
74
100
|
Args:
|
|
101
|
+
kind: Discriminator field for DTO type dispatch.
|
|
75
102
|
success: Whether the execution was successful.
|
|
76
103
|
value: The value of the execution or the exception if the execution failed.
|
|
77
104
|
protocol: The protocol used to serialize the result.
|
|
78
105
|
metadata: The metadata of the result.
|
|
79
106
|
"""
|
|
80
107
|
|
|
108
|
+
kind: Literal["result"] = "result" # type: ignore[assignment]
|
|
81
109
|
success: bool
|
|
82
|
-
value: Optional[Any] = None
|
|
83
|
-
protocol: Optional[ProtocolInfo] = None
|
|
84
110
|
metadata: Optional[Union[ResultMetadata, ExceptionMetadata]] = None
|
|
85
|
-
serialize_error: Optional[str] = None
|
|
86
111
|
|
|
87
112
|
@model_validator(mode="before")
|
|
88
113
|
@classmethod
|
|
@@ -93,3 +118,23 @@ class ResultDTO(BaseModel):
|
|
|
93
118
|
if not any(field in data for field in required_fields):
|
|
94
119
|
raise ValueError("At least one of 'value', 'protocol', or 'metadata' must be specified")
|
|
95
120
|
return data
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _get_dto_kind(data: Any) -> str:
|
|
124
|
+
"""Extract the 'kind' discriminator from input, defaulting to 'result' for backward compatibility."""
|
|
125
|
+
if isinstance(data, dict):
|
|
126
|
+
kind = data.get("kind", "result")
|
|
127
|
+
else:
|
|
128
|
+
kind = getattr(data, "kind", "result")
|
|
129
|
+
return str(kind)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
AnyResultDTO = Annotated[
|
|
133
|
+
Union[
|
|
134
|
+
Annotated[ResultDTO, Tag("result")],
|
|
135
|
+
Annotated[PayloadDTO, Tag("base")],
|
|
136
|
+
],
|
|
137
|
+
Discriminator(_get_dto_kind),
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
ResultDTOAdapter: TypeAdapter[AnyResultDTO] = TypeAdapter(AnyResultDTO)
|
|
@@ -17,6 +17,8 @@ Condition = Union[type, tuple[type, ...], Callable[[Any], bool], None]
|
|
|
17
17
|
|
|
18
18
|
logger = logging.getLogger(__name__)
|
|
19
19
|
|
|
20
|
+
SESSION_KEY_PREFIX = "session@"
|
|
21
|
+
|
|
20
22
|
|
|
21
23
|
class SerializationError(TypeError):
|
|
22
24
|
"""Exception raised when a serialization protocol fails."""
|
|
@@ -136,9 +138,10 @@ class CloudPickleProtocol(SerializationProtocol):
|
|
|
136
138
|
|
|
137
139
|
def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
|
|
138
140
|
"""Save the object to the destination directory."""
|
|
141
|
+
replaced_obj = self._pack_obj(obj)
|
|
139
142
|
result_path = posixpath.join(dest_dir, self.DEFAULT_PATH)
|
|
140
143
|
with data_utils.open_stream(result_path, "wb", session=session) as f:
|
|
141
|
-
self._backend.dump(
|
|
144
|
+
self._backend.dump(replaced_obj, f)
|
|
142
145
|
manifest: BinaryManifest = {"path": result_path}
|
|
143
146
|
return self.protocol_info.with_manifest(manifest)
|
|
144
147
|
|
|
@@ -157,12 +160,15 @@ class CloudPickleProtocol(SerializationProtocol):
|
|
|
157
160
|
payload_manifest = cast(BinaryManifest, payload_info.manifest)
|
|
158
161
|
try:
|
|
159
162
|
if payload_bytes := payload_manifest.get("bytes"):
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
163
|
+
result = self._backend.loads(payload_bytes)
|
|
164
|
+
elif payload_b64 := payload_manifest.get("base64"):
|
|
165
|
+
result = self._backend.loads(base64.b64decode(payload_b64))
|
|
166
|
+
else:
|
|
167
|
+
result_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"]
|
|
168
|
+
with data_utils.open_stream(result_path, "rb", session=session) as f:
|
|
169
|
+
result = self._backend.load(f)
|
|
170
|
+
|
|
171
|
+
return self._unpack_obj(result, session=session)
|
|
166
172
|
except (
|
|
167
173
|
pickle.UnpicklingError,
|
|
168
174
|
TypeError,
|
|
@@ -173,6 +179,117 @@ class CloudPickleProtocol(SerializationProtocol):
|
|
|
173
179
|
raise error from pickle_error
|
|
174
180
|
raise
|
|
175
181
|
|
|
182
|
+
def _pack_obj(self, obj: Any) -> Any:
|
|
183
|
+
"""Pack objects into JSON-safe dicts using reserved marker keys.
|
|
184
|
+
|
|
185
|
+
Markers:
|
|
186
|
+
- "type@": container type for list/tuple (list or tuple)
|
|
187
|
+
- "#<i>": positional element for list/tuple at index i
|
|
188
|
+
- "session@": placeholder for snowpark.Session values
|
|
189
|
+
- "session@#<i>" for list/tuple entries
|
|
190
|
+
- "session@<key>" for dict entries
|
|
191
|
+
- {"session@": None} for a bare Session object
|
|
192
|
+
|
|
193
|
+
Example:
|
|
194
|
+
obj = {"x": [1, session], "s": session}
|
|
195
|
+
packed = {
|
|
196
|
+
"x": {"type@": list, "#0": 1, "session@#1": None},
|
|
197
|
+
"session@s": None,
|
|
198
|
+
}
|
|
199
|
+
_unpack_obj(packed, session) == obj
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
obj: Object to pack into JSON-safe marker dictionaries.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
Packed representation with markers for session references.
|
|
206
|
+
"""
|
|
207
|
+
arguments: dict[str, Any] = {}
|
|
208
|
+
if isinstance(obj, tuple) or isinstance(obj, list):
|
|
209
|
+
arguments = {"type@": type(obj)}
|
|
210
|
+
for i, arg in enumerate(obj):
|
|
211
|
+
if isinstance(arg, snowpark.Session):
|
|
212
|
+
arguments[f"{SESSION_KEY_PREFIX}#{i}"] = None
|
|
213
|
+
else:
|
|
214
|
+
arguments[f"#{i}"] = self._pack_obj(arg)
|
|
215
|
+
return arguments
|
|
216
|
+
elif isinstance(obj, dict):
|
|
217
|
+
for k, v in obj.items():
|
|
218
|
+
if isinstance(v, snowpark.Session):
|
|
219
|
+
arguments[f"{SESSION_KEY_PREFIX}{k}"] = None
|
|
220
|
+
else:
|
|
221
|
+
arguments[k] = self._pack_obj(v)
|
|
222
|
+
return arguments
|
|
223
|
+
elif isinstance(obj, snowpark.Session):
|
|
224
|
+
# Box session into a dict marker so we can distinguish it from other plain objects.
|
|
225
|
+
arguments[f"{SESSION_KEY_PREFIX}"] = None
|
|
226
|
+
return arguments
|
|
227
|
+
else:
|
|
228
|
+
return obj
|
|
229
|
+
|
|
230
|
+
def _unpack_obj(self, obj: Any, session: Optional[snowpark.Session] = None) -> Any:
|
|
231
|
+
"""Unpack dict markers back into containers and Session references.
|
|
232
|
+
|
|
233
|
+
Markers:
|
|
234
|
+
- "type@": container type for list/tuple (list or tuple)
|
|
235
|
+
- "#<i>": positional element for list/tuple at index i
|
|
236
|
+
- "session@": placeholder for snowpark.Session values
|
|
237
|
+
- "session@#<i>" for list/tuple entries
|
|
238
|
+
- "session@<key>" for dict entries
|
|
239
|
+
- {"session@": None} for a bare Session object
|
|
240
|
+
|
|
241
|
+
Example:
|
|
242
|
+
packed = {
|
|
243
|
+
"x": {"type@": list, "#0": 1, "session@#1": None},
|
|
244
|
+
"session@s": None,
|
|
245
|
+
}
|
|
246
|
+
obj = _unpack_obj(packed, session)
|
|
247
|
+
# obj == {"x": [1, session], "s": session}
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
obj: Packed object with marker dictionaries.
|
|
251
|
+
session: Session to inject for session markers.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Unpacked object with session references restored.
|
|
255
|
+
"""
|
|
256
|
+
if not isinstance(obj, dict):
|
|
257
|
+
return obj
|
|
258
|
+
elif len(obj) == 1 and SESSION_KEY_PREFIX in obj:
|
|
259
|
+
return session
|
|
260
|
+
else:
|
|
261
|
+
type = obj.get("type@", None)
|
|
262
|
+
# If type is None, we are unpacking a dict
|
|
263
|
+
if type is None:
|
|
264
|
+
result_dict = {}
|
|
265
|
+
for k, v in obj.items():
|
|
266
|
+
if k.startswith(SESSION_KEY_PREFIX):
|
|
267
|
+
result_key = k[len(SESSION_KEY_PREFIX) :]
|
|
268
|
+
result_dict[result_key] = session
|
|
269
|
+
else:
|
|
270
|
+
result_dict[k] = self._unpack_obj(v, session)
|
|
271
|
+
return result_dict
|
|
272
|
+
# If type is not None, we are unpacking a tuple or list
|
|
273
|
+
else:
|
|
274
|
+
indexes = []
|
|
275
|
+
for k, _ in obj.items():
|
|
276
|
+
if "#" in k:
|
|
277
|
+
indexes.append(int(k.split("#")[-1]))
|
|
278
|
+
|
|
279
|
+
if not indexes:
|
|
280
|
+
return tuple() if type is tuple else []
|
|
281
|
+
result_list: list[Any] = [None] * (max(indexes) + 1)
|
|
282
|
+
|
|
283
|
+
for k, v in obj.items():
|
|
284
|
+
if k == "type@":
|
|
285
|
+
continue
|
|
286
|
+
idx = int(k.split("#")[-1])
|
|
287
|
+
if k.startswith(SESSION_KEY_PREFIX):
|
|
288
|
+
result_list[idx] = session
|
|
289
|
+
else:
|
|
290
|
+
result_list[idx] = self._unpack_obj(v, session)
|
|
291
|
+
return tuple(result_list) if type is tuple else result_list
|
|
292
|
+
|
|
176
293
|
|
|
177
294
|
class ArrowTableProtocol(SerializationProtocol):
|
|
178
295
|
"""
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import io
|
|
1
2
|
import logging
|
|
2
3
|
import os
|
|
3
4
|
import traceback
|
|
@@ -10,7 +11,9 @@ from snowflake import snowpark
|
|
|
10
11
|
from snowflake.ml.jobs._interop import data_utils, exception_utils, legacy, protocols
|
|
11
12
|
from snowflake.ml.jobs._interop.dto_schema import (
|
|
12
13
|
ExceptionMetadata,
|
|
14
|
+
PayloadDTO,
|
|
13
15
|
ResultDTO,
|
|
16
|
+
ResultDTOAdapter,
|
|
14
17
|
ResultMetadata,
|
|
15
18
|
)
|
|
16
19
|
from snowflake.ml.jobs._interop.results import ExecutionResult, LoadedExecutionResult
|
|
@@ -23,79 +26,137 @@ DEFAULT_PROTOCOL.try_register_protocol(protocols.ArrowTableProtocol)
|
|
|
23
26
|
DEFAULT_PROTOCOL.try_register_protocol(protocols.PandasDataFrameProtocol)
|
|
24
27
|
DEFAULT_PROTOCOL.try_register_protocol(protocols.NumpyArrayProtocol)
|
|
25
28
|
|
|
29
|
+
# Constants for argument encoding
|
|
30
|
+
_MAX_INLINE_SIZE = 1024 * 1024 # 1MB - https://docs.snowflake.com/en/user-guide/query-size-limits
|
|
26
31
|
|
|
27
32
|
logger = logging.getLogger(__name__)
|
|
28
33
|
|
|
29
34
|
|
|
30
|
-
def
|
|
35
|
+
def save(
|
|
36
|
+
value: Any,
|
|
37
|
+
path: str,
|
|
38
|
+
session: Optional[snowpark.Session] = None,
|
|
39
|
+
max_inline_size: int = 0,
|
|
40
|
+
) -> Optional[bytes]:
|
|
31
41
|
"""
|
|
32
|
-
|
|
42
|
+
Serialize a value. Returns inline bytes if small enough, else writes to file.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
value: The value to serialize. If ExecutionResult, creates ResultDTO with success flag.
|
|
46
|
+
path: Full file path for writing the DTO (if needed). Protocol data saved to path's parent.
|
|
47
|
+
session: Snowpark session for stage operations.
|
|
48
|
+
max_inline_size: Max bytes for inline return. 0 = always write to file.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Encoded bytes if <= max_inline_size, else None (written to file).
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
Exception: If session validation fails during serialization.
|
|
33
55
|
"""
|
|
34
|
-
|
|
35
|
-
success=
|
|
36
|
-
value
|
|
37
|
-
|
|
56
|
+
if isinstance(value, ExecutionResult):
|
|
57
|
+
dto: PayloadDTO = ResultDTO(success=value.success, value=value.value)
|
|
58
|
+
raw_value = value.value
|
|
59
|
+
else:
|
|
60
|
+
dto = PayloadDTO(value=value)
|
|
61
|
+
raw_value = value
|
|
38
62
|
|
|
39
63
|
try:
|
|
40
|
-
|
|
41
|
-
payload = DEFAULT_CODEC.encode(result_dto)
|
|
64
|
+
payload = DEFAULT_CODEC.encode(dto)
|
|
42
65
|
except TypeError:
|
|
43
|
-
|
|
44
|
-
|
|
66
|
+
dto.value = None # Remove raw value to avoid serialization error
|
|
67
|
+
if isinstance(dto, ResultDTO):
|
|
68
|
+
# Metadata enables client fallback display when result can't be deserialized (protocol mismatch)..
|
|
69
|
+
dto.metadata = _get_metadata(raw_value)
|
|
45
70
|
try:
|
|
46
71
|
path_dir = PurePath(path).parent.as_posix()
|
|
47
|
-
protocol_info = DEFAULT_PROTOCOL.save(
|
|
48
|
-
|
|
72
|
+
protocol_info = DEFAULT_PROTOCOL.save(raw_value, path_dir, session=session)
|
|
73
|
+
dto.protocol = protocol_info
|
|
49
74
|
|
|
50
75
|
except Exception as e:
|
|
51
76
|
logger.warning(f"Error dumping result value: {repr(e)}")
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
77
|
+
# We handle serialization failures differently based on the DTO type:
|
|
78
|
+
# 1. Job Results (ResultDTO): Allow a "soft-fail."
|
|
79
|
+
# Since the job has already executed, we return the serialization error
|
|
80
|
+
# to the client so they can debug the output or update their protocol version.
|
|
81
|
+
# 2. Input Arguments: Trigger a "hard-fail."
|
|
82
|
+
# If arguments cannot be saved, the job script cannot run. We raise
|
|
83
|
+
# an immediate exception to prevent execution with invalid state.
|
|
84
|
+
if not isinstance(dto, ResultDTO):
|
|
85
|
+
raise
|
|
86
|
+
dto.serialize_error = repr(e)
|
|
87
|
+
|
|
88
|
+
# Encode the modified DTO
|
|
89
|
+
payload = DEFAULT_CODEC.encode(dto)
|
|
90
|
+
|
|
91
|
+
if not isinstance(dto, ResultDTO) and len(payload) <= max_inline_size:
|
|
92
|
+
return payload
|
|
56
93
|
|
|
57
94
|
with data_utils.open_stream(path, "wb", session=session) as stream:
|
|
58
95
|
stream.write(payload)
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
save_result = save # Backwards compatibility
|
|
59
100
|
|
|
60
101
|
|
|
61
|
-
def
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
102
|
+
def load(
|
|
103
|
+
path_or_data: str,
|
|
104
|
+
session: Optional[snowpark.Session] = None,
|
|
105
|
+
path_transform: Optional[Callable[[str], str]] = None,
|
|
106
|
+
) -> Any:
|
|
107
|
+
"""Load data from a file path or inline string."""
|
|
108
|
+
|
|
65
109
|
try:
|
|
66
|
-
with data_utils.open_stream(
|
|
110
|
+
with data_utils.open_stream(path_or_data, "r", session=session) as stream:
|
|
67
111
|
# Load the DTO as a dict for easy fallback to legacy loading if necessary
|
|
68
112
|
dto_dict = DEFAULT_CODEC.decode(stream, as_dict=True)
|
|
113
|
+
# the exception could be OSError or BlockingIOError(the file name is too long)
|
|
114
|
+
except OSError as e:
|
|
115
|
+
# path_or_data might be inline data
|
|
116
|
+
try:
|
|
117
|
+
dto_dict = DEFAULT_CODEC.decode(io.StringIO(path_or_data), as_dict=True)
|
|
118
|
+
except Exception:
|
|
119
|
+
raise e
|
|
69
120
|
except UnicodeDecodeError:
|
|
70
121
|
# Path may be a legacy result file (cloudpickle)
|
|
71
|
-
# TODO: Re-use the stream
|
|
72
122
|
assert session is not None
|
|
73
|
-
return legacy.load_legacy_result(session,
|
|
123
|
+
return legacy.load_legacy_result(session, path_or_data)
|
|
74
124
|
|
|
75
125
|
try:
|
|
76
|
-
dto =
|
|
126
|
+
dto = ResultDTOAdapter.validate_python(dto_dict)
|
|
77
127
|
except pydantic.ValidationError as e:
|
|
78
128
|
if "success" in dto_dict:
|
|
79
129
|
assert session is not None
|
|
80
|
-
if
|
|
81
|
-
|
|
82
|
-
return legacy.load_legacy_result(session,
|
|
130
|
+
if path_or_data.endswith(".json"):
|
|
131
|
+
path_or_data = os.path.splitext(path_or_data)[0] + ".pkl"
|
|
132
|
+
return legacy.load_legacy_result(session, path_or_data, result_json=dto_dict)
|
|
83
133
|
raise ValueError("Invalid result schema") from e
|
|
84
134
|
|
|
85
135
|
# Try loading data from file using the protocol info
|
|
86
|
-
|
|
136
|
+
payload_value = None
|
|
87
137
|
data_load_error = None
|
|
88
138
|
if dto.protocol is not None:
|
|
89
139
|
try:
|
|
90
140
|
logger.debug(f"Loading result value with protocol {dto.protocol}")
|
|
91
|
-
|
|
141
|
+
payload_value = DEFAULT_PROTOCOL.load(dto.protocol, session=session, path_transform=path_transform)
|
|
92
142
|
except sp_exceptions.SnowparkSQLException:
|
|
93
143
|
raise # Data retrieval errors should be bubbled up
|
|
94
144
|
except Exception as e:
|
|
95
145
|
logger.debug(f"Error loading result value with protocol {dto.protocol}: {repr(e)}")
|
|
146
|
+
# Error handling strategy depends on the DTO type:
|
|
147
|
+
# 1. ResultDTO: Soft-fail. The job has already finished.
|
|
148
|
+
# We package the load error into the result so the client can
|
|
149
|
+
# debug or adjust their protocol version to retrieve the output.
|
|
150
|
+
# 2. PayloadDTO : Raise a hard error. If arguments cannot be
|
|
151
|
+
# loaded, the job cannot run. We abort early to prevent execution.
|
|
152
|
+
if not isinstance(dto, ResultDTO):
|
|
153
|
+
raise
|
|
96
154
|
data_load_error = e
|
|
97
155
|
|
|
98
|
-
#
|
|
156
|
+
# Prepare to assemble the final result
|
|
157
|
+
if not isinstance(dto, ResultDTO):
|
|
158
|
+
return payload_value
|
|
159
|
+
|
|
99
160
|
if dto.serialize_error:
|
|
100
161
|
serialize_error = TypeError("Original result serialization failed with error: " + dto.serialize_error)
|
|
101
162
|
if data_load_error:
|
|
@@ -103,8 +164,7 @@ def load_result(
|
|
|
103
164
|
else:
|
|
104
165
|
data_load_error = serialize_error
|
|
105
166
|
|
|
106
|
-
|
|
107
|
-
result_value = result_value if result_value is not None else dto.value
|
|
167
|
+
result_value = payload_value if payload_value is not None else dto.value
|
|
108
168
|
if not dto.success and result_value is None:
|
|
109
169
|
# Try to reconstruct exception from metadata if available
|
|
110
170
|
if isinstance(dto.metadata, ExceptionMetadata):
|
|
@@ -115,7 +175,6 @@ def load_result(
|
|
|
115
175
|
traceback=dto.metadata.traceback,
|
|
116
176
|
original_repr=dto.metadata.repr,
|
|
117
177
|
)
|
|
118
|
-
|
|
119
178
|
# Generate a generic error if we still don't have a value,
|
|
120
179
|
# attaching the data load error if any
|
|
121
180
|
if result_value is None:
|
|
@@ -5,6 +5,7 @@ from snowflake.ml.jobs._utils.types import ComputeResources
|
|
|
5
5
|
DEFAULT_CONTAINER_NAME = "main"
|
|
6
6
|
MEMORY_VOLUME_NAME = "dshm"
|
|
7
7
|
STAGE_VOLUME_NAME = "stage-volume"
|
|
8
|
+
RESULT_VOLUME_NAME = "result-volume"
|
|
8
9
|
DEFAULT_PYTHON_VERSION = "3.10"
|
|
9
10
|
|
|
10
11
|
# Environment variables
|
|
@@ -109,3 +110,6 @@ CLOUD_INSTANCE_FAMILIES = {
|
|
|
109
110
|
SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES,
|
|
110
111
|
SnowflakeCloudType.GCP: GCP_INSTANCE_FAMILIES,
|
|
111
112
|
}
|
|
113
|
+
|
|
114
|
+
# Magic attributes
|
|
115
|
+
IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
|