snowflake-ml-python 1.24.0__py3-none-any.whl → 1.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. snowflake/ml/_internal/utils/mixins.py +26 -1
  2. snowflake/ml/data/_internal/arrow_ingestor.py +5 -1
  3. snowflake/ml/data/data_connector.py +2 -2
  4. snowflake/ml/data/data_ingestor.py +2 -1
  5. snowflake/ml/experiment/_experiment_info.py +3 -3
  6. snowflake/ml/jobs/_interop/data_utils.py +8 -8
  7. snowflake/ml/jobs/_interop/dto_schema.py +52 -7
  8. snowflake/ml/jobs/_interop/protocols.py +124 -7
  9. snowflake/ml/jobs/_interop/utils.py +92 -33
  10. snowflake/ml/jobs/_utils/arg_protocol.py +7 -0
  11. snowflake/ml/jobs/_utils/constants.py +4 -0
  12. snowflake/ml/jobs/_utils/feature_flags.py +97 -13
  13. snowflake/ml/jobs/_utils/payload_utils.py +6 -40
  14. snowflake/ml/jobs/_utils/runtime_env_utils.py +12 -111
  15. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +204 -27
  16. snowflake/ml/jobs/decorators.py +17 -22
  17. snowflake/ml/jobs/job.py +25 -10
  18. snowflake/ml/jobs/job_definition.py +100 -8
  19. snowflake/ml/model/_client/model/model_version_impl.py +25 -14
  20. snowflake/ml/model/_client/ops/service_ops.py +6 -6
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  22. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  23. snowflake/ml/model/models/huggingface_pipeline.py +3 -0
  24. snowflake/ml/model/openai_signatures.py +154 -0
  25. snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -3
  26. snowflake/ml/version.py +1 -1
  27. {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/METADATA +41 -2
  28. {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/RECORD +31 -32
  29. {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/WHEEL +1 -1
  30. snowflake/ml/jobs/_utils/function_payload_utils.py +0 -43
  31. snowflake/ml/jobs/_utils/spec_utils.py +0 -22
  32. {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/licenses/LICENSE.txt +0 -0
  33. {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ _SESSION_ACCOUNT_KEY = "session$account"
9
9
  _SESSION_ROLE_KEY = "session$role"
10
10
  _SESSION_DATABASE_KEY = "session$database"
11
11
  _SESSION_SCHEMA_KEY = "session$schema"
12
+ _SESSION_STATE_ATTR = "_session_state"
12
13
 
13
14
 
14
15
  def _identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
@@ -61,7 +62,7 @@ class SerializableSessionMixin:
61
62
  else:
62
63
  self.__dict__.update(state)
63
64
 
64
- self._set_session(session_state)
65
+ setattr(self, _SESSION_STATE_ATTR, session_state)
65
66
 
66
67
  def _set_session(self, session_state: _SessionState) -> None:
67
68
 
@@ -86,3 +87,27 @@ class SerializableSessionMixin:
86
87
  ),
87
88
  ),
88
89
  )
90
+
91
+ @property
92
+ def session(self) -> Optional[snowpark_session.Session]:
93
+ if _SESSION_KEY not in self.__dict__:
94
+ session_state = getattr(self, _SESSION_STATE_ATTR, None)
95
+ if session_state is not None:
96
+ self._set_session(session_state)
97
+ return self.__dict__.get(_SESSION_KEY)
98
+
99
+ @session.setter
100
+ def session(self, value: Optional[snowpark_session.Session]) -> None:
101
+ self.__dict__[_SESSION_KEY] = value
102
+
103
+ # _getattr__ is only called when an attribute is NOT found through normal lookup.
104
+ # 1. Data descriptors (like @property with setter) from the class hierarchy
105
+ # 2. Instance __dict__ (e.g., self.x = 10)
106
+ # 3. Non-data descriptors (methods, `@property without setter) from the class hierarchy
107
+ # __getattr__ — only called if steps 1-3 all fail
108
+ def __getattr__(self, name: str) -> Any:
109
+ if name == _SESSION_KEY:
110
+ return self.session
111
+ if hasattr(super(), "__getattr__"):
112
+ return super().__getattr__(name) # type: ignore[misc]
113
+ raise AttributeError(f"{type(self).__name__!s} object has no attribute {name!r}")
@@ -73,15 +73,19 @@ class ArrowIngestor(data_ingestor.DataIngestor, mixins.SerializableSessionMixin)
73
73
  self._schema: Optional[pa.Schema] = None
74
74
 
75
75
  @classmethod
76
- def from_sources(cls, session: snowpark.Session, sources: Sequence[data_source.DataSource]) -> "ArrowIngestor":
76
+ def from_sources(
77
+ cls, session: snowpark.Session, sources: Sequence[data_source.DataSource], **kwargs: Any
78
+ ) -> "ArrowIngestor":
77
79
  if session is None:
78
80
  raise ValueError("Session is required")
81
+ # Skipping kwargs until needed to avoid impact other workflows.
79
82
  return cls(session, sources)
80
83
 
81
84
  @classmethod
82
85
  def from_ray_dataset(
83
86
  cls,
84
87
  ray_ds: "ray.data.Dataset",
88
+ **kwargs: Any,
85
89
  ) -> "ArrowIngestor":
86
90
  raise NotImplementedError
87
91
 
@@ -94,7 +94,7 @@ class DataConnector:
94
94
  **kwargs: Any,
95
95
  ) -> DataConnectorType:
96
96
  ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
97
- ray_ingestor = ingestor_class.from_ray_dataset(ray_ds=ray_ds)
97
+ ray_ingestor = ingestor_class.from_ray_dataset(ray_ds=ray_ds, **kwargs)
98
98
  return cls(ray_ingestor, **kwargs)
99
99
 
100
100
  @classmethod
@@ -111,7 +111,7 @@ class DataConnector:
111
111
  **kwargs: Any,
112
112
  ) -> DataConnectorType:
