snowflake-ml-python 1.8.0__py3-none-any.whl → 1.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +44 -10
- snowflake/ml/_internal/platform_capabilities.py +39 -3
- snowflake/ml/data/data_connector.py +25 -0
- snowflake/ml/dataset/dataset_reader.py +5 -1
- snowflake/ml/jobs/_utils/constants.py +3 -5
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +81 -47
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +178 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +27 -8
- snowflake/ml/jobs/_utils/types.py +6 -0
- snowflake/ml/jobs/decorators.py +10 -6
- snowflake/ml/jobs/job.py +145 -23
- snowflake/ml/jobs/manager.py +79 -12
- snowflake/ml/model/_client/ops/model_ops.py +6 -3
- snowflake/ml/model/_client/ops/service_ops.py +57 -39
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -4
- snowflake/ml/model/_client/sql/service.py +11 -5
- snowflake/ml/model/_model_composer/model_composer.py +29 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +1 -2
- snowflake/ml/model/_packager/model_env/model_env.py +8 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -4
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +6 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_packager.py +2 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +5 -1
- snowflake/ml/registry/_manager/model_manager.py +20 -1
- snowflake/ml/registry/registry.py +46 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/METADATA +55 -4
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/RECORD +40 -34
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/top_level.txt +0 -0
snowflake/cortex/_complete.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
3
|
import time
|
4
|
+
import typing
|
4
5
|
from io import BytesIO
|
5
6
|
from typing import Any, Callable, Dict, Iterator, List, Optional, TypedDict, Union, cast
|
6
7
|
from urllib.parse import urlunparse
|
7
8
|
|
8
9
|
import requests
|
10
|
+
from snowflake.core.rest import RESTResponse
|
9
11
|
from typing_extensions import NotRequired, deprecated
|
10
12
|
|
11
13
|
from snowflake import snowpark
|
@@ -72,6 +74,27 @@ class ResponseParseException(Exception):
|
|
72
74
|
pass
|
73
75
|
|
74
76
|
|
77
|
+
class MidStreamException(Exception):
|
78
|
+
"""The SSE (Server-sent Event) stream can contain error messages in the middle of the stream,
|
79
|
+
using the “error” event type. This exception is raised when there is such a mid-stream error."""
|
80
|
+
|
81
|
+
def __init__(
|
82
|
+
self,
|
83
|
+
reason: typing.Optional[str] = None,
|
84
|
+
http_resp: typing.Optional["RESTResponse"] = None,
|
85
|
+
request_id: typing.Optional[str] = None,
|
86
|
+
) -> None:
|
87
|
+
message = ""
|
88
|
+
if reason is not None:
|
89
|
+
message = reason
|
90
|
+
if http_resp:
|
91
|
+
message = f"Error in stream (HTTP Response: {http_resp.status}) - {http_resp.reason}"
|
92
|
+
if request_id != "":
|
93
|
+
# add request_id to error message
|
94
|
+
message += f" (Request ID: {request_id})"
|
95
|
+
super().__init__(message)
|
96
|
+
|
97
|
+
|
75
98
|
class GuardrailsOptions(TypedDict):
|
76
99
|
enabled: bool
|
77
100
|
"""A boolean value that controls whether Cortex Guard filters unsafe or harmful responses
|
@@ -120,6 +143,18 @@ def _make_common_request_headers() -> Dict[str, str]:
|
|
120
143
|
return headers
|
121
144
|
|
122
145
|
|
146
|
+
def _get_request_id(resp: Dict[str, Any]) -> Optional[Any]:
|
147
|
+
request_id = None
|
148
|
+
if "headers" in resp:
|
149
|
+
for key, value in resp["headers"].items():
|
150
|
+
# Note: There is some whitespace in the headers making it not possible
|
151
|
+
# to directly index the header reliably.
|
152
|
+
if key.strip().lower() == "x-snowflake-request-id":
|
153
|
+
request_id = value
|
154
|
+
break
|
155
|
+
return request_id
|
156
|
+
|
157
|
+
|
123
158
|
def _validate_response_format_object(options: CompleteOptions) -> None:
|
124
159
|
"""Validate the response format object for structured-output mode.
|
125
160
|
|
@@ -188,13 +223,7 @@ def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
|
|
188
223
|
response.status_code = int(raw_resp["status"])
|
189
224
|
response.headers = raw_resp["headers"]
|
190
225
|
|
191
|
-
request_id =
|
192
|
-
for key, value in raw_resp["headers"].items():
|
193
|
-
# Note: there is some whitespace in the headers making it not possible
|
194
|
-
# to directly index the header reliably.
|
195
|
-
if key.strip().lower() == "x-snowflake-request-id":
|
196
|
-
request_id = value
|
197
|
-
break
|
226
|
+
request_id = _get_request_id(raw_resp)
|
198
227
|
|
199
228
|
data = raw_resp["content"]
|
200
229
|
try:
|
@@ -276,7 +305,12 @@ def _call_complete_rest(
|
|
276
305
|
)
|
277
306
|
|
278
307
|
|
279
|
-
def _return_stream_response(
|
308
|
+
def _return_stream_response(
|
309
|
+
response: requests.Response,
|
310
|
+
deadline: Optional[float],
|
311
|
+
session: Optional[snowpark.Session] = None,
|
312
|
+
) -> Iterator[str]:
|
313
|
+
request_id = _get_request_id(dict(response.headers))
|
280
314
|
client = SSEClient(response)
|
281
315
|
for event in client.events():
|
282
316
|
if deadline is not None and time.time() > deadline:
|
@@ -294,7 +328,7 @@ def _return_stream_response(response: requests.Response, deadline: Optional[floa
|
|
294
328
|
# This is the case of midstream errors which were introduced specifically for structured output.
|
295
329
|
# TODO: discuss during code review
|
296
330
|
if parsed_resp.get("error"):
|
297
|
-
|
331
|
+
raise MidStreamException(reason=response.text, request_id=request_id)
|
298
332
|
else:
|
299
333
|
pass
|
300
334
|
|
@@ -375,7 +409,7 @@ def _complete_rest(
|
|
375
409
|
else:
|
376
410
|
response = _call_complete_rest(model=model, prompt=prompt, options=options, session=session, deadline=deadline)
|
377
411
|
assert response.status_code >= 200 and response.status_code < 300
|
378
|
-
return _return_stream_response(response, deadline)
|
412
|
+
return _return_stream_response(response, deadline, session)
|
379
413
|
|
380
414
|
|
381
415
|
def _complete_impl(
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import json
|
2
|
+
from contextlib import contextmanager
|
2
3
|
from typing import Any, Dict, Optional
|
3
4
|
|
4
5
|
from absl import logging
|
@@ -27,16 +28,45 @@ class PlatformCapabilities:
|
|
27
28
|
"""
|
28
29
|
|
29
30
|
_instance: Optional["PlatformCapabilities"] = None
|
31
|
+
# Used for unittesting only. This is to avoid the need to mock the session object or reaching out to Snowflake
|
32
|
+
_mock_features: Optional[Dict[str, Any]] = None
|
30
33
|
|
31
34
|
@classmethod
|
32
35
|
def get_instance(cls, session: Optional[snowpark_session.Session] = None) -> "PlatformCapabilities":
|
36
|
+
# Used for unittesting only. In this situation, _instance is not initialized.
|
37
|
+
if cls._mock_features is not None:
|
38
|
+
return cls(features=cls._mock_features)
|
33
39
|
if not cls._instance:
|
34
|
-
cls._instance = cls(session)
|
40
|
+
cls._instance = cls(session=session)
|
35
41
|
return cls._instance
|
36
42
|
|
43
|
+
@classmethod
|
44
|
+
def set_mock_features(cls, features: Optional[Dict[str, Any]] = None) -> None:
|
45
|
+
cls._mock_features = features
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def clear_mock_features(cls) -> None:
|
49
|
+
cls._mock_features = None
|
50
|
+
|
51
|
+
# For contextmanager, we need to have return type Iterator[Never]. However, Never type is introduced only in
|
52
|
+
# Python 3.11. So, we are ignoring the type for this method.
|
53
|
+
@classmethod # type: ignore[arg-type]
|
54
|
+
@contextmanager
|
55
|
+
def mock_features(cls, features: Dict[str, Any]) -> None: # type: ignore[misc]
|
56
|
+
logging.debug(f"Setting mock features: {features}")
|
57
|
+
cls.set_mock_features(features)
|
58
|
+
try:
|
59
|
+
yield
|
60
|
+
finally:
|
61
|
+
logging.debug(f"Clearing mock features: {features}")
|
62
|
+
cls.clear_mock_features()
|
63
|
+
|
37
64
|
def is_nested_function_enabled(self) -> bool:
|
38
65
|
return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
|
39
66
|
|
67
|
+
def is_inlined_deployment_spec_enabled(self) -> bool:
|
68
|
+
return self._get_bool_feature("ENABLE_INLINE_DEPLOYMENT_SPEC", False)
|
69
|
+
|
40
70
|
def is_live_commit_enabled(self) -> bool:
|
41
71
|
return self._get_bool_feature("ENABLE_BUNDLE_MODULE_CHECKOUT", False)
|
42
72
|
|
@@ -68,11 +98,17 @@ class PlatformCapabilities:
|
|
68
98
|
# This can happen is server side is older than 9.2. That is fine.
|
69
99
|
return {}
|
70
100
|
|
71
|
-
def __init__(
|
101
|
+
def __init__(
|
102
|
+
self, *, session: Optional[snowpark_session.Session] = None, features: Optional[Dict[str, Any]] = None
|
103
|
+
) -> None:
|
104
|
+
# This is for testing purposes only.
|
105
|
+
if features:
|
106
|
+
self.features = features
|
107
|
+
return
|
72
108
|
if not session:
|
73
109
|
session = next(iter(snowpark_session._get_active_sessions()))
|
74
110
|
assert session, "Missing active session object"
|
75
|
-
self.features
|
111
|
+
self.features = PlatformCapabilities._get_features(session)
|
76
112
|
|
77
113
|
def _get_bool_feature(self, feature_name: str, default_value: bool) -> bool:
|
78
114
|
value = self.features.get(feature_name, default_value)
|
@@ -27,6 +27,7 @@ from snowflake.snowpark import context as sf_context
|
|
27
27
|
|
28
28
|
if TYPE_CHECKING:
|
29
29
|
import pandas as pd
|
30
|
+
import ray
|
30
31
|
import tensorflow as tf
|
31
32
|
from torch.utils import data as torch_data
|
32
33
|
|
@@ -241,6 +242,30 @@ class DataConnector:
|
|
241
242
|
"""
|
242
243
|
return self._ingestor.to_pandas(limit)
|
243
244
|
|
245
|
+
@telemetry.send_api_usage_telemetry(
|
246
|
+
project=_PROJECT,
|
247
|
+
subproject_extractor=lambda self: type(self).__name__,
|
248
|
+
func_params_to_log=["limit"],
|
249
|
+
)
|
250
|
+
def to_ray_dataset(self) -> "ray.data.Dataset":
|
251
|
+
"""Retrieve the Snowflake data as a Ray Dataset.
|
252
|
+
|
253
|
+
Returns:
|
254
|
+
A Ray Dataset.
|
255
|
+
|
256
|
+
Raises:
|
257
|
+
ImportError: If Ray is not installed in the local environment.
|
258
|
+
"""
|
259
|
+
if hasattr(self._ingestor, "to_ray_dataset"):
|
260
|
+
return self._ingestor.to_ray_dataset()
|
261
|
+
|
262
|
+
try:
|
263
|
+
import ray
|
264
|
+
|
265
|
+
return ray.data.from_pandas(self._ingestor.to_pandas())
|
266
|
+
except ImportError as e:
|
267
|
+
raise ImportError("Ray is not installed, please install ray in your local environment.") from e
|
268
|
+
|
244
269
|
|
245
270
|
# Switch to use Runtime's Data Ingester if running in ML runtime
|
246
271
|
# Fail silently if the data ingester is not found
|
@@ -5,6 +5,7 @@ from snowflake.ml._internal import telemetry
|
|
5
5
|
from snowflake.ml._internal.lineage import lineage_utils
|
6
6
|
from snowflake.ml.data import data_connector, data_ingestor, data_source, ingestor_utils
|
7
7
|
from snowflake.ml.fileset import snowfs
|
8
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
8
9
|
|
9
10
|
_PROJECT = "Dataset"
|
10
11
|
_SUBPROJECT = "DatasetReader"
|
@@ -94,7 +95,10 @@ class DatasetReader(data_connector.DataConnector):
|
|
94
95
|
dfs: List[snowpark.DataFrame] = []
|
95
96
|
for source in self.data_sources:
|
96
97
|
assert isinstance(source, data_source.DatasetInfo) and source.url is not None
|
97
|
-
|
98
|
+
stage_reader = self._session.read.option("pattern", file_path_pattern)
|
99
|
+
if "INFER_SCHEMA_OPTIONS" in snowpark_utils.NON_FORMAT_TYPE_OPTIONS:
|
100
|
+
stage_reader = stage_reader.option("INFER_SCHEMA_OPTIONS", {"MAX_FILE_COUNT": 1})
|
101
|
+
df = stage_reader.parquet(source.url)
|
98
102
|
if only_feature_cols and source.exclude_cols:
|
99
103
|
df = df.drop(source.exclude_cols)
|
100
104
|
dfs.append(df)
|
@@ -4,6 +4,7 @@ from snowflake.ml.jobs._utils.types import ComputeResources
|
|
4
4
|
# SPCS specification constants
|
5
5
|
DEFAULT_CONTAINER_NAME = "main"
|
6
6
|
PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
7
|
+
RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
|
7
8
|
MEMORY_VOLUME_NAME = "dshm"
|
8
9
|
STAGE_VOLUME_NAME = "stage-volume"
|
9
10
|
STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
@@ -12,16 +13,12 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
|
12
13
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
13
14
|
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
14
15
|
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
15
|
-
DEFAULT_IMAGE_TAG = "0.
|
16
|
+
DEFAULT_IMAGE_TAG = "1.0.1"
|
16
17
|
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
17
18
|
|
18
19
|
# Percent of container memory to allocate for /dev/shm volume
|
19
20
|
MEMORY_VOLUME_SIZE = 0.3
|
20
21
|
|
21
|
-
# Multi Node Headless prototype constants
|
22
|
-
# TODO: Replace this placeholder with the actual container runtime image tag.
|
23
|
-
MULTINODE_HEADLESS_IMAGE_TAG = "latest"
|
24
|
-
|
25
22
|
# Ray port configuration
|
26
23
|
RAY_PORTS = {
|
27
24
|
"HEAD_CLIENT_SERVER_PORT": "10001",
|
@@ -48,6 +45,7 @@ JOB_POLL_MAX_DELAY_SECONDS = 1
|
|
48
45
|
|
49
46
|
# Magic attributes
|
50
47
|
IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
|
48
|
+
RESULT_PATH_DEFAULT_VALUE = "mljob_result.pkl"
|
51
49
|
|
52
50
|
# Compute pool resource information
|
53
51
|
# TODO: Query Snowflake for resource information instead of relying on this hardcoded
|