snowflake-ml-python 1.25.0__py3-none-any.whl → 1.25.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,6 @@ _SESSION_ACCOUNT_KEY = "session$account"
9
9
  _SESSION_ROLE_KEY = "session$role"
10
10
  _SESSION_DATABASE_KEY = "session$database"
11
11
  _SESSION_SCHEMA_KEY = "session$schema"
12
- _SESSION_STATE_ATTR = "_session_state"
13
12
 
14
13
 
15
14
  def _identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
@@ -62,7 +61,7 @@ class SerializableSessionMixin:
62
61
  else:
63
62
  self.__dict__.update(state)
64
63
 
65
- setattr(self, _SESSION_STATE_ATTR, session_state)
64
+ self._set_session(session_state)
66
65
 
67
66
  def _set_session(self, session_state: _SessionState) -> None:
68
67
 
@@ -87,27 +86,3 @@ class SerializableSessionMixin:
87
86
  ),
88
87
  ),
89
88
  )
90
-
91
- @property
92
- def session(self) -> Optional[snowpark_session.Session]:
93
- if _SESSION_KEY not in self.__dict__:
94
- session_state = getattr(self, _SESSION_STATE_ATTR, None)
95
- if session_state is not None:
96
- self._set_session(session_state)
97
- return self.__dict__.get(_SESSION_KEY)
98
-
99
- @session.setter
100
- def session(self, value: Optional[snowpark_session.Session]) -> None:
101
- self.__dict__[_SESSION_KEY] = value
102
-
103
- # _getattr__ is only called when an attribute is NOT found through normal lookup.
104
- # 1. Data descriptors (like @property with setter) from the class hierarchy
105
- # 2. Instance __dict__ (e.g., self.x = 10)
106
- # 3. Non-data descriptors (methods, `@property without setter) from the class hierarchy
107
- # __getattr__ — only called if steps 1-3 all fail
108
- def __getattr__(self, name: str) -> Any:
109
- if name == _SESSION_KEY:
110
- return self.session
111
- if hasattr(super(), "__getattr__"):
112
- return super().__getattr__(name) # type: ignore[misc]
113
- raise AttributeError(f"{type(self).__name__!s} object has no attribute {name!r}")
@@ -31,7 +31,7 @@ class StageFileWriter(io.IOBase):
31
31
  # Only upload if buffer has content and no exception occurred
32
32
  if write_contents and self._buffer.tell() > 0:
33
33
  self._buffer.seek(0)
34
- self._session.file.put_stream(self._buffer, self._path, auto_compress=False)
34
+ self._session.file.put_stream(self._buffer, self._path)
35
35
  self._buffer.close()
36
36
  self._closed = True
37
37
 
@@ -84,15 +84,15 @@ class DtoCodec(Protocol):
84
84
 
85
85
  @overload
86
86
  @staticmethod
87
- def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.PayloadDTO:
87
+ def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO:
88
88
  ...
89
89
 
90
90
  @staticmethod
91
- def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.PayloadDTO, dict[str, Any]]:
91
+ def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]:
92
92
  pass
93
93
 
94
94
  @staticmethod
95
- def encode(dto: dto_schema.PayloadDTO) -> bytes:
95
+ def encode(dto: dto_schema.ResultDTO) -> bytes:
96
96
  pass
97
97
 
98
98
 
@@ -104,18 +104,18 @@ class JsonDtoCodec(DtoCodec):
104
104
 
105
105
  @overload
106
106
  @staticmethod
107
- def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.PayloadDTO:
107
+ def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO:
108
108
  ...
109
109
 
110
110
  @staticmethod
111
- def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.PayloadDTO, dict[str, Any]]:
111
+ def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]:
112
112
  data = cast(dict[str, Any], json.load(stream))
113
113
  if as_dict:
114
114
  return data
115
- return dto_schema.ResultDTOAdapter.validate_python(data)
115
+ return dto_schema.ResultDTO.model_validate(data)
116
116
 
117
117
  @staticmethod
118
- def encode(dto: dto_schema.PayloadDTO) -> bytes:
118
+ def encode(dto: dto_schema.ResultDTO) -> bytes:
119
119
  # Temporarily extract the value to avoid accidentally applying model_dump() on it
120
120
  result_value = dto.value
121
121
  dto.value = None # Clear value to avoid serializing it in the model_dump
@@ -1,7 +1,7 @@
1
- from typing import Any, Literal, Optional, Union
1
+ from typing import Any, Optional, Union
2
2
 