113
113
  ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
114
- ingestor = ingestor_class.from_sources(session, sources)
114
+ ingestor = ingestor_class.from_sources(session, sources, **kwargs)
115
115
  return cls(ingestor, **kwargs)
116
116
 
117
117
  @property
@@ -16,7 +16,7 @@ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
16
16
  class DataIngestor(Protocol):
17
17
  @classmethod
18
18
  def from_sources(
19
- cls: type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource]
19
+ cls: type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource], **kwargs: Any
20
20
  ) -> DataIngestorType:
21
21
  raise NotImplementedError
22
22
 
@@ -24,6 +24,7 @@ class DataIngestor(Protocol):
24
24
  def from_ray_dataset(
25
25
  cls: type[DataIngestorType],
26
26
  ray_ds: "ray.data.Dataset",
27
+ **kwargs: Any,
27
28
  ) -> DataIngestorType:
28
29
  raise NotImplementedError
29
30
 
@@ -3,7 +3,7 @@ import functools
3
3
  import types
4
4
  from typing import Callable, Optional
5
5
 
6
- from snowflake.ml import model
6
+ from snowflake.ml.model._client.model import model_version_impl
7
7
  from snowflake.ml.registry._manager import model_manager
8
8
 
9
9
 
@@ -23,7 +23,7 @@ class ExperimentInfoPatcher:
23
23
  """
24
24
 
25
25
  # Store original method at class definition time to avoid recursive patching
26
- _original_log_model: Callable[..., model.ModelVersion] = model_manager.ModelManager.log_model
26
+ _original_log_model: Callable[..., model_version_impl.ModelVersion] = model_manager.ModelManager.log_model
27
27
 
28
28
  # Stack of active experiment_info contexts for nested experiment support
29
29
  _experiment_info_stack: list[ExperimentInfo] = []
@@ -36,7 +36,7 @@ class ExperimentInfoPatcher:
36
36
  if not ExperimentInfoPatcher._experiment_info_stack:
37
37
 
38
38
  @functools.wraps(ExperimentInfoPatcher._original_log_model)
39
- def patched(*args, **kwargs) -> model.ModelVersion: # type: ignore[no-untyped-def]
39
+ def patched(*args, **kwargs) -> model_version_impl.ModelVersion: # type: ignore[no-untyped-def]
40
40
  # Use the most recent (top of stack) experiment_info for nested contexts
41
41
  current_experiment_info = ExperimentInfoPatcher._experiment_info_stack[-1]
42
42
  return ExperimentInfoPatcher._original_log_model(
@@ -31,7 +31,7 @@ class StageFileWriter(io.IOBase):
31
31
  # Only upload if buffer has content and no exception occurred
32
32
  if write_contents and self._buffer.tell() > 0:
33
33
  self._buffer.seek(0)
34
- self._session.file.put_stream(self._buffer, self._path)
34
+ self._session.file.put_stream(self._buffer, self._path, auto_compress=False)
35
35
  self._buffer.close()
36
36
  self._closed = True
37
37
 
@@ -84,15 +84,15 @@ class DtoCodec(Protocol):
84
84
 
85
85
  @overload
86
86
  @staticmethod
87
- def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO:
87
+ def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.PayloadDTO:
88
88
  ...
89
89
 
90
90
  @staticmethod
91
- def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]:
91
+ def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.PayloadDTO, dict[str, Any]]:
92
92
  pass
93
93
 
94
94
  @staticmethod
95
- def encode(dto: dto_schema.ResultDTO) -> bytes:
95
+ def encode(dto: dto_schema.PayloadDTO) -> bytes:
96
96
  pass
97
97
 
98
98
 
@@ -104,18 +104,18 @@ class JsonDtoCodec(DtoCodec):
104
104
 
105
105
  @overload
106
106
  @staticmethod
107
- def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO:
107
+ def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.PayloadDTO:
108
108
  ...
109
109
 
110
110
  @staticmethod
111
- def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]:
111
+ def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.PayloadDTO, dict[str, Any]]:
112
112
  data = cast(dict[str, Any], json.load(stream))
113
113
  if as_dict:
114
114
  return data
115
- return dto_schema.ResultDTO.model_validate(data)
115
+ return dto_schema.ResultDTOAdapter.validate_python(data)
116
116
 
117
117
  @staticmethod
118
- def encode(dto: dto_schema.ResultDTO) -> bytes:
118
+ def encode(dto: dto_schema.PayloadDTO) -> bytes:
119
119
  # Temporarily extract the value to avoid accidentally applying model_dump() on it
120
120
  result_value = dto.value
121
121
  dto.value = None # Clear value to avoid serializing it in the model_dump
@@ -1,7 +1,7 @@
1
- from typing import Any, Optional, Union
1
+ from typing import Any, Literal, Optional, Union
2
2
 
3
- from pydantic import BaseModel, model_validator
4
- from typing_extensions import NotRequired, TypedDict
3
+ from pydantic import BaseModel, Discriminator, Tag, TypeAdapter, model_validator
4
+ from typing_extensions import Annotated, NotRequired, TypedDict
5
5
 
6
6
 
7
7
  class BinaryManifest(TypedDict):
@@ -67,22 +67,47 @@ class ExceptionMetadata(ResultMetadata):
67
67
  traceback: str
68
68
 
69
69
 
70
- class ResultDTO(BaseModel):
70
+ class PayloadDTO(BaseModel):
71
+ """
72
+ Base class for serializable payloads.
73
+
74
+ Args:
75
+ kind: Discriminator field for DTO type dispatch.
76
+ value: The payload value (if JSON-serializable).
77
+ protocol: The protocol used to serialize the payload (if not JSON-serializable).
78
+ """
79
+
80
+ kind: Literal["base"] = "base"
81
+ value: Optional[Any] = None
82
+ protocol: Optional[ProtocolInfo] = None
83
+ serialize_error: Optional[str] = None
84
+
85
+ @model_validator(mode="before")
86
+ @classmethod
87
+ def validate_fields(cls, data: Any) -> Any:
88
+ """Ensure at least one of value or protocol keys is specified."""
89
+ if cls is PayloadDTO and isinstance(data, dict):
90
+ required_fields = {"value", "protocol"}
91
+ if not any(field in data for field in required_fields):
92
+ raise ValueError("At least one of 'value' or 'protocol' must be specified")
93
+ return data
94
+
95
+
96
+ class ResultDTO(PayloadDTO):
71
97
  """
