snowflake-ml-python 1.7.5__py3-none-any.whl → 1.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. snowflake/cortex/_complete.py +58 -3
  2. snowflake/ml/_internal/file_utils.py +18 -4
  3. snowflake/ml/_internal/platform_capabilities.py +3 -0
  4. snowflake/ml/_internal/telemetry.py +4 -0
  5. snowflake/ml/fileset/fileset.py +0 -1
  6. snowflake/ml/jobs/_utils/constants.py +25 -1
  7. snowflake/ml/jobs/_utils/payload_utils.py +94 -20
  8. snowflake/ml/jobs/_utils/spec_utils.py +95 -31
  9. snowflake/ml/jobs/decorators.py +7 -0
  10. snowflake/ml/jobs/manager.py +20 -0
  11. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  12. snowflake/ml/model/_client/ops/model_ops.py +113 -17
  13. snowflake/ml/model/_client/ops/service_ops.py +16 -5
  14. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
  15. snowflake/ml/model/_client/sql/model_version.py +58 -0
  16. snowflake/ml/model/_client/sql/service.py +10 -2
  17. snowflake/ml/model/_model_composer/model_composer.py +50 -3
  18. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +5 -2
  19. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  20. snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
  21. snowflake/ml/model/_packager/model_env/model_env.py +4 -1
  22. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +28 -24
  23. snowflake/ml/model/_packager/model_handlers/keras.py +1 -5
  24. snowflake/ml/model/_packager/model_handlers/pytorch.py +50 -20
  25. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -8
  26. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -2
  27. snowflake/ml/model/_packager/model_handlers/tensorflow.py +46 -26
  28. snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
  29. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  30. snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
  31. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
  32. snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +1 -2
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +5 -1
  35. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -0
  36. snowflake/ml/model/_packager/model_packager.py +3 -5
  37. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  38. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -0
  39. snowflake/ml/model/_signatures/builtins_handler.py +20 -9
  40. snowflake/ml/model/_signatures/core.py +52 -31
  41. snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
  42. snowflake/ml/model/_signatures/numpy_handler.py +9 -17
  43. snowflake/ml/model/_signatures/pandas_handler.py +19 -30
  44. snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
  45. snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
  46. snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
  47. snowflake/ml/model/_signatures/utils.py +120 -8
  48. snowflake/ml/model/custom_model.py +13 -4
  49. snowflake/ml/model/model_signature.py +31 -13
  50. snowflake/ml/model/type_hints.py +13 -2
  51. snowflake/ml/modeling/_internal/estimator_utils.py +5 -1
  52. snowflake/ml/modeling/metrics/ranking.py +3 -0
  53. snowflake/ml/modeling/metrics/regression.py +3 -0
  54. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
  55. snowflake/ml/registry/_manager/model_manager.py +55 -7
  56. snowflake/ml/registry/registry.py +59 -1
  57. snowflake/ml/version.py +1 -1
  58. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/METADATA +308 -12
  59. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/RECORD +62 -58
  60. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/WHEEL +1 -1
  61. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info/licenses}/LICENSE.txt +0 -0
  62. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,15 @@ logger = logging.getLogger(__name__)
23
23
  _REST_COMPLETE_URL = "/api/v2/cortex/inference:complete"
24
24
 
25
25
 
26
+ class ResponseFormat(TypedDict):
27
+ """Represents an object describing response format config for structured-output mode"""
28
+
29
+ type: str
30
+ """The response format type (e.g. "json")"""
31
+ schema: Dict[str, Any]
32
+ """The schema defining the structure of the response. For json it should be a valid json schema object"""
33
+
34
+
26
35
  class ConversationMessage(TypedDict):
27
36
  """Represents an conversation interaction."""
28
37
 
@@ -53,6 +62,9 @@ class CompleteOptions(TypedDict):
53
62
  """ A boolean value that controls whether Cortex Guard filters unsafe or harmful responses
54
63
  from the language model. """
55
64
 
65
+ response_format: NotRequired[ResponseFormat]
66
+ """ An object describing response format config for structured-output mode """
67
+
56
68
 
57
69
  class ResponseParseException(Exception):
58
70
  """This exception is raised when the server response cannot be parsed."""
@@ -108,6 +120,32 @@ def _make_common_request_headers() -> Dict[str, str]:
108
120
  return headers
109
121
 
110
122
 