3
- from pydantic import BaseModel, Discriminator, Tag, TypeAdapter, model_validator
4
- from typing_extensions import Annotated, NotRequired, TypedDict
3
+ from pydantic import BaseModel, model_validator
4
+ from typing_extensions import NotRequired, TypedDict
5
5
 
6
6
 
7
7
  class BinaryManifest(TypedDict):
@@ -67,47 +67,22 @@ class ExceptionMetadata(ResultMetadata):
67
67
  traceback: str
68
68
 
69
69
 
70
- class PayloadDTO(BaseModel):
71
- """
72
- Base class for serializable payloads.
73
-
74
- Args:
75
- kind: Discriminator field for DTO type dispatch.
76
- value: The payload value (if JSON-serializable).
77
- protocol: The protocol used to serialize the payload (if not JSON-serializable).
78
- """
79
-
80
- kind: Literal["base"] = "base"
81
- value: Optional[Any] = None
82
- protocol: Optional[ProtocolInfo] = None
83
- serialize_error: Optional[str] = None
84
-
85
- @model_validator(mode="before")
86
- @classmethod
87
- def validate_fields(cls, data: Any) -> Any:
88
- """Ensure at least one of value or protocol keys is specified."""
89
- if cls is PayloadDTO and isinstance(data, dict):
90
- required_fields = {"value", "protocol"}
91
- if not any(field in data for field in required_fields):
92
- raise ValueError("At least one of 'value' or 'protocol' must be specified")
93
- return data
94
-
95
-
96
- class ResultDTO(PayloadDTO):
70
+ class ResultDTO(BaseModel):
97
71
  """
98
72
  A JSON representation of an execution result.
99
73
 
100
74
  Args:
101
- kind: Discriminator field for DTO type dispatch.
102
75
  success: Whether the execution was successful.
103
76
  value: The value of the execution or the exception if the execution failed.
104
77
  protocol: The protocol used to serialize the result.
105
78
  metadata: The metadata of the result.
106
79
  """
107
80
 
108
- kind: Literal["result"] = "result" # type: ignore[assignment]
109
81
  success: bool
82
+ value: Optional[Any] = None
83
+ protocol: Optional[ProtocolInfo] = None
110
84
  metadata: Optional[Union[ResultMetadata, ExceptionMetadata]] = None
85
+ serialize_error: Optional[str] = None
111
86
 
112
87
  @model_validator(mode="before")
113
88
  @classmethod
@@ -118,23 +93,3 @@ class ResultDTO(PayloadDTO):
118
93
  if not any(field in data for field in required_fields):
119
94
  raise ValueError("At least one of 'value', 'protocol', or 'metadata' must be specified")
120
95
  return data
121
-
122
-
123
- def _get_dto_kind(data: Any) -> str:
124
- """Extract the 'kind' discriminator from input, defaulting to 'result' for backward compatibility."""
125
- if isinstance(data, dict):
126
- kind = data.get("kind", "result")
127
- else:
128
- kind = getattr(data, "kind", "result")
129
- return str(kind)
130
-
131
-
132
- AnyResultDTO = Annotated[
133
- Union[
134
- Annotated[ResultDTO, Tag("result")],
135
- Annotated[PayloadDTO, Tag("base")],
136
- ],
137
- Discriminator(_get_dto_kind),
138
- ]
139
-
140
- ResultDTOAdapter: TypeAdapter[AnyResultDTO] = TypeAdapter(AnyResultDTO)
@@ -17,8 +17,6 @@ Condition = Union[type, tuple[type, ...], Callable[[Any], bool], None]
17
17
 
18
18
  logger = logging.getLogger(__name__)
19
19
 
20
- SESSION_KEY_PREFIX = "session@"
21
-
22
20
 
23
21
  class SerializationError(TypeError):
24
22
  """Exception raised when a serialization protocol fails."""
@@ -138,10 +136,9 @@ class CloudPickleProtocol(SerializationProtocol):
138
136
 
139
137
  def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
140
138
  """Save the object to the destination directory."""
141
- replaced_obj = self._pack_obj(obj)
142
139
  result_path = posixpath.join(dest_dir, self.DEFAULT_PATH)
143
140
  with data_utils.open_stream(result_path, "wb", session=session) as f:
144
- self._backend.dump(replaced_obj, f)
141
+ self._backend.dump(obj, f)
145
142
  manifest: BinaryManifest = {"path": result_path}