72
98
  A JSON representation of an execution result.
73
99
 
74
100
  Args:
101
+ kind: Discriminator field for DTO type dispatch.
75
102
  success: Whether the execution was successful.
76
103
  value: The value of the execution or the exception if the execution failed.
77
104
  protocol: The protocol used to serialize the result.
78
105
  metadata: The metadata of the result.
79
106
  """
80
107
 
108
+ kind: Literal["result"] = "result" # type: ignore[assignment]
81
109
  success: bool
82
- value: Optional[Any] = None
83
- protocol: Optional[ProtocolInfo] = None
84
110
  metadata: Optional[Union[ResultMetadata, ExceptionMetadata]] = None
85
- serialize_error: Optional[str] = None
86
111
 
87
112
  @model_validator(mode="before")
88
113
  @classmethod
@@ -93,3 +118,23 @@ class ResultDTO(BaseModel):
93
118
  if not any(field in data for field in required_fields):
94
119
  raise ValueError("At least one of 'value', 'protocol', or 'metadata' must be specified")
95
120
  return data
121
+
122
+
123
+ def _get_dto_kind(data: Any) -> str:
124
+ """Extract the 'kind' discriminator from input, defaulting to 'result' for backward compatibility."""
125
+ if isinstance(data, dict):
126
+ kind = data.get("kind", "result")
127
+ else:
128
+ kind = getattr(data, "kind", "result")
129
+ return str(kind)
130
+
131
+
132
+ AnyResultDTO = Annotated[
133
+ Union[
134
+ Annotated[ResultDTO, Tag("result")],
135
+ Annotated[PayloadDTO, Tag("base")],
136
+ ],
137
+ Discriminator(_get_dto_kind),
138
+ ]
139
+
140
+ ResultDTOAdapter: TypeAdapter[AnyResultDTO] = TypeAdapter(AnyResultDTO)
@@ -17,6 +17,8 @@ Condition = Union[type, tuple[type, ...], Callable[[Any], bool], None]
17
17
 
18
18
  logger = logging.getLogger(__name__)
19
19
 
20
+ SESSION_KEY_PREFIX = "session@"
21
+
20
22
 
21
23
  class SerializationError(TypeError):
22
24
  """Exception raised when a serialization protocol fails."""
@@ -136,9 +138,10 @@ class CloudPickleProtocol(SerializationProtocol):
136
138
 
137
139
  def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
138
140
  """Save the object to the destination directory."""
141
+ replaced_obj = self._pack_obj(obj)
139
142
  result_path = posixpath.join(dest_dir, self.DEFAULT_PATH)
140
143
  with data_utils.open_stream(result_path, "wb", session=session) as f:
141
- self._backend.dump(obj, f)
144
+ self._backend.dump(replaced_obj, f)
142
145
  manifest: BinaryManifest = {"path": result_path}
143
146
  return self.protocol_info.with_manifest(manifest)
144
147
 
@@ -157,12 +160,15 @@ class CloudPickleProtocol(SerializationProtocol):
157
160
  payload_manifest = cast(BinaryManifest, payload_info.manifest)
158
161
  try:
159
162
  if payload_bytes := payload_manifest.get("bytes"):
160
- return self._backend.loads(payload_bytes)
161
- if payload_b64 := payload_manifest.get("base64"):
162
- return self._backend.loads(base64.b64decode(payload_b64))
163
- result_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"]
164
- with data_utils.open_stream(result_path, "rb", session=session) as f:
165
- return self._backend.load(f)
163
+ result = self._backend.loads(payload_bytes)
164
+ elif payload_b64 := payload_manifest.get("base64"):
165
+ result = self._backend.loads(base64.b64decode(payload_b64))
166
+ else:
167
+ result_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"]
168
+ with data_utils.open_stream(result_path, "rb", session=session) as f:
169
+ result = self._backend.load(f)
170
+
171
+ return self._unpack_obj(result, session=session)
166
172
  except (
167
173
  pickle.UnpicklingError,
168
174
  TypeError,
@@ -173,6 +179,117 @@ class CloudPickleProtocol(SerializationProtocol):
173
179
  raise error from pickle_error
174
180
  raise
175
181
 
182
+ def _pack_obj(self, obj: Any) -> Any:
183
+ """Pack objects into JSON-safe dicts using reserved marker keys.
184
+
185
+ Markers:
186
+ - "type@": container type for list/tuple (list or tuple)
187
+ - "#<i>": positional element for list/tuple at index i
188
+ - "session@": placeholder for snowpark.Session values
189
+ - "session@#<i>" for list/tuple entries
190
+ - "session@<key>" for dict entries
191
+ - {"session@": None} for a bare Session object
192
+
193
+ Example:
194
+ obj = {"x": [1, session], "s": session}
195
+ packed = {
196
+ "x": {"type@": list, "#0": 1, "session@#1": None},
197
+ "session@s": None,
198
+ }
199
+ _unpack_obj(packed, session) == obj
200
+
201
+ Args:
202
+ obj: Object to pack into JSON-safe marker dictionaries.
203
+
204
+ Returns:
205
+ Packed representation with markers for session references.
206
+ """
207
+ arguments: dict[str, Any] = {}
208
+ if isinstance(obj, tuple) or isinstance(obj, list):
209
+ arguments = {"type@": type(obj)}
210
+ for i, arg in enumerate(obj):
211
+ if isinstance(arg, snowpark.Session):
212
+ arguments[f"{SESSION_KEY_PREFIX}#{i}"] = None
213
+ else:
214
+ arguments[f"#{i}"] = self._pack_obj(arg)
215
+ return arguments
216
+ elif isinstance(obj, dict):
217
+ for k, v in obj.items():
218
+ if isinstance(v, snowpark.Session):
219
+ arguments[f"{SESSION_KEY_PREFIX}{k}"] = None
220
+ else:
221
+ arguments[k] = self._pack_obj(v)
222
+ return arguments
223
+ elif isinstance(obj, snowpark.Session):
224
+ # Box session into a dict marker so we can distinguish it from other plain objects.
225
+ arguments[f"{SESSION_KEY_PREFIX}"] = None
226
+ return arguments
227
+ else:
228
+ return obj
229
+
230
+ def _unpack_obj(self, obj: Any, session: Optional[snowpark.Session] = None) -> Any:
231
+ """Unpack dict markers back into containers and Session references.
232
+
233
+ Markers:
234
+ - "type@": container type for list/tuple (list or tuple)
235
+ - "#<i>": positional element for list/tuple at index i
236
+ - "session@": placeholder for snowpark.Session values
237
+ - "session@#<i>" for list/tuple entries
238
+ - "session@<key>" for dict entries
239
+ - {"session@": None} for a bare Session object
240
+
241
+ Example:
242
+ packed = {
243
+ "x": {"type@": list, "#0": 1, "session@#1": None},
244
+ "session@s": None,
245
+ }
246
+ obj = _unpack_obj(packed, session)
247
+ # obj == {"x": [1, session], "s": session}
248
+
249
+ Args:
250
+ obj: Packed object with marker dictionaries.
251
+ session: Session to inject for session markers.
252
+
253
+ Returns:
254
+ Unpacked object with session references restored.
255
+ """
256
+ if not isinstance(obj, dict):
257
+ return obj
258
+ elif len(obj) == 1 and SESSION_KEY_PREFIX in obj:
259
+ return session
260
+ else:
261
+ type = obj.get("type@", None)
262
+ # If type is None, we are unpacking a dict
263
+ if type is None:
264
+ result_dict = {}
265
+ for k, v in obj.items():
266
+ if k.startswith(SESSION_KEY_PREFIX):
267
+ result_key = k[len(SESSION_KEY_PREFIX) :]
268
+ result_dict[result_key] = session
269
+ else:
270
+ result_dict[k] = self._unpack_obj(v, session)
271
+ return result_dict
272
+ # If type is not None, we are unpacking a tuple or list
273
+ else:
274
+ indexes = []
275
+ for k, _ in obj.items():
276
+ if "#" in k:
277
+ indexes.append(int(k.split("#")[-1]))
278
+
279
+ if not indexes:
280
+ return tuple() if type is tuple else []
281
+ result_list: list[Any] = [None] * (max(indexes) + 1)
282
+
283
+ for k, v in obj.items():
284
+ if k == "type@":
285
+ continue
286
+ idx = int(k.split("#")[-1])
287
+ if k.startswith(SESSION_KEY_PREFIX):
288
+ result_list[idx] = session
289
+ else:
290
+ result_list[idx] = self._unpack_obj(v, session)
291
+ return tuple(result_list) if type is tuple else result_list
292
+
176
293
 
177
294
  class ArrowTableProtocol(SerializationProtocol):
178
295
  """
