snowflake-ml-python 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
  2. snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
  3. snowflake/ml/jobs/__init__.py +4 -0
  4. snowflake/ml/jobs/_interop/__init__.py +0 -0
  5. snowflake/ml/jobs/_interop/data_utils.py +124 -0
  6. snowflake/ml/jobs/_interop/dto_schema.py +95 -0
  7. snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
  8. snowflake/ml/jobs/_interop/legacy.py +225 -0
  9. snowflake/ml/jobs/_interop/protocols.py +471 -0
  10. snowflake/ml/jobs/_interop/results.py +51 -0
  11. snowflake/ml/jobs/_interop/utils.py +144 -0
  12. snowflake/ml/jobs/_utils/constants.py +4 -1
  13. snowflake/ml/jobs/_utils/feature_flags.py +37 -5
  14. snowflake/ml/jobs/_utils/payload_utils.py +1 -1
  15. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
  16. snowflake/ml/jobs/_utils/spec_utils.py +2 -1
  17. snowflake/ml/jobs/_utils/types.py +10 -0
  18. snowflake/ml/jobs/job.py +168 -36
  19. snowflake/ml/jobs/manager.py +36 -38
  20. snowflake/ml/model/_client/model/model_version_impl.py +39 -7
  21. snowflake/ml/model/_client/ops/model_ops.py +4 -0
  22. snowflake/ml/model/_client/sql/model_version.py +3 -1
  23. snowflake/ml/model/_model_composer/model_method/model_method.py +7 -2
  24. snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
  25. snowflake/ml/model/_packager/model_env/model_env.py +22 -5
  26. snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
  27. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  28. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  29. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
  30. snowflake/ml/version.py +1 -1
  31. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/METADATA +26 -4
  32. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/RECORD +35 -27
  33. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/WHEEL +0 -0
  34. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/licenses/LICENSE.txt +0 -0
  35. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,225 @@
