snowflake-ml-python 1.16.0__py3-none-any.whl → 1.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
  2. snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
  3. snowflake/ml/_internal/telemetry.py +3 -2
  4. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +17 -12
  5. snowflake/ml/experiment/callback/keras.py +3 -0
  6. snowflake/ml/experiment/callback/lightgbm.py +3 -0
  7. snowflake/ml/experiment/callback/xgboost.py +3 -0
  8. snowflake/ml/experiment/experiment_tracking.py +19 -7
  9. snowflake/ml/feature_store/feature_store.py +236 -61
  10. snowflake/ml/jobs/__init__.py +4 -0
  11. snowflake/ml/jobs/_interop/__init__.py +0 -0
  12. snowflake/ml/jobs/_interop/data_utils.py +124 -0
  13. snowflake/ml/jobs/_interop/dto_schema.py +95 -0
  14. snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
  15. snowflake/ml/jobs/_interop/legacy.py +225 -0
  16. snowflake/ml/jobs/_interop/protocols.py +471 -0
  17. snowflake/ml/jobs/_interop/results.py +51 -0
  18. snowflake/ml/jobs/_interop/utils.py +144 -0
  19. snowflake/ml/jobs/_utils/constants.py +16 -2
  20. snowflake/ml/jobs/_utils/feature_flags.py +37 -5
  21. snowflake/ml/jobs/_utils/payload_utils.py +8 -2
  22. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
  23. snowflake/ml/jobs/_utils/spec_utils.py +2 -1
  24. snowflake/ml/jobs/_utils/stage_utils.py +4 -0
  25. snowflake/ml/jobs/_utils/types.py +15 -0
  26. snowflake/ml/jobs/job.py +186 -40
  27. snowflake/ml/jobs/manager.py +48 -39
  28. snowflake/ml/model/__init__.py +19 -0
  29. snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
  30. snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
  31. snowflake/ml/model/_client/model/model_version_impl.py +168 -18
  32. snowflake/ml/model/_client/ops/model_ops.py +4 -0
  33. snowflake/ml/model/_client/ops/service_ops.py +3 -0
  34. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  35. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  36. snowflake/ml/model/_client/sql/model_version.py +3 -1
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
  38. snowflake/ml/model/_model_composer/model_method/model_method.py +11 -3
  39. snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
  40. snowflake/ml/model/_packager/model_env/model_env.py +22 -5
  41. snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
  42. snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
  43. snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
  44. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +7 -0
  45. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  46. snowflake/ml/model/type_hints.py +16 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
  48. snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
  49. snowflake/ml/version.py +1 -1
  50. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/METADATA +50 -4
  51. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/RECORD +54 -45
  52. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/WHEEL +0 -0
  53. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/licenses/LICENSE.txt +0 -0
  54. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,471 @@