@@ -1,3 +1,4 @@
1
+ import io
1
2
  import logging
2
3
  import os
3
4
  import traceback
@@ -10,7 +11,9 @@ from snowflake import snowpark
10
11
  from snowflake.ml.jobs._interop import data_utils, exception_utils, legacy, protocols
11
12
  from snowflake.ml.jobs._interop.dto_schema import (
12
13
  ExceptionMetadata,
14
+ PayloadDTO,
13
15
  ResultDTO,
16
+ ResultDTOAdapter,
14
17
  ResultMetadata,
15
18
  )
16
19
  from snowflake.ml.jobs._interop.results import ExecutionResult, LoadedExecutionResult
@@ -23,79 +26,137 @@ DEFAULT_PROTOCOL.try_register_protocol(protocols.ArrowTableProtocol)
23
26
  DEFAULT_PROTOCOL.try_register_protocol(protocols.PandasDataFrameProtocol)
24
27
  DEFAULT_PROTOCOL.try_register_protocol(protocols.NumpyArrayProtocol)
25
28
 
29
+ # Constants for argument encoding
30
+ _MAX_INLINE_SIZE = 1024 * 1024 # 1MB - https://docs.snowflake.com/en/user-guide/query-size-limits
26
31
 
27
32
  logger = logging.getLogger(__name__)