1
+ """Legacy result serialization protocol support for ML Jobs.
2
+
3
+ This module provides backward compatibility with the result serialization protocol used by
4
+ mljob_launcher.py prior to snowflake-ml-python>=1.17.0
5
+
6
+ LEGACY PROTOCOL (v1):
7
+ ---------------------
8
+ The old serialization protocol (save_mljob_result_v1 in mljob_launcher.py) worked as follows:
9
+
10
+ 1. Results were stored in an ExecutionResult dataclass with two optional fields:
11
+ - result: Any = None # For successful executions
12
+ - exception: BaseException = None # For failed executions
13
+
14
+ 2. The ExecutionResult was converted to a dictionary via to_dict():
15
+ Success case:
16
+ {"success": True, "result_type": <type qualname>, "result": <value>}
17
+
18
+ Failure case:
19
+ {"success": False, "exc_type": "<module>.<class>", "exc_value": <exception>,
20
+ "exc_tb": <formatted traceback string>}
21
+
22
+ 3. The dictionary was serialized TWICE for fault tolerance:
23
+ - Primary: cloudpickle to .pkl file under output/mljob_result.pkl (supports complex Python objects)
24
+ - Fallback: JSON to .json file under output/mljob_result.json (for cross-version compatibility)
25
+
26
+ WHY THIS MODULE EXISTS:
27
+ -----------------------
28
+ Jobs submitted with client versions using the v1 protocol will write v1-format result files.
29
+ This module ensures that newer clients can still retrieve results from:
30
+ - Jobs submitted before the protocol change
31
+ - Jobs running in environments where snowflake.ml.jobs._interop is not available
32
+ (triggering the ImportError fallback to v1 in save_mljob_result)
33
+
34
+ RETRIEVAL FLOW:
35
+ ---------------
36
+ fetch_result() implements the v1 retrieval logic:
37
+ 1. Try to unpickle from .pkl file
38
+ 2. On failure (version mismatch, missing imports, etc.), fall back to .json file
39
+ 3. Convert the legacy dict format to ExecutionResult
40
+ 4. Provide helpful error messages for common failure modes
41
+
42
+ REMOVAL IMPLICATIONS:
43
+ ---------------------
44
+ Removing this module would break result retrieval for:
45
+ - Any jobs that were submitted with snowflake-ml-python<1.17.0 and are still running/completed
46
+ - Any jobs running in old runtime environments that fall back to v1 serialization
47
+
48
+ Safe to remove when:
49
+ - All ML Runtime images have been updated to include the new _interop modules
50
+ - Sufficient time has passed that no jobs using the old protocol are still retrievable
51
+ (consider retention policies for job history/logs)
52
+ """
53
+
54
+ import json
55
+ import os
56
+ import pickle
57
+ import sys
58
+ import traceback
59
+ from dataclasses import dataclass
60
+ from typing import Any, Optional, Union
61
+
62
+ from snowflake import snowpark
63
+ from snowflake.ml.jobs._interop import exception_utils, results
64
+ from snowflake.snowpark import exceptions as sp_exceptions
65
+
66
+
67
+ @dataclass(frozen=True)
68
+ class ExecutionResult:
69
+ result: Any = None
70
+ exception: Optional[BaseException] = None
71
+
72
+ @property
73
+ def success(self) -> bool:
74
+ return self.exception is None
75
+
76
+ def to_dict(self) -> dict[str, Any]:
77
+ """Return the serializable dictionary."""
78
+ if isinstance(self.exception, BaseException):
79
+ exc_type = type(self.exception)
80
+ return {
81
+ "success": False,
82
+ "exc_type": f"{exc_type.__module__}.{exc_type.__name__}",
83
+ "exc_value": self.exception,
84
+ "exc_tb": "".join(traceback.format_tb(self.exception.__traceback__)),
85
+ }
86
+ return {
87
+ "success": True,
88
+ "result_type": type(self.result).__qualname__,
89
+ "result": self.result,
90
+ }
91
+
92
+ @classmethod
93
+ def from_dict(cls, result_dict: dict[str, Any]) -> "ExecutionResult":
94
+ if not isinstance(result_dict.get("success"), bool):
95
+ raise ValueError("Invalid result dictionary")
96
+
97
+ if result_dict["success"]:
98
+ # Load successful result
99
+ return cls(result=result_dict.get("result"))
100
+
101
+ # Load exception
102
+ exc_type = result_dict.get("exc_type", "RuntimeError")
103
+ exc_value = result_dict.get("exc_value", "Unknown error")
104
+ exc_tb = result_dict.get("exc_tb", "")
105
+ return cls(exception=load_exception(exc_type, exc_value, exc_tb))
106
+
107
+
108
+ def fetch_result(
109
+ session: snowpark.Session, result_path: str, result_json: Optional[dict[str, Any]] = None
110
+ ) -> ExecutionResult:
111
+ """
112
+ Fetch the serialized result from the specified path.
113
+
114
+ Args:
115
+ session: Snowpark Session to use for file operations.
116
+ result_path: The path to the serialized result file.
117
+ result_json: Optional pre-loaded JSON result dictionary to use instead of fetching from file.
118
+
119
+ Returns:
120
+ A dictionary containing the execution result if available, None otherwise.
121
+
122
+ Raises:
123
+ RuntimeError: If both pickle and JSON result retrieval fail.
124
+ """
125
+ try:
126
+ with session.file.get_stream(result_path) as result_stream:
127
+ return ExecutionResult.from_dict(pickle.load(result_stream))
128
+ except (
129
+ sp_exceptions.SnowparkSQLException,
130
+ pickle.UnpicklingError,
131
+ TypeError,
132
+ ImportError,
133
+ AttributeError,
134
+ MemoryError,
135
+ ) as pickle_error:
136
+ # Fall back to JSON result if loading pickled result fails for any reason
137
+ try:
138
+ if result_json is None:
139
+ result_json_path = os.path.splitext(result_path)[0] + ".json"
140
+ with session.file.get_stream(result_json_path) as result_stream:
141
+ result_json = json.load(result_stream)
142
+ return ExecutionResult.from_dict(result_json)
143
+ except Exception as json_error:
144
+ # Both pickle and JSON failed - provide helpful error message
145
+ raise RuntimeError(_fetch_result_error_message(pickle_error, result_path, json_error)) from pickle_error
146
+
147
+
148
+ def _fetch_result_error_message(error: Exception, result_path: str, json_error: Optional[Exception] = None) -> str:
149
+ """Create helpful error messages for common result retrieval failures."""
150
+
151
+ # Package import issues
152
+ if isinstance(error, ImportError):
153
+ return f"Failed to retrieve job result: Package not installed in your local environment. Error: {str(error)}"
154
+
155
+ # Package versions differ between runtime and local environment
156
+ if isinstance(error, AttributeError):
157
+ return f"Failed to retrieve job result: Package version mismatch. Error: {str(error)}"
158
+
159
+ # Serialization issues
160
+ if isinstance(error, TypeError):
161
+ return f"Failed to retrieve job result: Non-serializable objects were returned. Error: {str(error)}"
162
+
163
+ # Python version pickling incompatibility
164
+ if isinstance(error, pickle.UnpicklingError) and "protocol" in str(error).lower():
165
+ client_version = f"Python {sys.version_info.major}.{sys.version_info.minor}"
166
+ runtime_version = "Python 3.10" # NOTE: This may be inaccurate, but this path isn't maintained anymore
167
+ return (
168
+ f"Failed to retrieve job result: Python version mismatch - job ran on {runtime_version}, "
169
+ f"local environment using Python {client_version}. Error: {str(error)}"
170
+ )
171
+
172
+ # File access issues
173
+ if isinstance(error, sp_exceptions.SnowparkSQLException):
174
+ if "not found" in str(error).lower() or "does not exist" in str(error).lower():
175
+ return (
176
+ f"Failed to retrieve job result: No result file found. Check job.get_logs() for execution "
177
+ f"errors. Error: {str(error)}"
178
+ )
179
+ else:
180
+ return f"Failed to retrieve job result: Cannot access result file. Error: {str(error)}"
181
+
182
+ if isinstance(error, MemoryError):
183
+ return f"Failed to retrieve job result: Result too large for memory. Error: {str(error)}"
184
+
185
+ # Generic fallback
186
+ base_message = f"Failed to retrieve job result: {str(error)}"
187
+ if json_error:
188
+ base_message += f" (JSON fallback also failed: {str(json_error)})"
189
+ return base_message
190
+
191
+
192
+ def load_exception(exc_type_name: str, exc_value: Union[Exception, str], exc_tb: str) -> BaseException:
193
+ """
194
+ Create an exception with a string-formatted traceback.
195
+
196
+ When this exception is raised and not caught, it will display the original traceback.
197
+ When caught, it behaves like a regular exception without showing the traceback.
198
+
199
+ Args:
200
+ exc_type_name: Name of the exception type (e.g., 'ValueError', 'RuntimeError')
201
+ exc_value: The deserialized exception value or exception string (i.e. message)
202
+ exc_tb: String representation of the traceback
203
+
204
+ Returns:
205
+ An exception object with the original traceback information
206
+
207
+ # noqa: DAR401
208
+ """
209
+ if isinstance(exc_value, Exception):
210
+ exception = exc_value
211
+ return exception_utils.attach_remote_error_info(exception, exc_type_name, str(exc_value), exc_tb)
212
+ return exception_utils.build_exception(exc_type_name, str(exc_value), exc_tb)
213
+
214
+
215
+ def load_legacy_result(
216
+ session: snowpark.Session, result_path: str, result_json: Optional[dict[str, Any]] = None
217
+ ) -> results.ExecutionResult:
218
+ # Load result using legacy interop
219
+ legacy_result = fetch_result(session, result_path, result_json=result_json)
220
+
221
+ # Adapt legacy result to new result
222
+ return results.ExecutionResult(
223
+ success=legacy_result.success,
224
+ value=legacy_result.exception or legacy_result.result,
225
+ )
@@ -0,0 +1,471 @@
1
+ import base64
2
+ import logging
3
+ import pickle
4
+ import posixpath
5
+ import sys
6
+ from typing import Any, Callable, Optional, Protocol, Union, cast, runtime_checkable
7
+
8
+ from snowflake import snowpark
9
+ from snowflake.ml.jobs._interop import data_utils
10
+ from snowflake.ml.jobs._interop.dto_schema import (
11
+ BinaryManifest,
12
+ ParquetManifest,
13
+ ProtocolInfo,
14
+ )
15
+
16
+ Condition = Union[type, tuple[type, ...], Callable[[Any], bool], None]
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class SerializationError(TypeError):
22
+ """Exception raised when a serialization protocol fails."""
23
+
24
+
25
+ class DeserializationError(ValueError):
26
+ """Exception raised when a serialization protocol fails."""
27
+
28
+
29
+ class InvalidPayloadError(DeserializationError):
30
+ """Exception raised when the payload is invalid."""
31
+
32
+
33
+ class ProtocolMismatchError(DeserializationError):
34
+ """Exception raised when the protocol of the serialization protocol is incompatible."""
35
+
36
+
37
+ class VersionMismatchError(ProtocolMismatchError):
38
+ """Exception raised when the version of the serialization protocol is incompatible."""
39
+
40
+
41
+ class ProtocolNotFoundError(SerializationError):
42
+ """Exception raised when no suitable serialization protocol is available."""
43
+
44
+
45
+ @runtime_checkable
46
+ class SerializationProtocol(Protocol):
47
+ """
48
+ More advanced protocol which supports more flexibility in how results are saved or loaded.
49
+ Results can be saved as one or more files, or directly inline in the PayloadManifest.
50
+ If saving as files, the PayloadManifest can save arbitrary "manifest" information.
51
+ """
52
+
53
+ @property
54
+ def supported_types(self) -> Condition:
55
+ """The types that the protocol supports."""
56
+
57
+ @property
58
+ def protocol_info(self) -> ProtocolInfo:
59
+ """The information about the protocol."""
60
+
61
+ def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
62
+ """Save the object to the destination directory."""
63
+
64
+ def load(
65
+ self,
66
+ payload_info: ProtocolInfo,
67
+ session: Optional[snowpark.Session] = None,
68
+ path_transform: Optional[Callable[[str], str]] = None,
69
+ ) -> Any:
70
+ """Load the object from the source directory."""
71
+
72
+
73
+ class CloudPickleProtocol(SerializationProtocol):
74
+ """
75
+ CloudPickle serialization protocol.
76
+ Uses BinaryManifest for manifest schema.
77
+ """
78
+
79
+ DEFAULT_PATH = "mljob_extra.pkl"
80
+
81
+ def __init__(self) -> None:
82
+ import cloudpickle as cp
83
+
84
+ self._backend = cp
85
+
86
+ def _get_compatibility_error(self, payload_info: ProtocolInfo) -> Optional[Exception]:
87
+ """Check compatibility and attempt load, raising helpful errors on failure."""
88
+ version_error = python_error = None
89
+
90
+ # Check cloudpickle version compatibility
91
+ if payload_info.version:
92
+ try:
93
+ from packaging import version
94
+
95
+ payload_major, current_major = (
96
+ version.parse(payload_info.version).major,
97
+ version.parse(self._backend.__version__).major,
98
+ )
99
+ if payload_major != current_major:
100
+ version_error = "cloudpickle version mismatch: payload={}, current={}".format(
101
+ payload_info.version, self._backend.__version__
102
+ )
103
+ except Exception:
104
+ if payload_info.version != self.protocol_info.version:
105
+ version_error = "cloudpickle version mismatch: payload={}, current={}".format(
106
+ payload_info.version, self.protocol_info.version
107
+ )
108
+
109
+ # Check Python version compatibility
110
+ if payload_info.metadata and "python_version" in payload_info.metadata:
111
+ payload_py, current_py = (
112
+ payload_info.metadata["python_version"],
113
+ f"{sys.version_info.major}.{sys.version_info.minor}",
114
+ )
115
+ if payload_py != current_py:
116
+ python_error = f"Python version mismatch: payload={payload_py}, current={current_py}"
117
+
118
+ if version_error or python_error:
119
+ errors = [err for err in [version_error, python_error] if err]
120
+ return VersionMismatchError(f"Load failed due to incompatibility: {'; '.join(errors)}")
121
+ return None
122
+
123
+ @property
124
+ def supported_types(self) -> Condition:
125
+ return None # All types are supported
126
+
127
+ @property
128
+ def protocol_info(self) -> ProtocolInfo:
129
+ return ProtocolInfo(
130
+ name="cloudpickle",
131
+ version=self._backend.__version__,
132
+ metadata={
133
+ "python_version": f"{sys.version_info.major}.{sys.version_info.minor}",
134
+ },
135
+ )
136
+
137
+ def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
138
+ """Save the object to the destination directory."""
139
+ result_path = posixpath.join(dest_dir, self.DEFAULT_PATH)
140
+ with data_utils.open_stream(result_path, "wb", session=session) as f:
141
+ self._backend.dump(obj, f)
142
+ manifest: BinaryManifest = {"path": result_path}
143
+ return self.protocol_info.with_manifest(manifest)
144
+
145
+ def load(
146
+ self,
147
+ payload_info: ProtocolInfo,
148
+ session: Optional[snowpark.Session] = None,
149
+ path_transform: Optional[Callable[[str], str]] = None,
150
+ ) -> Any:
151
+ """Load the object from the source directory."""
152
+ if payload_info.name != self.protocol_info.name:
153
+ raise ProtocolMismatchError(
154
+ f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
155
+ )
156
+
157
+ payload_manifest = cast(BinaryManifest, payload_info.manifest)
158
+ try:
159
+ if payload_bytes := payload_manifest.get("bytes"):
160
+ return self._backend.loads(payload_bytes)
161
+ if payload_b64 := payload_manifest.get("base64"):
162
+ return self._backend.loads(base64.b64decode(payload_b64))
163
+ result_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"]
164
+ with data_utils.open_stream(result_path, "rb", session=session) as f:
165
+ return self._backend.load(f)
166
+ except (
167
+ pickle.UnpicklingError,
168
+ TypeError,
169
+ AttributeError,
170
+ MemoryError,
171
+ ) as pickle_error:
172
+ if error := self._get_compatibility_error(payload_info):
173
+ raise error from pickle_error
174
+ raise
175
+
176
+
177
+ class ArrowTableProtocol(SerializationProtocol):
178
+ """
179
+ Arrow Table serialization protocol.
180
+ Uses ParquetManifest for manifest schema.
181
+ """
182
+
183
+ DEFAULT_PATH_PATTERN = "mljob_extra_{0}.parquet"
184
+
185
+ def __init__(self) -> None:
186
+ import pyarrow as pa
187
+ import pyarrow.parquet as pq
188
+
189
+ self._pa = pa
190
+ self._pq = pq
191
+
192
+ @property
193
+ def supported_types(self) -> Condition:
194
+ return cast(type, self._pa.Table)
195
+
196
+ @property
197
+ def protocol_info(self) -> ProtocolInfo:
198
+ return ProtocolInfo(
199
+ name="pyarrow",
200
+ version=self._pa.__version__,
201
+ )
202
+
203
+ def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
204
+ """Save the object to the destination directory."""
205
+ if not isinstance(obj, self._pa.Table):
206
+ raise SerializationError(f"Expected {self._pa.Table.__name__} object, got {type(obj).__name__}")
207
+
208
+ # TODO: Support partitioned writes for large datasets
209
+ result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN.format(0))
210
+ with data_utils.open_stream(result_path, "wb", session=session) as stream:
211
+ self._pq.write_table(obj, stream)
212
+
213
+ manifest: ParquetManifest = {"paths": [result_path]}
214
+ return self.protocol_info.with_manifest(manifest)
215
+
216
+ def load(
217
+ self,
218
+ payload_info: ProtocolInfo,
219
+ session: Optional[snowpark.Session] = None,
220
+ path_transform: Optional[Callable[[str], str]] = None,
221
+ ) -> Any:
222
+ """Load the object from the source directory."""
223
+ if payload_info.name != self.protocol_info.name:
224
+ raise ProtocolMismatchError(
225
+ f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
226
+ )
227
+
228
+ payload_manifest = cast(ParquetManifest, payload_info.manifest)
229
+ tables = []
230
+ for path in payload_manifest["paths"]:
231
+ transformed_path = path_transform(path) if path_transform else path
232
+ with data_utils.open_stream(transformed_path, "rb", session=session) as f:
233
+ table = self._pq.read_table(f)
234
+ tables.append(table)
235
+ return self._pa.concat_tables(tables) if len(tables) > 1 else tables[0]
236
+
237
+
238
+ class PandasDataFrameProtocol(SerializationProtocol):
239
+ """
240
+ Pandas DataFrame serialization protocol.
241
+ Uses ParquetManifest for manifest schema.
242
+ """
243
+
244
+ DEFAULT_PATH_PATTERN = "mljob_extra_{0}.parquet"
245
+
246
+ def __init__(self) -> None:
247
+ import pandas as pd
248
+
249
+ self._pd = pd
250
+
251
+ @property
252
+ def supported_types(self) -> Condition:
253
+ return cast(type, self._pd.DataFrame)
254
+
255
+ @property
256
+ def protocol_info(self) -> ProtocolInfo:
257
+ return ProtocolInfo(
258
+ name="pandas",
259
+ version=self._pd.__version__,
260
+ )
261
+
262
+ def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
263
+ """Save the object to the destination directory."""
264
+ if not isinstance(obj, self._pd.DataFrame):
265
+ raise SerializationError(f"Expected {self._pd.DataFrame.__name__} object, got {type(obj).__name__}")
266
+
267
+ # TODO: Support partitioned writes for large datasets
268
+ result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN.format(0))
269
+ with data_utils.open_stream(result_path, "wb", session=session) as stream:
270
+ obj.to_parquet(stream)
271
+
272
+ manifest: ParquetManifest = {"paths": [result_path]}
273
+ return self.protocol_info.with_manifest(manifest)
274
+
275
+ def load(
276
+ self,
277
+ payload_info: ProtocolInfo,
278
+ session: Optional[snowpark.Session] = None,
279
+ path_transform: Optional[Callable[[str], str]] = None,
280
+ ) -> Any:
281
+ """Load the object from the source directory."""
282
+ if payload_info.name != self.protocol_info.name:
283
+ raise ProtocolMismatchError(
284
+ f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
285
+ )
286
+
287
+ payload_manifest = cast(ParquetManifest, payload_info.manifest)
288
+ dfs = []
289
+ for path in payload_manifest["paths"]:
290
+ transformed_path = path_transform(path) if path_transform else path
291
+ with data_utils.open_stream(transformed_path, "rb", session=session) as f:
292
+ df = self._pd.read_parquet(f)
293
+ dfs.append(df)
294
+ return self._pd.concat(dfs) if len(dfs) > 1 else dfs[0]
295
+
296
+
297
+ class NumpyArrayProtocol(SerializationProtocol):
298
+ """
299
+ Numpy Array serialization protocol.
300
+ Uses BinaryManifest for manifest schema.
301
+ """
302
+
303
+ DEFAULT_PATH_PATTERN = "mljob_extra.npy"
304
+
305
+ def __init__(self) -> None:
306
+ import numpy as np
307
+
308
+ self._np = np
309
+
310
+ @property
311
+ def supported_types(self) -> Condition:
312
+ return cast(type, self._np.ndarray)
313
+
314
+ @property
315
+ def protocol_info(self) -> ProtocolInfo:
316
+ return ProtocolInfo(
317
+ name="numpy",
318
+ version=self._np.__version__,
319
+ )
320
+
321
+ def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
322
+ """Save the object to the destination directory."""
323
+ if not isinstance(obj, self._np.ndarray):
324
+ raise SerializationError(f"Expected {self._np.ndarray.__name__} object, got {type(obj).__name__}")
325
+ result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN)
326
+ with data_utils.open_stream(result_path, "wb", session=session) as stream:
327
+ self._np.save(stream, obj)
328
+
329
+ manifest: BinaryManifest = {"path": result_path}
330
+ return self.protocol_info.with_manifest(manifest)
331
+
332
+ def load(
333
+ self,
334
+ payload_info: ProtocolInfo,
335
+ session: Optional[snowpark.Session] = None,
336
+ path_transform: Optional[Callable[[str], str]] = None,
337
+ ) -> Any:
338
+ """Load the object from the source directory."""
339
+ if payload_info.name != self.protocol_info.name:
340
+ raise ProtocolMismatchError(
341
+ f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
342
+ )
343
+
344
+ payload_manifest = cast(BinaryManifest, payload_info.manifest)
345
+ transformed_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"]
346
+ with data_utils.open_stream(transformed_path, "rb", session=session) as f:
347
+ return self._np.load(f)
348
+
349
+
350
+ class AutoProtocol(SerializationProtocol):
351
+ def __init__(self) -> None:
352
+ self._protocols: list[SerializationProtocol] = []
353
+ self._protocol_info = ProtocolInfo(
354
+ name="auto",
355
+ version=None,
356
+ metadata=None,
357
+ )
358
+
359
+ @property
360
+ def supported_types(self) -> Condition:
361
+ return None # All types are supported
362
+
363
+ @property
364
+ def protocol_info(self) -> ProtocolInfo:
365
+ return self._protocol_info
366
+
367
+ def try_register_protocol(
368
+ self,
369
+ klass: type[SerializationProtocol],
370
+ *args: Any,
371
+ index: int = 0,
372
+ **kwargs: Any,
373
+ ) -> None:
374
+ """
375
+ Try to construct and register a protocol. If the protocol cannot be constructed,
376
+ log a warning and skip registration. By default (index=0), the most recently
377
+ registered protocol takes precedence.
378
+
379
+ Args:
380
+ klass: The class of the protocol to register.
381
+ args: The positional arguments to pass to the protocol constructor.
382
+ index: The index to register the protocol at. If -1, the protocol is registered at the end of the list.
383
+ kwargs: The keyword arguments to pass to the protocol constructor.
384
+ """
385
+ try:
386
+ protocol = klass(*args, **kwargs)
387
+ self.register_protocol(protocol, index=index)
388
+ except Exception as e:
389
+ logger.warning(f"Failed to register protocol {klass}: {e}")
390
+
391
+ def register_protocol(
392
+ self,
393
+ protocol: SerializationProtocol,
394
+ index: int = 0,
395
+ ) -> None:
396
+ """
397
+ Register a protocol with a condition. By default (index=0), the most recently
398
+ registered protocol takes precedence.
399
+
400
+ Args:
401
+ protocol: The protocol to register.
402
+ index: The index to register the protocol at. If -1, the protocol is registered at the end of the list.
403
+
404
+ Raises:
405
+ ValueError: If the condition is invalid.
406
+ ValueError: If the index is invalid.
407
+ """
408
+ # Validate condition
409
+ # TODO: Build lookup table of supported types to protocols (in priority order)
410
+ # for faster lookup at save/load time (instead of iterating over all protocols)
411
+ if not isinstance(protocol, SerializationProtocol):
412
+ raise ValueError(f"Invalid protocol type: {type(protocol)}. Expected SerializationProtocol.")
413
+ if index == -1:
414
+ self._protocols.append(protocol)
415
+ elif index < 0:
416
+ raise ValueError(f"Invalid index: {index}. Expected -1 or >= 0.")
417
+ else:
418
+ self._protocols.insert(index, protocol)
419
+
420
+ def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
421
+ """Save the object to the destination directory."""
422
+ last_protocol_error = None
423
+ for protocol in self._protocols:
424
+ try:
425
+ if self._is_supported_type(obj, protocol):
426
+ logger.debug(f"Dumping object of type {type(obj)} with protocol {protocol}")
427
+ return protocol.save(obj, dest_dir, session)
428
+ except Exception as e:
429
+ logger.warning(f"Error dumping object {obj} with protocol {protocol}: {repr(e)}")
430
+ last_protocol_error = (protocol.protocol_info, e)
431
+ last_error_str = (
432
+ f", most recent error ({last_protocol_error[0]}): {repr(last_protocol_error[1])}"
433
+ if last_protocol_error
434
+ else ""
435
+ )
436
+ raise ProtocolNotFoundError(
437
+ f"No suitable protocol found for type {type(obj).__name__}"
438
+ f" (available: {', '.join(str(p.protocol_info) for p in self._protocols)}){last_error_str}"
439
+ )
440
+
441
+ def load(
442
+ self,
443
+ payload_info: ProtocolInfo,
444
+ session: Optional[snowpark.Session] = None,
445
+ path_transform: Optional[Callable[[str], str]] = None,
446
+ ) -> Any:
447
+ """Load the object from the source directory."""
448
+ last_error = None
449
+ for protocol in self._protocols:
450
+ if protocol.protocol_info.name == payload_info.name:
451
+ try:
452
+ return protocol.load(payload_info, session, path_transform)
453
+ except Exception as e:
454
+ logger.warning(f"Error loading object with protocol {protocol}: {repr(e)}")
455
+ last_error = e
456
+ if last_error:
457
+ raise last_error
458
+ raise ProtocolNotFoundError(
459
+ f"No protocol matching {payload_info} available"
460
+ f" (available: {', '.join(str(p.protocol_info) for p in self._protocols)})"
461
+ ", possibly due to snowflake-ml-python package version mismatch"
462
+ )
463
+
464
+ def _is_supported_type(self, obj: Any, protocol: SerializationProtocol) -> bool:
465
+ if protocol.supported_types is None:
466
+ return True # None means all types are supported
467
+ elif isinstance(protocol.supported_types, (type, tuple)):
468
+ return isinstance(obj, protocol.supported_types)
469
+ elif callable(protocol.supported_types):
470
+ return protocol.supported_types(obj) is True
471
+ raise ValueError(f"Invalid supported types: {protocol.supported_types} for protocol {protocol}")