snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. snowflake/cortex/_complete.py +44 -10
  2. snowflake/ml/_internal/platform_capabilities.py +39 -3
  3. snowflake/ml/data/data_connector.py +25 -0
  4. snowflake/ml/dataset/dataset_reader.py +5 -1
  5. snowflake/ml/jobs/_utils/constants.py +2 -4
  6. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  7. snowflake/ml/jobs/_utils/payload_utils.py +81 -47
  8. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  9. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  10. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +178 -0
  11. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  12. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  13. snowflake/ml/jobs/_utils/spec_utils.py +5 -8
  14. snowflake/ml/jobs/_utils/types.py +6 -0
  15. snowflake/ml/jobs/decorators.py +3 -3
  16. snowflake/ml/jobs/job.py +145 -23
  17. snowflake/ml/jobs/manager.py +62 -10
  18. snowflake/ml/model/_client/ops/service_ops.py +42 -35
  19. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -4
  20. snowflake/ml/model/_client/sql/service.py +9 -5
  21. snowflake/ml/model/_model_composer/model_composer.py +29 -11
  22. snowflake/ml/model/_packager/model_env/model_env.py +8 -2
  23. snowflake/ml/model/_packager/model_meta/model_meta.py +6 -1
  24. snowflake/ml/model/_packager/model_packager.py +2 -0
  25. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  26. snowflake/ml/model/type_hints.py +2 -0
  27. snowflake/ml/registry/_manager/model_manager.py +20 -1
  28. snowflake/ml/registry/registry.py +5 -1
  29. snowflake/ml/version.py +1 -1
  30. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/METADATA +35 -4
  31. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/RECORD +34 -28
  32. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/WHEEL +0 -0
  33. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/licenses/LICENSE.txt +0 -0
  34. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,13 @@
1
1
  import json
2
2
  import logging
3
3
  import time
4
+ import typing
4
5
  from io import BytesIO
5
6
  from typing import Any, Callable, Dict, Iterator, List, Optional, TypedDict, Union, cast
6
7
  from urllib.parse import urlunparse
7
8
 
8
9
  import requests
10
+ from snowflake.core.rest import RESTResponse
9
11
  from typing_extensions import NotRequired, deprecated
10
12
 
11
13
  from snowflake import snowpark
@@ -72,6 +74,27 @@ class ResponseParseException(Exception):
72
74
  pass
73
75
 
74
76
 
77
+ class MidStreamException(Exception):
78
+ """The SSE (Server-sent Event) stream can contain error messages in the middle of the stream,
79
+ using the “error” event type. This exception is raised when there is such a mid-stream error."""
80
+
81
+ def __init__(
82
+ self,
83
+ reason: typing.Optional[str] = None,
84
+ http_resp: typing.Optional["RESTResponse"] = None,
85
+ request_id: typing.Optional[str] = None,
86
+ ) -> None:
87
+ message = ""
88
+ if reason is not None:
89
+ message = reason
90
+ if http_resp:
91
+ message = f"Error in stream (HTTP Response: {http_resp.status}) - {http_resp.reason}"
92
+ if request_id != "":
93
+ # add request_id to error message
94
+ message += f" (Request ID: {request_id})"
95
+ super().__init__(message)
96
+
97
+
75
98
  class GuardrailsOptions(TypedDict):
76
99
  enabled: bool
77
100
  """A boolean value that controls whether Cortex Guard filters unsafe or harmful responses
@@ -120,6 +143,18 @@ def _make_common_request_headers() -> Dict[str, str]:
120
143
  return headers
121
144
 
122
145
 
146
+ def _get_request_id(resp: Dict[str, Any]) -> Optional[Any]:
147
+ request_id = None
148
+ if "headers" in resp:
149
+ for key, value in resp["headers"].items():
150
+ # Note: There is some whitespace in the headers making it not possible
151
+ # to directly index the header reliably.
152
+ if key.strip().lower() == "x-snowflake-request-id":
153
+ request_id = value
154
+ break
155
+ return request_id
156
+
157
+
123
158
  def _validate_response_format_object(options: CompleteOptions) -> None:
124
159
  """Validate the response format object for structured-output mode.
125
160
 
@@ -188,13 +223,7 @@ def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
188
223
  response.status_code = int(raw_resp["status"])
189
224
  response.headers = raw_resp["headers"]
190
225
 
191
- request_id = None
192
- for key, value in raw_resp["headers"].items():
193
- # Note: there is some whitespace in the headers making it not possible
194
- # to directly index the header reliably.
195
- if key.strip().lower() == "x-snowflake-request-id":
196
- request_id = value
197
- break
226
+ request_id = _get_request_id(raw_resp)
198
227
 
