snowflake-ml-python 1.7.5__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +58 -3
- snowflake/ml/_internal/file_utils.py +18 -4
- snowflake/ml/_internal/platform_capabilities.py +3 -0
- snowflake/ml/_internal/telemetry.py +4 -0
- snowflake/ml/fileset/fileset.py +0 -1
- snowflake/ml/jobs/_utils/constants.py +24 -0
- snowflake/ml/jobs/_utils/payload_utils.py +94 -20
- snowflake/ml/jobs/_utils/spec_utils.py +73 -31
- snowflake/ml/jobs/decorators.py +3 -0
- snowflake/ml/jobs/manager.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/model_ops.py +107 -14
- snowflake/ml/model/_client/ops/service_ops.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
- snowflake/ml/model/_client/sql/model_version.py +58 -0
- snowflake/ml/model/_client/sql/service.py +8 -2
- snowflake/ml/model/_model_composer/model_composer.py +50 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
- snowflake/ml/model/_packager/model_env/model_env.py +4 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +28 -24
- snowflake/ml/model/_packager/model_handlers/keras.py +1 -5
- snowflake/ml/model/_packager/model_handlers/pytorch.py +50 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +46 -26
- snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
- snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -0
- snowflake/ml/model/_packager/model_packager.py +3 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -0
- snowflake/ml/model/_signatures/builtins_handler.py +20 -9
- snowflake/ml/model/_signatures/core.py +52 -31
- snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
- snowflake/ml/model/_signatures/numpy_handler.py +9 -17
- snowflake/ml/model/_signatures/pandas_handler.py +19 -30
- snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
- snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
- snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
- snowflake/ml/model/_signatures/utils.py +120 -8
- snowflake/ml/model/custom_model.py +13 -4
- snowflake/ml/model/model_signature.py +31 -13
- snowflake/ml/model/type_hints.py +13 -2
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +55 -7
- snowflake/ml/registry/registry.py +18 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +287 -11
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +61 -57
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
snowflake/cortex/_complete.py
CHANGED
@@ -23,6 +23,15 @@ logger = logging.getLogger(__name__)
|
|
23
23
|
_REST_COMPLETE_URL = "/api/v2/cortex/inference:complete"
|
24
24
|
|
25
25
|
|
26
|
+
class ResponseFormat(TypedDict):
|
27
|
+
"""Represents an object describing response format config for structured-output mode"""
|
28
|
+
|
29
|
+
type: str
|
30
|
+
"""The response format type (e.g. "json")"""
|
31
|
+
schema: Dict[str, Any]
|
32
|
+
"""The schema defining the structure of the response. For json it should be a valid json schema object"""
|
33
|
+
|
34
|
+
|
26
35
|
class ConversationMessage(TypedDict):
|
27
36
|
"""Represents an conversation interaction."""
|
28
37
|
|
@@ -53,6 +62,9 @@ class CompleteOptions(TypedDict):
|
|
53
62
|
""" A boolean value that controls whether Cortex Guard filters unsafe or harmful responses
|
54
63
|
from the language model. """
|
55
64
|
|
65
|
+
response_format: NotRequired[ResponseFormat]
|
66
|
+
""" An object describing response format config for structured-output mode """
|
67
|
+
|
56
68
|
|
57
69
|
class ResponseParseException(Exception):
|
58
70
|
"""This exception is raised when the server response cannot be parsed."""
|
@@ -108,6 +120,32 @@ def _make_common_request_headers() -> Dict[str, str]:
|
|
108
120
|
return headers
|
109
121
|
|
110
122
|
|
123
|
+
def _validate_response_format_object(options: CompleteOptions) -> None:
|
124
|
+
"""Validate the response format object for structured-output mode.
|
125
|
+
|
126
|
+
More details can be found in:
|
127
|
+
docs.snowflake.com/en/user-guide/snowflake-cortex/complete-structured-outputs#using-complete-structured-outputs
|
128
|
+
|
129
|
+
Args:
|
130
|
+
options: The complete options object.
|
131
|
+
|
132
|
+
Raises:
|
133
|
+
ValueError: If the response format object is invalid or missing required fields.
|
134
|
+
"""
|
135
|
+
if options is not None and options.get("response_format") is not None:
|
136
|
+
options_obj = options.get("response_format")
|
137
|
+
if not isinstance(options_obj, dict):
|
138
|
+
raise ValueError("'response_format' should be an object")
|
139
|
+
if options_obj.get("type") is None:
|
140
|
+
raise ValueError("'type' cannot be empty for 'response_format' object")
|
141
|
+
if not isinstance(options_obj.get("type"), str):
|
142
|
+
raise ValueError("'type' needs to be a str for 'response_format' object")
|
143
|
+
if options_obj.get("schema") is None:
|
144
|
+
raise ValueError("'schema' cannot be empty for 'response_format' object")
|
145
|
+
if not isinstance(options_obj.get("schema"), dict):
|
146
|
+
raise ValueError("'schema' needs to be a dict for 'response_format' object")
|
147
|
+
|
148
|
+
|
111
149
|
def _make_request_body(
|
112
150
|
model: str,
|
113
151
|
prompt: Union[str, List[ConversationMessage]],
|
@@ -136,12 +174,16 @@ def _make_request_body(
|
|
136
174
|
"response_when_unsafe": "Response filtered by Cortex Guard",
|
137
175
|
}
|
138
176
|
data["guardrails"] = guardrails_options
|
177
|
+
if "response_format" in options:
|
178
|
+
data["response_format"] = options["response_format"]
|
179
|
+
|
139
180
|
return data
|
140
181
|
|
141
182
|
|
142
183
|
# XP endpoint returns a dict response which needs to be converted to a format which can
|
143
184
|
# be consumed by the SSEClient. This method does that.
|
144
185
|
def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
|
186
|
+
|
145
187
|
response = requests.Response()
|
146
188
|
response.status_code = int(raw_resp["status"])
|
147
189
|
response.headers = raw_resp["headers"]
|
@@ -159,7 +201,6 @@ def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
|
|
159
201
|
data = json.loads(data)
|
160
202
|
except json.JSONDecodeError:
|
161
203
|
raise ValueError(f"Request failed (request id: {request_id})")
|
162
|
-
|
163
204
|
if response.status_code < 200 or response.status_code >= 300:
|
164
205
|
if "message" not in data:
|
165
206
|
raise ValueError(f"Request failed (request id: {request_id})")
|
@@ -241,11 +282,21 @@ def _return_stream_response(response: requests.Response, deadline: Optional[floa
|
|
241
282
|
if deadline is not None and time.time() > deadline:
|
242
283
|
raise TimeoutError()
|
243
284
|
try:
|
244
|
-
|
285
|
+
parsed_resp = json.loads(event.data)
|
286
|
+
except json.JSONDecodeError:
|
287
|
+
raise ResponseParseException("Server response cannot be parsed")
|
288
|
+
try:
|
289
|
+
yield parsed_resp["choices"][0]["delta"]["content"]
|
245
290
|
except (json.JSONDecodeError, KeyError, IndexError):
|
246
291
|
# For the sake of evolution of the output format,
|
247
292
|
# ignore stream messages that don't match the expected format.
|
248
|
-
|
293
|
+
|
294
|
+
# This is the case of midstream errors which were introduced specifically for structured output.
|
295
|
+
# TODO: discuss during code review
|
296
|
+
if parsed_resp.get("error"):
|
297
|
+
yield json.dumps(parsed_resp)
|
298
|
+
else:
|
299
|
+
pass
|
249
300
|
|
250
301
|
|
251
302
|
def _complete_call_sql_function_snowpark(
|
@@ -291,6 +342,8 @@ def _complete_non_streaming_impl(
|
|
291
342
|
raise ValueError("'model' cannot be a snowpark.Column when 'prompt' is a string.")
|
292
343
|
if isinstance(options, snowpark.Column):
|
293
344
|
raise ValueError("'options' cannot be a snowpark.Column when 'prompt' is a string.")
|
345
|
+
if options and not isinstance(options, snowpark.Column):
|
346
|
+
_validate_response_format_object(options)
|
294
347
|
return _complete_non_streaming_immediate(
|
295
348
|
snow_api_xp_request_handler=snow_api_xp_request_handler,
|
296
349
|
model=model,
|
@@ -309,6 +362,8 @@ def _complete_rest(
|
|
309
362
|
session: Optional[snowpark.Session] = None,
|
310
363
|
deadline: Optional[float] = None,
|
311
364
|
) -> Iterator[str]:
|
365
|
+
if options:
|
366
|
+
_validate_response_format_object(options)
|
312
367
|
if snow_api_xp_request_handler is not None:
|
313
368
|
response = _call_complete_xp(
|
314
369
|
snow_api_xp_request_handler=snow_api_xp_request_handler,
|
@@ -23,6 +23,7 @@ from typing import (
|
|
23
23
|
Tuple,
|
24
24
|
Union,
|
25
25
|
)
|
26
|
+
from urllib import parse
|
26
27
|
|
27
28
|
import cloudpickle
|
28
29
|
|
@@ -294,7 +295,7 @@ def _retry_on_sql_error(exception: Exception) -> bool:
|
|
294
295
|
def upload_directory_to_stage(
|
295
296
|
session: snowpark.Session,
|
296
297
|
local_path: pathlib.Path,
|
297
|
-
stage_path: pathlib.PurePosixPath,
|
298
|
+
stage_path: Union[pathlib.PurePosixPath, parse.ParseResult],
|
298
299
|
*,
|
299
300
|
statement_params: Optional[Dict[str, Any]] = None,
|
300
301
|
) -> None:
|
@@ -314,9 +315,22 @@ def upload_directory_to_stage(
|
|
314
315
|
root_path = pathlib.Path(root)
|
315
316
|
for filename in filenames:
|
316
317
|
local_file_path = root_path / filename
|
317
|
-
|
318
|
-
|
319
|
-
)
|
318
|
+
relative_path = pathlib.PurePosixPath(local_file_path.relative_to(local_path).as_posix())
|
319
|
+
|
320
|
+
if isinstance(stage_path, parse.ParseResult):
|
321
|
+
relative_stage_path = (pathlib.PosixPath(stage_path.path) / relative_path).parent
|
322
|
+
new_url = parse.ParseResult(
|
323
|
+
scheme=stage_path.scheme,
|
324
|
+
netloc=stage_path.netloc,
|
325
|
+
path=str(relative_stage_path),
|
326
|
+
params=stage_path.params,
|
327
|
+
query=stage_path.query,
|
328
|
+
fragment=stage_path.fragment,
|
329
|
+
)
|
330
|
+
stage_dir_path = parse.urlunparse(new_url)
|
331
|
+
else:
|
332
|
+
stage_dir_path = str((stage_path / relative_path).parent)
|
333
|
+
|
320
334
|
retrying.retry(
|
321
335
|
retry_on_exception=_retry_on_sql_error,
|
322
336
|
stop_max_attempt_number=5,
|
@@ -37,6 +37,9 @@ class PlatformCapabilities:
|
|
37
37
|
def is_nested_function_enabled(self) -> bool:
|
38
38
|
return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
|
39
39
|
|
40
|
+
def is_live_commit_enabled(self) -> bool:
|
41
|
+
return self._get_bool_feature("ENABLE_BUNDLE_MODULE_CHECKOUT", False)
|
42
|
+
|
40
43
|
@staticmethod
|
41
44
|
def _get_features(session: snowpark_session.Session) -> Dict[str, Any]:
|
42
45
|
try:
|
@@ -353,6 +353,10 @@ def get_function_usage_statement_params(
|
|
353
353
|
statement_params[TelemetryField.KEY_API_CALLS.value].append({TelemetryField.NAME.value: api_call})
|
354
354
|
if custom_tags:
|
355
355
|
statement_params[TelemetryField.KEY_CUSTOM_TAGS.value] = custom_tags
|
356
|
+
# Snowpark doesn't support None value in statement_params from version 1.29
|
357
|
+
for k in statement_params:
|
358
|
+
if statement_params[k] is None:
|
359
|
+
statement_params[k] = ""
|
356
360
|
return statement_params
|
357
361
|
|
358
362
|
|
snowflake/ml/fileset/fileset.py
CHANGED
@@ -257,7 +257,6 @@ class FileSet:
|
|
257
257
|
function_name=telemetry.get_statement_params_full_func_name(
|
258
258
|
inspect.currentframe(), cls.__class__.__name__
|
259
259
|
),
|
260
|
-
api_calls=[snowpark.DataFrameWriter.copy_into_location],
|
261
260
|
),
|
262
261
|
)
|
263
262
|
except snowpark_exceptions.SnowparkSQLException as e:
|
@@ -18,6 +18,30 @@ DEFAULT_ENTRYPOINT_PATH = "func.py"
|
|
18
18
|
# Percent of container memory to allocate for /dev/shm volume
|
19
19
|
MEMORY_VOLUME_SIZE = 0.3
|
20
20
|
|
21
|
+
# Multi Node Headless prototype constants
|
22
|
+
# TODO: Replace this placeholder with the actual container runtime image tag.
|
23
|
+
MULTINODE_HEADLESS_IMAGE_TAG = "latest"
|
24
|
+
|
25
|
+
# Ray port configuration
|
26
|
+
RAY_PORTS = {
|
27
|
+
"HEAD_CLIENT_SERVER_PORT": "10001",
|
28
|
+
"HEAD_GCS_PORT": "12001",
|
29
|
+
"HEAD_DASHBOARD_GRPC_PORT": "12002",
|
30
|
+
"HEAD_DASHBOARD_PORT": "12003",
|
31
|
+
"OBJECT_MANAGER_PORT": "12011",
|
32
|
+
"NODE_MANAGER_PORT": "12012",
|
33
|
+
"RUNTIME_ENV_AGENT_PORT": "12013",
|
34
|
+
"DASHBOARD_AGENT_GRPC_PORT": "12014",
|
35
|
+
"DASHBOARD_AGENT_LISTEN_PORT": "12015",
|
36
|
+
"MIN_WORKER_PORT": "12031",
|
37
|
+
"MAX_WORKER_PORT": "13000",
|
38
|
+
}
|
39
|
+
|
40
|
+
# Node health check configuration
|
41
|
+
# TODO(SNOW-1937020): Revisit the health check configuration
|
42
|
+
ML_RUNTIME_HEALTH_CHECK_PORT = "5001"
|
43
|
+
ENABLE_HEALTH_CHECKS = "false"
|
44
|
+
|
21
45
|
# Job status polling constants
|
22
46
|
JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
|
23
47
|
JOB_POLL_MAX_DELAY_SECONDS = 1
|
@@ -73,8 +73,17 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
73
73
|
##### Ray configuration #####
|
74
74
|
shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
|
75
75
|
|
76
|
+
# Check if the instance ip retrieval module exists, which is a prerequisite for multi node jobs
|
77
|
+
HELPER_EXISTS=$(
|
78
|
+
python3 -c "import snowflake.runtime.utils.get_instance_ip" 2>/dev/null && echo "true" || echo "false"
|
79
|
+
)
|
80
|
+
|
76
81
|
# Configure IP address and logging directory
|
77
|
-
|
82
|
+
if [ "$HELPER_EXISTS" = "true" ]; then
|
83
|
+
eth0Ip=$(python3 -m snowflake.runtime.utils.get_instance_ip "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
|
84
|
+
else
|
85
|
+
eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
|
86
|
+
fi
|
78
87
|
log_dir="/tmp/ray"
|
79
88
|
|
80
89
|
# Check if eth0Ip is a valid IP address and fall back to default if necessary
|
@@ -82,6 +91,38 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
82
91
|
eth0Ip="127.0.0.1"
|
83
92
|
fi
|
84
93
|
|
94
|
+
# Get the environment values of SNOWFLAKE_JOBS_COUNT and SNOWFLAKE_JOB_INDEX for batch jobs
|
95
|
+
# These variables don't exist for non-batch jobs, so set defaults
|
96
|
+
if [ -z "$SNOWFLAKE_JOBS_COUNT" ]; then
|
97
|
+
SNOWFLAKE_JOBS_COUNT=1
|
98
|
+
fi
|
99
|
+
|
100
|
+
if [ -z "$SNOWFLAKE_JOB_INDEX" ]; then
|
101
|
+
SNOWFLAKE_JOB_INDEX=0
|
102
|
+
fi
|
103
|
+
|
104
|
+
# Determine if it should be a worker or a head node for batch jobs
|
105
|
+
if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 && "$HELPER_EXISTS" = "true" ]]; then
|
106
|
+
head_info=$(python3 -m snowflake.runtime.utils.get_instance_ip "$SNOWFLAKE_SERVICE_NAME" --head)
|
107
|
+
if [ $? -eq 0 ]; then
|
108
|
+
# Parse the output using read
|
109
|
+
read head_index head_ip <<< "$head_info"
|
110
|
+
|
111
|
+
# Use the parsed variables
|
112
|
+
echo "Head Instance Index: $head_index"
|
113
|
+
echo "Head Instance IP: $head_ip"
|
114
|
+
|
115
|
+
else
|
116
|
+
echo "Error: Failed to get head instance information."
|
117
|
+
echo "$head_info" # Print the error message
|
118
|
+
exit 1
|
119
|
+
fi
|
120
|
+
|
121
|
+
if [ "$SNOWFLAKE_JOB_INDEX" -ne "$head_index" ]; then
|
122
|
+
NODE_TYPE="worker"
|
123
|
+
fi
|
124
|
+
fi
|
125
|
+
|
85
126
|
# Common parameters for both head and worker nodes
|
86
127
|
common_params=(
|
87
128
|
"--node-ip-address=$eth0Ip"
|
@@ -97,29 +138,62 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
97
138
|
"--disable-usage-stats"
|
98
139
|
)
|
99
140
|
|
100
|
-
|
101
|
-
|
102
|
-
"
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
141
|
+
if [ "$NODE_TYPE" = "worker" ]; then
|
142
|
+
# Use head_ip as head address if it exists
|
143
|
+
if [ ! -z "$head_ip" ]; then
|
144
|
+
RAY_HEAD_ADDRESS="$head_ip"
|
145
|
+
fi
|
146
|
+
|
147
|
+
# If RAY_HEAD_ADDRESS is still empty, exit with an error
|
148
|
+
if [ -z "$RAY_HEAD_ADDRESS" ]; then
|
149
|
+
echo "Error: Failed to determine head node address using default instance-index=0"
|
150
|
+
exit 1
|
151
|
+
fi
|
152
|
+
|
153
|
+
if [ -z "$SERVICE_NAME" ]; then
|
154
|
+
SERVICE_NAME="$SNOWFLAKE_SERVICE_NAME"
|
155
|
+
fi
|
156
|
+
|
157
|
+
if [ -z "$RAY_HEAD_ADDRESS" ] || [ -z "$SERVICE_NAME" ]; then
|
158
|
+
echo "Error: RAY_HEAD_ADDRESS and SERVICE_NAME must be set."
|
159
|
+
exit 1
|
160
|
+
fi
|
161
|
+
|
162
|
+
# Additional worker-specific parameters
|
163
|
+
worker_params=(
|
164
|
+
"--address=${{RAY_HEAD_ADDRESS}}:12001" # Connect to head node
|
165
|
+
"--resources={{\\"${{SERVICE_NAME}}\\":1, \\"node_tag:worker\\":1}}" # Tag for node identification
|
166
|
+
"--object-store-memory=${{shm_size}}"
|
167
|
+
)
|
110
168
|
|
111
|
-
|
112
|
-
|
113
|
-
|
169
|
+
# Start Ray on a worker node
|
170
|
+
ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block
|
171
|
+
else
|
172
|
+
|
173
|
+
# Additional head-specific parameters
|
174
|
+
head_params=(
|
175
|
+
"--head"
|
176
|
+
"--port=${{RAY_HEAD_GCS_PORT:-12001}}" # Port of Ray (GCS server)
|
177
|
+
"--ray-client-server-port=${{RAY_HEAD_CLIENT_SERVER_PORT:-10001}}" # Rort for Ray Client Server
|
178
|
+
"--dashboard-host=${{NODE_IP_ADDRESS}}" # Host to bind the dashboard server
|
179
|
+
"--dashboard-grpc-port=${{RAY_HEAD_DASHBOARD_GRPC_PORT:-12002}}" # Dashboard head to listen for grpc
|
180
|
+
"--dashboard-port=${{DASHBOARD_PORT}}" # Port to bind the dashboard server for debugging
|
181
|
+
"--resources={{\\"node_tag:head\\":1}}" # Resource tag for selecting head as coordinator
|
182
|
+
)
|
183
|
+
|
184
|
+
# Start Ray on the head node
|
185
|
+
ray start "${{common_params[@]}}" "${{head_params[@]}}" -v
|
186
|
+
##### End Ray configuration #####
|
114
187
|
|
115
|
-
|
116
|
-
|
188
|
+
# TODO: Monitor MLRS and handle process crashes
|
189
|
+
python -m web.ml_runtime_grpc_server &
|
117
190
|
|
118
|
-
|
191
|
+
# TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
|
119
192
|
|
120
|
-
|
121
|
-
|
122
|
-
|
193
|
+
# Run user's Python entrypoint
|
194
|
+
echo Running command: python "$@"
|
195
|
+
python "$@"
|
196
|
+
fi
|
123
197
|
"""
|
124
198
|
).strip()
|
125
199
|
|
@@ -26,19 +26,22 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
|
|
26
26
|
)
|
27
27
|
|
28
28
|
|
29
|
-
def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
|
29
|
+
def _get_image_spec(session: snowpark.Session, compute_pool: str, image_tag: Optional[str] = None) -> types.ImageSpec:
|
30
30
|
# Retrieve compute pool node resources
|
31
31
|
resources = _get_node_resources(session, compute_pool=compute_pool)
|
32
32
|
|
33
33
|
# Use MLRuntime image
|
34
34
|
image_repo = constants.DEFAULT_IMAGE_REPO
|
35
35
|
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
36
|
-
image_tag = constants.DEFAULT_IMAGE_TAG
|
37
36
|
|
38
37
|
# Try to pull latest image tag from server side if possible
|
39
|
-
|
40
|
-
|
41
|
-
|
38
|
+
if not image_tag:
|
39
|
+
query_result = session.sql("SHOW PARAMETERS LIKE 'constants.RUNTIME_BASE_IMAGE_TAG' IN ACCOUNT").collect()
|
40
|
+
if query_result:
|
41
|
+
image_tag = query_result[0]["value"]
|
42
|
+
|
43
|
+
if image_tag is None:
|
44
|
+
image_tag = constants.DEFAULT_IMAGE_TAG
|
42
45
|
|
43
46
|
# TODO: Should each instance consume the entire pod?
|
44
47
|
return types.ImageSpec(
|
@@ -93,6 +96,7 @@ def generate_service_spec(
|
|
93
96
|
compute_pool: str,
|
94
97
|
payload: types.UploadedPayload,
|
95
98
|
args: Optional[List[str]] = None,
|
99
|
+
num_instances: Optional[int] = None,
|
96
100
|
) -> Dict[str, Any]:
|
97
101
|
"""
|
98
102
|
Generate a service specification for a job.
|
@@ -102,12 +106,21 @@ def generate_service_spec(
|
|
102
106
|
compute_pool: Compute pool for job execution
|
103
107
|
payload: Uploaded job payload
|
104
108
|
args: Arguments to pass to entrypoint script
|
109
|
+
num_instances: Number of instances for multi-node job
|
105
110
|
|
106
111
|
Returns:
|
107
112
|
Job service specification
|
108
113
|
"""
|
114
|
+
is_multi_node = num_instances is not None and num_instances > 1
|
115
|
+
|
109
116
|
# Set resource requests/limits, including nvidia.com/gpu quantity if applicable
|
110
|
-
|
117
|
+
if is_multi_node:
|
118
|
+
# If the job is of multi-node, we will need a different image which contains
|
119
|
+
# module snowflake.runtime.utils.get_instance_ip
|
120
|
+
# TODO(SNOW-1961849): Remove the hard-coded image name
|
121
|
+
image_spec = _get_image_spec(session, compute_pool, constants.MULTINODE_HEADLESS_IMAGE_TAG)
|
122
|
+
else:
|
123
|
+
image_spec = _get_image_spec(session, compute_pool)
|
111
124
|
resource_requests: Dict[str, Union[str, int]] = {
|
112
125
|
"cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
|
113
126
|
"memory": f"{image_spec.resource_limits.memory}Gi",
|
@@ -176,31 +189,53 @@ def generate_service_spec(
|
|
176
189
|
|
177
190
|
# TODO: Add hooks for endpoints for integration with TensorBoard etc
|
178
191
|
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
192
|
+
env_vars = {constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix()}
|
193
|
+
endpoints = []
|
194
|
+
|
195
|
+
if is_multi_node:
|
196
|
+
# Update environment variables for multi-node job
|
197
|
+
env_vars.update(constants.RAY_PORTS)
|
198
|
+
env_vars["ENABLE_HEALTH_CHECKS"] = constants.ENABLE_HEALTH_CHECKS
|
199
|
+
|
200
|
+
# Define Ray endpoints for intra-service instance communication
|
201
|
+
ray_endpoints = [
|
202
|
+
{"name": "ray-client-server-endpoint", "port": 10001, "protocol": "TCP"},
|
203
|
+
{"name": "ray-gcs-endpoint", "port": 12001, "protocol": "TCP"},
|
204
|
+
{"name": "ray-dashboard-grpc-endpoint", "port": 12002, "protocol": "TCP"},
|
205
|
+
{"name": "ray-object-manager-endpoint", "port": 12011, "protocol": "TCP"},
|
206
|
+
{"name": "ray-node-manager-endpoint", "port": 12012, "protocol": "TCP"},
|
207
|
+
{"name": "ray-runtime-agent-endpoint", "port": 12013, "protocol": "TCP"},
|
208
|
+
{"name": "ray-dashboard-agent-grpc-endpoint", "port": 12014, "protocol": "TCP"},
|
209
|
+
{"name": "ephemeral-port-range", "portRange": "32768-60999", "protocol": "TCP"},
|
210
|
+
{"name": "ray-worker-port-range", "portRange": "12031-13000", "protocol": "TCP"},
|
211
|
+
]
|
212
|
+
endpoints.extend(ray_endpoints)
|
213
|
+
|
214
|
+
spec_dict = {
|
215
|
+
"containers": [
|
216
|
+
{
|
217
|
+
"name": constants.DEFAULT_CONTAINER_NAME,
|
218
|
+
"image": image_spec.full_name,
|
219
|
+
"command": ["/usr/local/bin/_entrypoint.sh"],
|
220
|
+
"args": [
|
221
|
+
(stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
|
222
|
+
]
|
223
|
+
+ (args or []),
|
224
|
+
"env": env_vars,
|
225
|
+
"volumeMounts": volume_mounts,
|
226
|
+
"resources": {
|
227
|
+
"requests": resource_requests,
|
228
|
+
"limits": resource_limits,
|
199
229
|
},
|
200
|
-
|
201
|
-
|
202
|
-
|
230
|
+
},
|
231
|
+
],
|
232
|
+
"volumes": volumes,
|
203
233
|
}
|
234
|
+
if endpoints:
|
235
|
+
spec_dict["endpoints"] = endpoints
|
236
|
+
|
237
|
+
# Assemble into service specification dict
|
238
|
+
spec = {"spec": spec_dict}
|
204
239
|
|
205
240
|
return spec
|
206
241
|
|
@@ -248,7 +283,10 @@ def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
|
|
248
283
|
|
249
284
|
|
250
285
|
def _merge_lists_of_dicts(
|
251
|
-
base: List[Dict[str, Any]],
|
286
|
+
base: List[Dict[str, Any]],
|
287
|
+
patch: List[Dict[str, Any]],
|
288
|
+
merge_key: str = "name",
|
289
|
+
display_name: str = "",
|
252
290
|
) -> List[Dict[str, Any]]:
|
253
291
|
"""
|
254
292
|
Attempts to merge lists of dicts by matching on a merge key (default "name").
|
@@ -288,7 +326,11 @@ def _merge_lists_of_dicts(
|
|
288
326
|
|
289
327
|
# Apply patch
|
290
328
|
if key in result:
|
291
|
-
d = merge_patch(
|
329
|
+
d = merge_patch(
|
330
|
+
result[key],
|
331
|
+
d,
|
332
|
+
display_name=f"{display_name}[{merge_key}={d[merge_key]}]",
|
333
|
+
)
|
292
334
|
# TODO: Should we drop the item if the patch result is empty save for the merge key?
|
293
335
|
# Can check `d.keys() <= {merge_key}`
|
294
336
|
result[key] = d
|
snowflake/ml/jobs/decorators.py
CHANGED
@@ -25,6 +25,7 @@ def remote(
|
|
25
25
|
query_warehouse: Optional[str] = None,
|
26
26
|
env_vars: Optional[Dict[str, str]] = None,
|
27
27
|
session: Optional[snowpark.Session] = None,
|
28
|
+
num_instances: Optional[int] = None,
|
28
29
|
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob]]:
|
29
30
|
"""
|
30
31
|
Submit a job to the compute pool.
|
@@ -37,6 +38,7 @@ def remote(
|
|
37
38
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
38
39
|
env_vars: Environment variables to set in container
|
39
40
|
session: The Snowpark session to use. If none specified, uses active session.
|
41
|
+
num_instances: The number of nodes in the job. If none specified, create a single node job.
|
40
42
|
|
41
43
|
Returns:
|
42
44
|
Decorator that dispatches invocations of the decorated function as remote jobs.
|
@@ -62,6 +64,7 @@ def remote(
|
|
62
64
|
query_warehouse=query_warehouse,
|
63
65
|
env_vars=env_vars,
|
64
66
|
session=session,
|
67
|
+
num_instances=num_instances,
|
65
68
|
)
|
66
69
|
assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
|
67
70
|
return job
|
snowflake/ml/jobs/manager.py
CHANGED
@@ -213,6 +213,7 @@ def _submit_job(
|
|
213
213
|
query_warehouse: Optional[str] = None,
|
214
214
|
spec_overrides: Optional[Dict[str, Any]] = None,
|
215
215
|
session: Optional[snowpark.Session] = None,
|
216
|
+
num_instances: Optional[int] = None,
|
216
217
|
) -> jb.MLJob:
|
217
218
|
"""
|
218
219
|
Submit a job to the compute pool.
|
@@ -229,6 +230,7 @@ def _submit_job(
|
|
229
230
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
230
231
|
spec_overrides: Custom service specification overrides to apply.
|
231
232
|
session: The Snowpark session to use. If none specified, uses active session.
|
233
|
+
num_instances: The number of instances to use for the job. If none specified, single node job is created.
|
232
234
|
|
233
235
|
Returns:
|
234
236
|
An object representing the submitted job.
|
@@ -254,6 +256,7 @@ def _submit_job(
|
|
254
256
|
compute_pool=compute_pool,
|
255
257
|
payload=uploaded_payload,
|
256
258
|
args=args,
|
259
|
+
num_instances=num_instances,
|
257
260
|
)
|
258
261
|
spec_overrides = spec_utils.generate_spec_overrides(
|
259
262
|
environment_vars=env_vars,
|
@@ -281,6 +284,8 @@ def _submit_job(
|
|
281
284
|
query_warehouse = query_warehouse or session.get_current_warehouse()
|
282
285
|
if query_warehouse:
|
283
286
|
query.append(f"QUERY_WAREHOUSE = {query_warehouse}")
|
287
|
+
if num_instances:
|
288
|
+
query.append(f"REPLICAS = {num_instances}")
|
284
289
|
|
285
290
|
# Submit job
|
286
291
|
query_text = "\n".join(line for line in query if line)
|
@@ -746,7 +746,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
746
746
|
max_instances: int = 1,
|
747
747
|
cpu_requests: Optional[str] = None,
|
748
748
|
memory_requests: Optional[str] = None,
|
749
|
-
gpu_requests: Optional[str] = None,
|
749
|
+
gpu_requests: Optional[Union[str, int]] = None,
|
750
750
|
num_workers: Optional[int] = None,
|
751
751
|
max_batch_rows: Optional[int] = None,
|
752
752
|
force_rebuild: bool = False,
|