snowflake-ml-python 1.23.0__py3-none-any.whl → 1.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +0 -4
- snowflake/ml/_internal/utils/mixins.py +26 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +5 -1
- snowflake/ml/data/data_connector.py +2 -2
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/experiment/_experiment_info.py +3 -3
- snowflake/ml/feature_store/__init__.py +2 -0
- snowflake/ml/feature_store/aggregation.py +367 -0
- snowflake/ml/feature_store/feature.py +366 -0
- snowflake/ml/feature_store/feature_store.py +234 -20
- snowflake/ml/feature_store/feature_view.py +189 -4
- snowflake/ml/feature_store/metadata_manager.py +425 -0
- snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
- snowflake/ml/jobs/_interop/data_utils.py +8 -8
- snowflake/ml/jobs/_interop/dto_schema.py +52 -7
- snowflake/ml/jobs/_interop/protocols.py +124 -7
- snowflake/ml/jobs/_interop/utils.py +92 -33
- snowflake/ml/jobs/_utils/arg_protocol.py +7 -0
- snowflake/ml/jobs/_utils/constants.py +4 -0
- snowflake/ml/jobs/_utils/feature_flags.py +97 -13
- snowflake/ml/jobs/_utils/payload_utils.py +6 -40
- snowflake/ml/jobs/_utils/runtime_env_utils.py +12 -111
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +204 -27
- snowflake/ml/jobs/decorators.py +17 -22
- snowflake/ml/jobs/job.py +25 -10
- snowflake/ml/jobs/job_definition.py +100 -8
- snowflake/ml/model/__init__.py +4 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
- snowflake/ml/model/_client/model/model_version_impl.py +56 -28
- snowflake/ml/model/_client/ops/model_ops.py +2 -8
- snowflake/ml/model/_client/ops/service_ops.py +6 -11
- snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/service.py +21 -29
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -1
- snowflake/ml/model/_packager/model_handlers/huggingface.py +20 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +70 -14
- snowflake/ml/model/_signatures/utils.py +76 -1
- snowflake/ml/model/models/huggingface_pipeline.py +3 -0
- snowflake/ml/model/openai_signatures.py +154 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/METADATA +79 -2
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/RECORD +47 -44
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/WHEEL +1 -1
- snowflake/ml/jobs/_utils/function_payload_utils.py +0 -43
- snowflake/ml/jobs/_utils/spec_utils.py +0 -22
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/top_level.txt +0 -0
|
@@ -31,7 +31,7 @@ class StageFileWriter(io.IOBase):
|
|
|
31
31
|
# Only upload if buffer has content and no exception occurred
|
|
32
32
|
if write_contents and self._buffer.tell() > 0:
|
|
33
33
|
self._buffer.seek(0)
|
|
34
|
-
self._session.file.put_stream(self._buffer, self._path)
|
|
34
|
+
self._session.file.put_stream(self._buffer, self._path, auto_compress=False)
|
|
35
35
|
self._buffer.close()
|
|
36
36
|
self._closed = True
|
|
37
37
|
|
|
@@ -84,15 +84,15 @@ class DtoCodec(Protocol):
|
|
|
84
84
|
|
|
85
85
|
@overload
|
|
86
86
|
@staticmethod
|
|
87
|
-
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.
|
|
87
|
+
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.PayloadDTO:
|
|
88
88
|
...
|
|
89
89
|
|
|
90
90
|
@staticmethod
|
|
91
|
-
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.
|
|
91
|
+
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.PayloadDTO, dict[str, Any]]:
|
|
92
92
|
pass
|
|
93
93
|
|
|
94
94
|
@staticmethod
|
|
95
|
-
def encode(dto: dto_schema.
|
|
95
|
+
def encode(dto: dto_schema.PayloadDTO) -> bytes:
|
|
96
96
|
pass
|
|
97
97
|
|
|
98
98
|
|
|
@@ -104,18 +104,18 @@ class JsonDtoCodec(DtoCodec):
|
|
|
104
104
|
|
|
105
105
|
@overload
|
|
106
106
|
@staticmethod
|
|
107
|
-
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.
|
|
107
|
+
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.PayloadDTO:
|
|
108
108
|
...
|
|
109
109
|
|
|
110
110
|
@staticmethod
|
|
111
|
-
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.
|
|
111
|
+
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.PayloadDTO, dict[str, Any]]:
|
|
112
112
|
data = cast(dict[str, Any], json.load(stream))
|
|
113
113
|
if as_dict:
|
|
114
114
|
return data
|
|
115
|
-
return dto_schema.
|
|
115
|
+
return dto_schema.ResultDTOAdapter.validate_python(data)
|
|
116
116
|
|
|
117
117
|
@staticmethod
|
|
118
|
-
def encode(dto: dto_schema.
|
|
118
|
+
def encode(dto: dto_schema.PayloadDTO) -> bytes:
|
|
119
119
|
# Temporarily extract the value to avoid accidentally applying model_dump() on it
|
|
120
120
|
result_value = dto.value
|
|
121
121
|
dto.value = None # Clear value to avoid serializing it in the model_dump
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from typing import Any, Optional, Union
|
|
1
|
+
from typing import Any, Literal, Optional, Union
|
|
2
2
|
|
|
3
|
-
from pydantic import BaseModel, model_validator
|
|
4
|
-
from typing_extensions import NotRequired, TypedDict
|
|
3
|
+
from pydantic import BaseModel, Discriminator, Tag, TypeAdapter, model_validator
|
|
4
|
+
from typing_extensions import Annotated, NotRequired, TypedDict
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class BinaryManifest(TypedDict):
|
|
@@ -67,22 +67,47 @@ class ExceptionMetadata(ResultMetadata):
|
|
|
67
67
|
traceback: str
|
|
68
68
|
|
|
69
69
|
|
|
70
|
-
class
|
|
70
|
+
class PayloadDTO(BaseModel):
|
|
71
|
+
"""
|
|
72
|
+
Base class for serializable payloads.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
kind: Discriminator field for DTO type dispatch.
|
|
76
|
+
value: The payload value (if JSON-serializable).
|
|
77
|
+
protocol: The protocol used to serialize the payload (if not JSON-serializable).
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
kind: Literal["base"] = "base"
|
|
81
|
+
value: Optional[Any] = None
|
|
82
|
+
protocol: Optional[ProtocolInfo] = None
|
|
83
|
+
serialize_error: Optional[str] = None
|
|
84
|
+
|
|
85
|
+
@model_validator(mode="before")
|
|
86
|
+
@classmethod
|
|
87
|
+
def validate_fields(cls, data: Any) -> Any:
|
|
88
|
+
"""Ensure at least one of value or protocol keys is specified."""
|
|
89
|
+
if cls is PayloadDTO and isinstance(data, dict):
|
|
90
|
+
required_fields = {"value", "protocol"}
|
|
91
|
+
if not any(field in data for field in required_fields):
|
|
92
|
+
raise ValueError("At least one of 'value' or 'protocol' must be specified")
|
|
93
|
+
return data
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class ResultDTO(PayloadDTO):
|
|
71
97
|
"""
|
|
72
98
|
A JSON representation of an execution result.
|
|
73
99
|
|
|
74
100
|
Args:
|
|
101
|
+
kind: Discriminator field for DTO type dispatch.
|
|
75
102
|
success: Whether the execution was successful.
|
|
76
103
|
value: The value of the execution or the exception if the execution failed.
|
|
77
104
|
protocol: The protocol used to serialize the result.
|
|
78
105
|
metadata: The metadata of the result.
|
|
79
106
|
"""
|
|
80
107
|
|
|
108
|
+
kind: Literal["result"] = "result" # type: ignore[assignment]
|
|
81
109
|
success: bool
|
|
82
|
-
value: Optional[Any] = None
|
|
83
|
-
protocol: Optional[ProtocolInfo] = None
|
|
84
110
|
metadata: Optional[Union[ResultMetadata, ExceptionMetadata]] = None
|
|
85
|
-
serialize_error: Optional[str] = None
|
|
86
111
|
|
|
87
112
|
@model_validator(mode="before")
|
|
88
113
|
@classmethod
|
|
@@ -93,3 +118,23 @@ class ResultDTO(BaseModel):
|
|
|
93
118
|
if not any(field in data for field in required_fields):
|
|
94
119
|
raise ValueError("At least one of 'value', 'protocol', or 'metadata' must be specified")
|
|
95
120
|
return data
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _get_dto_kind(data: Any) -> str:
|
|
124
|
+
"""Extract the 'kind' discriminator from input, defaulting to 'result' for backward compatibility."""
|
|
125
|
+
if isinstance(data, dict):
|
|
126
|
+
kind = data.get("kind", "result")
|
|
127
|
+
else:
|
|
128
|
+
kind = getattr(data, "kind", "result")
|
|
129
|
+
return str(kind)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
AnyResultDTO = Annotated[
|
|
133
|
+
Union[
|
|
134
|
+
Annotated[ResultDTO, Tag("result")],
|
|
135
|
+
Annotated[PayloadDTO, Tag("base")],
|
|
136
|
+
],
|
|
137
|
+
Discriminator(_get_dto_kind),
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
ResultDTOAdapter: TypeAdapter[AnyResultDTO] = TypeAdapter(AnyResultDTO)
|
|
@@ -17,6 +17,8 @@ Condition = Union[type, tuple[type, ...], Callable[[Any], bool], None]
|
|
|
17
17
|
|
|
18
18
|
logger = logging.getLogger(__name__)
|
|
19
19
|
|
|
20
|
+
SESSION_KEY_PREFIX = "session@"
|
|
21
|
+
|
|
20
22
|
|
|
21
23
|
class SerializationError(TypeError):
|
|
22
24
|
"""Exception raised when a serialization protocol fails."""
|
|
@@ -136,9 +138,10 @@ class CloudPickleProtocol(SerializationProtocol):
|
|
|
136
138
|
|
|
137
139
|
def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
|
|
138
140
|
"""Save the object to the destination directory."""
|
|
141
|
+
replaced_obj = self._pack_obj(obj)
|
|
139
142
|
result_path = posixpath.join(dest_dir, self.DEFAULT_PATH)
|
|
140
143
|
with data_utils.open_stream(result_path, "wb", session=session) as f:
|
|
141
|
-
self._backend.dump(
|
|
144
|
+
self._backend.dump(replaced_obj, f)
|
|
142
145
|
manifest: BinaryManifest = {"path": result_path}
|
|
143
146
|
return self.protocol_info.with_manifest(manifest)
|
|
144
147
|
|
|
@@ -157,12 +160,15 @@ class CloudPickleProtocol(SerializationProtocol):
|
|
|
157
160
|
payload_manifest = cast(BinaryManifest, payload_info.manifest)
|
|
158
161
|
try:
|
|
159
162
|
if payload_bytes := payload_manifest.get("bytes"):
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
163
|
+
result = self._backend.loads(payload_bytes)
|
|
164
|
+
elif payload_b64 := payload_manifest.get("base64"):
|
|
165
|
+
result = self._backend.loads(base64.b64decode(payload_b64))
|
|
166
|
+
else:
|
|
167
|
+
result_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"]
|
|
168
|
+
with data_utils.open_stream(result_path, "rb", session=session) as f:
|
|
169
|
+
result = self._backend.load(f)
|
|
170
|
+
|
|
171
|
+
return self._unpack_obj(result, session=session)
|
|
166
172
|
except (
|
|
167
173
|
pickle.UnpicklingError,
|
|
168
174
|
TypeError,
|
|
@@ -173,6 +179,117 @@ class CloudPickleProtocol(SerializationProtocol):
|
|
|
173
179
|
raise error from pickle_error
|
|
174
180
|
raise
|
|
175
181
|
|
|
182
|
+
def _pack_obj(self, obj: Any) -> Any:
|
|
183
|
+
"""Pack objects into JSON-safe dicts using reserved marker keys.
|
|
184
|
+
|
|
185
|
+
Markers:
|
|
186
|
+
- "type@": container type for list/tuple (list or tuple)
|
|
187
|
+
- "#<i>": positional element for list/tuple at index i
|
|
188
|
+
- "session@": placeholder for snowpark.Session values
|
|
189
|
+
- "session@#<i>" for list/tuple entries
|
|
190
|
+
- "session@<key>" for dict entries
|
|
191
|
+
- {"session@": None} for a bare Session object
|
|
192
|
+
|
|
193
|
+
Example:
|
|
194
|
+
obj = {"x": [1, session], "s": session}
|
|
195
|
+
packed = {
|
|
196
|
+
"x": {"type@": list, "#0": 1, "session@#1": None},
|
|
197
|
+
"session@s": None,
|
|
198
|
+
}
|
|
199
|
+
_unpack_obj(packed, session) == obj
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
obj: Object to pack into JSON-safe marker dictionaries.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
Packed representation with markers for session references.
|
|
206
|
+
"""
|
|
207
|
+
arguments: dict[str, Any] = {}
|
|
208
|
+
if isinstance(obj, tuple) or isinstance(obj, list):
|
|
209
|
+
arguments = {"type@": type(obj)}
|
|
210
|
+
for i, arg in enumerate(obj):
|
|
211
|
+
if isinstance(arg, snowpark.Session):
|
|
212
|
+
arguments[f"{SESSION_KEY_PREFIX}#{i}"] = None
|
|
213
|
+
else:
|
|
214
|
+
arguments[f"#{i}"] = self._pack_obj(arg)
|
|
215
|
+
return arguments
|
|
216
|
+
elif isinstance(obj, dict):
|
|
217
|
+
for k, v in obj.items():
|
|
218
|
+
if isinstance(v, snowpark.Session):
|
|
219
|
+
arguments[f"{SESSION_KEY_PREFIX}{k}"] = None
|
|
220
|
+
else:
|
|
221
|
+
arguments[k] = self._pack_obj(v)
|
|
222
|
+
return arguments
|
|
223
|
+
elif isinstance(obj, snowpark.Session):
|
|
224
|
+
# Box session into a dict marker so we can distinguish it from other plain objects.
|
|
225
|
+
arguments[f"{SESSION_KEY_PREFIX}"] = None
|
|
226
|
+
return arguments
|
|
227
|
+
else:
|
|
228
|
+
return obj
|
|
229
|
+
|
|
230
|
+
def _unpack_obj(self, obj: Any, session: Optional[snowpark.Session] = None) -> Any:
|
|
231
|
+
"""Unpack dict markers back into containers and Session references.
|
|
232
|
+
|
|
233
|
+
Markers:
|
|
234
|
+
- "type@": container type for list/tuple (list or tuple)
|
|
235
|
+
- "#<i>": positional element for list/tuple at index i
|
|
236
|
+
- "session@": placeholder for snowpark.Session values
|
|
237
|
+
- "session@#<i>" for list/tuple entries
|
|
238
|
+
- "session@<key>" for dict entries
|
|
239
|
+
- {"session@": None} for a bare Session object
|
|
240
|
+
|
|
241
|
+
Example:
|
|
242
|
+
packed = {
|
|
243
|
+
"x": {"type@": list, "#0": 1, "session@#1": None},
|
|
244
|
+
"session@s": None,
|
|
245
|
+
}
|
|
246
|
+
obj = _unpack_obj(packed, session)
|
|
247
|
+
# obj == {"x": [1, session], "s": session}
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
obj: Packed object with marker dictionaries.
|
|
251
|
+
session: Session to inject for session markers.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Unpacked object with session references restored.
|
|
255
|
+
"""
|
|
256
|
+
if not isinstance(obj, dict):
|
|
257
|
+
return obj
|
|
258
|
+
elif len(obj) == 1 and SESSION_KEY_PREFIX in obj:
|
|
259
|
+
return session
|
|
260
|
+
else:
|
|
261
|
+
type = obj.get("type@", None)
|
|
262
|
+
# If type is None, we are unpacking a dict
|
|
263
|
+
if type is None:
|
|
264
|
+
result_dict = {}
|
|
265
|
+
for k, v in obj.items():
|
|
266
|
+
if k.startswith(SESSION_KEY_PREFIX):
|
|
267
|
+
result_key = k[len(SESSION_KEY_PREFIX) :]
|
|
268
|
+
result_dict[result_key] = session
|
|
269
|
+
else:
|
|
270
|
+
result_dict[k] = self._unpack_obj(v, session)
|
|
271
|
+
return result_dict
|
|
272
|
+
# If type is not None, we are unpacking a tuple or list
|
|
273
|
+
else:
|
|
274
|
+
indexes = []
|
|
275
|
+
for k, _ in obj.items():
|
|
276
|
+
if "#" in k:
|
|
277
|
+
indexes.append(int(k.split("#")[-1]))
|
|
278
|
+
|
|
279
|
+
if not indexes:
|
|
280
|
+
return tuple() if type is tuple else []
|
|
281
|
+
result_list: list[Any] = [None] * (max(indexes) + 1)
|
|
282
|
+
|
|
283
|
+
for k, v in obj.items():
|
|
284
|
+
if k == "type@":
|
|
285
|
+
continue
|
|
286
|
+
idx = int(k.split("#")[-1])
|
|
287
|
+
if k.startswith(SESSION_KEY_PREFIX):
|
|
288
|
+
result_list[idx] = session
|
|
289
|
+
else:
|
|
290
|
+
result_list[idx] = self._unpack_obj(v, session)
|
|
291
|
+
return tuple(result_list) if type is tuple else result_list
|
|
292
|
+
|
|
176
293
|
|
|
177
294
|
class ArrowTableProtocol(SerializationProtocol):
|
|
178
295
|
"""
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import io
|
|
1
2
|
import logging
|
|
2
3
|
import os
|
|
3
4
|
import traceback
|
|
@@ -10,7 +11,9 @@ from snowflake import snowpark
|
|
|
10
11
|
from snowflake.ml.jobs._interop import data_utils, exception_utils, legacy, protocols
|
|
11
12
|
from snowflake.ml.jobs._interop.dto_schema import (
|
|
12
13
|
ExceptionMetadata,
|
|
14
|
+
PayloadDTO,
|
|
13
15
|
ResultDTO,
|
|
16
|
+
ResultDTOAdapter,
|
|
14
17
|
ResultMetadata,
|
|
15
18
|
)
|
|
16
19
|
from snowflake.ml.jobs._interop.results import ExecutionResult, LoadedExecutionResult
|
|
@@ -23,79 +26,137 @@ DEFAULT_PROTOCOL.try_register_protocol(protocols.ArrowTableProtocol)
|
|
|
23
26
|
DEFAULT_PROTOCOL.try_register_protocol(protocols.PandasDataFrameProtocol)
|
|
24
27
|
DEFAULT_PROTOCOL.try_register_protocol(protocols.NumpyArrayProtocol)
|
|
25
28
|
|
|
29
|
+
# Constants for argument encoding
|
|
30
|
+
_MAX_INLINE_SIZE = 1024 * 1024 # 1MB - https://docs.snowflake.com/en/user-guide/query-size-limits
|
|
26
31
|
|
|
27
32
|
logger = logging.getLogger(__name__)
|
|
28
33
|
|
|
29
34
|
|
|
30
|
-
def
|
|
35
|
+
def save(
|
|
36
|
+
value: Any,
|
|
37
|
+
path: str,
|
|
38
|
+
session: Optional[snowpark.Session] = None,
|
|
39
|
+
max_inline_size: int = 0,
|
|
40
|
+
) -> Optional[bytes]:
|
|
31
41
|
"""
|
|
32
|
-
|
|
42
|
+
Serialize a value. Returns inline bytes if small enough, else writes to file.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
value: The value to serialize. If ExecutionResult, creates ResultDTO with success flag.
|
|
46
|
+
path: Full file path for writing the DTO (if needed). Protocol data saved to path's parent.
|
|
47
|
+
session: Snowpark session for stage operations.
|
|
48
|
+
max_inline_size: Max bytes for inline return. 0 = always write to file.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Encoded bytes if <= max_inline_size, else None (written to file).
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
Exception: If session validation fails during serialization.
|
|
33
55
|
"""
|
|
34
|
-
|
|
35
|
-
success=
|
|
36
|
-
value
|
|
37
|
-
|
|
56
|
+
if isinstance(value, ExecutionResult):
|
|
57
|
+
dto: PayloadDTO = ResultDTO(success=value.success, value=value.value)
|
|
58
|
+
raw_value = value.value
|
|
59
|
+
else:
|
|
60
|
+
dto = PayloadDTO(value=value)
|
|
61
|
+
raw_value = value
|
|
38
62
|
|
|
39
63
|
try:
|
|
40
|
-
|
|
41
|
-
payload = DEFAULT_CODEC.encode(result_dto)
|
|
64
|
+
payload = DEFAULT_CODEC.encode(dto)
|
|
42
65
|
except TypeError:
|
|
43
|
-
|
|
44
|
-
|
|
66
|
+
dto.value = None # Remove raw value to avoid serialization error
|
|
67
|
+
if isinstance(dto, ResultDTO):
|
|
68
|
+
# Metadata enables client fallback display when result can't be deserialized (protocol mismatch)..
|
|
69
|
+
dto.metadata = _get_metadata(raw_value)
|
|
45
70
|
try:
|
|
46
71
|
path_dir = PurePath(path).parent.as_posix()
|
|
47
|
-
protocol_info = DEFAULT_PROTOCOL.save(
|
|
48
|
-
|
|
72
|
+
protocol_info = DEFAULT_PROTOCOL.save(raw_value, path_dir, session=session)
|
|
73
|
+
dto.protocol = protocol_info
|
|
49
74
|
|
|
50
75
|
except Exception as e:
|
|
51
76
|
logger.warning(f"Error dumping result value: {repr(e)}")
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
77
|
+
# We handle serialization failures differently based on the DTO type:
|
|
78
|
+
# 1. Job Results (ResultDTO): Allow a "soft-fail."
|
|
79
|
+
# Since the job has already executed, we return the serialization error
|
|
80
|
+
# to the client so they can debug the output or update their protocol version.
|
|
81
|
+
# 2. Input Arguments: Trigger a "hard-fail."
|
|
82
|
+
# If arguments cannot be saved, the job script cannot run. We raise
|
|
83
|
+
# an immediate exception to prevent execution with invalid state.
|
|
84
|
+
if not isinstance(dto, ResultDTO):
|
|
85
|
+
raise
|
|
86
|
+
dto.serialize_error = repr(e)
|
|
87
|
+
|
|
88
|
+
# Encode the modified DTO
|
|
89
|
+
payload = DEFAULT_CODEC.encode(dto)
|
|
90
|
+
|
|
91
|
+
if not isinstance(dto, ResultDTO) and len(payload) <= max_inline_size:
|
|
92
|
+
return payload
|
|
56
93
|
|
|
57
94
|
with data_utils.open_stream(path, "wb", session=session) as stream:
|
|
58
95
|
stream.write(payload)
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
save_result = save # Backwards compatibility
|
|
59
100
|
|
|
60
101
|
|
|
61
|
-
def
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
102
|
+
def load(
|
|
103
|
+
path_or_data: str,
|
|
104
|
+
session: Optional[snowpark.Session] = None,
|
|
105
|
+
path_transform: Optional[Callable[[str], str]] = None,
|
|
106
|
+
) -> Any:
|
|
107
|
+
"""Load data from a file path or inline string."""
|
|
108
|
+
|
|
65
109
|
try:
|
|
66
|
-
with data_utils.open_stream(
|
|
110
|
+
with data_utils.open_stream(path_or_data, "r", session=session) as stream:
|
|
67
111
|
# Load the DTO as a dict for easy fallback to legacy loading if necessary
|
|
68
112
|
dto_dict = DEFAULT_CODEC.decode(stream, as_dict=True)
|
|
113
|
+
# the exception could be OSError or BlockingIOError(the file name is too long)
|
|
114
|
+
except OSError as e:
|
|
115
|
+
# path_or_data might be inline data
|
|
116
|
+
try:
|
|
117
|
+
dto_dict = DEFAULT_CODEC.decode(io.StringIO(path_or_data), as_dict=True)
|
|
118
|
+
except Exception:
|
|
119
|
+
raise e
|
|
69
120
|
except UnicodeDecodeError:
|
|
70
121
|
# Path may be a legacy result file (cloudpickle)
|
|
71
|
-
# TODO: Re-use the stream
|
|
72
122
|
assert session is not None
|
|
73
|
-
return legacy.load_legacy_result(session,
|
|
123
|
+
return legacy.load_legacy_result(session, path_or_data)
|
|
74
124
|
|
|
75
125
|
try:
|
|
76
|
-
dto =
|
|
126
|
+
dto = ResultDTOAdapter.validate_python(dto_dict)
|
|
77
127
|
except pydantic.ValidationError as e:
|
|
78
128
|
if "success" in dto_dict:
|
|
79
129
|
assert session is not None
|
|
80
|
-
if
|
|
81
|
-
|
|
82
|
-
return legacy.load_legacy_result(session,
|
|
130
|
+
if path_or_data.endswith(".json"):
|
|
131
|
+
path_or_data = os.path.splitext(path_or_data)[0] + ".pkl"
|
|
132
|
+
return legacy.load_legacy_result(session, path_or_data, result_json=dto_dict)
|
|
83
133
|
raise ValueError("Invalid result schema") from e
|
|
84
134
|
|
|
85
135
|
# Try loading data from file using the protocol info
|
|
86
|
-
|
|
136
|
+
payload_value = None
|
|
87
137
|
data_load_error = None
|
|
88
138
|
if dto.protocol is not None:
|
|
89
139
|
try:
|
|
90
140
|
logger.debug(f"Loading result value with protocol {dto.protocol}")
|
|
91
|
-
|
|
141
|
+
payload_value = DEFAULT_PROTOCOL.load(dto.protocol, session=session, path_transform=path_transform)
|
|
92
142
|
except sp_exceptions.SnowparkSQLException:
|
|
93
143
|
raise # Data retrieval errors should be bubbled up
|
|
94
144
|
except Exception as e:
|
|
95
145
|
logger.debug(f"Error loading result value with protocol {dto.protocol}: {repr(e)}")
|
|
146
|
+
# Error handling strategy depends on the DTO type:
|
|
147
|
+
# 1. ResultDTO: Soft-fail. The job has already finished.
|
|
148
|
+
# We package the load error into the result so the client can
|
|
149
|
+
# debug or adjust their protocol version to retrieve the output.
|
|
150
|
+
# 2. PayloadDTO : Raise a hard error. If arguments cannot be
|
|
151
|
+
# loaded, the job cannot run. We abort early to prevent execution.
|
|
152
|
+
if not isinstance(dto, ResultDTO):
|
|
153
|
+
raise
|
|
96
154
|
data_load_error = e
|
|
97
155
|
|
|
98
|
-
#
|
|
156
|
+
# Prepare to assemble the final result
|
|
157
|
+
if not isinstance(dto, ResultDTO):
|
|
158
|
+
return payload_value
|
|
159
|
+
|
|
99
160
|
if dto.serialize_error:
|
|
100
161
|
serialize_error = TypeError("Original result serialization failed with error: " + dto.serialize_error)
|
|
101
162
|
if data_load_error:
|
|
@@ -103,8 +164,7 @@ def load_result(
|
|
|
103
164
|
else:
|
|
104
165
|
data_load_error = serialize_error
|
|
105
166
|
|
|
106
|
-
|
|
107
|
-
result_value = result_value if result_value is not None else dto.value
|
|
167
|
+
result_value = payload_value if payload_value is not None else dto.value
|
|
108
168
|
if not dto.success and result_value is None:
|
|
109
169
|
# Try to reconstruct exception from metadata if available
|
|
110
170
|
if isinstance(dto.metadata, ExceptionMetadata):
|
|
@@ -115,7 +175,6 @@ def load_result(
|
|
|
115
175
|
traceback=dto.metadata.traceback,
|
|
116
176
|
original_repr=dto.metadata.repr,
|
|
117
177
|
)
|
|
118
|
-
|
|
119
178
|
# Generate a generic error if we still don't have a value,
|
|
120
179
|
# attaching the data load error if any
|
|
121
180
|
if result_value is None:
|
|
@@ -5,6 +5,7 @@ from snowflake.ml.jobs._utils.types import ComputeResources
|
|
|
5
5
|
DEFAULT_CONTAINER_NAME = "main"
|
|
6
6
|
MEMORY_VOLUME_NAME = "dshm"
|
|
7
7
|
STAGE_VOLUME_NAME = "stage-volume"
|
|
8
|
+
RESULT_VOLUME_NAME = "result-volume"
|
|
8
9
|
DEFAULT_PYTHON_VERSION = "3.10"
|
|
9
10
|
|
|
10
11
|
# Environment variables
|
|
@@ -109,3 +110,6 @@ CLOUD_INSTANCE_FAMILIES = {
|
|
|
109
110
|
SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES,
|
|
110
111
|
SnowflakeCloudType.GCP: GCP_INSTANCE_FAMILIES,
|
|
111
112
|
}
|
|
113
|
+
|
|
114
|
+
# Magic attributes
|
|
115
|
+
IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
|
|
@@ -1,6 +1,11 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from
|
|
3
|
-
|
|
2
|
+
from typing import Callable, Optional, Union
|
|
3
|
+
|
|
4
|
+
from snowflake.ml._internal.utils.snowflake_env import SnowflakeCloudType
|
|
5
|
+
from snowflake.snowpark import context as sp_context
|
|
6
|
+
|
|
7
|
+
# Default value type: can be a bool or a callable that returns a bool
|
|
8
|
+
DefaultValue = Union[bool, Callable[[], bool]]
|
|
4
9
|
|
|
5
10
|
|
|
6
11
|
def parse_bool_env_value(value: Optional[str], default: bool = False) -> bool:
|
|
@@ -28,22 +33,101 @@ def parse_bool_env_value(value: Optional[str], default: bool = False) -> bool:
|
|
|
28
33
|
return default
|
|
29
34
|
|
|
30
35
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
36
|
+
def _enabled_in_clouds(*clouds: SnowflakeCloudType) -> Callable[[], bool]:
|
|
37
|
+
"""Create a callable that checks if the current environment is in any of the specified clouds.
|
|
38
|
+
|
|
39
|
+
This factory function returns a callable that can be used as a dynamic default
|
|
40
|
+
for feature flags. The returned callable will check if the current Snowflake
|
|
41
|
+
session is connected to a region in any of the specified cloud providers.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
*clouds: One or more SnowflakeCloudType values to check against.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
A callable that returns True if running in any of the specified clouds,
|
|
48
|
+
False otherwise (including when no session is available).
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
>>> # Enable feature only in GCP
|
|
52
|
+
>>> default=_enabled_in_clouds(SnowflakeCloudType.GCP)
|
|
53
|
+
>>>
|
|
54
|
+
>>> # Enable feature in both GCP and Azure
|
|
55
|
+
>>> default=_enabled_in_clouds(SnowflakeCloudType.GCP, SnowflakeCloudType.AZURE)
|
|
56
|
+
"""
|
|
57
|
+
cloud_set = frozenset(clouds)
|
|
58
|
+
|
|
59
|
+
def check() -> bool:
|
|
60
|
+
try:
|
|
61
|
+
from snowflake.ml._internal.utils.snowflake_env import get_current_cloud
|
|
62
|
+
|
|
63
|
+
session = sp_context.get_active_session()
|
|
64
|
+
current_cloud = get_current_cloud(session, default=SnowflakeCloudType.AWS)
|
|
65
|
+
return current_cloud in cloud_set
|
|
66
|
+
except Exception:
|
|
67
|
+
# If we can't determine the cloud (no session, SQL error, etc.),
|
|
68
|
+
# default to False for safety
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
return check
|
|
35
72
|
|
|
36
|
-
|
|
37
|
-
|
|
73
|
+
|
|
74
|
+
class _FeatureFlag:
|
|
75
|
+
"""A feature flag backed by an environment variable with a configurable default.
|
|
76
|
+
|
|
77
|
+
The default value can be a constant boolean or a callable that dynamically
|
|
78
|
+
determines the default based on runtime context (e.g., cloud provider).
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(self, env_var: str, default: DefaultValue = False) -> None:
|
|
82
|
+
"""Initialize a feature flag.
|
|
38
83
|
|
|
39
84
|
Args:
|
|
40
|
-
|
|
85
|
+
env_var: The environment variable name that controls this flag.
|
|
86
|
+
default: The default value when the env var is not set. Can be:
|
|
87
|
+
- A boolean constant (True/False)
|
|
88
|
+
- A callable that returns a boolean (evaluated at check time)
|
|
89
|
+
"""
|
|
90
|
+
self._env_var = env_var
|
|
91
|
+
self._default = default
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def value(self) -> str:
|
|
95
|
+
"""Return the environment variable name (for compatibility with Enum-style access)."""
|
|
96
|
+
return self._env_var
|
|
97
|
+
|
|
98
|
+
def _get_default(self) -> bool:
|
|
99
|
+
"""Get the default value, calling it if it's a callable."""
|
|
100
|
+
if callable(self._default):
|
|
101
|
+
return self._default()
|
|
102
|
+
return self._default
|
|
103
|
+
|
|
104
|
+
def is_enabled(self) -> bool:
|
|
105
|
+
"""Check if the feature flag is enabled.
|
|
106
|
+
|
|
107
|
+
First checks the environment variable. If not set or unrecognized,
|
|
108
|
+
falls back to the configured default value.
|
|
41
109
|
|
|
42
110
|
Returns:
|
|
43
|
-
True if the
|
|
44
|
-
False if set to a falsy value, or the default value if not set.
|
|
111
|
+
True if the feature is enabled, False otherwise.
|
|
45
112
|
"""
|
|
46
|
-
|
|
113
|
+
env_value = os.getenv(self._env_var)
|
|
114
|
+
if env_value is not None:
|
|
115
|
+
# Environment variable is set, parse it
|
|
116
|
+
result = parse_bool_env_value(env_value, default=self._get_default())
|
|
117
|
+
return result
|
|
118
|
+
else:
|
|
119
|
+
# Environment variable not set, use the default
|
|
120
|
+
return self._get_default()
|
|
47
121
|
|
|
48
122
|
def __str__(self) -> str:
|
|
49
|
-
return self.
|
|
123
|
+
return self._env_var
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class FeatureFlags:
|
|
127
|
+
"""Collection of feature flags for ML Jobs."""
|
|
128
|
+
|
|
129
|
+
ENABLE_RUNTIME_VERSIONS = _FeatureFlag("MLRS_ENABLE_RUNTIME_VERSIONS", default=True)
|
|
130
|
+
ENABLE_STAGE_MOUNT_V2 = _FeatureFlag(
|
|
131
|
+
"MLRS_ENABLE_STAGE_MOUNT_V2",
|
|
132
|
+
default=_enabled_in_clouds(SnowflakeCloudType.GCP),
|
|
133
|
+
)
|