199
228
  data = raw_resp["content"]
200
229
  try:
@@ -276,7 +305,12 @@ def _call_complete_rest(
276
305
  )
277
306
 
278
307
 
279
- def _return_stream_response(response: requests.Response, deadline: Optional[float]) -> Iterator[str]:
308
+ def _return_stream_response(
309
+ response: requests.Response,
310
+ deadline: Optional[float],
311
+ session: Optional[snowpark.Session] = None,
312
+ ) -> Iterator[str]:
313
+ request_id = _get_request_id(dict(response.headers))
280
314
  client = SSEClient(response)
281
315
  for event in client.events():
282
316
  if deadline is not None and time.time() > deadline:
@@ -294,7 +328,7 @@ def _return_stream_response(response: requests.Response, deadline: Optional[floa
294
328
  # This is the case of midstream errors which were introduced specifically for structured output.
295
329
  # TODO: discuss during code review
296
330
  if parsed_resp.get("error"):
297
- yield json.dumps(parsed_resp)
331
+ raise MidStreamException(reason=response.text, request_id=request_id)
298
332
  else:
299
333
  pass
300
334
 
@@ -375,7 +409,7 @@ def _complete_rest(
375
409
  else:
376
410
  response = _call_complete_rest(model=model, prompt=prompt, options=options, session=session, deadline=deadline)
377
411
  assert response.status_code >= 200 and response.status_code < 300
378
- return _return_stream_response(response, deadline)
412
+ return _return_stream_response(response, deadline, session)
379
413
 
380
414
 
381
415
  def _complete_impl(
@@ -1,4 +1,5 @@
1
1
  import json
2
+ from contextlib import contextmanager
2
3
  from typing import Any, Dict, Optional
3
4
 
4
5
  from absl import logging
@@ -27,16 +28,45 @@ class PlatformCapabilities:
27
28
  """
28
29
 
29
30
  _instance: Optional["PlatformCapabilities"] = None
31
+ # Used for unittesting only. This is to avoid the need to mock the session object or reaching out to Snowflake
32
+ _mock_features: Optional[Dict[str, Any]] = None
30
33
 
31
34
  @classmethod
32
35
  def get_instance(cls, session: Optional[snowpark_session.Session] = None) -> "PlatformCapabilities":
36
+ # Used for unittesting only. In this situation, _instance is not initialized.
37
+ if cls._mock_features is not None:
38
+ return cls(features=cls._mock_features)
33
39
  if not cls._instance:
34
- cls._instance = cls(session)
40
+ cls._instance = cls(session=session)
35
41
  return cls._instance
36
42
 
43
+ @classmethod
44
+ def set_mock_features(cls, features: Optional[Dict[str, Any]] = None) -> None:
45
+ cls._mock_features = features
46
+
47
+ @classmethod
48
+ def clear_mock_features(cls) -> None:
49
+ cls._mock_features = None
50
+
51
+ # For contextmanager, we need to have return type Iterator[Never]. However, Never type is introduced only in
52
+ # Python 3.11. So, we are ignoring the type for this method.
53
+ @classmethod # type: ignore[arg-type]
54
+ @contextmanager
55
+ def mock_features(cls, features: Dict[str, Any]) -> None: # type: ignore[misc]
56
+ logging.debug(f"Setting mock features: {features}")
57
+ cls.set_mock_features(features)
58
+ try:
59
+ yield
60
+ finally:
61
+ logging.debug(f"Clearing mock features: {features}")
62
+ cls.clear_mock_features()
63
+
37
64
  def is_nested_function_enabled(self) -> bool:
38
65
  return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
39
66
 
67
+ def is_inlined_deployment_spec_enabled(self) -> bool:
68
+ return self._get_bool_feature("ENABLE_INLINE_DEPLOYMENT_SPEC", False)
69
+
40
70
  def is_live_commit_enabled(self) -> bool:
41
71
  return self._get_bool_feature("ENABLE_BUNDLE_MODULE_CHECKOUT", False)
42
72
 
@@ -68,11 +98,17 @@ class PlatformCapabilities:
68
98
  # This can happen is server side is older than 9.2. That is fine.
69
99
  return {}
70
100
 
71
- def __init__(self, session: Optional[snowpark_session.Session] = None) -> None:
101
+ def __init__(
102
+ self, *, session: Optional[snowpark_session.Session] = None, features: Optional[Dict[str, Any]] = None
103
+ ) -> None:
104
+ # This is for testing purposes only.
105
+ if features:
106
+ self.features = features
107
+ return
72
108
  if not session:
73
109
  session = next(iter(snowpark_session._get_active_sessions()))
74
110
  assert session, "Missing active session object"
75
- self.features: Dict[str, Any] = PlatformCapabilities._get_features(session)
111
+ self.features = PlatformCapabilities._get_features(session)
76
112
 
77
113
  def _get_bool_feature(self, feature_name: str, default_value: bool) -> bool:
78
114
  value = self.features.get(feature_name, default_value)
@@ -27,6 +27,7 @@ from snowflake.snowpark import context as sf_context
27
27
 
28
28
  if TYPE_CHECKING:
29
29
  import pandas as pd
30
+ import ray
30
31
  import tensorflow as tf
31
32
  from torch.utils import data as torch_data
32
33
 
@@ -241,6 +242,30 @@ class DataConnector:
241
242
  """
242
243
  return self._ingestor.to_pandas(limit)
243
244
 
245
+ @telemetry.send_api_usage_telemetry(
246
+ project=_PROJECT,
247
+ subproject_extractor=lambda self: type(self).__name__,
248
+ func_params_to_log=["limit"],
249
+ )
250
+ def to_ray_dataset(self) -> "ray.data.Dataset":
251
+ """Retrieve the Snowflake data as a Ray Dataset.
252
+
253
+ Returns:
254
+ A Ray Dataset.
255
+
256
+ Raises:
257
+ ImportError: If Ray is not installed in the local environment.
258
+ """
259
+ if hasattr(self._ingestor, "to_ray_dataset"):
260
+ return self._ingestor.to_ray_dataset()
261
+
262
+ try:
263
+ import ray
264
+
265
+ return ray.data.from_pandas(self._ingestor.to_pandas())
266
+ except ImportError as e:
267
+ raise ImportError("Ray is not installed, please install ray in your local environment.") from e
268
+
244
269
 
245
270
  # Switch to use Runtime's Data Ingester if running in ML runtime
246
271
  # Fail silently if the data ingester is not found
@@ -5,6 +5,7 @@ from snowflake.ml._internal import telemetry
5
5
  from snowflake.ml._internal.lineage import lineage_utils
6
6
  from snowflake.ml.data import data_connector, data_ingestor, data_source, ingestor_utils
7
7
  from snowflake.ml.fileset import snowfs
8
+ from snowflake.snowpark._internal import utils as snowpark_utils
8
9
 
9
10
  _PROJECT = "Dataset"
10
11
  _SUBPROJECT = "DatasetReader"
@@ -94,7 +95,10 @@ class DatasetReader(data_connector.DataConnector):
94
95
  dfs: List[snowpark.DataFrame] = []
95
96
  for source in self.data_sources:
96
97
  assert isinstance(source, data_source.DatasetInfo) and source.url is not None
97
- df = self._session.read.option("pattern", file_path_pattern).parquet(source.url)
98
+ stage_reader = self._session.read.option("pattern", file_path_pattern)
99
+ if "INFER_SCHEMA_OPTIONS" in snowpark_utils.NON_FORMAT_TYPE_OPTIONS:
100
+ stage_reader = stage_reader.option("INFER_SCHEMA_OPTIONS", {"MAX_FILE_COUNT": 1})
101
+ df = stage_reader.parquet(source.url)
98
102
  if only_feature_cols and source.exclude_cols:
99
103
  df = df.drop(source.exclude_cols)
100
104
  dfs.append(df)
@@ -4,6 +4,7 @@ from snowflake.ml.jobs._utils.types import ComputeResources
4
4
  # SPCS specification constants
5
5
  DEFAULT_CONTAINER_NAME = "main"
6
6
  PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
7
+ RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
7
8
  MEMORY_VOLUME_NAME = "dshm"
8
9
  STAGE_VOLUME_NAME = "stage-volume"
9
10
  STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
@@ -18,10 +19,6 @@ DEFAULT_ENTRYPOINT_PATH = "func.py"
18
19
  # Percent of container memory to allocate for /dev/shm volume
19
20
  MEMORY_VOLUME_SIZE = 0.3
20
21
 
21
- # Multi Node Headless prototype constants
22
- # TODO: Replace this placeholder with the actual container runtime image tag.
23
- MULTINODE_HEADLESS_IMAGE_TAG = "latest"
24
-
25
22
  # Ray port configuration
26
23
  RAY_PORTS = {
27
24
  "HEAD_CLIENT_SERVER_PORT": "10001",
@@ -48,6 +45,7 @@ JOB_POLL_MAX_DELAY_SECONDS = 1
48
45
 
49
46
  # Magic attributes
50
47
  IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
48
+ RESULT_PATH_DEFAULT_VALUE = "mljob_result.pkl"
51
49
 
52
50
  # Compute pool resource information
53
51
  # TODO: Query Snowflake for resource information instead of relying on this hardcoded