28
33
 
29
34
 
30
- def save_result(result: ExecutionResult, path: str, session: Optional[snowpark.Session] = None) -> None:
35
+ def save(
36
+ value: Any,
37
+ path: str,
38
+ session: Optional[snowpark.Session] = None,
39
+ max_inline_size: int = 0,
40
+ ) -> Optional[bytes]:
31
41
  """
32
- Save the result to a file.
42
+ Serialize a value. Returns inline bytes if small enough, else writes to file.
43
+
44
+ Args:
45
+ value: The value to serialize. If ExecutionResult, creates ResultDTO with success flag.
46
+ path: Full file path for writing the DTO (if needed). Protocol data saved to path's parent.
47
+ session: Snowpark session for stage operations.
48
+ max_inline_size: Max bytes for inline return. 0 = always write to file.
49
+
50
+ Returns:
51
+ Encoded bytes if <= max_inline_size, else None (written to file).
52
+
53
+ Raises:
54
+ Exception: If session validation fails during serialization.
33
55
  """
34
- result_dto = ResultDTO(
35
- success=result.success,
36
- value=result.value,
37
- )
56
+ if isinstance(value, ExecutionResult):
57
+ dto: PayloadDTO = ResultDTO(success=value.success, value=value.value)
58
+ raw_value = value.value
59
+ else:
60
+ dto = PayloadDTO(value=value)
61
+ raw_value = value
38
62
 