123
+ def _validate_response_format_object(options: CompleteOptions) -> None:
124
+ """Validate the response format object for structured-output mode.
125
+
126
+ More details can be found in:
127
+ docs.snowflake.com/en/user-guide/snowflake-cortex/complete-structured-outputs#using-complete-structured-outputs
128
+
129
+ Args:
130
+ options: The complete options object.
131
+
132
+ Raises:
133
+ ValueError: If the response format object is invalid or missing required fields.
134
+ """
135
+ if options is not None and options.get("response_format") is not None:
136
+ options_obj = options.get("response_format")
137
+ if not isinstance(options_obj, dict):
138
+ raise ValueError("'response_format' should be an object")
139
+ if options_obj.get("type") is None:
140
+ raise ValueError("'type' cannot be empty for 'response_format' object")
141
+ if not isinstance(options_obj.get("type"), str):
142
+ raise ValueError("'type' needs to be a str for 'response_format' object")
143
+ if options_obj.get("schema") is None:
144
+ raise ValueError("'schema' cannot be empty for 'response_format' object")
145
+ if not isinstance(options_obj.get("schema"), dict):
146
+ raise ValueError("'schema' needs to be a dict for 'response_format' object")
147
+
148
+
111
149
  def _make_request_body(
112
150
  model: str,
113
151
  prompt: Union[str, List[ConversationMessage]],
@@ -136,12 +174,16 @@ def _make_request_body(
136
174
  "response_when_unsafe": "Response filtered by Cortex Guard",
137
175
  }
138
176
  data["guardrails"] = guardrails_options
177
+ if "response_format" in options:
178
+ data["response_format"] = options["response_format"]
179
+
139
180
  return data
140
181
 
141
182
 
142
183
  # XP endpoint returns a dict response which needs to be converted to a format which can
143
184
  # be consumed by the SSEClient. This method does that.
144
185
  def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
186
+
145
187
  response = requests.Response()
146
188
  response.status_code = int(raw_resp["status"])
147
189
  response.headers = raw_resp["headers"]
@@ -159,7 +201,6 @@ def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
159
201
  data = json.loads(data)
160
202
  except json.JSONDecodeError:
161
203
  raise ValueError(f"Request failed (request id: {request_id})")
162
-
163
204
  if response.status_code < 200 or response.status_code >= 300:
164
205
  if "message" not in data:
165
206
  raise ValueError(f"Request failed (request id: {request_id})")
@@ -241,11 +282,21 @@ def _return_stream_response(response: requests.Response, deadline: Optional[floa
241
282
  if deadline is not None and time.time() > deadline:
242
283
  raise TimeoutError()
243
284
  try:
244
- yield json.loads(event.data)["choices"][0]["delta"]["content"]
285
+ parsed_resp = json.loads(event.data)
286
+ except json.JSONDecodeError:
287
+ raise ResponseParseException("Server response cannot be parsed")
288
+ try:
289
+ yield parsed_resp["choices"][0]["delta"]["content"]
245
290
  except (json.JSONDecodeError, KeyError, IndexError):
246
291
  # For the sake of evolution of the output format,
247
292
  # ignore stream messages that don't match the expected format.
248
- pass
293
+
294
+ # This is the case of midstream errors which were introduced specifically for structured output.
295
+ # TODO: discuss during code review
296
+ if parsed_resp.get("error"):
297
+ yield json.dumps(parsed_resp)
298
+ else:
299
+ pass
249
300
 
250
301
 
251
302
  def _complete_call_sql_function_snowpark(
@@ -291,6 +342,8 @@ def _complete_non_streaming_impl(
291
342
  raise ValueError("'model' cannot be a snowpark.Column when 'prompt' is a string.")
292
343
  if isinstance(options, snowpark.Column):
293
344
  raise ValueError("'options' cannot be a snowpark.Column when 'prompt' is a string.")
345
+ if options and not isinstance(options, snowpark.Column):
346
+ _validate_response_format_object(options)
294
347
  return _complete_non_streaming_immediate(
295
348
  snow_api_xp_request_handler=snow_api_xp_request_handler,
296
349
  model=model,
@@ -309,6 +362,8 @@ def _complete_rest(
309
362
  session: Optional[snowpark.Session] = None,
310
363
  deadline: Optional[float] = None,
311
364
  ) -> Iterator[str]:
365
+ if options:
366
+ _validate_response_format_object(options)
312
367
  if snow_api_xp_request_handler is not None:
313
368
  response = _call_complete_xp(
314
369
  snow_api_xp_request_handler=snow_api_xp_request_handler,
@@ -23,6 +23,7 @@ from typing import (
23
23
  Tuple,
24
24
  Union,
25
25
  )
26
+ from urllib import parse
26
27
 
27
28
  import cloudpickle
28
29
 
@@ -294,7 +295,7 @@ def _retry_on_sql_error(exception: Exception) -> bool:
294
295
  def upload_directory_to_stage(
295
296
  session: snowpark.Session,
296
297
  local_path: pathlib.Path,
297
- stage_path: pathlib.PurePosixPath,
298
+ stage_path: Union[pathlib.PurePosixPath, parse.ParseResult],
298
299
  *,
299
300
  statement_params: Optional[Dict[str, Any]] = None,
300
301
  ) -> None:
@@ -314,9 +315,22 @@ def upload_directory_to_stage(
314
315
  root_path = pathlib.Path(root)
315
316
  for filename in filenames:
316
317
  local_file_path = root_path / filename
317
- stage_dir_path = (
318
- stage_path / pathlib.PurePosixPath(local_file_path.relative_to(local_path).as_posix()).parent
319
- )
318
+ relative_path = pathlib.PurePosixPath(local_file_path.relative_to(local_path).as_posix())
319
+
320
+ if isinstance(stage_path, parse.ParseResult):
321
+ relative_stage_path = (pathlib.PosixPath(stage_path.path) / relative_path).parent
322
+ new_url = parse.ParseResult(
323
+ scheme=stage_path.scheme,
324
+ netloc=stage_path.netloc,
325
+ path=str(relative_stage_path),
326
+ params=stage_path.params,
327
+ query=stage_path.query,
328
+ fragment=stage_path.fragment,
329
+ )
330
+ stage_dir_path = parse.urlunparse(new_url)
331
+ else:
332
+ stage_dir_path = str((stage_path / relative_path).parent)
333
+
320
334
  retrying.retry(
321
335
  retry_on_exception=_retry_on_sql_error,
322
336
  stop_max_attempt_number=5,
@@ -37,6 +37,9 @@ class PlatformCapabilities:
37
37
  def is_nested_function_enabled(self) -> bool:
38
38
  return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
39
39
 
40
+ def is_live_commit_enabled(self) -> bool:
41
+ return self._get_bool_feature("ENABLE_BUNDLE_MODULE_CHECKOUT", False)
42
+
40
43
  @staticmethod
41
44
  def _get_features(session: snowpark_session.Session) -> Dict[str, Any]:
42
45
  try:
@@ -353,6 +353,10 @@ def get_function_usage_statement_params(
353
353
  statement_params[TelemetryField.KEY_API_CALLS.value].append({TelemetryField.NAME.value: api_call})
354
354
  if custom_tags:
355
355
  statement_params[TelemetryField.KEY_CUSTOM_TAGS.value] = custom_tags
356
+ # Snowpark doesn't support None value in statement_params from version 1.29
357
+ for k in statement_params:
358
+ if statement_params[k] is None:
359
+ statement_params[k] = ""
356
360
  return statement_params
357
361
 
358
362
 
@@ -257,7 +257,6 @@ class FileSet:
257
257
  function_name=telemetry.get_statement_params_full_func_name(
258
258
  inspect.currentframe(), cls.__class__.__name__
259
259
  ),
260
- api_calls=[snowpark.DataFrameWriter.copy_into_location],
261
260
  ),
262
261
  )
263
262
  except snowpark_exceptions.SnowparkSQLException as e:
@@ -12,12 +12,36 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
12
12
  DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
13
13
  DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
14
14
  DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
15
- DEFAULT_IMAGE_TAG = "0.9.2"
15
+ DEFAULT_IMAGE_TAG = "1.0.1"
16
16
  DEFAULT_ENTRYPOINT_PATH = "func.py"
17
17
 
18
18
  # Percent of container memory to allocate for /dev/shm volume
19
19
  MEMORY_VOLUME_SIZE = 0.3
20
20
 
21
+ # Multi Node Headless prototype constants
22
+ # TODO: Replace this placeholder with the actual container runtime image tag.
23
+ MULTINODE_HEADLESS_IMAGE_TAG = "latest"
24
+
25
+ # Ray port configuration
26
+ RAY_PORTS = {
27
+ "HEAD_CLIENT_SERVER_PORT": "10001",
28
+ "HEAD_GCS_PORT": "12001",
29
+ "HEAD_DASHBOARD_GRPC_PORT": "12002",
30
+ "HEAD_DASHBOARD_PORT": "12003",
31
+ "OBJECT_MANAGER_PORT": "12011",
32
+ "NODE_MANAGER_PORT": "12012",
33
+ "RUNTIME_ENV_AGENT_PORT": "12013",
34
+ "DASHBOARD_AGENT_GRPC_PORT": "12014",
35
+ "DASHBOARD_AGENT_LISTEN_PORT": "12015",
36
+ "MIN_WORKER_PORT": "12031",
37
+ "MAX_WORKER_PORT": "13000",
38
+ }
39
+
40
+ # Node health check configuration
41
+ # TODO(SNOW-1937020): Revisit the health check configuration
42
+ ML_RUNTIME_HEALTH_CHECK_PORT = "5001"
43
+ ENABLE_HEALTH_CHECKS = "false"
44
+
21
45
  # Job status polling constants
22
46
  JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
23
47
  JOB_POLL_MAX_DELAY_SECONDS = 1
@@ -73,8 +73,17 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
73
73
  ##### Ray configuration #####
74
74
  shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
75
75
 
76
+ # Check if the instance ip retrieval module exists, which is a prerequisite for multi node jobs
77
+ HELPER_EXISTS=$(
78
+ python3 -c "import snowflake.runtime.utils.get_instance_ip" 2>/dev/null && echo "true" || echo "false"
79
+ )
80
+
76
81
  # Configure IP address and logging directory
77
- eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
82
+ if [ "$HELPER_EXISTS" = "true" ]; then
83
+ eth0Ip=$(python3 -m snowflake.runtime.utils.get_instance_ip "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
84
+ else
85
+ eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
86
+ fi
78
87
  log_dir="/tmp/ray"
79
88
 
80
89
  # Check if eth0Ip is a valid IP address and fall back to default if necessary
@@ -82,6 +91,38 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
82
91
  eth0Ip="127.0.0.1"
83
92
  fi
84
93
 
94
+ # Get the environment values of SNOWFLAKE_JOBS_COUNT and SNOWFLAKE_JOB_INDEX for batch jobs
95
+ # These variables don't exist for non-batch jobs, so set defaults
96
+ if [ -z "$SNOWFLAKE_JOBS_COUNT" ]; then
97
+ SNOWFLAKE_JOBS_COUNT=1
98
+ fi
99
+
100
+ if [ -z "$SNOWFLAKE_JOB_INDEX" ]; then
101
+ SNOWFLAKE_JOB_INDEX=0
102
+ fi
103
+
104
+ # Determine if it should be a worker or a head node for batch jobs
105
+ if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 && "$HELPER_EXISTS" = "true" ]]; then
106
+ head_info=$(python3 -m snowflake.runtime.utils.get_instance_ip "$SNOWFLAKE_SERVICE_NAME" --head)
107
+ if [ $? -eq 0 ]; then
108
+ # Parse the output using read
109
+ read head_index head_ip <<< "$head_info"
110
+
111
+ # Use the parsed variables
112
+ echo "Head Instance Index: $head_index"
113
+ echo "Head Instance IP: $head_ip"
114
+
115
+ else
116
+ echo "Error: Failed to get head instance information."
117
+ echo "$head_info" # Print the error message
118
+ exit 1
119
+ fi
120
+
121
+ if [ "$SNOWFLAKE_JOB_INDEX" -ne "$head_index" ]; then
122
+ NODE_TYPE="worker"
123
+ fi
124
+ fi
125
+
85
126
  # Common parameters for both head and worker nodes
86
127
  common_params=(
87
128
  "--node-ip-address=$eth0Ip"
@@ -97,29 +138,62 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
97
138
  "--disable-usage-stats"
98
139
  )
99
140
 
100
- # Additional head-specific parameters
101
- head_params=(
102
- "--head"
103
- "--port=${{RAY_HEAD_GCS_PORT:-12001}}" # Port of Ray (GCS server)
104
- "--ray-client-server-port=${{RAY_HEAD_CLIENT_SERVER_PORT:-10001}}" # Listening port for Ray Client Server
105
- "--dashboard-host=${{NODE_IP_ADDRESS}}" # Host to bind the dashboard server
106
- "--dashboard-grpc-port=${{RAY_HEAD_DASHBOARD_GRPC_PORT:-12002}}" # Dashboard head to listen for grpc on
107
- "--dashboard-port=${{DASHBOARD_PORT}}" # Port to bind the dashboard server for local debugging
108
- "--resources={{\\"node_tag:head\\":1}}" # Resource tag for selecting head as coordinator
109
- )
141
+ if [ "$NODE_TYPE" = "worker" ]; then
142
+ # Use head_ip as head address if it exists
143
+ if [ ! -z "$head_ip" ]; then
144
+ RAY_HEAD_ADDRESS="$head_ip"
145
+ fi
146
+
147
+ # If RAY_HEAD_ADDRESS is still empty, exit with an error
148
+ if [ -z "$RAY_HEAD_ADDRESS" ]; then
149
+ echo "Error: Failed to determine head node address using default instance-index=0"
150
+ exit 1
151
+ fi
152
+
153
+ if [ -z "$SERVICE_NAME" ]; then
154
+ SERVICE_NAME="$SNOWFLAKE_SERVICE_NAME"
155
+ fi
156
+
157
+ if [ -z "$RAY_HEAD_ADDRESS" ] || [ -z "$SERVICE_NAME" ]; then
158
+ echo "Error: RAY_HEAD_ADDRESS and SERVICE_NAME must be set."
159
+ exit 1
160
+ fi
161
+
162
+ # Additional worker-specific parameters
163
+ worker_params=(
164
+ "--address=${{RAY_HEAD_ADDRESS}}:12001" # Connect to head node
165
+ "--resources={{\\"${{SERVICE_NAME}}\\":1, \\"node_tag:worker\\":1}}" # Tag for node identification
166
+ "--object-store-memory=${{shm_size}}"
167
+ )
110
168
 
111
- # Start Ray on the head node
112
- ray start "${{common_params[@]}}" "${{head_params[@]}}" &
113
- ##### End Ray configuration #####
169
+ # Start Ray on a worker node
170
+ ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block
171
+ else
172
+
173
+ # Additional head-specific parameters
174
+ head_params=(
175
+ "--head"
176
+ "--port=${{RAY_HEAD_GCS_PORT:-12001}}" # Port of Ray (GCS server)
177
+ "--ray-client-server-port=${{RAY_HEAD_CLIENT_SERVER_PORT:-10001}}" # Rort for Ray Client Server
178
+ "--dashboard-host=${{NODE_IP_ADDRESS}}" # Host to bind the dashboard server
179
+ "--dashboard-grpc-port=${{RAY_HEAD_DASHBOARD_GRPC_PORT:-12002}}" # Dashboard head to listen for grpc
180
+ "--dashboard-port=${{DASHBOARD_PORT}}" # Port to bind the dashboard server for debugging
181
+ "--resources={{\\"node_tag:head\\":1}}" # Resource tag for selecting head as coordinator
182
+ )
183
+
184
+ # Start Ray on the head node
185
+ ray start "${{common_params[@]}}" "${{head_params[@]}}" -v
186
+ ##### End Ray configuration #####
114
187
 
115
- # TODO: Monitor MLRS and handle process crashes
116
- python -m web.ml_runtime_grpc_server &
188
+ # TODO: Monitor MLRS and handle process crashes
189
+ python -m web.ml_runtime_grpc_server &
117
190
 
118
- # TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
191
+ # TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
119
192
 
120
- # Run user's Python entrypoint
121
- echo Running command: python "$@"
122
- python "$@"
193
+ # Run user's Python entrypoint
194
+ echo Running command: python "$@"
195
+ python "$@"
196
+ fi
123
197
  """
124
198
  ).strip()
125
199
 
@@ -26,19 +26,22 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
26
26
  )
27
27
 
28
28
 
29
- def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
29
+ def _get_image_spec(session: snowpark.Session, compute_pool: str, image_tag: Optional[str] = None) -> types.ImageSpec:
30
30
  # Retrieve compute pool node resources
31
31
  resources = _get_node_resources(session, compute_pool=compute_pool)
32
32
 
33
33
  # Use MLRuntime image
34
34
  image_repo = constants.DEFAULT_IMAGE_REPO
35
35
  image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
36
- image_tag = constants.DEFAULT_IMAGE_TAG
37
36
 
38
37
  # Try to pull latest image tag from server side if possible
39
- query_result = session.sql("SHOW PARAMETERS LIKE 'constants.RUNTIME_BASE_IMAGE_TAG' IN ACCOUNT").collect()
40
- if query_result:
41
- image_tag = query_result[0]["value"]
38
+ if not image_tag:
39
+ query_result = session.sql("SHOW PARAMETERS LIKE 'constants.RUNTIME_BASE_IMAGE_TAG' IN ACCOUNT").collect()
40
+ if query_result:
41
+ image_tag = query_result[0]["value"]
42
+
43
+ if image_tag is None:
44
+ image_tag = constants.DEFAULT_IMAGE_TAG
42
45
 
43
46
  # TODO: Should each instance consume the entire pod?
44
47
  return types.ImageSpec(
@@ -93,6 +96,8 @@ def generate_service_spec(
93
96
  compute_pool: str,
94
97
  payload: types.UploadedPayload,
95
98
  args: Optional[List[str]] = None,
99
+ num_instances: Optional[int] = None,
100
+ enable_metrics: bool = False,
96
101
  ) -> Dict[str, Any]:
97
102
  """
98
103
  Generate a service specification for a job.
@@ -102,12 +107,22 @@ def generate_service_spec(
102
107
  compute_pool: Compute pool for job execution
103
108
  payload: Uploaded job payload
104
109
  args: Arguments to pass to entrypoint script
110
+ num_instances: Number of instances for multi-node job
111
+ enable_metrics: Enable platform metrics for the job
105
112
 
106
113
  Returns:
107
114
  Job service specification
108
115
  """
116
+ is_multi_node = num_instances is not None and num_instances > 1
117
+
109
118
  # Set resource requests/limits, including nvidia.com/gpu quantity if applicable
110
- image_spec = _get_image_spec(session, compute_pool)
119
+ if is_multi_node:
120
+ # If the job is of multi-node, we will need a different image which contains
121
+ # module snowflake.runtime.utils.get_instance_ip
122
+ # TODO(SNOW-1961849): Remove the hard-coded image name
123
+ image_spec = _get_image_spec(session, compute_pool, constants.MULTINODE_HEADLESS_IMAGE_TAG)
124
+ else:
125
+ image_spec = _get_image_spec(session, compute_pool)
111
126
  resource_requests: Dict[str, Union[str, int]] = {
112
127
  "cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
113
128
  "memory": f"{image_spec.resource_limits.memory}Gi",
@@ -176,31 +191,73 @@ def generate_service_spec(
176
191
 
177
192
  # TODO: Add hooks for endpoints for integration with TensorBoard etc
178
193
 
179
- # Assemble into service specification dict
180
- spec = {
181
- "spec": {
182
- "containers": [
183
- {
184
- "name": constants.DEFAULT_CONTAINER_NAME,
185
- "image": image_spec.full_name,
186
- "command": ["/usr/local/bin/_entrypoint.sh"],
187
- "args": [
188
- stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v for v in payload.entrypoint
189
- ]
190
- + (args or []),
191
- "env": {
192
- constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix(),
193
- },
194
- "volumeMounts": volume_mounts,
195
- "resources": {
196
- "requests": resource_requests,
197
- "limits": resource_limits,
198
- },
194
+ env_vars = {constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix()}
195
+ endpoints = []
196
+
197
+ if is_multi_node:
198
+ # Update environment variables for multi-node job
199
+ env_vars.update(constants.RAY_PORTS)
200
+ env_vars["ENABLE_HEALTH_CHECKS"] = constants.ENABLE_HEALTH_CHECKS
201
+
202
+ # Define Ray endpoints for intra-service instance communication
203
+ ray_endpoints = [
204
+ {"name": "ray-client-server-endpoint", "port": 10001, "protocol": "TCP"},
205
+ {"name": "ray-gcs-endpoint", "port": 12001, "protocol": "TCP"},
206
+ {"name": "ray-dashboard-grpc-endpoint", "port": 12002, "protocol": "TCP"},
207
+ {"name": "ray-object-manager-endpoint", "port": 12011, "protocol": "TCP"},
208
+ {"name": "ray-node-manager-endpoint", "port": 12012, "protocol": "TCP"},
209
+ {"name": "ray-runtime-agent-endpoint", "port": 12013, "protocol": "TCP"},
210
+ {"name": "ray-dashboard-agent-grpc-endpoint", "port": 12014, "protocol": "TCP"},
211
+ {"name": "ephemeral-port-range", "portRange": "32768-60999", "protocol": "TCP"},
212
+ {"name": "ray-worker-port-range", "portRange": "12031-13000", "protocol": "TCP"},
213
+ ]
214
+ endpoints.extend(ray_endpoints)
215
+
216
+ metrics = []
217
+ if enable_metrics:
218
+ # https://docs.snowflake.com/en/developer-guide/snowpark-container-services/monitoring-services#label-spcs-available-platform-metrics
219
+ metrics = [
220
+ "system",
221
+ "status",
222
+ "network",
223
+ "storage",
224
+ ]
225
+
226
+ spec_dict = {
227
+ "containers": [
228
+ {
229
+ "name": constants.DEFAULT_CONTAINER_NAME,
230
+ "image": image_spec.full_name,
231
+ "command": ["/usr/local/bin/_entrypoint.sh"],
232
+ "args": [
233
+ (stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
234
+ ]
235
+ + (args or []),
236
+ "env": env_vars,
237
+ "volumeMounts": volume_mounts,
238
+ "resources": {
239
+ "requests": resource_requests,
240
+ "limits": resource_limits,
199
241
  },
200
- ],
201
- "volumes": volumes,
202
- }
242
+ },
243
+ ],
244
+ "volumes": volumes,
203
245
  }
246
+ if endpoints:
247
+ spec_dict["endpoints"] = endpoints
248
+ if metrics:
249
+ spec_dict.update(
250
+ {
251
+ "platformMonitor": {
252
+ "metricConfig": {
253
+ "groups": metrics,
254
+ },
255
+ },
256
+ }
257
+ )
258
+
259
+ # Assemble into service specification dict
260
+ spec = {"spec": spec_dict}
204
261
 
205
262
  return spec
206
263
 
@@ -248,7 +305,10 @@ def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
248
305
 
249
306
 
250
307
  def _merge_lists_of_dicts(
251
- base: List[Dict[str, Any]], patch: List[Dict[str, Any]], merge_key: str = "name", display_name: str = ""
308
+ base: List[Dict[str, Any]],
309
+ patch: List[Dict[str, Any]],
310
+ merge_key: str = "name",
311
+ display_name: str = "",
252
312
  ) -> List[Dict[str, Any]]:
253
313
  """
254
314
  Attempts to merge lists of dicts by matching on a merge key (default "name").
@@ -288,7 +348,11 @@ def _merge_lists_of_dicts(
288
348
 
289
349
  # Apply patch
290
350
  if key in result:
291
- d = merge_patch(result[key], d, display_name=f"{display_name}[{merge_key}={d[merge_key]}]")
351
+ d = merge_patch(
352
+ result[key],
353
+ d,
354
+ display_name=f"{display_name}[{merge_key}={d[merge_key]}]",
355
+ )
292
356
  # TODO: Should we drop the item if the patch result is empty save for the merge key?
293
357
  # Can check `d.keys() <= {merge_key}`
294
358
  result[key] = d
@@ -19,11 +19,14 @@ _ReturnValue = TypeVar("_ReturnValue")
19
19
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
20
20
  def remote(
21
21
  compute_pool: str,
22
+ *,
22
23
  stage_name: str,
23
24
  pip_requirements: Optional[List[str]] = None,
24
25
  external_access_integrations: Optional[List[str]] = None,
25
26
  query_warehouse: Optional[str] = None,
26
27
  env_vars: Optional[Dict[str, str]] = None,
28
+ num_instances: Optional[int] = None,
29
+ enable_metrics: bool = False,
27
30
  session: Optional[snowpark.Session] = None,
28
31
  ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob]]:
29
32
  """
@@ -36,6 +39,8 @@ def remote(
36
39
  external_access_integrations: A list of external access integrations.
37
40
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
38
41
  env_vars: Environment variables to set in container
42
+ num_instances: The number of nodes in the job. If none specified, create a single node job.
43
+ enable_metrics: Whether to enable metrics publishing for the job.
39
44
  session: The Snowpark session to use. If none specified, uses active session.
40
45
 
41
46
  Returns:
@@ -61,6 +66,8 @@ def remote(
61
66
  external_access_integrations=external_access_integrations,
62
67
  query_warehouse=query_warehouse,
63
68
  env_vars=env_vars,
69
+ num_instances=num_instances,
70
+ enable_metrics=enable_metrics,
64
71
  session=session,
65
72
  )
66
73
  assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"