146
143
  return self.protocol_info.with_manifest(manifest)
147
144
 
@@ -160,15 +157,12 @@ class CloudPickleProtocol(SerializationProtocol):
160
157
  payload_manifest = cast(BinaryManifest, payload_info.manifest)
161
158
  try:
162
159
  if payload_bytes := payload_manifest.get("bytes"):
163
- result = self._backend.loads(payload_bytes)
164
- elif payload_b64 := payload_manifest.get("base64"):
165
- result = self._backend.loads(base64.b64decode(payload_b64))
166
- else:
167
- result_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"]
168
- with data_utils.open_stream(result_path, "rb", session=session) as f:
169
- result = self._backend.load(f)
170
-
171
- return self._unpack_obj(result, session=session)
160
+ return self._backend.loads(payload_bytes)
161
+ if payload_b64 := payload_manifest.get("base64"):
162
+ return self._backend.loads(base64.b64decode(payload_b64))
163
+ result_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"]
164
+ with data_utils.open_stream(result_path, "rb", session=session) as f:
165
+ return self._backend.load(f)
172
166
  except (
173
167
  pickle.UnpicklingError,
174
168
  TypeError,
@@ -179,117 +173,6 @@ class CloudPickleProtocol(SerializationProtocol):
179
173
  raise error from pickle_error
180
174
  raise
181
175
 
182
- def _pack_obj(self, obj: Any) -> Any:
183
- """Pack objects into JSON-safe dicts using reserved marker keys.
184
-
185
- Markers:
186
- - "type@": container type for list/tuple (list or tuple)
187
- - "#<i>": positional element for list/tuple at index i
188
- - "session@": placeholder for snowpark.Session values
189
- - "session@#<i>" for list/tuple entries
190
- - "session@<key>" for dict entries
191
- - {"session@": None} for a bare Session object
192
-
193
- Example:
194
- obj = {"x": [1, session], "s": session}
195
- packed = {
196
- "x": {"type@": list, "#0": 1, "session@#1": None},
197
- "session@s": None,
198
- }
199
- _unpack_obj(packed, session) == obj
200
-
201
- Args:
202
- obj: Object to pack into JSON-safe marker dictionaries.
203
-
204
- Returns:
205
- Packed representation with markers for session references.
206
- """
207
- arguments: dict[str, Any] = {}
208
- if isinstance(obj, tuple) or isinstance(obj, list):
209
- arguments = {"type@": type(obj)}
210
- for i, arg in enumerate(obj):
211
- if isinstance(arg, snowpark.Session):
212
- arguments[f"{SESSION_KEY_PREFIX}#{i}"] = None
213
- else:
214
- arguments[f"#{i}"] = self._pack_obj(arg)
215
- return arguments
216
- elif isinstance(obj, dict):
217
- for k, v in obj.items():
218
- if isinstance(v, snowpark.Session):
219
- arguments[f"{SESSION_KEY_PREFIX}{k}"] = None
220
- else:
221
- arguments[k] = self._pack_obj(v)
222
- return arguments
223
- elif isinstance(obj, snowpark.Session):
224
- # Box session into a dict marker so we can distinguish it from other plain objects.
225
- arguments[f"{SESSION_KEY_PREFIX}"] = None
226
- return arguments
227
- else:
228
- return obj
229
-
230
- def _unpack_obj(self, obj: Any, session: Optional[snowpark.Session] = None) -> Any:
231
- """Unpack dict markers back into containers and Session references.
232
-
233
- Markers:
234
- - "type@": container type for list/tuple (list or tuple)
235
- - "#<i>": positional element for list/tuple at index i
236
- - "session@": placeholder for snowpark.Session values
237
- - "session@#<i>" for list/tuple entries
238
- - "session@<key>" for dict entries
239
- - {"session@": None} for a bare Session object
240
-
241
- Example:
242
- packed = {
243
- "x": {"type@": list, "#0": 1, "session@#1": None},
244
- "session@s": None,
245
- }
246
- obj = _unpack_obj(packed, session)
247
- # obj == {"x": [1, session], "s": session}
248
-
249
- Args:
250
- obj: Packed object with marker dictionaries.
251
- session: Session to inject for session markers.
252
-
253
- Returns:
254
- Unpacked object with session references restored.
255
- """
256
- if not isinstance(obj, dict):
257
- return obj
258
- elif len(obj) == 1 and SESSION_KEY_PREFIX in obj:
259
- return session
260
- else:
261
- type = obj.get("type@", None)
262
- # If type is None, we are unpacking a dict
263
- if type is None:
264
- result_dict = {}
265
- for k, v in obj.items():
266
- if k.startswith(SESSION_KEY_PREFIX):
267
- result_key = k[len(SESSION_KEY_PREFIX) :]
268
- result_dict[result_key] = session
269
- else:
270
- result_dict[k] = self._unpack_obj(v, session)
271
- return result_dict
272
- # If type is not None, we are unpacking a tuple or list
273
- else:
274
- indexes = []
275
- for k, _ in obj.items():
276
- if "#" in k:
277
- indexes.append(int(k.split("#")[-1]))
278
-
279
- if not indexes:
280
- return tuple() if type is tuple else []
281
- result_list: list[Any] = [None] * (max(indexes) + 1)
282
-
283
- for k, v in obj.items():
284
- if k == "type@":
285
- continue
286
- idx = int(k.split("#")[-1])
287
- if k.startswith(SESSION_KEY_PREFIX):
288
- result_list[idx] = session
289
- else:
290
- result_list[idx] = self._unpack_obj(v, session)
291
- return tuple(result_list) if type is tuple else result_list
292
-
293
176
 
294
177
  class ArrowTableProtocol(SerializationProtocol):
295
178
  """
@@ -1,4 +1,3 @@
1
- import io
2
1
  import logging
3
2
  import os
4
3
  import traceback
@@ -11,9 +10,7 @@ from snowflake import snowpark
11
10
  from snowflake.ml.jobs._interop import data_utils, exception_utils, legacy, protocols
12
11
  from snowflake.ml.jobs._interop.dto_schema import (
13
12
  ExceptionMetadata,
14
- PayloadDTO,
15
13
  ResultDTO,
16
- ResultDTOAdapter,
17
14
  ResultMetadata,
18
15
  )
19
16
  from snowflake.ml.jobs._interop.results import ExecutionResult, LoadedExecutionResult
@@ -26,137 +23,79 @@ DEFAULT_PROTOCOL.try_register_protocol(protocols.ArrowTableProtocol)
26
23
  DEFAULT_PROTOCOL.try_register_protocol(protocols.PandasDataFrameProtocol)
27
24
  DEFAULT_PROTOCOL.try_register_protocol(protocols.NumpyArrayProtocol)
28
25
 
29
- # Constants for argument encoding
30
- _MAX_INLINE_SIZE = 1024 * 1024 # 1MB - https://docs.snowflake.com/en/user-guide/query-size-limits
31
26
 
32
27
  logger = logging.getLogger(__name__)
33
28
 
34
29
 
35
- def save(
36
- value: Any,
37
- path: str,
38
- session: Optional[snowpark.Session] = None,
39
- max_inline_size: int = 0,
40
- ) -> Optional[bytes]:
30
+ def save_result(result: ExecutionResult, path: str, session: Optional[snowpark.Session] = None) -> None:
41
31
  """
42
- Serialize a value. Returns inline bytes if small enough, else writes to file.
43
-
44
- Args:
45
- value: The value to serialize. If ExecutionResult, creates ResultDTO with success flag.
46
- path: Full file path for writing the DTO (if needed). Protocol data saved to path's parent.
47
- session: Snowpark session for stage operations.
48
- max_inline_size: Max bytes for inline return. 0 = always write to file.
49
-
50
- Returns:
51
- Encoded bytes if <= max_inline_size, else None (written to file).
52
-
53
- Raises:
54
- Exception: If session validation fails during serialization.
32
+ Save the result to a file.
55
33
  """
56
- if isinstance(value, ExecutionResult):
57
- dto: PayloadDTO = ResultDTO(success=value.success, value=value.value)
58
- raw_value = value.value
59
- else:
60
- dto = PayloadDTO(value=value)
61
- raw_value = value
34
+ result_dto = ResultDTO(
35
+ success=result.success,
36
+ value=result.value,
37
+ )
62
38
 
63
39
  try:
64
- payload = DEFAULT_CODEC.encode(dto)
40
+ # Try to encode result directly
41
+ payload = DEFAULT_CODEC.encode(result_dto)
65
42
  except TypeError:
66
- dto.value = None # Remove raw value to avoid serialization error
67
- if isinstance(dto, ResultDTO):
68
- # Metadata enables client fallback display when result can't be deserialized (protocol mismatch)..
69
- dto.metadata = _get_metadata(raw_value)
43
+ result_dto.value = None # Remove raw value to avoid serialization error
44
+ result_dto.metadata = _get_metadata(result.value) # Add metadata for client fallback on protocol mismatch
70
45
  try:
71
46
  path_dir = PurePath(path).parent.as_posix()
72
- protocol_info = DEFAULT_PROTOCOL.save(raw_value, path_dir, session=session)
73
- dto.protocol = protocol_info
47
+ protocol_info = DEFAULT_PROTOCOL.save(result.value, path_dir, session=session)
48
+ result_dto.protocol = protocol_info
74
49
 
75
50
  except Exception as e:
76
51
  logger.warning(f"Error dumping result value: {repr(e)}")
77
- # We handle serialization failures differently based on the DTO type:
78
- # 1. Job Results (ResultDTO): Allow a "soft-fail."
79
- # Since the job has already executed, we return the serialization error
80
- # to the client so they can debug the output or update their protocol version.
81
- # 2. Input Arguments: Trigger a "hard-fail."
82
- # If arguments cannot be saved, the job script cannot run. We raise
83
- # an immediate exception to prevent execution with invalid state.
84
- if not isinstance(dto, ResultDTO):
85
- raise
86
- dto.serialize_error = repr(e)
87
-
88
- # Encode the modified DTO
89
- payload = DEFAULT_CODEC.encode(dto)
90
-
91
- if not isinstance(dto, ResultDTO) and len(payload) <= max_inline_size:
92
- return payload
52
+ result_dto.serialize_error = repr(e)
53
+
54
+ # Encode the modified result DTO
55
+ payload = DEFAULT_CODEC.encode(result_dto)
93
56
 
94
57
  with data_utils.open_stream(path, "wb", session=session) as stream:
95
58
  stream.write(payload)
96
- return None
97
-
98
-
99
- save_result = save # Backwards compatibility
100
59
 
101
60
 
102
- def load(
103
- path_or_data: str,
104
- session: Optional[snowpark.Session] = None,
105
- path_transform: Optional[Callable[[str], str]] = None,
106
- ) -> Any:
107
- """Load data from a file path or inline string."""
108
-
61
+ def load_result(
62
+ path: str, session: Optional[snowpark.Session] = None, path_transform: Optional[Callable[[str], str]] = None
63
+ ) -> ExecutionResult:
64
+ """Load the result from a file on a Snowflake stage."""
109
65
  try:
110
- with data_utils.open_stream(path_or_data, "r", session=session) as stream:
66
+ with data_utils.open_stream(path, "r", session=session) as stream:
111
67
  # Load the DTO as a dict for easy fallback to legacy loading if necessary
112
68
  dto_dict = DEFAULT_CODEC.decode(stream, as_dict=True)
113
- # the exception could be OSError or BlockingIOError(the file name is too long)
114
- except OSError as e:
115
- # path_or_data might be inline data
116
- try:
117
- dto_dict = DEFAULT_CODEC.decode(io.StringIO(path_or_data), as_dict=True)
118
- except Exception:
119
- raise e
120
69
  except UnicodeDecodeError:
121
70
  # Path may be a legacy result file (cloudpickle)
71
+ # TODO: Re-use the stream
122
72
  assert session is not None
123
- return legacy.load_legacy_result(session, path_or_data)
73
+ return legacy.load_legacy_result(session, path)
124
74
 
125
75
  try:
126
- dto = ResultDTOAdapter.validate_python(dto_dict)
76
+ dto = ResultDTO.model_validate(dto_dict)
127
77
  except pydantic.ValidationError as e:
128
78
  if "success" in dto_dict:
129
79
  assert session is not None
130
- if path_or_data.endswith(".json"):
131
- path_or_data = os.path.splitext(path_or_data)[0] + ".pkl"
132
- return legacy.load_legacy_result(session, path_or_data, result_json=dto_dict)
80
+ if path.endswith(".json"):
81
+ path = os.path.splitext(path)[0] + ".pkl"
82
+ return legacy.load_legacy_result(session, path, result_json=dto_dict)
133
83
  raise ValueError("Invalid result schema") from e
134
84
 
135
85
  # Try loading data from file using the protocol info
136
- payload_value = None
86
+ result_value = None
137
87
  data_load_error = None
138
88
  if dto.protocol is not None:
139
89
  try:
140
90
  logger.debug(f"Loading result value with protocol {dto.protocol}")
141
- payload_value = DEFAULT_PROTOCOL.load(dto.protocol, session=session, path_transform=path_transform)
91
+ result_value = DEFAULT_PROTOCOL.load(dto.protocol, session=session, path_transform=path_transform)
142
92
  except sp_exceptions.SnowparkSQLException:
143
93
  raise # Data retrieval errors should be bubbled up
144
94
  except Exception as e:
145
95
  logger.debug(f"Error loading result value with protocol {dto.protocol}: {repr(e)}")
146
- # Error handling strategy depends on the DTO type:
147
- # 1. ResultDTO: Soft-fail. The job has already finished.
148
- # We package the load error into the result so the client can
149
- # debug or adjust their protocol version to retrieve the output.
150
- # 2. PayloadDTO : Raise a hard error. If arguments cannot be
151
- # loaded, the job cannot run. We abort early to prevent execution.
152
- if not isinstance(dto, ResultDTO):
153
- raise
154
96
  data_load_error = e
155
97
 
156
- # Prepare to assemble the final result
157
- if not isinstance(dto, ResultDTO):
158
- return payload_value
159
-
98
+ # Wrap serialize_error in a TypeError
160
99
  if dto.serialize_error:
161
100
  serialize_error = TypeError("Original result serialization failed with error: " + dto.serialize_error)
162
101
  if data_load_error:
@@ -164,7 +103,8 @@ def load(
164
103
  else:
165
104
  data_load_error = serialize_error
166
105
 
167
- result_value = payload_value if payload_value is not None else dto.value
106
+ # Prepare to assemble the final result
107
+ result_value = result_value if result_value is not None else dto.value
168
108
  if not dto.success and result_value is None:
169
109
  # Try to reconstruct exception from metadata if available
170
110
  if isinstance(dto.metadata, ExceptionMetadata):
@@ -175,6 +115,7 @@ def load(
175
115
  traceback=dto.metadata.traceback,
176
116
  original_repr=dto.metadata.repr,
177
117
  )
118
+
178
119
  # Generate a generic error if we still don't have a value,
179
120
  # attaching the data load error if any
180
121
  if result_value is None:
@@ -5,7 +5,6 @@ from snowflake.ml.jobs._utils.types import ComputeResources
5
5
  DEFAULT_CONTAINER_NAME = "main"
6
6
  MEMORY_VOLUME_NAME = "dshm"
7
7
  STAGE_VOLUME_NAME = "stage-volume"
8
- RESULT_VOLUME_NAME = "result-volume"
9
8
  DEFAULT_PYTHON_VERSION = "3.10"
10
9
 
11
10
  # Environment variables
@@ -110,6 +109,3 @@ CLOUD_INSTANCE_FAMILIES = {
110
109
  SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES,
111
110
  SnowflakeCloudType.GCP: GCP_INSTANCE_FAMILIES,
112
111
  }
113
-
114
- # Magic attributes
115
- IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
@@ -0,0 +1,43 @@
1
+ import inspect
2
+ from typing import Any, Callable, Optional
3
+
4
+ from snowflake import snowpark
5
+ from snowflake.snowpark import context as sp_context
6
+
7
+
8
+ class FunctionPayload:
9
+ def __init__(
10
+ self,
11
+ func: Callable[..., Any],
12
+ session: Optional[snowpark.Session] = None,
13
+ session_argument: str = "",
14
+ *args: Any,
15
+ **kwargs: Any
16
+ ) -> None:
17
+ self.function = func
18
+ self.args = args
19
+ self.kwargs = kwargs
20
+ self._session = session
21
+ self._session_argument = session_argument
22
+
23
+ @property
24
+ def session(self) -> Optional[snowpark.Session]:
25
+ return self._session
26
+
27
+ def __getstate__(self) -> dict[str, Any]:
28
+ """Customize pickling to exclude session."""
29
+ state = self.__dict__.copy()
30
+ state["_session"] = None
31
+ return state
32
+
33
+ def __setstate__(self, state: dict[str, Any]) -> None:
34
+ """Restore session from context during unpickling."""
35
+ self.__dict__.update(state)
36
+ self._session = sp_context.get_active_session()
37
+
38
+ def __call__(self) -> Any:
39
+ sig = inspect.signature(self.function)
40
+ bound = sig.bind_partial(*self.args, **self.kwargs)
41
+ bound.arguments[self._session_argument] = self._session
42
+
43
+ return self.function(*bound.args, **bound.kwargs)