39
63
  try:
40
- # Try to encode result directly
41
- payload = DEFAULT_CODEC.encode(result_dto)
64
+ payload = DEFAULT_CODEC.encode(dto)
42
65
  except TypeError:
43
- result_dto.value = None # Remove raw value to avoid serialization error
44
- result_dto.metadata = _get_metadata(result.value) # Add metadata for client fallback on protocol mismatch
66
+ dto.value = None # Remove raw value to avoid serialization error
67
+ if isinstance(dto, ResultDTO):
68
+ # Metadata enables client fallback display when result can't be deserialized (protocol mismatch)..
69
+ dto.metadata = _get_metadata(raw_value)
45
70
  try:
46
71
  path_dir = PurePath(path).parent.as_posix()
47
- protocol_info = DEFAULT_PROTOCOL.save(result.value, path_dir, session=session)
48
- result_dto.protocol = protocol_info
72
+ protocol_info = DEFAULT_PROTOCOL.save(raw_value, path_dir, session=session)
73
+ dto.protocol = protocol_info
49
74
 
50
75
  except Exception as e:
51
76
  logger.warning(f"Error dumping result value: {repr(e)}")
52
- result_dto.serialize_error = repr(e)
53
-
54
- # Encode the modified result DTO
55
- payload = DEFAULT_CODEC.encode(result_dto)
77
+ # We handle serialization failures differently based on the DTO type:
78
+ # 1. Job Results (ResultDTO): Allow a "soft-fail."
79
+ # Since the job has already executed, we return the serialization error
80
+ # to the client so they can debug the output or update their protocol version.
81
+ # 2. Input Arguments: Trigger a "hard-fail."
82
+ # If arguments cannot be saved, the job script cannot run. We raise
83
+ # an immediate exception to prevent execution with invalid state.
84
+ if not isinstance(dto, ResultDTO):
85
+ raise
86
+ dto.serialize_error = repr(e)
87
+
88
+ # Encode the modified DTO
89
+ payload = DEFAULT_CODEC.encode(dto)
90
+
91
+ if not isinstance(dto, ResultDTO) and len(payload) <= max_inline_size:
92
+ return payload
56
93
 