1
+ import base64
2
+ import logging
3
+ import pickle
4
+ import posixpath
5
+ import sys
6
+ from typing import Any, Callable, Optional, Protocol, Union, cast, runtime_checkable
7
+
8
+ from snowflake import snowpark
9
+ from snowflake.ml.jobs._interop import data_utils
10
+ from snowflake.ml.jobs._interop.dto_schema import (
11
+ BinaryManifest,
12
+ ParquetManifest,
13
+ ProtocolInfo,
14
+ )
15
+
16
+ Condition = Union[type, tuple[type, ...], Callable[[Any], bool], None]
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class SerializationError(TypeError):
22
+ """Exception raised when a serialization protocol fails."""
23
+
24
+
25
+ class DeserializationError(ValueError):
26
+ """Exception raised when a serialization protocol fails."""
27
+
28
+
29
+ class InvalidPayloadError(DeserializationError):
30
+ """Exception raised when the payload is invalid."""
31
+
32
+
33
+ class ProtocolMismatchError(DeserializationError):
34
+ """Exception raised when the protocol of the serialization protocol is incompatible."""
35
+
36
+
37
+ class VersionMismatchError(ProtocolMismatchError):
38
+ """Exception raised when the version of the serialization protocol is incompatible."""
39
+
40
+
41
+ class ProtocolNotFoundError(SerializationError):
42
+ """Exception raised when no suitable serialization protocol is available."""
43
+
44
+
45
+ @runtime_checkable
46
+ class SerializationProtocol(Protocol):
47
+ """
48
+ More advanced protocol which supports more flexibility in how results are saved or loaded.
49
+ Results can be saved as one or more files, or directly inline in the PayloadManifest.
50
+ If saving as files, the PayloadManifest can save arbitrary "manifest" information.
51
+ """
52
+
53
+ @property
54
+ def supported_types(self) -> Condition:
55
+ """The types that the protocol supports."""
56
+
57
+ @property
58
+ def protocol_info(self) -> ProtocolInfo:
59
+ """The information about the protocol."""
60
+
61
+ def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
62
+ """Save the object to the destination directory."""
63
+
64
+ def load(
65
+ self,
66
+ payload_info: ProtocolInfo,
67
+ session: Optional[snowpark.Session] = None,
68
+ path_transform: Optional[Callable[[str], str]] = None,
69
+ ) -> Any:
70
+ """Load the object from the source directory."""
71
+
72
+
73
+ class CloudPickleProtocol(SerializationProtocol):
74
+ """
75
+ CloudPickle serialization protocol.
76
+ Uses BinaryManifest for manifest schema.
77
+ """
78
+
79
+ DEFAULT_PATH = "mljob_extra.pkl"
80
+
81
+ def __init__(self) -> None:
82
+ import cloudpickle as cp
83
+
84
+ self._backend = cp
85
+
86
+ def _get_compatibility_error(self, payload_info: ProtocolInfo) -> Optional[Exception]:
87
+ """Check compatibility and attempt load, raising helpful errors on failure."""
88
+ version_error = python_error = None
89
+
90
+ # Check cloudpickle version compatibility
91
+ if payload_info.version:
92
+ try:
93
+ from packaging import version
94
+
95
+ payload_major, current_major = (
96
+ version.parse(payload_info.version).major,
97
+ version.parse(self._backend.__version__).major,
98
+ )
99
+ if payload_major != current_major:
100
+ version_error = "cloudpickle version mismatch: payload={}, current={}".format(
101
+ payload_info.version, self._backend.__version__
102
+ )
103
+ except Exception:
104
+ if payload_info.version != self.protocol_info.version:
105
+ version_error = "cloudpickle version mismatch: payload={}, current={}".format(
106
+ payload_info.version, self.protocol_info.version
107
+ )
108
+
109
+ # Check Python version compatibility
110
+ if payload_info.metadata and "python_version" in payload_info.metadata:
111
+ payload_py, current_py = (
112
+ payload_info.metadata["python_version"],
113
+ f"{sys.version_info.major}.{sys.version_info.minor}",
114
+ )
115
+ if payload_py != current_py:
116
+ python_error = f"Python version mismatch: payload={payload_py}, current={current_py}"
117
+
118
+ if version_error or python_error:
119
+ errors = [err for err in [version_error, python_error] if err]
120
+ return VersionMismatchError(f"Load failed due to incompatibility: {'; '.join(errors)}")
121
+ return None
122
+
123
+ @property
124
+ def supported_types(self) -> Condition:
125
+ return None # All types are supported
126
+
127
+ @property
128
+ def protocol_info(self) -> ProtocolInfo:
129
+ return ProtocolInfo(
130
+ name="cloudpickle",
131
+ version=self._backend.__version__,
132
+ metadata={
133
+ "python_version": f"{sys.version_info.major}.{sys.version_info.minor}",
134
+ },
135
+ )
136
+
137
+ def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
138
+ """Save the object to the destination directory."""
139
+ result_path = posixpath.join(dest_dir, self.DEFAULT_PATH)
140
+ with data_utils.open_stream(result_path, "wb", session=session) as f:
141
+ self._backend.dump(obj, f)
142
+ manifest: BinaryManifest = {"path": result_path}
143
+ return self.protocol_info.with_manifest(manifest)
144
+
145
+ def load(
146
+ self,
147
+ payload_info: ProtocolInfo,
148
+ session: Optional[snowpark.Session] = None,
149
+ path_transform: Optional[Callable[[str], str]] = None,
150
+ ) -> Any:
151
+ """Load the object from the source directory."""
152
+ if payload_info.name != self.protocol_info.name:
153
+ raise ProtocolMismatchError(
154
+ f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
155
+ )
156
+
157
+ payload_manifest = cast(BinaryManifest, payload_info.manifest)
158
+ try:
159
+ if payload_bytes := payload_manifest.get("bytes"):
160
+ return self._backend.loads(payload_bytes)
161
+ if payload_b64 := payload_manifest.get("base64"):
162
+ return self._backend.loads(base64.b64decode(payload_b64))
163
+ result_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"]
164
+ with data_utils.open_stream(result_path, "rb", session=session) as f:
165
+ return self._backend.load(f)
166
+ except (
167
+ pickle.UnpicklingError,
168
+ TypeError,
169
+ AttributeError,
170
+ MemoryError,
171
+ ) as pickle_error:
172
+ if error := self._get_compatibility_error(payload_info):
173
+ raise error from pickle_error
174
+ raise
175
+
176
+
177
+ class ArrowTableProtocol(SerializationProtocol):
178
+ """
179
+ Arrow Table serialization protocol.
180
+ Uses ParquetManifest for manifest schema.
181
+ """
182
+
183
+ DEFAULT_PATH_PATTERN = "mljob_extra_{0}.parquet"
184
+
185
+ def __init__(self) -> None:
186
+ import pyarrow as pa
187
+ import pyarrow.parquet as pq
188
+
189
+ self._pa = pa
190
+ self._pq = pq
191
+
192
+ @property
193
+ def supported_types(self) -> Condition:
194
+ return cast(type, self._pa.Table)
195
+
196
+ @property
197
+ def protocol_info(self) -> ProtocolInfo:
198
+ return ProtocolInfo(
199
+ name="pyarrow",
200
+ version=self._pa.__version__,
201
+ )
202
+
203
+ def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
204
+ """Save the object to the destination directory."""
205
+ if not isinstance(obj, self._pa.Table):
206
+ raise SerializationError(f"Expected {self._pa.Table.__name__} object, got {type(obj).__name__}")
207
+
208
+ # TODO: Support partitioned writes for large datasets
209
+ result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN.format(0))
210
+ with data_utils.open_stream(result_path, "wb", session=session) as stream:
211
+ self._pq.write_table(obj, stream)
212
+
213
+ manifest: ParquetManifest = {"paths": [result_path]}
214
+ return self.protocol_info.with_manifest(manifest)
215
+
216
+ def load(
217
+ self,
218
+ payload_info: ProtocolInfo,
219
+ session: Optional[snowpark.Session] = None,
220
+ path_transform: Optional[Callable[[str], str]] = None,
221
+ ) -> Any:
222
+ """Load the object from the source directory."""
223
+ if payload_info.name != self.protocol_info.name:
224
+ raise ProtocolMismatchError(
225
+ f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
226
+ )
227
+
228
+ payload_manifest = cast(ParquetManifest, payload_info.manifest)
229
+ tables = []
230
+ for path in payload_manifest["paths"]:
231
+ transformed_path = path_transform(path) if path_transform else path
232
+ with data_utils.open_stream(transformed_path, "rb", session=session) as f:
233
+ table = self._pq.read_table(f)
234
+ tables.append(table)
235
+ return self._pa.concat_tables(tables) if len(tables) > 1 else tables[0]
236
+
237
+
238
+ class PandasDataFrameProtocol(SerializationProtocol):
239
+ """
240
+ Pandas DataFrame serialization protocol.
241
+ Uses ParquetManifest for manifest schema.
242
+ """
243
+
244
+ DEFAULT_PATH_PATTERN = "mljob_extra_{0}.parquet"
245
+
246
+ def __init__(self) -> None:
247
+ import pandas as pd
248
+
249
+ self._pd = pd
250
+
251
+ @property
252
+ def supported_types(self) -> Condition:
253
+ return cast(type, self._pd.DataFrame)
254
+
255
+ @property
256
+ def protocol_info(self) -> ProtocolInfo:
257
+ return ProtocolInfo(
258
+ name="pandas",
259
+ version=self._pd.__version__,
260
+ )
261
+
262
+ def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
263
+ """Save the object to the destination directory."""
264
+ if not isinstance(obj, self._pd.DataFrame):
265
+ raise SerializationError(f"Expected {self._pd.DataFrame.__name__} object, got {type(obj).__name__}")
266
+
267
+ # TODO: Support partitioned writes for large datasets
268
+ result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN.format(0))
269
+ with data_utils.open_stream(result_path, "wb", session=session) as stream:
270
+ obj.to_parquet(stream)
271
+
272
+ manifest: ParquetManifest = {"paths": [result_path]}
273
+ return self.protocol_info.with_manifest(manifest)
274
+
275
+ def load(
276
+ self,
277
+ payload_info: ProtocolInfo,
278
+ session: Optional[snowpark.Session] = None,
279
+ path_transform: Optional[Callable[[str], str]] = None,
280
+ ) -> Any:
281
+ """Load the object from the source directory."""
282
+ if payload_info.name != self.protocol_info.name:
283
+ raise ProtocolMismatchError(
284
+ f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
285
+ )
286
+
287
+ payload_manifest = cast(ParquetManifest, payload_info.manifest)
288
+ dfs = []
289
+ for path in payload_manifest["paths"]:
290
+ transformed_path = path_transform(path) if path_transform else path
291
+ with data_utils.open_stream(transformed_path, "rb", session=session) as f:
292
+ df = self._pd.read_parquet(f)
293
+ dfs.append(df)
294
+ return self._pd.concat(dfs) if len(dfs) > 1 else dfs[0]
295
+
296
+
297
+ class NumpyArrayProtocol(SerializationProtocol):
298
+ """
299
+ Numpy Array serialization protocol.
300
+ Uses BinaryManifest for manifest schema.
301
+ """
302
+
303
+ DEFAULT_PATH_PATTERN = "mljob_extra.npy"
304
+
305
+ def __init__(self) -> None:
306
+ import numpy as np
307
+
308
+ self._np = np
309
+
310
+ @property
311
+ def supported_types(self) -> Condition:
312
+ return cast(type, self._np.ndarray)
313
+
314
+ @property
315
+ def protocol_info(self) -> ProtocolInfo:
316
+ return ProtocolInfo(
317
+ name="numpy",
318
+ version=self._np.__version__,
319
+ )
320
+
321
+ def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
322
+ """Save the object to the destination directory."""
323
+ if not isinstance(obj, self._np.ndarray):
324
+ raise SerializationError(f"Expected {self._np.ndarray.__name__} object, got {type(obj).__name__}")
325
+ result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN)
326
+ with data_utils.open_stream(result_path, "wb", session=session) as stream:
327
+ self._np.save(stream, obj)
328
+
329
+ manifest: BinaryManifest = {"path": result_path}
330
+ return self.protocol_info.with_manifest(manifest)
331
+
332
+ def load(
333
+ self,
334
+ payload_info: ProtocolInfo,
335
+ session: Optional[snowpark.Session] = None,
336
+ path_transform: Optional[Callable[[str], str]] = None,
337
+ ) -> Any:
338
+ """Load the object from the source directory."""
339
+ if payload_info.name != self.protocol_info.name:
340
+ raise ProtocolMismatchError(
341
+ f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
342
+ )
343
+
344
+ payload_manifest = cast(BinaryManifest, payload_info.manifest)
345
+ transformed_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"]
346
+ with data_utils.open_stream(transformed_path, "rb", session=session) as f:
347
+ return self._np.load(f)
348
+
349
+
350
+ class AutoProtocol(SerializationProtocol):
351
+ def __init__(self) -> None:
352
+ self._protocols: list[SerializationProtocol] = []
353
+ self._protocol_info = ProtocolInfo(
354
+ name="auto",
355
+ version=None,
356
+ metadata=None,
357
+ )
358
+
359
+ @property
360
+ def supported_types(self) -> Condition:
361
+ return None # All types are supported
362
+
363
+ @property
364
+ def protocol_info(self) -> ProtocolInfo:
365
+ return self._protocol_info
366
+
367
+ def try_register_protocol(
368
+ self,
369
+ klass: type[SerializationProtocol],
370
+ *args: Any,
371
+ index: int = 0,
372
+ **kwargs: Any,
373
+ ) -> None:
374
+ """
375
+ Try to construct and register a protocol. If the protocol cannot be constructed,
376
+ log a warning and skip registration. By default (index=0), the most recently
377
+ registered protocol takes precedence.
378
+
379
+ Args:
380
+ klass: The class of the protocol to register.
381
+ args: The positional arguments to pass to the protocol constructor.
382
+ index: The index to register the protocol at. If -1, the protocol is registered at the end of the list.
383
+ kwargs: The keyword arguments to pass to the protocol constructor.
384
+ """
385
+ try:
386
+ protocol = klass(*args, **kwargs)
387
+ self.register_protocol(protocol, index=index)
388
+ except Exception as e:
389
+ logger.warning(f"Failed to register protocol {klass}: {e}")
390
+
391
+ def register_protocol(
392
+ self,
393
+ protocol: SerializationProtocol,
394
+ index: int = 0,
395
+ ) -> None:
396
+ """
397
+ Register a protocol with a condition. By default (index=0), the most recently
398
+ registered protocol takes precedence.
399
+
400
+ Args:
401
+ protocol: The protocol to register.
402
+ index: The index to register the protocol at. If -1, the protocol is registered at the end of the list.
403
+
404
+ Raises:
405
+ ValueError: If the condition is invalid.
406
+ ValueError: If the index is invalid.
407
+ """
408
+ # Validate condition
409
+ # TODO: Build lookup table of supported types to protocols (in priority order)
410
+ # for faster lookup at save/load time (instead of iterating over all protocols)
411
+ if not isinstance(protocol, SerializationProtocol):
412
+ raise ValueError(f"Invalid protocol type: {type(protocol)}. Expected SerializationProtocol.")
413
+ if index == -1:
414
+ self._protocols.append(protocol)
415
+ elif index < 0:
416
+ raise ValueError(f"Invalid index: {index}. Expected -1 or >= 0.")
417
+ else:
418
+ self._protocols.insert(index, protocol)
419
+
420
+ def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
421
+ """Save the object to the destination directory."""
422
+ last_protocol_error = None
423
+ for protocol in self._protocols:
424
+ try:
425
+ if self._is_supported_type(obj, protocol):
426
+ logger.debug(f"Dumping object of type {type(obj)} with protocol {protocol}")
427
+ return protocol.save(obj, dest_dir, session)
428
+ except Exception as e:
429
+ logger.warning(f"Error dumping object {obj} with protocol {protocol}: {repr(e)}")
430
+ last_protocol_error = (protocol.protocol_info, e)
431
+ last_error_str = (
432
+ f", most recent error ({last_protocol_error[0]}): {repr(last_protocol_error[1])}"
433
+ if last_protocol_error
434
+ else ""
435
+ )
436
+ raise ProtocolNotFoundError(
437
+ f"No suitable protocol found for type {type(obj).__name__}"
438
+ f" (available: {', '.join(str(p.protocol_info) for p in self._protocols)}){last_error_str}"
439
+ )
440
+
441
+ def load(
442
+ self,
443
+ payload_info: ProtocolInfo,
444
+ session: Optional[snowpark.Session] = None,
445
+ path_transform: Optional[Callable[[str], str]] = None,
446
+ ) -> Any:
447
+ """Load the object from the source directory."""
448
+ last_error = None
449
+ for protocol in self._protocols:
450
+ if protocol.protocol_info.name == payload_info.name:
451
+ try:
452
+ return protocol.load(payload_info, session, path_transform)
453
+ except Exception as e:
454
+ logger.warning(f"Error loading object with protocol {protocol}: {repr(e)}")
455
+ last_error = e
456
+ if last_error:
457
+ raise last_error
458
+ raise ProtocolNotFoundError(
459
+ f"No protocol matching {payload_info} available"
460
+ f" (available: {', '.join(str(p.protocol_info) for p in self._protocols)})"
461
+ ", possibly due to snowflake-ml-python package version mismatch"
462
+ )
463
+
464
+ def _is_supported_type(self, obj: Any, protocol: SerializationProtocol) -> bool:
465
+ if protocol.supported_types is None:
466
+ return True # None means all types are supported
467
+ elif isinstance(protocol.supported_types, (type, tuple)):
468
+ return isinstance(obj, protocol.supported_types)
469
+ elif callable(protocol.supported_types):
470
+ return protocol.supported_types(obj) is True
471
+ raise ValueError(f"Invalid supported types: {protocol.supported_types} for protocol {protocol}")
@@ -0,0 +1,51 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Optional
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class ExecutionResult:
7
+ """
8
+ A result of a job execution.
9
+
10
+ Args:
11
+ success: Whether the execution was successful.
12
+ value: The value of the execution.
13
+ """
14
+
15
+ success: bool
16
+ value: Any
17
+
18
+ def get_value(self, wrap_exceptions: bool = True) -> Any:
19
+ if not self.success:
20
+ assert isinstance(self.value, BaseException), "Unexpected non-exception value for failed result"
21
+ self._raise_exception(self.value, wrap_exceptions)
22
+ return self.value
23
+
24
+ def _raise_exception(self, exception: BaseException, wrap_exceptions: bool) -> None:
25
+ if wrap_exceptions:
26
+ raise RuntimeError(f"Job execution failed with error: {exception!r}") from exception
27
+ else:
28
+ raise exception
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class LoadedExecutionResult(ExecutionResult):
33
+ """
34
+ A result of a job execution that has been loaded from a file.
35
+ """
36
+
37
+ load_error: Optional[Exception] = None
38
+ result_metadata: Optional[dict[str, Any]] = None
39
+
40
+ def get_value(self, wrap_exceptions: bool = True) -> Any:
41
+ if not self.success:
42
+ # Raise the original exception if available, otherwise raise the load error
43
+ ex = self.value
44
+ if not isinstance(ex, BaseException):
45
+ ex = RuntimeError(f"Unknown error {ex or ''}")
46
+ ex.__cause__ = self.load_error
47
+ self._raise_exception(ex, wrap_exceptions)
48
+ else:
49
+ if self.load_error:
50
+ raise ValueError("Job execution succeeded but result retrieval failed") from self.load_error
51
+ return self.value
@@ -0,0 +1,144 @@
1
+ import logging
2
+ import os
3
+ import traceback
4
+ from pathlib import PurePath
5
+ from typing import Any, Callable, Optional
6
+
7
+ import pydantic
8
+
9
+ from snowflake import snowpark
10
+ from snowflake.ml.jobs._interop import data_utils, exception_utils, legacy, protocols
11
+ from snowflake.ml.jobs._interop.dto_schema import (
12
+ ExceptionMetadata,
13
+ ResultDTO,
14
+ ResultMetadata,
15
+ )
16
+ from snowflake.ml.jobs._interop.results import ExecutionResult, LoadedExecutionResult
17
+ from snowflake.snowpark import exceptions as sp_exceptions
18
+
19
+ DEFAULT_CODEC = data_utils.JsonDtoCodec
20
+ DEFAULT_PROTOCOL = protocols.AutoProtocol()
21
+ DEFAULT_PROTOCOL.try_register_protocol(protocols.CloudPickleProtocol)
22
+ DEFAULT_PROTOCOL.try_register_protocol(protocols.ArrowTableProtocol)
23
+ DEFAULT_PROTOCOL.try_register_protocol(protocols.PandasDataFrameProtocol)
24
+ DEFAULT_PROTOCOL.try_register_protocol(protocols.NumpyArrayProtocol)
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def save_result(result: ExecutionResult, path: str, session: Optional[snowpark.Session] = None) -> None:
31
+ """
32
+ Save the result to a file.
33
+ """
34
+ result_dto = ResultDTO(
35
+ success=result.success,
36
+ value=result.value,
37
+ )
38
+
39
+ try:
40
+ # Try to encode result directly
41
+ payload = DEFAULT_CODEC.encode(result_dto)
42
+ except TypeError:
43
+ result_dto.value = None # Remove raw value to avoid serialization error
44
+ result_dto.metadata = _get_metadata(result.value) # Add metadata for client fallback on protocol mismatch
45
+ try:
46
+ path_dir = PurePath(path).parent.as_posix()
47
+ protocol_info = DEFAULT_PROTOCOL.save(result.value, path_dir, session=session)
48
+ result_dto.protocol = protocol_info
49
+
50
+ except Exception as e:
51
+ logger.warning(f"Error dumping result value: {repr(e)}")
52
+ result_dto.serialize_error = repr(e)
53
+
54
+ # Encode the modified result DTO
55
+ payload = DEFAULT_CODEC.encode(result_dto)
56
+
57
+ with data_utils.open_stream(path, "wb", session=session) as stream:
58
+ stream.write(payload)
59
+
60
+
61
+ def load_result(
62
+ path: str, session: Optional[snowpark.Session] = None, path_transform: Optional[Callable[[str], str]] = None
63
+ ) -> ExecutionResult:
64
+ """Load the result from a file on a Snowflake stage."""
65
+ try:
66
+ with data_utils.open_stream(path, "r", session=session) as stream:
67
+ # Load the DTO as a dict for easy fallback to legacy loading if necessary
68
+ dto_dict = DEFAULT_CODEC.decode(stream, as_dict=True)
69
+ except UnicodeDecodeError:
70
+ # Path may be a legacy result file (cloudpickle)
71
+ # TODO: Re-use the stream
72
+ assert session is not None
73
+ return legacy.load_legacy_result(session, path)
74
+
75
+ try:
76
+ dto = ResultDTO.model_validate(dto_dict)
77
+ except pydantic.ValidationError as e:
78
+ if "success" in dto_dict:
79
+ assert session is not None
80
+ if path.endswith(".json"):
81
+ path = os.path.splitext(path)[0] + ".pkl"
82
+ return legacy.load_legacy_result(session, path, result_json=dto_dict)
83
+ raise ValueError("Invalid result schema") from e
84
+
85
+ # Try loading data from file using the protocol info
86
+ result_value = None
87
+ data_load_error = None
88
+ if dto.protocol is not None:
89
+ try:
90
+ logger.debug(f"Loading result value with protocol {dto.protocol}")
91
+ result_value = DEFAULT_PROTOCOL.load(dto.protocol, session=session, path_transform=path_transform)
92
+ except sp_exceptions.SnowparkSQLException:
93
+ raise # Data retrieval errors should be bubbled up
94
+ except Exception as e:
95
+ logger.debug(f"Error loading result value with protocol {dto.protocol}: {repr(e)}")
96
+ data_load_error = e
97
+
98
+ # Wrap serialize_error in a TypeError
99
+ if dto.serialize_error:
100
+ serialize_error = TypeError("Original result serialization failed with error: " + dto.serialize_error)
101
+ if data_load_error:
102
+ data_load_error.__context__ = serialize_error
103
+ else:
104
+ data_load_error = serialize_error
105
+
106
+ # Prepare to assemble the final result
107
+ result_value = result_value if result_value is not None else dto.value
108
+ if not dto.success and result_value is None:
109
+ # Try to reconstruct exception from metadata if available
110
+ if isinstance(dto.metadata, ExceptionMetadata):
111
+ logger.debug(f"Reconstructing exception from metadata {dto.metadata}")
112
+ result_value = exception_utils.build_exception(
113
+ type_str=dto.metadata.type,
114
+ message=dto.metadata.message,
115
+ traceback=dto.metadata.traceback,
116
+ original_repr=dto.metadata.repr,
117
+ )
118
+
119
+ # Generate a generic error if we still don't have a value,
120
+ # attaching the data load error if any
121
+ if result_value is None:
122
+ result_value = exception_utils.RemoteError("Unknown remote error")
123
+ result_value.__cause__ = data_load_error
124
+
125
+ return LoadedExecutionResult(
126
+ success=dto.success,
127
+ value=result_value,
128
+ load_error=data_load_error,
129
+ )
130
+
131
+
132
+ def _get_metadata(value: Any) -> ResultMetadata:
133
+ type_name = f"{type(value).__module__}.{type(value).__name__}"
134
+ if isinstance(value, BaseException):
135
+ return ExceptionMetadata(
136
+ type=type_name,
137
+ repr=repr(value),
138
+ message=str(value),
139
+ traceback="".join(traceback.format_tb(value.__traceback__)),
140
+ )
141
+ return ResultMetadata(
142
+ type=type_name,
143
+ repr=repr(value),
144
+ )