snowflake-ml-python 1.7.4__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +58 -3
- snowflake/ml/_internal/env_utils.py +64 -21
- snowflake/ml/_internal/file_utils.py +18 -4
- snowflake/ml/_internal/platform_capabilities.py +3 -0
- snowflake/ml/_internal/relax_version_strategy.py +16 -0
- snowflake/ml/_internal/telemetry.py +25 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
- snowflake/ml/feature_store/feature_store.py +18 -0
- snowflake/ml/feature_store/feature_view.py +46 -1
- snowflake/ml/fileset/fileset.py +0 -1
- snowflake/ml/jobs/_utils/constants.py +31 -1
- snowflake/ml/jobs/_utils/payload_utils.py +232 -72
- snowflake/ml/jobs/_utils/spec_utils.py +78 -38
- snowflake/ml/jobs/decorators.py +8 -25
- snowflake/ml/jobs/job.py +4 -4
- snowflake/ml/jobs/manager.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/model_ops.py +107 -14
- snowflake/ml/model/_client/ops/service_ops.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
- snowflake/ml/model/_client/sql/model_version.py +58 -0
- snowflake/ml/model/_client/sql/service.py +8 -2
- snowflake/ml/model/_model_composer/model_composer.py +50 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
- snowflake/ml/model/_packager/model_env/model_env.py +49 -29
- snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +44 -24
- snowflake/ml/model/_packager/model_handlers/keras.py +226 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +51 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +25 -3
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +73 -21
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -72
- snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
- snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +6 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +16 -0
- snowflake/ml/model/_packager/model_packager.py +3 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +8 -1
- snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
- snowflake/ml/model/_signatures/builtins_handler.py +20 -9
- snowflake/ml/model/_signatures/core.py +54 -33
- snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
- snowflake/ml/model/_signatures/numpy_handler.py +12 -20
- snowflake/ml/model/_signatures/pandas_handler.py +28 -37
- snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
- snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
- snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
- snowflake/ml/model/_signatures/utils.py +120 -8
- snowflake/ml/model/custom_model.py +13 -4
- snowflake/ml/model/model_signature.py +39 -13
- snowflake/ml/model/type_hints.py +28 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/pipeline/pipeline.py +18 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
- snowflake/ml/registry/_manager/model_manager.py +55 -7
- snowflake/ml/registry/registry.py +52 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +336 -27
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +73 -66
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -4,21 +4,51 @@ from snowflake.ml.jobs._utils.types import ComputeResources
|
|
4
4
|
# SPCS specification constants
|
5
5
|
DEFAULT_CONTAINER_NAME = "main"
|
6
6
|
PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
7
|
+
MEMORY_VOLUME_NAME = "dshm"
|
8
|
+
STAGE_VOLUME_NAME = "stage-volume"
|
9
|
+
STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
7
10
|
|
8
11
|
# Default container image information
|
9
12
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
10
13
|
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
11
14
|
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
12
|
-
DEFAULT_IMAGE_TAG = "0.
|
15
|
+
DEFAULT_IMAGE_TAG = "0.9.2"
|
13
16
|
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
14
17
|
|
15
18
|
# Percent of container memory to allocate for /dev/shm volume
|
16
19
|
MEMORY_VOLUME_SIZE = 0.3
|
17
20
|
|
21
|
+
# Multi Node Headless prototype constants
|
22
|
+
# TODO: Replace this placeholder with the actual container runtime image tag.
|
23
|
+
MULTINODE_HEADLESS_IMAGE_TAG = "latest"
|
24
|
+
|
25
|
+
# Ray port configuration
|
26
|
+
RAY_PORTS = {
|
27
|
+
"HEAD_CLIENT_SERVER_PORT": "10001",
|
28
|
+
"HEAD_GCS_PORT": "12001",
|
29
|
+
"HEAD_DASHBOARD_GRPC_PORT": "12002",
|
30
|
+
"HEAD_DASHBOARD_PORT": "12003",
|
31
|
+
"OBJECT_MANAGER_PORT": "12011",
|
32
|
+
"NODE_MANAGER_PORT": "12012",
|
33
|
+
"RUNTIME_ENV_AGENT_PORT": "12013",
|
34
|
+
"DASHBOARD_AGENT_GRPC_PORT": "12014",
|
35
|
+
"DASHBOARD_AGENT_LISTEN_PORT": "12015",
|
36
|
+
"MIN_WORKER_PORT": "12031",
|
37
|
+
"MAX_WORKER_PORT": "13000",
|
38
|
+
}
|
39
|
+
|
40
|
+
# Node health check configuration
|
41
|
+
# TODO(SNOW-1937020): Revisit the health check configuration
|
42
|
+
ML_RUNTIME_HEALTH_CHECK_PORT = "5001"
|
43
|
+
ENABLE_HEALTH_CHECKS = "false"
|
44
|
+
|
18
45
|
# Job status polling constants
|
19
46
|
JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
|
20
47
|
JOB_POLL_MAX_DELAY_SECONDS = 1
|
21
48
|
|
49
|
+
# Magic attributes
|
50
|
+
IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
|
51
|
+
|
22
52
|
# Compute pool resource information
|
23
53
|
# TODO: Query Snowflake for resource information instead of relying on this hardcoded
|
24
54
|
# table from https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool
|
@@ -1,5 +1,8 @@
|
|
1
|
+
import functools
|
1
2
|
import inspect
|
2
3
|
import io
|
4
|
+
import itertools
|
5
|
+
import pickle
|
3
6
|
import sys
|
4
7
|
import textwrap
|
5
8
|
from pathlib import Path, PurePath
|
@@ -19,9 +22,11 @@ import cloudpickle as cp
|
|
19
22
|
|
20
23
|
from snowflake import snowpark
|
21
24
|
from snowflake.ml.jobs._utils import constants, types
|
25
|
+
from snowflake.snowpark import exceptions as sp_exceptions
|
22
26
|
from snowflake.snowpark._internal import code_generation
|
23
27
|
|
24
28
|
_SUPPORTED_ARG_TYPES = {str, int, float}
|
29
|
+
_SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
|
25
30
|
_STARTUP_SCRIPT_PATH = PurePath("startup.sh")
|
26
31
|
_STARTUP_SCRIPT_CODE = textwrap.dedent(
|
27
32
|
f"""
|
@@ -68,16 +73,56 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
68
73
|
##### Ray configuration #####
|
69
74
|
shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
|
70
75
|
|
76
|
+
# Check if the instance ip retrieval module exists, which is a prerequisite for multi node jobs
|
77
|
+
HELPER_EXISTS=$(
|
78
|
+
python3 -c "import snowflake.runtime.utils.get_instance_ip" 2>/dev/null && echo "true" || echo "false"
|
79
|
+
)
|
80
|
+
|
71
81
|
# Configure IP address and logging directory
|
72
|
-
|
82
|
+
if [ "$HELPER_EXISTS" = "true" ]; then
|
83
|
+
eth0Ip=$(python3 -m snowflake.runtime.utils.get_instance_ip "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
|
84
|
+
else
|
85
|
+
eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
|
86
|
+
fi
|
73
87
|
log_dir="/tmp/ray"
|
74
88
|
|
75
|
-
# Check if eth0Ip is
|
76
|
-
if [
|
77
|
-
# This should never happen, but just in case ethOIp is not set, we should default to localhost
|
89
|
+
# Check if eth0Ip is a valid IP address and fall back to default if necessary
|
90
|
+
if [[ ! $eth0Ip =~ ^[0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+$ ]]; then
|
78
91
|
eth0Ip="127.0.0.1"
|
79
92
|
fi
|
80
93
|
|
94
|
+
# Get the environment values of SNOWFLAKE_JOBS_COUNT and SNOWFLAKE_JOB_INDEX for batch jobs
|
95
|
+
# These variables don't exist for non-batch jobs, so set defaults
|
96
|
+
if [ -z "$SNOWFLAKE_JOBS_COUNT" ]; then
|
97
|
+
SNOWFLAKE_JOBS_COUNT=1
|
98
|
+
fi
|
99
|
+
|
100
|
+
if [ -z "$SNOWFLAKE_JOB_INDEX" ]; then
|
101
|
+
SNOWFLAKE_JOB_INDEX=0
|
102
|
+
fi
|
103
|
+
|
104
|
+
# Determine if it should be a worker or a head node for batch jobs
|
105
|
+
if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 && "$HELPER_EXISTS" = "true" ]]; then
|
106
|
+
head_info=$(python3 -m snowflake.runtime.utils.get_instance_ip "$SNOWFLAKE_SERVICE_NAME" --head)
|
107
|
+
if [ $? -eq 0 ]; then
|
108
|
+
# Parse the output using read
|
109
|
+
read head_index head_ip <<< "$head_info"
|
110
|
+
|
111
|
+
# Use the parsed variables
|
112
|
+
echo "Head Instance Index: $head_index"
|
113
|
+
echo "Head Instance IP: $head_ip"
|
114
|
+
|
115
|
+
else
|
116
|
+
echo "Error: Failed to get head instance information."
|
117
|
+
echo "$head_info" # Print the error message
|
118
|
+
exit 1
|
119
|
+
fi
|
120
|
+
|
121
|
+
if [ "$SNOWFLAKE_JOB_INDEX" -ne "$head_index" ]; then
|
122
|
+
NODE_TYPE="worker"
|
123
|
+
fi
|
124
|
+
fi
|
125
|
+
|
81
126
|
# Common parameters for both head and worker nodes
|
82
127
|
common_params=(
|
83
128
|
"--node-ip-address=$eth0Ip"
|
@@ -93,33 +138,94 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
93
138
|
"--disable-usage-stats"
|
94
139
|
)
|
95
140
|
|
96
|
-
|
97
|
-
|
98
|
-
"
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
141
|
+
if [ "$NODE_TYPE" = "worker" ]; then
|
142
|
+
# Use head_ip as head address if it exists
|
143
|
+
if [ ! -z "$head_ip" ]; then
|
144
|
+
RAY_HEAD_ADDRESS="$head_ip"
|
145
|
+
fi
|
146
|
+
|
147
|
+
# If RAY_HEAD_ADDRESS is still empty, exit with an error
|
148
|
+
if [ -z "$RAY_HEAD_ADDRESS" ]; then
|
149
|
+
echo "Error: Failed to determine head node address using default instance-index=0"
|
150
|
+
exit 1
|
151
|
+
fi
|
152
|
+
|
153
|
+
if [ -z "$SERVICE_NAME" ]; then
|
154
|
+
SERVICE_NAME="$SNOWFLAKE_SERVICE_NAME"
|
155
|
+
fi
|
156
|
+
|
157
|
+
if [ -z "$RAY_HEAD_ADDRESS" ] || [ -z "$SERVICE_NAME" ]; then
|
158
|
+
echo "Error: RAY_HEAD_ADDRESS and SERVICE_NAME must be set."
|
159
|
+
exit 1
|
160
|
+
fi
|
161
|
+
|
162
|
+
# Additional worker-specific parameters
|
163
|
+
worker_params=(
|
164
|
+
"--address=${{RAY_HEAD_ADDRESS}}:12001" # Connect to head node
|
165
|
+
"--resources={{\\"${{SERVICE_NAME}}\\":1, \\"node_tag:worker\\":1}}" # Tag for node identification
|
166
|
+
"--object-store-memory=${{shm_size}}"
|
167
|
+
)
|
106
168
|
|
107
|
-
|
108
|
-
|
109
|
-
|
169
|
+
# Start Ray on a worker node
|
170
|
+
ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block
|
171
|
+
else
|
172
|
+
|
173
|
+
# Additional head-specific parameters
|
174
|
+
head_params=(
|
175
|
+
"--head"
|
176
|
+
"--port=${{RAY_HEAD_GCS_PORT:-12001}}" # Port of Ray (GCS server)
|
177
|
+
"--ray-client-server-port=${{RAY_HEAD_CLIENT_SERVER_PORT:-10001}}" # Rort for Ray Client Server
|
178
|
+
"--dashboard-host=${{NODE_IP_ADDRESS}}" # Host to bind the dashboard server
|
179
|
+
"--dashboard-grpc-port=${{RAY_HEAD_DASHBOARD_GRPC_PORT:-12002}}" # Dashboard head to listen for grpc
|
180
|
+
"--dashboard-port=${{DASHBOARD_PORT}}" # Port to bind the dashboard server for debugging
|
181
|
+
"--resources={{\\"node_tag:head\\":1}}" # Resource tag for selecting head as coordinator
|
182
|
+
)
|
183
|
+
|
184
|
+
# Start Ray on the head node
|
185
|
+
ray start "${{common_params[@]}}" "${{head_params[@]}}" -v
|
186
|
+
##### End Ray configuration #####
|
110
187
|
|
111
|
-
|
112
|
-
|
188
|
+
# TODO: Monitor MLRS and handle process crashes
|
189
|
+
python -m web.ml_runtime_grpc_server &
|
113
190
|
|
114
|
-
|
191
|
+
# TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
|
115
192
|
|
116
|
-
|
117
|
-
|
118
|
-
|
193
|
+
# Run user's Python entrypoint
|
194
|
+
echo Running command: python "$@"
|
195
|
+
python "$@"
|
196
|
+
fi
|
119
197
|
"""
|
120
198
|
).strip()
|
121
199
|
|
122
200
|
|
201
|
+
def _resolve_entrypoint(parent: Path, entrypoint: Optional[Path]) -> Path:
|
202
|
+
parent = parent.absolute()
|
203
|
+
if entrypoint is None:
|
204
|
+
if parent.is_file():
|
205
|
+
# Infer entrypoint from source
|
206
|
+
entrypoint = parent
|
207
|
+
else:
|
208
|
+
raise ValueError("entrypoint must be provided when source is a directory")
|
209
|
+
elif entrypoint.is_absolute():
|
210
|
+
# Absolute path - validate it's a subpath of source dir
|
211
|
+
if not entrypoint.is_relative_to(parent):
|
212
|
+
raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint})")
|
213
|
+
else:
|
214
|
+
# Relative path
|
215
|
+
if (abs_entrypoint := entrypoint.absolute()).is_relative_to(parent) and abs_entrypoint.is_file():
|
216
|
+
# Relative to working dir iff path is relative to source dir and exists
|
217
|
+
entrypoint = abs_entrypoint
|
218
|
+
else:
|
219
|
+
# Relative to source dir
|
220
|
+
entrypoint = parent.joinpath(entrypoint)
|
221
|
+
if not entrypoint.is_file():
|
222
|
+
raise FileNotFoundError(
|
223
|
+
"Entrypoint not found. Ensure the entrypoint is a valid file and is under"
|
224
|
+
f" the source directory (source={parent}, entrypoint={entrypoint})"
|
225
|
+
)
|
226
|
+
return entrypoint
|
227
|
+
|
228
|
+
|
123
229
|
class JobPayload:
|
124
230
|
def __init__(
|
125
231
|
self,
|
@@ -138,23 +244,23 @@ class JobPayload:
|
|
138
244
|
# since we will generate the file from the serialized callable
|
139
245
|
pass
|
140
246
|
elif isinstance(self.source, Path):
|
141
|
-
# Validate
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
if not
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
247
|
+
# Validate source
|
248
|
+
source = self.source
|
249
|
+
if not source.exists():
|
250
|
+
raise FileNotFoundError(f"{source} does not exist")
|
251
|
+
source = source.absolute()
|
252
|
+
|
253
|
+
# Validate entrypoint
|
254
|
+
entrypoint = _resolve_entrypoint(source, self.entrypoint)
|
255
|
+
if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
|
256
|
+
raise ValueError(
|
257
|
+
"Unsupported entrypoint type:"
|
258
|
+
f" supported={','.join(_SUPPORTED_ENTRYPOINT_EXTENSIONS)} got={entrypoint.suffix}"
|
259
|
+
)
|
260
|
+
|
261
|
+
# Update fields with normalized values
|
262
|
+
self.source = source
|
263
|
+
self.entrypoint = entrypoint
|
158
264
|
else:
|
159
265
|
raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
|
160
266
|
|
@@ -168,12 +274,16 @@ class JobPayload:
|
|
168
274
|
entrypoint = self.entrypoint or Path(constants.DEFAULT_ENTRYPOINT_PATH)
|
169
275
|
|
170
276
|
# Create stage if necessary
|
171
|
-
stage_name = stage_path.parts[0]
|
172
|
-
|
173
|
-
|
174
|
-
"
|
175
|
-
|
176
|
-
|
277
|
+
stage_name = stage_path.parts[0].lstrip("@")
|
278
|
+
# Explicitly check if stage exists first since we may not have CREATE STAGE privilege
|
279
|
+
try:
|
280
|
+
session.sql(f"describe stage {stage_name}").collect()
|
281
|
+
except sp_exceptions.SnowparkSQLException:
|
282
|
+
session.sql(
|
283
|
+
f"create stage if not exists {stage_name}"
|
284
|
+
" encryption = ( type = 'SNOWFLAKE_SSE' )"
|
285
|
+
" comment = 'Created by snowflake.ml.jobs Python API'"
|
286
|
+
).collect()
|
177
287
|
|
178
288
|
# Upload payload to stage
|
179
289
|
if not isinstance(source, Path):
|
@@ -237,7 +347,7 @@ class JobPayload:
|
|
237
347
|
)
|
238
348
|
|
239
349
|
|
240
|
-
def
|
350
|
+
def _get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
|
241
351
|
# Unwrap Optional type annotations
|
242
352
|
param_type = param.annotation
|
243
353
|
if get_origin(param_type) is Union and len(get_args(param_type)) == 2 and type(None) in get_args(param_type):
|
@@ -249,7 +359,7 @@ def get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
|
|
249
359
|
return cast(Type[object], param_type)
|
250
360
|
|
251
361
|
|
252
|
-
def
|
362
|
+
def _validate_parameter_type(param_type: Type[object], param_name: str) -> None:
|
253
363
|
# Validate param_type is a supported type
|
254
364
|
if param_type not in _SUPPORTED_ARG_TYPES:
|
255
365
|
raise ValueError(
|
@@ -258,41 +368,60 @@ def validate_parameter_type(param_type: Type[object], param_name: str) -> None:
|
|
258
368
|
)
|
259
369
|
|
260
370
|
|
261
|
-
def
|
262
|
-
|
263
|
-
if any(
|
264
|
-
p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
265
|
-
for p in signature.parameters.values()
|
266
|
-
):
|
267
|
-
raise NotImplementedError("Function must not have unpacking arguments (* or **)")
|
268
|
-
|
269
|
-
# Mirrored from Snowpark generate_python_code() function
|
270
|
-
# https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
|
371
|
+
def _generate_source_code_comment(func: Callable[..., Any]) -> str:
|
372
|
+
"""Generate a comment string containing the source code of a function for readability."""
|
271
373
|
try:
|
272
|
-
|
273
|
-
|
274
|
-
|
374
|
+
if isinstance(func, functools.partial):
|
375
|
+
# Unwrap functools.partial and generate source code comment from the original function
|
376
|
+
comment = code_generation.generate_source_code(func.func) # type: ignore[arg-type]
|
377
|
+
args = itertools.chain((repr(a) for a in func.args), (f"{k}={v!r}" for k, v in func.keywords.items()))
|
378
|
+
|
379
|
+
# Update invocation comment to show arguments passed via functools.partial
|
380
|
+
comment = comment.replace(
|
381
|
+
f"= {func.func.__name__}",
|
382
|
+
"= functools.partial({}({}))".format(
|
383
|
+
func.func.__name__,
|
384
|
+
", ".join(args),
|
385
|
+
),
|
386
|
+
)
|
387
|
+
return comment
|
388
|
+
else:
|
389
|
+
return code_generation.generate_source_code(func) # type: ignore[arg-type]
|
275
390
|
except Exception as exc:
|
276
391
|
error_msg = f"Source code comment could not be generated for {func} due to error {exc}."
|
277
|
-
|
278
|
-
|
279
|
-
func_name = "func"
|
280
|
-
func_code = f"""
|
281
|
-
{source_code_comment}
|
392
|
+
return code_generation.comment_source_code(error_msg)
|
282
393
|
|
283
|
-
import pickle
|
284
|
-
{func_name} = pickle.loads(bytes.fromhex('{cp.dumps(func).hex()}'))
|
285
|
-
"""
|
286
394
|
|
395
|
+
def _serialize_callable(func: Callable[..., Any]) -> bytes:
|
396
|
+
try:
|
397
|
+
func_bytes: bytes = cp.dumps(func)
|
398
|
+
return func_bytes
|
399
|
+
except pickle.PicklingError as e:
|
400
|
+
if isinstance(func, functools.partial):
|
401
|
+
# Try to find which part of the partial isn't serializable for better debuggability
|
402
|
+
objects = [
|
403
|
+
("function", func.func),
|
404
|
+
*((f"positional arg {i}", a) for i, a in enumerate(func.args)),
|
405
|
+
*((f"keyword arg '{k}'", v) for k, v in func.keywords.items()),
|
406
|
+
]
|
407
|
+
for name, obj in objects:
|
408
|
+
try:
|
409
|
+
cp.dumps(obj)
|
410
|
+
except pickle.PicklingError:
|
411
|
+
raise ValueError(f"Unable to serialize {name}: {obj}") from e
|
412
|
+
raise ValueError(f"Unable to serialize function: {func}") from e
|
413
|
+
|
414
|
+
|
415
|
+
def _generate_param_handler_code(signature: inspect.Signature, output_name: str = "kwargs") -> str:
|
287
416
|
# Generate argparse logic for argument handling (type coercion, default values, etc)
|
288
417
|
argparse_code = ["import argparse", "", "parser = argparse.ArgumentParser()"]
|
289
418
|
argparse_postproc = []
|
290
419
|
for name, param in signature.parameters.items():
|
291
420
|
opts = {}
|
292
421
|
|
293
|
-
param_type =
|
422
|
+
param_type = _get_parameter_type(param)
|
294
423
|
if param_type is not None:
|
295
|
-
|
424
|
+
_validate_parameter_type(param_type, name)
|
296
425
|
opts["type"] = param_type.__name__
|
297
426
|
|
298
427
|
if param.default != inspect.Parameter.empty:
|
@@ -324,6 +453,37 @@ import pickle
|
|
324
453
|
)
|
325
454
|
argparse_code.append("args = parser.parse_args()")
|
326
455
|
param_code = "\n".join(argparse_code + argparse_postproc)
|
456
|
+
param_code += f"\n{output_name} = vars(args)"
|
457
|
+
|
458
|
+
return param_code
|
459
|
+
|
460
|
+
|
461
|
+
def generate_python_code(func: Callable[..., Any], source_code_display: bool = False) -> str:
|
462
|
+
"""Generate an entrypoint script from a Python function."""
|
463
|
+
signature = inspect.signature(func)
|
464
|
+
if any(
|
465
|
+
p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
466
|
+
for p in signature.parameters.values()
|
467
|
+
):
|
468
|
+
raise NotImplementedError("Function must not have unpacking arguments (* or **)")
|
469
|
+
|
470
|
+
# Mirrored from Snowpark generate_python_code() function
|
471
|
+
# https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
|
472
|
+
source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
|
473
|
+
|
474
|
+
func_name = "func"
|
475
|
+
func_code = f"""
|
476
|
+
{source_code_comment}
|
477
|
+
|
478
|
+
import pickle
|
479
|
+
{func_name} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
|
480
|
+
"""
|
481
|
+
|
482
|
+
arg_dict_name = "kwargs"
|
483
|
+
if getattr(func, constants.IS_MLJOB_REMOTE_ATTR, None):
|
484
|
+
param_code = f"{arg_dict_name} = {{}}"
|
485
|
+
else:
|
486
|
+
param_code = _generate_param_handler_code(signature, arg_dict_name)
|
327
487
|
|
328
488
|
return f"""
|
329
489
|
### Version guard to check compatibility across Python versions ###
|
@@ -348,5 +508,5 @@ if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor
|
|
348
508
|
if __name__ == '__main__':
|
349
509
|
{textwrap.indent(param_code, ' ')}
|
350
510
|
|
351
|
-
{func_name}(**
|
511
|
+
{func_name}(**{arg_dict_name})
|
352
512
|
"""
|
@@ -26,19 +26,22 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
|
|
26
26
|
)
|
27
27
|
|
28
28
|
|
29
|
-
def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
|
29
|
+
def _get_image_spec(session: snowpark.Session, compute_pool: str, image_tag: Optional[str] = None) -> types.ImageSpec:
|
30
30
|
# Retrieve compute pool node resources
|
31
31
|
resources = _get_node_resources(session, compute_pool=compute_pool)
|
32
32
|
|
33
33
|
# Use MLRuntime image
|
34
34
|
image_repo = constants.DEFAULT_IMAGE_REPO
|
35
35
|
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
36
|
-
image_tag = constants.DEFAULT_IMAGE_TAG
|
37
36
|
|
38
37
|
# Try to pull latest image tag from server side if possible
|
39
|
-
|
40
|
-
|
41
|
-
|
38
|
+
if not image_tag:
|
39
|
+
query_result = session.sql("SHOW PARAMETERS LIKE 'constants.RUNTIME_BASE_IMAGE_TAG' IN ACCOUNT").collect()
|
40
|
+
if query_result:
|
41
|
+
image_tag = query_result[0]["value"]
|
42
|
+
|
43
|
+
if image_tag is None:
|
44
|
+
image_tag = constants.DEFAULT_IMAGE_TAG
|
42
45
|
|
43
46
|
# TODO: Should each instance consume the entire pod?
|
44
47
|
return types.ImageSpec(
|
@@ -93,6 +96,7 @@ def generate_service_spec(
|
|
93
96
|
compute_pool: str,
|
94
97
|
payload: types.UploadedPayload,
|
95
98
|
args: Optional[List[str]] = None,
|
99
|
+
num_instances: Optional[int] = None,
|
96
100
|
) -> Dict[str, Any]:
|
97
101
|
"""
|
98
102
|
Generate a service specification for a job.
|
@@ -102,12 +106,21 @@ def generate_service_spec(
|
|
102
106
|
compute_pool: Compute pool for job execution
|
103
107
|
payload: Uploaded job payload
|
104
108
|
args: Arguments to pass to entrypoint script
|
109
|
+
num_instances: Number of instances for multi-node job
|
105
110
|
|
106
111
|
Returns:
|
107
112
|
Job service specification
|
108
113
|
"""
|
114
|
+
is_multi_node = num_instances is not None and num_instances > 1
|
115
|
+
|
109
116
|
# Set resource requests/limits, including nvidia.com/gpu quantity if applicable
|
110
|
-
|
117
|
+
if is_multi_node:
|
118
|
+
# If the job is of multi-node, we will need a different image which contains
|
119
|
+
# module snowflake.runtime.utils.get_instance_ip
|
120
|
+
# TODO(SNOW-1961849): Remove the hard-coded image name
|
121
|
+
image_spec = _get_image_spec(session, compute_pool, constants.MULTINODE_HEADLESS_IMAGE_TAG)
|
122
|
+
else:
|
123
|
+
image_spec = _get_image_spec(session, compute_pool)
|
111
124
|
resource_requests: Dict[str, Union[str, int]] = {
|
112
125
|
"cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
|
113
126
|
"memory": f"{image_spec.resource_limits.memory}Gi",
|
@@ -141,68 +154,88 @@ def generate_service_spec(
|
|
141
154
|
)
|
142
155
|
|
143
156
|
# Mount 30% of memory limit as a memory-backed volume
|
144
|
-
memory_volume_name = "dshm"
|
145
157
|
memory_volume_size = min(
|
146
158
|
ceil(image_spec.resource_limits.memory * constants.MEMORY_VOLUME_SIZE),
|
147
159
|
image_spec.resource_requests.memory,
|
148
160
|
)
|
149
161
|
volume_mounts.append(
|
150
162
|
{
|
151
|
-
"name":
|
163
|
+
"name": constants.MEMORY_VOLUME_NAME,
|
152
164
|
"mountPath": "/dev/shm",
|
153
165
|
}
|
154
166
|
)
|
155
167
|
volumes.append(
|
156
168
|
{
|
157
|
-
"name":
|
169
|
+
"name": constants.MEMORY_VOLUME_NAME,
|
158
170
|
"source": "memory",
|
159
171
|
"size": f"{memory_volume_size}Gi",
|
160
172
|
}
|
161
173
|
)
|
162
174
|
|
163
175
|
# Mount payload as volume
|
164
|
-
stage_mount = PurePath(
|
165
|
-
stage_volume_name = "stage-volume"
|
176
|
+
stage_mount = PurePath(constants.STAGE_VOLUME_MOUNT_PATH)
|
166
177
|
volume_mounts.append(
|
167
178
|
{
|
168
|
-
"name":
|
179
|
+
"name": constants.STAGE_VOLUME_NAME,
|
169
180
|
"mountPath": stage_mount.as_posix(),
|
170
181
|
}
|
171
182
|
)
|
172
183
|
volumes.append(
|
173
184
|
{
|
174
|
-
"name":
|
185
|
+
"name": constants.STAGE_VOLUME_NAME,
|
175
186
|
"source": payload.stage_path.as_posix(),
|
176
187
|
}
|
177
188
|
)
|
178
189
|
|
179
190
|
# TODO: Add hooks for endpoints for integration with TensorBoard etc
|
180
191
|
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
192
|
+
env_vars = {constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix()}
|
193
|
+
endpoints = []
|
194
|
+
|
195
|
+
if is_multi_node:
|
196
|
+
# Update environment variables for multi-node job
|
197
|
+
env_vars.update(constants.RAY_PORTS)
|
198
|
+
env_vars["ENABLE_HEALTH_CHECKS"] = constants.ENABLE_HEALTH_CHECKS
|
199
|
+
|
200
|
+
# Define Ray endpoints for intra-service instance communication
|
201
|
+
ray_endpoints = [
|
202
|
+
{"name": "ray-client-server-endpoint", "port": 10001, "protocol": "TCP"},
|
203
|
+
{"name": "ray-gcs-endpoint", "port": 12001, "protocol": "TCP"},
|
204
|
+
{"name": "ray-dashboard-grpc-endpoint", "port": 12002, "protocol": "TCP"},
|
205
|
+
{"name": "ray-object-manager-endpoint", "port": 12011, "protocol": "TCP"},
|
206
|
+
{"name": "ray-node-manager-endpoint", "port": 12012, "protocol": "TCP"},
|
207
|
+
{"name": "ray-runtime-agent-endpoint", "port": 12013, "protocol": "TCP"},
|
208
|
+
{"name": "ray-dashboard-agent-grpc-endpoint", "port": 12014, "protocol": "TCP"},
|
209
|
+
{"name": "ephemeral-port-range", "portRange": "32768-60999", "protocol": "TCP"},
|
210
|
+
{"name": "ray-worker-port-range", "portRange": "12031-13000", "protocol": "TCP"},
|
211
|
+
]
|
212
|
+
endpoints.extend(ray_endpoints)
|
213
|
+
|
214
|
+
spec_dict = {
|
215
|
+
"containers": [
|
216
|
+
{
|
217
|
+
"name": constants.DEFAULT_CONTAINER_NAME,
|
218
|
+
"image": image_spec.full_name,
|
219
|
+
"command": ["/usr/local/bin/_entrypoint.sh"],
|
220
|
+
"args": [
|
221
|
+
(stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
|
222
|
+
]
|
223
|
+
+ (args or []),
|
224
|
+
"env": env_vars,
|
225
|
+
"volumeMounts": volume_mounts,
|
226
|
+
"resources": {
|
227
|
+
"requests": resource_requests,
|
228
|
+
"limits": resource_limits,
|
201
229
|
},
|
202
|
-
|
203
|
-
|
204
|
-
|
230
|
+
},
|
231
|
+
],
|
232
|
+
"volumes": volumes,
|
205
233
|
}
|
234
|
+
if endpoints:
|
235
|
+
spec_dict["endpoints"] = endpoints
|
236
|
+
|
237
|
+
# Assemble into service specification dict
|
238
|
+
spec = {"spec": spec_dict}
|
206
239
|
|
207
240
|
return spec
|
208
241
|
|
@@ -250,7 +283,10 @@ def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
|
|
250
283
|
|
251
284
|
|
252
285
|
def _merge_lists_of_dicts(
|
253
|
-
base: List[Dict[str, Any]],
|
286
|
+
base: List[Dict[str, Any]],
|
287
|
+
patch: List[Dict[str, Any]],
|
288
|
+
merge_key: str = "name",
|
289
|
+
display_name: str = "",
|
254
290
|
) -> List[Dict[str, Any]]:
|
255
291
|
"""
|
256
292
|
Attempts to merge lists of dicts by matching on a merge key (default "name").
|
@@ -290,7 +326,11 @@ def _merge_lists_of_dicts(
|
|
290
326
|
|
291
327
|
# Apply patch
|
292
328
|
if key in result:
|
293
|
-
d = merge_patch(
|
329
|
+
d = merge_patch(
|
330
|
+
result[key],
|
331
|
+
d,
|
332
|
+
display_name=f"{display_name}[{merge_key}={d[merge_key]}]",
|
333
|
+
)
|
294
334
|
# TODO: Should we drop the item if the patch result is empty save for the merge key?
|
295
335
|
# Can check `d.keys() <= {merge_key}`
|
296
336
|
result[key] = d
|