57
94
  with data_utils.open_stream(path, "wb", session=session) as stream:
58
95
  stream.write(payload)
96
+ return None
97
+
98
+
99
+ save_result = save # Backwards compatibility
59
100
 
60
101
 
61
- def load_result(
62
- path: str, session: Optional[snowpark.Session] = None, path_transform: Optional[Callable[[str], str]] = None
63
- ) -> ExecutionResult:
64
- """Load the result from a file on a Snowflake stage."""
102
+ def load(
103
+ path_or_data: str,
104
+ session: Optional[snowpark.Session] = None,
105
+ path_transform: Optional[Callable[[str], str]] = None,
106
+ ) -> Any:
107
+ """Load data from a file path or inline string."""
108
+
65
109
  try:
66
- with data_utils.open_stream(path, "r", session=session) as stream:
110
+ with data_utils.open_stream(path_or_data, "r", session=session) as stream:
67
111
  # Load the DTO as a dict for easy fallback to legacy loading if necessary
68
112
  dto_dict = DEFAULT_CODEC.decode(stream, as_dict=True)
113
+ # the exception could be OSError or BlockingIOError(the file name is too long)
114
+ except OSError as e:
115
+ # path_or_data might be inline data
116
+ try:
117
+ dto_dict = DEFAULT_CODEC.decode(io.StringIO(path_or_data), as_dict=True)
118
+ except Exception:
119
+ raise e
69
120
  except UnicodeDecodeError:
70
121
  # Path may be a legacy result file (cloudpickle)
71
- # TODO: Re-use the stream
72
122
  assert session is not None
73
- return legacy.load_legacy_result(session, path)
123
+ return legacy.load_legacy_result(session, path_or_data)
74
124
 
75
125
  try:
76
- dto = ResultDTO.model_validate(dto_dict)
126
+ dto = ResultDTOAdapter.validate_python(dto_dict)
77
127
  except pydantic.ValidationError as e:
78
128
  if "success" in dto_dict:
79
129
  assert session is not None
80
- if path.endswith(".json"):
81
- path = os.path.splitext(path)[0] + ".pkl"
82
- return legacy.load_legacy_result(session, path, result_json=dto_dict)
130
+ if path_or_data.endswith(".json"):
131
+ path_or_data = os.path.splitext(path_or_data)[0] + ".pkl"
132
+ return legacy.load_legacy_result(session, path_or_data, result_json=dto_dict)
83
133
  raise ValueError("Invalid result schema") from e
84
134
 
85
135
  # Try loading data from file using the protocol info
86
- result_value = None
136
+ payload_value = None
87
137
  data_load_error = None
88
138
  if dto.protocol is not None:
89
139
  try:
90
140
  logger.debug(f"Loading result value with protocol {dto.protocol}")
91
- result_value = DEFAULT_PROTOCOL.load(dto.protocol, session=session, path_transform=path_transform)
141
+ payload_value = DEFAULT_PROTOCOL.load(dto.protocol, session=session, path_transform=path_transform)
92
142
  except sp_exceptions.SnowparkSQLException:
93
143
  raise # Data retrieval errors should be bubbled up
94
144
  except Exception as e:
95
145
  logger.debug(f"Error loading result value with protocol {dto.protocol}: {repr(e)}")
146
+ # Error handling strategy depends on the DTO type:
147
+ # 1. ResultDTO: Soft-fail. The job has already finished.
148
+ # We package the load error into the result so the client can
149
+ # debug or adjust their protocol version to retrieve the output.
150
+ # 2. PayloadDTO : Raise a hard error. If arguments cannot be
151
+ # loaded, the job cannot run. We abort early to prevent execution.
152
+ if not isinstance(dto, ResultDTO):
153
+ raise
96
154
  data_load_error = e
97
155
 
98
- # Wrap serialize_error in a TypeError
156
+ # Prepare to assemble the final result
157
+ if not isinstance(dto, ResultDTO):
158
+ return payload_value
159
+
99
160
  if dto.serialize_error:
100
161
  serialize_error = TypeError("Original result serialization failed with error: " + dto.serialize_error)
101
162
  if data_load_error:
@@ -103,8 +164,7 @@ def load_result(
103
164
  else:
104
165
  data_load_error = serialize_error
105
166
 
106
- # Prepare to assemble the final result
107
- result_value = result_value if result_value is not None else dto.value
167
+ result_value = payload_value if payload_value is not None else dto.value
108
168
  if not dto.success and result_value is None:
109
169
  # Try to reconstruct exception from metadata if available
110
170
  if isinstance(dto.metadata, ExceptionMetadata):
@@ -115,7 +175,6 @@ def load_result(
115
175
  traceback=dto.metadata.traceback,
116
176
  original_repr=dto.metadata.repr,
117
177
  )
118
-
119
178
  # Generate a generic error if we still don't have a value,
120
179
  # attaching the data load error if any
121
180
  if result_value is None:
@@ -0,0 +1,7 @@
1
+ from enum import Enum
2
+
3
+
4
+ class ArgProtocol(Enum):
5
+ NONE = 0
6
+ CLI = 1
7
+ PICKLE = 2
@@ -5,6 +5,7 @@ from snowflake.ml.jobs._utils.types import ComputeResources
5
5
  DEFAULT_CONTAINER_NAME = "main"
6
6
  MEMORY_VOLUME_NAME = "dshm"
7
7
  STAGE_VOLUME_NAME = "stage-volume"
8
+ RESULT_VOLUME_NAME = "result-volume"
8
9
  DEFAULT_PYTHON_VERSION = "3.10"
9
10
 
10
11
  # Environment variables
@@ -109,3 +110,6 @@ CLOUD_INSTANCE_FAMILIES = {
109
110
  SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES,
110
111
  SnowflakeCloudType.GCP: GCP_INSTANCE_FAMILIES,
111
112
  }
113
+
114
+ # Magic attributes
115
+ IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"