snowflake-ml-python 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
- snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
- snowflake/ml/jobs/__init__.py +4 -0
- snowflake/ml/jobs/_interop/__init__.py +0 -0
- snowflake/ml/jobs/_interop/data_utils.py +124 -0
- snowflake/ml/jobs/_interop/dto_schema.py +95 -0
- snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
- snowflake/ml/jobs/_interop/legacy.py +225 -0
- snowflake/ml/jobs/_interop/protocols.py +471 -0
- snowflake/ml/jobs/_interop/results.py +51 -0
- snowflake/ml/jobs/_interop/utils.py +144 -0
- snowflake/ml/jobs/_utils/constants.py +4 -1
- snowflake/ml/jobs/_utils/feature_flags.py +37 -5
- snowflake/ml/jobs/_utils/payload_utils.py +1 -1
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
- snowflake/ml/jobs/_utils/spec_utils.py +2 -1
- snowflake/ml/jobs/_utils/types.py +10 -0
- snowflake/ml/jobs/job.py +168 -36
- snowflake/ml/jobs/manager.py +36 -38
- snowflake/ml/model/_client/model/model_version_impl.py +39 -7
- snowflake/ml/model/_client/ops/model_ops.py +4 -0
- snowflake/ml/model/_client/sql/model_version.py +3 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +7 -2
- snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
- snowflake/ml/model/_packager/model_env/model_env.py +22 -5
- snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/METADATA +26 -4
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/RECORD +35 -27
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
"""Legacy result serialization protocol support for ML Jobs.
|
|
2
|
+
|
|
3
|
+
This module provides backward compatibility with the result serialization protocol used by
|
|
4
|
+
mljob_launcher.py prior to snowflake-ml-python>=1.17.0
|
|
5
|
+
|
|
6
|
+
LEGACY PROTOCOL (v1):
|
|
7
|
+
---------------------
|
|
8
|
+
The old serialization protocol (save_mljob_result_v1 in mljob_launcher.py) worked as follows:
|
|
9
|
+
|
|
10
|
+
1. Results were stored in an ExecutionResult dataclass with two optional fields:
|
|
11
|
+
- result: Any = None # For successful executions
|
|
12
|
+
- exception: BaseException = None # For failed executions
|
|
13
|
+
|
|
14
|
+
2. The ExecutionResult was converted to a dictionary via to_dict():
|
|
15
|
+
Success case:
|
|
16
|
+
{"success": True, "result_type": <type qualname>, "result": <value>}
|
|
17
|
+
|
|
18
|
+
Failure case:
|
|
19
|
+
{"success": False, "exc_type": "<module>.<class>", "exc_value": <exception>,
|
|
20
|
+
"exc_tb": <formatted traceback string>}
|
|
21
|
+
|
|
22
|
+
3. The dictionary was serialized TWICE for fault tolerance:
|
|
23
|
+
- Primary: cloudpickle to .pkl file under output/mljob_result.pkl (supports complex Python objects)
|
|
24
|
+
- Fallback: JSON to .json file under output/mljob_result.json (for cross-version compatibility)
|
|
25
|
+
|
|
26
|
+
WHY THIS MODULE EXISTS:
|
|
27
|
+
-----------------------
|
|
28
|
+
Jobs submitted with client versions using the v1 protocol will write v1-format result files.
|
|
29
|
+
This module ensures that newer clients can still retrieve results from:
|
|
30
|
+
- Jobs submitted before the protocol change
|
|
31
|
+
- Jobs running in environments where snowflake.ml.jobs._interop is not available
|
|
32
|
+
(triggering the ImportError fallback to v1 in save_mljob_result)
|
|
33
|
+
|
|
34
|
+
RETRIEVAL FLOW:
|
|
35
|
+
---------------
|
|
36
|
+
fetch_result() implements the v1 retrieval logic:
|
|
37
|
+
1. Try to unpickle from .pkl file
|
|
38
|
+
2. On failure (version mismatch, missing imports, etc.), fall back to .json file
|
|
39
|
+
3. Convert the legacy dict format to ExecutionResult
|
|
40
|
+
4. Provide helpful error messages for common failure modes
|
|
41
|
+
|
|
42
|
+
REMOVAL IMPLICATIONS:
|
|
43
|
+
---------------------
|
|
44
|
+
Removing this module would break result retrieval for:
|
|
45
|
+
- Any jobs that were submitted with snowflake-ml-python<1.17.0 and are still running/completed
|
|
46
|
+
- Any jobs running in old runtime environments that fall back to v1 serialization
|
|
47
|
+
|
|
48
|
+
Safe to remove when:
|
|
49
|
+
- All ML Runtime images have been updated to include the new _interop modules
|
|
50
|
+
- Sufficient time has passed that no jobs using the old protocol are still retrievable
|
|
51
|
+
(consider retention policies for job history/logs)
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
import json
|
|
55
|
+
import os
|
|
56
|
+
import pickle
|
|
57
|
+
import sys
|
|
58
|
+
import traceback
|
|
59
|
+
from dataclasses import dataclass
|
|
60
|
+
from typing import Any, Optional, Union
|
|
61
|
+
|
|
62
|
+
from snowflake import snowpark
|
|
63
|
+
from snowflake.ml.jobs._interop import exception_utils, results
|
|
64
|
+
from snowflake.snowpark import exceptions as sp_exceptions
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass(frozen=True)
|
|
68
|
+
class ExecutionResult:
|
|
69
|
+
result: Any = None
|
|
70
|
+
exception: Optional[BaseException] = None
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def success(self) -> bool:
|
|
74
|
+
return self.exception is None
|
|
75
|
+
|
|
76
|
+
def to_dict(self) -> dict[str, Any]:
|
|
77
|
+
"""Return the serializable dictionary."""
|
|
78
|
+
if isinstance(self.exception, BaseException):
|
|
79
|
+
exc_type = type(self.exception)
|
|
80
|
+
return {
|
|
81
|
+
"success": False,
|
|
82
|
+
"exc_type": f"{exc_type.__module__}.{exc_type.__name__}",
|
|
83
|
+
"exc_value": self.exception,
|
|
84
|
+
"exc_tb": "".join(traceback.format_tb(self.exception.__traceback__)),
|
|
85
|
+
}
|
|
86
|
+
return {
|
|
87
|
+
"success": True,
|
|
88
|
+
"result_type": type(self.result).__qualname__,
|
|
89
|
+
"result": self.result,
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def from_dict(cls, result_dict: dict[str, Any]) -> "ExecutionResult":
|
|
94
|
+
if not isinstance(result_dict.get("success"), bool):
|
|
95
|
+
raise ValueError("Invalid result dictionary")
|
|
96
|
+
|
|
97
|
+
if result_dict["success"]:
|
|
98
|
+
# Load successful result
|
|
99
|
+
return cls(result=result_dict.get("result"))
|
|
100
|
+
|
|
101
|
+
# Load exception
|
|
102
|
+
exc_type = result_dict.get("exc_type", "RuntimeError")
|
|
103
|
+
exc_value = result_dict.get("exc_value", "Unknown error")
|
|
104
|
+
exc_tb = result_dict.get("exc_tb", "")
|
|
105
|
+
return cls(exception=load_exception(exc_type, exc_value, exc_tb))
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def fetch_result(
|
|
109
|
+
session: snowpark.Session, result_path: str, result_json: Optional[dict[str, Any]] = None
|
|
110
|
+
) -> ExecutionResult:
|
|
111
|
+
"""
|
|
112
|
+
Fetch the serialized result from the specified path.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
session: Snowpark Session to use for file operations.
|
|
116
|
+
result_path: The path to the serialized result file.
|
|
117
|
+
result_json: Optional pre-loaded JSON result dictionary to use instead of fetching from file.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
A dictionary containing the execution result if available, None otherwise.
|
|
121
|
+
|
|
122
|
+
Raises:
|
|
123
|
+
RuntimeError: If both pickle and JSON result retrieval fail.
|
|
124
|
+
"""
|
|
125
|
+
try:
|
|
126
|
+
with session.file.get_stream(result_path) as result_stream:
|
|
127
|
+
return ExecutionResult.from_dict(pickle.load(result_stream))
|
|
128
|
+
except (
|
|
129
|
+
sp_exceptions.SnowparkSQLException,
|
|
130
|
+
pickle.UnpicklingError,
|
|
131
|
+
TypeError,
|
|
132
|
+
ImportError,
|
|
133
|
+
AttributeError,
|
|
134
|
+
MemoryError,
|
|
135
|
+
) as pickle_error:
|
|
136
|
+
# Fall back to JSON result if loading pickled result fails for any reason
|
|
137
|
+
try:
|
|
138
|
+
if result_json is None:
|
|
139
|
+
result_json_path = os.path.splitext(result_path)[0] + ".json"
|
|
140
|
+
with session.file.get_stream(result_json_path) as result_stream:
|
|
141
|
+
result_json = json.load(result_stream)
|
|
142
|
+
return ExecutionResult.from_dict(result_json)
|
|
143
|
+
except Exception as json_error:
|
|
144
|
+
# Both pickle and JSON failed - provide helpful error message
|
|
145
|
+
raise RuntimeError(_fetch_result_error_message(pickle_error, result_path, json_error)) from pickle_error
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _fetch_result_error_message(error: Exception, result_path: str, json_error: Optional[Exception] = None) -> str:
|
|
149
|
+
"""Create helpful error messages for common result retrieval failures."""
|
|
150
|
+
|
|
151
|
+
# Package import issues
|
|
152
|
+
if isinstance(error, ImportError):
|
|
153
|
+
return f"Failed to retrieve job result: Package not installed in your local environment. Error: {str(error)}"
|
|
154
|
+
|
|
155
|
+
# Package versions differ between runtime and local environment
|
|
156
|
+
if isinstance(error, AttributeError):
|
|
157
|
+
return f"Failed to retrieve job result: Package version mismatch. Error: {str(error)}"
|
|
158
|
+
|
|
159
|
+
# Serialization issues
|
|
160
|
+
if isinstance(error, TypeError):
|
|
161
|
+
return f"Failed to retrieve job result: Non-serializable objects were returned. Error: {str(error)}"
|
|
162
|
+
|
|
163
|
+
# Python version pickling incompatibility
|
|
164
|
+
if isinstance(error, pickle.UnpicklingError) and "protocol" in str(error).lower():
|
|
165
|
+
client_version = f"Python {sys.version_info.major}.{sys.version_info.minor}"
|
|
166
|
+
runtime_version = "Python 3.10" # NOTE: This may be inaccurate, but this path isn't maintained anymore
|
|
167
|
+
return (
|
|
168
|
+
f"Failed to retrieve job result: Python version mismatch - job ran on {runtime_version}, "
|
|
169
|
+
f"local environment using Python {client_version}. Error: {str(error)}"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# File access issues
|
|
173
|
+
if isinstance(error, sp_exceptions.SnowparkSQLException):
|
|
174
|
+
if "not found" in str(error).lower() or "does not exist" in str(error).lower():
|
|
175
|
+
return (
|
|
176
|
+
f"Failed to retrieve job result: No result file found. Check job.get_logs() for execution "
|
|
177
|
+
f"errors. Error: {str(error)}"
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
return f"Failed to retrieve job result: Cannot access result file. Error: {str(error)}"
|
|
181
|
+
|
|
182
|
+
if isinstance(error, MemoryError):
|
|
183
|
+
return f"Failed to retrieve job result: Result too large for memory. Error: {str(error)}"
|
|
184
|
+
|
|
185
|
+
# Generic fallback
|
|
186
|
+
base_message = f"Failed to retrieve job result: {str(error)}"
|
|
187
|
+
if json_error:
|
|
188
|
+
base_message += f" (JSON fallback also failed: {str(json_error)})"
|
|
189
|
+
return base_message
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def load_exception(exc_type_name: str, exc_value: Union[Exception, str], exc_tb: str) -> BaseException:
|
|
193
|
+
"""
|
|
194
|
+
Create an exception with a string-formatted traceback.
|
|
195
|
+
|
|
196
|
+
When this exception is raised and not caught, it will display the original traceback.
|
|
197
|
+
When caught, it behaves like a regular exception without showing the traceback.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
exc_type_name: Name of the exception type (e.g., 'ValueError', 'RuntimeError')
|
|
201
|
+
exc_value: The deserialized exception value or exception string (i.e. message)
|
|
202
|
+
exc_tb: String representation of the traceback
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
An exception object with the original traceback information
|
|
206
|
+
|
|
207
|
+
# noqa: DAR401
|
|
208
|
+
"""
|
|
209
|
+
if isinstance(exc_value, Exception):
|
|
210
|
+
exception = exc_value
|
|
211
|
+
return exception_utils.attach_remote_error_info(exception, exc_type_name, str(exc_value), exc_tb)
|
|
212
|
+
return exception_utils.build_exception(exc_type_name, str(exc_value), exc_tb)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def load_legacy_result(
|
|
216
|
+
session: snowpark.Session, result_path: str, result_json: Optional[dict[str, Any]] = None
|
|
217
|
+
) -> results.ExecutionResult:
|
|
218
|
+
# Load result using legacy interop
|
|
219
|
+
legacy_result = fetch_result(session, result_path, result_json=result_json)
|
|
220
|
+
|
|
221
|
+
# Adapt legacy result to new result
|
|
222
|
+
return results.ExecutionResult(
|
|
223
|
+
success=legacy_result.success,
|
|
224
|
+
value=legacy_result.exception or legacy_result.result,
|
|
225
|
+
)
|
|
@@ -0,0 +1,471 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import logging
|
|
3
|
+
import pickle
|
|
4
|
+
import posixpath
|
|
5
|
+
import sys
|
|
6
|
+
from typing import Any, Callable, Optional, Protocol, Union, cast, runtime_checkable
|
|
7
|
+
|
|
8
|
+
from snowflake import snowpark
|
|
9
|
+
from snowflake.ml.jobs._interop import data_utils
|
|
10
|
+
from snowflake.ml.jobs._interop.dto_schema import (
|
|
11
|
+
BinaryManifest,
|
|
12
|
+
ParquetManifest,
|
|
13
|
+
ProtocolInfo,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
Condition = Union[type, tuple[type, ...], Callable[[Any], bool], None]
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SerializationError(TypeError):
|
|
22
|
+
"""Exception raised when a serialization protocol fails."""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DeserializationError(ValueError):
|
|
26
|
+
"""Exception raised when a serialization protocol fails."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class InvalidPayloadError(DeserializationError):
|
|
30
|
+
"""Exception raised when the payload is invalid."""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ProtocolMismatchError(DeserializationError):
|
|
34
|
+
"""Exception raised when the protocol of the serialization protocol is incompatible."""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class VersionMismatchError(ProtocolMismatchError):
|
|
38
|
+
"""Exception raised when the version of the serialization protocol is incompatible."""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ProtocolNotFoundError(SerializationError):
|
|
42
|
+
"""Exception raised when no suitable serialization protocol is available."""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@runtime_checkable
|
|
46
|
+
class SerializationProtocol(Protocol):
|
|
47
|
+
"""
|
|
48
|
+
More advanced protocol which supports more flexibility in how results are saved or loaded.
|
|
49
|
+
Results can be saved as one or more files, or directly inline in the PayloadManifest.
|
|
50
|
+
If saving as files, the PayloadManifest can save arbitrary "manifest" information.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def supported_types(self) -> Condition:
|
|
55
|
+
"""The types that the protocol supports."""
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def protocol_info(self) -> ProtocolInfo:
|
|
59
|
+
"""The information about the protocol."""
|
|
60
|
+
|
|
61
|
+
def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
|
|
62
|
+
"""Save the object to the destination directory."""
|
|
63
|
+
|
|
64
|
+
def load(
|
|
65
|
+
self,
|
|
66
|
+
payload_info: ProtocolInfo,
|
|
67
|
+
session: Optional[snowpark.Session] = None,
|
|
68
|
+
path_transform: Optional[Callable[[str], str]] = None,
|
|
69
|
+
) -> Any:
|
|
70
|
+
"""Load the object from the source directory."""
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class CloudPickleProtocol(SerializationProtocol):
|
|
74
|
+
"""
|
|
75
|
+
CloudPickle serialization protocol.
|
|
76
|
+
Uses BinaryManifest for manifest schema.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
DEFAULT_PATH = "mljob_extra.pkl"
|
|
80
|
+
|
|
81
|
+
def __init__(self) -> None:
|
|
82
|
+
import cloudpickle as cp
|
|
83
|
+
|
|
84
|
+
self._backend = cp
|
|
85
|
+
|
|
86
|
+
def _get_compatibility_error(self, payload_info: ProtocolInfo) -> Optional[Exception]:
|
|
87
|
+
"""Check compatibility and attempt load, raising helpful errors on failure."""
|
|
88
|
+
version_error = python_error = None
|
|
89
|
+
|
|
90
|
+
# Check cloudpickle version compatibility
|
|
91
|
+
if payload_info.version:
|
|
92
|
+
try:
|
|
93
|
+
from packaging import version
|
|
94
|
+
|
|
95
|
+
payload_major, current_major = (
|
|
96
|
+
version.parse(payload_info.version).major,
|
|
97
|
+
version.parse(self._backend.__version__).major,
|
|
98
|
+
)
|
|
99
|
+
if payload_major != current_major:
|
|
100
|
+
version_error = "cloudpickle version mismatch: payload={}, current={}".format(
|
|
101
|
+
payload_info.version, self._backend.__version__
|
|
102
|
+
)
|
|
103
|
+
except Exception:
|
|
104
|
+
if payload_info.version != self.protocol_info.version:
|
|
105
|
+
version_error = "cloudpickle version mismatch: payload={}, current={}".format(
|
|
106
|
+
payload_info.version, self.protocol_info.version
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Check Python version compatibility
|
|
110
|
+
if payload_info.metadata and "python_version" in payload_info.metadata:
|
|
111
|
+
payload_py, current_py = (
|
|
112
|
+
payload_info.metadata["python_version"],
|
|
113
|
+
f"{sys.version_info.major}.{sys.version_info.minor}",
|
|
114
|
+
)
|
|
115
|
+
if payload_py != current_py:
|
|
116
|
+
python_error = f"Python version mismatch: payload={payload_py}, current={current_py}"
|
|
117
|
+
|
|
118
|
+
if version_error or python_error:
|
|
119
|
+
errors = [err for err in [version_error, python_error] if err]
|
|
120
|
+
return VersionMismatchError(f"Load failed due to incompatibility: {'; '.join(errors)}")
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def supported_types(self) -> Condition:
|
|
125
|
+
return None # All types are supported
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def protocol_info(self) -> ProtocolInfo:
|
|
129
|
+
return ProtocolInfo(
|
|
130
|
+
name="cloudpickle",
|
|
131
|
+
version=self._backend.__version__,
|
|
132
|
+
metadata={
|
|
133
|
+
"python_version": f"{sys.version_info.major}.{sys.version_info.minor}",
|
|
134
|
+
},
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
|
|
138
|
+
"""Save the object to the destination directory."""
|
|
139
|
+
result_path = posixpath.join(dest_dir, self.DEFAULT_PATH)
|
|
140
|
+
with data_utils.open_stream(result_path, "wb", session=session) as f:
|
|
141
|
+
self._backend.dump(obj, f)
|
|
142
|
+
manifest: BinaryManifest = {"path": result_path}
|
|
143
|
+
return self.protocol_info.with_manifest(manifest)
|
|
144
|
+
|
|
145
|
+
def load(
|
|
146
|
+
self,
|
|
147
|
+
payload_info: ProtocolInfo,
|
|
148
|
+
session: Optional[snowpark.Session] = None,
|
|
149
|
+
path_transform: Optional[Callable[[str], str]] = None,
|
|
150
|
+
) -> Any:
|
|
151
|
+
"""Load the object from the source directory."""
|
|
152
|
+
if payload_info.name != self.protocol_info.name:
|
|
153
|
+
raise ProtocolMismatchError(
|
|
154
|
+
f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
payload_manifest = cast(BinaryManifest, payload_info.manifest)
|
|
158
|
+
try:
|
|
159
|
+
if payload_bytes := payload_manifest.get("bytes"):
|
|
160
|
+
return self._backend.loads(payload_bytes)
|
|
161
|
+
if payload_b64 := payload_manifest.get("base64"):
|
|
162
|
+
return self._backend.loads(base64.b64decode(payload_b64))
|
|
163
|
+
result_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"]
|
|
164
|
+
with data_utils.open_stream(result_path, "rb", session=session) as f:
|
|
165
|
+
return self._backend.load(f)
|
|
166
|
+
except (
|
|
167
|
+
pickle.UnpicklingError,
|
|
168
|
+
TypeError,
|
|
169
|
+
AttributeError,
|
|
170
|
+
MemoryError,
|
|
171
|
+
) as pickle_error:
|
|
172
|
+
if error := self._get_compatibility_error(payload_info):
|
|
173
|
+
raise error from pickle_error
|
|
174
|
+
raise
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class ArrowTableProtocol(SerializationProtocol):
|
|
178
|
+
"""
|
|
179
|
+
Arrow Table serialization protocol.
|
|
180
|
+
Uses ParquetManifest for manifest schema.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
DEFAULT_PATH_PATTERN = "mljob_extra_{0}.parquet"
|
|
184
|
+
|
|
185
|
+
def __init__(self) -> None:
|
|
186
|
+
import pyarrow as pa
|
|
187
|
+
import pyarrow.parquet as pq
|
|
188
|
+
|
|
189
|
+
self._pa = pa
|
|
190
|
+
self._pq = pq
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def supported_types(self) -> Condition:
|
|
194
|
+
return cast(type, self._pa.Table)
|
|
195
|
+
|
|
196
|
+
@property
|
|
197
|
+
def protocol_info(self) -> ProtocolInfo:
|
|
198
|
+
return ProtocolInfo(
|
|
199
|
+
name="pyarrow",
|
|
200
|
+
version=self._pa.__version__,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
|
|
204
|
+
"""Save the object to the destination directory."""
|
|
205
|
+
if not isinstance(obj, self._pa.Table):
|
|
206
|
+
raise SerializationError(f"Expected {self._pa.Table.__name__} object, got {type(obj).__name__}")
|
|
207
|
+
|
|
208
|
+
# TODO: Support partitioned writes for large datasets
|
|
209
|
+
result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN.format(0))
|
|
210
|
+
with data_utils.open_stream(result_path, "wb", session=session) as stream:
|
|
211
|
+
self._pq.write_table(obj, stream)
|
|
212
|
+
|
|
213
|
+
manifest: ParquetManifest = {"paths": [result_path]}
|
|
214
|
+
return self.protocol_info.with_manifest(manifest)
|
|
215
|
+
|
|
216
|
+
def load(
|
|
217
|
+
self,
|
|
218
|
+
payload_info: ProtocolInfo,
|
|
219
|
+
session: Optional[snowpark.Session] = None,
|
|
220
|
+
path_transform: Optional[Callable[[str], str]] = None,
|
|
221
|
+
) -> Any:
|
|
222
|
+
"""Load the object from the source directory."""
|
|
223
|
+
if payload_info.name != self.protocol_info.name:
|
|
224
|
+
raise ProtocolMismatchError(
|
|
225
|
+
f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
payload_manifest = cast(ParquetManifest, payload_info.manifest)
|
|
229
|
+
tables = []
|
|
230
|
+
for path in payload_manifest["paths"]:
|
|
231
|
+
transformed_path = path_transform(path) if path_transform else path
|
|
232
|
+
with data_utils.open_stream(transformed_path, "rb", session=session) as f:
|
|
233
|
+
table = self._pq.read_table(f)
|
|
234
|
+
tables.append(table)
|
|
235
|
+
return self._pa.concat_tables(tables) if len(tables) > 1 else tables[0]
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class PandasDataFrameProtocol(SerializationProtocol):
|
|
239
|
+
"""
|
|
240
|
+
Pandas DataFrame serialization protocol.
|
|
241
|
+
Uses ParquetManifest for manifest schema.
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
DEFAULT_PATH_PATTERN = "mljob_extra_{0}.parquet"
|
|
245
|
+
|
|
246
|
+
def __init__(self) -> None:
|
|
247
|
+
import pandas as pd
|
|
248
|
+
|
|
249
|
+
self._pd = pd
|
|
250
|
+
|
|
251
|
+
@property
|
|
252
|
+
def supported_types(self) -> Condition:
|
|
253
|
+
return cast(type, self._pd.DataFrame)
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def protocol_info(self) -> ProtocolInfo:
|
|
257
|
+
return ProtocolInfo(
|
|
258
|
+
name="pandas",
|
|
259
|
+
version=self._pd.__version__,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
|
|
263
|
+
"""Save the object to the destination directory."""
|
|
264
|
+
if not isinstance(obj, self._pd.DataFrame):
|
|
265
|
+
raise SerializationError(f"Expected {self._pd.DataFrame.__name__} object, got {type(obj).__name__}")
|
|
266
|
+
|
|
267
|
+
# TODO: Support partitioned writes for large datasets
|
|
268
|
+
result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN.format(0))
|
|
269
|
+
with data_utils.open_stream(result_path, "wb", session=session) as stream:
|
|
270
|
+
obj.to_parquet(stream)
|
|
271
|
+
|
|
272
|
+
manifest: ParquetManifest = {"paths": [result_path]}
|
|
273
|
+
return self.protocol_info.with_manifest(manifest)
|
|
274
|
+
|
|
275
|
+
def load(
|
|
276
|
+
self,
|
|
277
|
+
payload_info: ProtocolInfo,
|
|
278
|
+
session: Optional[snowpark.Session] = None,
|
|
279
|
+
path_transform: Optional[Callable[[str], str]] = None,
|
|
280
|
+
) -> Any:
|
|
281
|
+
"""Load the object from the source directory."""
|
|
282
|
+
if payload_info.name != self.protocol_info.name:
|
|
283
|
+
raise ProtocolMismatchError(
|
|
284
|
+
f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
payload_manifest = cast(ParquetManifest, payload_info.manifest)
|
|
288
|
+
dfs = []
|
|
289
|
+
for path in payload_manifest["paths"]:
|
|
290
|
+
transformed_path = path_transform(path) if path_transform else path
|
|
291
|
+
with data_utils.open_stream(transformed_path, "rb", session=session) as f:
|
|
292
|
+
df = self._pd.read_parquet(f)
|
|
293
|
+
dfs.append(df)
|
|
294
|
+
return self._pd.concat(dfs) if len(dfs) > 1 else dfs[0]
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class NumpyArrayProtocol(SerializationProtocol):
|
|
298
|
+
"""
|
|
299
|
+
Numpy Array serialization protocol.
|
|
300
|
+
Uses BinaryManifest for manifest schema.
|
|
301
|
+
"""
|
|
302
|
+
|
|
303
|
+
DEFAULT_PATH_PATTERN = "mljob_extra.npy"
|
|
304
|
+
|
|
305
|
+
def __init__(self) -> None:
|
|
306
|
+
import numpy as np
|
|
307
|
+
|
|
308
|
+
self._np = np
|
|
309
|
+
|
|
310
|
+
@property
|
|
311
|
+
def supported_types(self) -> Condition:
|
|
312
|
+
return cast(type, self._np.ndarray)
|
|
313
|
+
|
|
314
|
+
@property
|
|
315
|
+
def protocol_info(self) -> ProtocolInfo:
|
|
316
|
+
return ProtocolInfo(
|
|
317
|
+
name="numpy",
|
|
318
|
+
version=self._np.__version__,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
|
|
322
|
+
"""Save the object to the destination directory."""
|
|
323
|
+
if not isinstance(obj, self._np.ndarray):
|
|
324
|
+
raise SerializationError(f"Expected {self._np.ndarray.__name__} object, got {type(obj).__name__}")
|
|
325
|
+
result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN)
|
|
326
|
+
with data_utils.open_stream(result_path, "wb", session=session) as stream:
|
|
327
|
+
self._np.save(stream, obj)
|
|
328
|
+
|
|
329
|
+
manifest: BinaryManifest = {"path": result_path}
|
|
330
|
+
return self.protocol_info.with_manifest(manifest)
|
|
331
|
+
|
|
332
|
+
def load(
|
|
333
|
+
self,
|
|
334
|
+
payload_info: ProtocolInfo,
|
|
335
|
+
session: Optional[snowpark.Session] = None,
|
|
336
|
+
path_transform: Optional[Callable[[str], str]] = None,
|
|
337
|
+
) -> Any:
|
|
338
|
+
"""Load the object from the source directory."""
|
|
339
|
+
if payload_info.name != self.protocol_info.name:
|
|
340
|
+
raise ProtocolMismatchError(
|
|
341
|
+
f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
payload_manifest = cast(BinaryManifest, payload_info.manifest)
|
|
345
|
+
transformed_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"]
|
|
346
|
+
with data_utils.open_stream(transformed_path, "rb", session=session) as f:
|
|
347
|
+
return self._np.load(f)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class AutoProtocol(SerializationProtocol):
|
|
351
|
+
def __init__(self) -> None:
|
|
352
|
+
self._protocols: list[SerializationProtocol] = []
|
|
353
|
+
self._protocol_info = ProtocolInfo(
|
|
354
|
+
name="auto",
|
|
355
|
+
version=None,
|
|
356
|
+
metadata=None,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
@property
|
|
360
|
+
def supported_types(self) -> Condition:
|
|
361
|
+
return None # All types are supported
|
|
362
|
+
|
|
363
|
+
@property
|
|
364
|
+
def protocol_info(self) -> ProtocolInfo:
|
|
365
|
+
return self._protocol_info
|
|
366
|
+
|
|
367
|
+
def try_register_protocol(
|
|
368
|
+
self,
|
|
369
|
+
klass: type[SerializationProtocol],
|
|
370
|
+
*args: Any,
|
|
371
|
+
index: int = 0,
|
|
372
|
+
**kwargs: Any,
|
|
373
|
+
) -> None:
|
|
374
|
+
"""
|
|
375
|
+
Try to construct and register a protocol. If the protocol cannot be constructed,
|
|
376
|
+
log a warning and skip registration. By default (index=0), the most recently
|
|
377
|
+
registered protocol takes precedence.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
klass: The class of the protocol to register.
|
|
381
|
+
args: The positional arguments to pass to the protocol constructor.
|
|
382
|
+
index: The index to register the protocol at. If -1, the protocol is registered at the end of the list.
|
|
383
|
+
kwargs: The keyword arguments to pass to the protocol constructor.
|
|
384
|
+
"""
|
|
385
|
+
try:
|
|
386
|
+
protocol = klass(*args, **kwargs)
|
|
387
|
+
self.register_protocol(protocol, index=index)
|
|
388
|
+
except Exception as e:
|
|
389
|
+
logger.warning(f"Failed to register protocol {klass}: {e}")
|
|
390
|
+
|
|
391
|
+
def register_protocol(
|
|
392
|
+
self,
|
|
393
|
+
protocol: SerializationProtocol,
|
|
394
|
+
index: int = 0,
|
|
395
|
+
) -> None:
|
|
396
|
+
"""
|
|
397
|
+
Register a protocol with a condition. By default (index=0), the most recently
|
|
398
|
+
registered protocol takes precedence.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
protocol: The protocol to register.
|
|
402
|
+
index: The index to register the protocol at. If -1, the protocol is registered at the end of the list.
|
|
403
|
+
|
|
404
|
+
Raises:
|
|
405
|
+
ValueError: If the condition is invalid.
|
|
406
|
+
ValueError: If the index is invalid.
|
|
407
|
+
"""
|
|
408
|
+
# Validate condition
|
|
409
|
+
# TODO: Build lookup table of supported types to protocols (in priority order)
|
|
410
|
+
# for faster lookup at save/load time (instead of iterating over all protocols)
|
|
411
|
+
if not isinstance(protocol, SerializationProtocol):
|
|
412
|
+
raise ValueError(f"Invalid protocol type: {type(protocol)}. Expected SerializationProtocol.")
|
|
413
|
+
if index == -1:
|
|
414
|
+
self._protocols.append(protocol)
|
|
415
|
+
elif index < 0:
|
|
416
|
+
raise ValueError(f"Invalid index: {index}. Expected -1 or >= 0.")
|
|
417
|
+
else:
|
|
418
|
+
self._protocols.insert(index, protocol)
|
|
419
|
+
|
|
420
|
+
def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
|
|
421
|
+
"""Save the object to the destination directory."""
|
|
422
|
+
last_protocol_error = None
|
|
423
|
+
for protocol in self._protocols:
|
|
424
|
+
try:
|
|
425
|
+
if self._is_supported_type(obj, protocol):
|
|
426
|
+
logger.debug(f"Dumping object of type {type(obj)} with protocol {protocol}")
|
|
427
|
+
return protocol.save(obj, dest_dir, session)
|
|
428
|
+
except Exception as e:
|
|
429
|
+
logger.warning(f"Error dumping object {obj} with protocol {protocol}: {repr(e)}")
|
|
430
|
+
last_protocol_error = (protocol.protocol_info, e)
|
|
431
|
+
last_error_str = (
|
|
432
|
+
f", most recent error ({last_protocol_error[0]}): {repr(last_protocol_error[1])}"
|
|
433
|
+
if last_protocol_error
|
|
434
|
+
else ""
|
|
435
|
+
)
|
|
436
|
+
raise ProtocolNotFoundError(
|
|
437
|
+
f"No suitable protocol found for type {type(obj).__name__}"
|
|
438
|
+
f" (available: {', '.join(str(p.protocol_info) for p in self._protocols)}){last_error_str}"
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
def load(
|
|
442
|
+
self,
|
|
443
|
+
payload_info: ProtocolInfo,
|
|
444
|
+
session: Optional[snowpark.Session] = None,
|
|
445
|
+
path_transform: Optional[Callable[[str], str]] = None,
|
|
446
|
+
) -> Any:
|
|
447
|
+
"""Load the object from the source directory."""
|
|
448
|
+
last_error = None
|
|
449
|
+
for protocol in self._protocols:
|
|
450
|
+
if protocol.protocol_info.name == payload_info.name:
|
|
451
|
+
try:
|
|
452
|
+
return protocol.load(payload_info, session, path_transform)
|
|
453
|
+
except Exception as e:
|
|
454
|
+
logger.warning(f"Error loading object with protocol {protocol}: {repr(e)}")
|
|
455
|
+
last_error = e
|
|
456
|
+
if last_error:
|
|
457
|
+
raise last_error
|
|
458
|
+
raise ProtocolNotFoundError(
|
|
459
|
+
f"No protocol matching {payload_info} available"
|
|
460
|
+
f" (available: {', '.join(str(p.protocol_info) for p in self._protocols)})"
|
|
461
|
+
", possibly due to snowflake-ml-python package version mismatch"
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
def _is_supported_type(self, obj: Any, protocol: SerializationProtocol) -> bool:
|
|
465
|
+
if protocol.supported_types is None:
|
|
466
|
+
return True # None means all types are supported
|
|
467
|
+
elif isinstance(protocol.supported_types, (type, tuple)):
|
|
468
|
+
return isinstance(obj, protocol.supported_types)
|
|
469
|
+
elif callable(protocol.supported_types):
|
|
470
|
+
return protocol.supported_types(obj) is True
|
|
471
|
+
raise ValueError(f"Invalid supported types: {protocol.supported_types} for protocol {protocol}")
|