snowflake-ml-python 1.7.4__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. snowflake/cortex/_complete.py +58 -3
  2. snowflake/ml/_internal/env_utils.py +64 -21
  3. snowflake/ml/_internal/file_utils.py +18 -4
  4. snowflake/ml/_internal/platform_capabilities.py +3 -0
  5. snowflake/ml/_internal/relax_version_strategy.py +16 -0
  6. snowflake/ml/_internal/telemetry.py +25 -0
  7. snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +18 -0
  9. snowflake/ml/feature_store/feature_view.py +46 -1
  10. snowflake/ml/fileset/fileset.py +0 -1
  11. snowflake/ml/jobs/_utils/constants.py +31 -1
  12. snowflake/ml/jobs/_utils/payload_utils.py +232 -72
  13. snowflake/ml/jobs/_utils/spec_utils.py +78 -38
  14. snowflake/ml/jobs/decorators.py +8 -25
  15. snowflake/ml/jobs/job.py +4 -4
  16. snowflake/ml/jobs/manager.py +5 -0
  17. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  18. snowflake/ml/model/_client/ops/model_ops.py +107 -14
  19. snowflake/ml/model/_client/ops/service_ops.py +1 -1
  20. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
  21. snowflake/ml/model/_client/sql/model_version.py +58 -0
  22. snowflake/ml/model/_client/sql/service.py +8 -2
  23. snowflake/ml/model/_model_composer/model_composer.py +50 -3
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  26. snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
  27. snowflake/ml/model/_packager/model_env/model_env.py +49 -29
  28. snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
  29. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +44 -24
  30. snowflake/ml/model/_packager/model_handlers/keras.py +226 -0
  31. snowflake/ml/model/_packager/model_handlers/pytorch.py +51 -20
  32. snowflake/ml/model/_packager/model_handlers/sklearn.py +25 -3
  33. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +73 -21
  34. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -72
  35. snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
  36. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  37. snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
  38. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
  39. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
  40. snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
  41. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  42. snowflake/ml/model/_packager/model_meta/model_meta.py +6 -2
  43. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +16 -0
  44. snowflake/ml/model/_packager/model_packager.py +3 -5
  45. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +8 -1
  47. snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
  48. snowflake/ml/model/_signatures/builtins_handler.py +20 -9
  49. snowflake/ml/model/_signatures/core.py +54 -33
  50. snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
  51. snowflake/ml/model/_signatures/numpy_handler.py +12 -20
  52. snowflake/ml/model/_signatures/pandas_handler.py +28 -37
  53. snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
  54. snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
  55. snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
  56. snowflake/ml/model/_signatures/utils.py +120 -8
  57. snowflake/ml/model/custom_model.py +13 -4
  58. snowflake/ml/model/model_signature.py +39 -13
  59. snowflake/ml/model/type_hints.py +28 -2
  60. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
  61. snowflake/ml/modeling/metrics/ranking.py +3 -0
  62. snowflake/ml/modeling/metrics/regression.py +3 -0
  63. snowflake/ml/modeling/pipeline/pipeline.py +18 -1
  64. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
  65. snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
  66. snowflake/ml/registry/_manager/model_manager.py +55 -7
  67. snowflake/ml/registry/registry.py +52 -4
  68. snowflake/ml/version.py +1 -1
  69. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +336 -27
  70. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +73 -66
  71. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
  72. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
  73. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -4,21 +4,51 @@ from snowflake.ml.jobs._utils.types import ComputeResources
4
4
  # SPCS specification constants
5
5
  DEFAULT_CONTAINER_NAME = "main"
6
6
  PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
7
+ MEMORY_VOLUME_NAME = "dshm"
8
+ STAGE_VOLUME_NAME = "stage-volume"
9
+ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
7
10
 
8
11
  # Default container image information
9
12
  DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
10
13
  DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
11
14
  DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
12
- DEFAULT_IMAGE_TAG = "0.8.0"
15
+ DEFAULT_IMAGE_TAG = "0.9.2"
13
16
  DEFAULT_ENTRYPOINT_PATH = "func.py"
14
17
 
15
18
  # Percent of container memory to allocate for /dev/shm volume
16
19
  MEMORY_VOLUME_SIZE = 0.3
17
20
 
21
+ # Multi Node Headless prototype constants
22
+ # TODO: Replace this placeholder with the actual container runtime image tag.
23
+ MULTINODE_HEADLESS_IMAGE_TAG = "latest"
24
+
25
+ # Ray port configuration
26
+ RAY_PORTS = {
27
+ "HEAD_CLIENT_SERVER_PORT": "10001",
28
+ "HEAD_GCS_PORT": "12001",
29
+ "HEAD_DASHBOARD_GRPC_PORT": "12002",
30
+ "HEAD_DASHBOARD_PORT": "12003",
31
+ "OBJECT_MANAGER_PORT": "12011",
32
+ "NODE_MANAGER_PORT": "12012",
33
+ "RUNTIME_ENV_AGENT_PORT": "12013",
34
+ "DASHBOARD_AGENT_GRPC_PORT": "12014",
35
+ "DASHBOARD_AGENT_LISTEN_PORT": "12015",
36
+ "MIN_WORKER_PORT": "12031",
37
+ "MAX_WORKER_PORT": "13000",
38
+ }
39
+
40
+ # Node health check configuration
41
+ # TODO(SNOW-1937020): Revisit the health check configuration
42
+ ML_RUNTIME_HEALTH_CHECK_PORT = "5001"
43
+ ENABLE_HEALTH_CHECKS = "false"
44
+
18
45
  # Job status polling constants
19
46
  JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
20
47
  JOB_POLL_MAX_DELAY_SECONDS = 1
21
48
 
49
+ # Magic attributes
50
+ IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
51
+
22
52
  # Compute pool resource information
23
53
  # TODO: Query Snowflake for resource information instead of relying on this hardcoded
24
54
  # table from https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool
@@ -1,5 +1,8 @@
1
+ import functools
1
2
  import inspect
2
3
  import io
4
+ import itertools
5
+ import pickle
3
6
  import sys
4
7
  import textwrap
5
8
  from pathlib import Path, PurePath
@@ -19,9 +22,11 @@ import cloudpickle as cp
19
22
 
20
23
  from snowflake import snowpark
21
24
  from snowflake.ml.jobs._utils import constants, types
25
+ from snowflake.snowpark import exceptions as sp_exceptions
22
26
  from snowflake.snowpark._internal import code_generation
23
27
 
24
28
  _SUPPORTED_ARG_TYPES = {str, int, float}
29
+ _SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
25
30
  _STARTUP_SCRIPT_PATH = PurePath("startup.sh")
26
31
  _STARTUP_SCRIPT_CODE = textwrap.dedent(
27
32
  f"""
@@ -68,16 +73,56 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
68
73
  ##### Ray configuration #####
69
74
  shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
70
75
 
76
+ # Check if the instance ip retrieval module exists, which is a prerequisite for multi node jobs
77
+ HELPER_EXISTS=$(
78
+ python3 -c "import snowflake.runtime.utils.get_instance_ip" 2>/dev/null && echo "true" || echo "false"
79
+ )
80
+
71
81
  # Configure IP address and logging directory
72
- eth0Ip=$(ifconfig eth0 | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
82
+ if [ "$HELPER_EXISTS" = "true" ]; then
83
+ eth0Ip=$(python3 -m snowflake.runtime.utils.get_instance_ip "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
84
+ else
85
+ eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
86
+ fi
73
87
  log_dir="/tmp/ray"
74
88
 
75
- # Check if eth0Ip is empty and set default if necessary
76
- if [ -z "$eth0Ip" ]; then
77
- # This should never happen, but just in case ethOIp is not set, we should default to localhost
89
+ # Check if eth0Ip is a valid IP address and fall back to default if necessary
90
+ if [[ ! $eth0Ip =~ ^[0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+$ ]]; then
78
91
  eth0Ip="127.0.0.1"
79
92
  fi
80
93
 
94
+ # Get the environment values of SNOWFLAKE_JOBS_COUNT and SNOWFLAKE_JOB_INDEX for batch jobs
95
+ # These variables don't exist for non-batch jobs, so set defaults
96
+ if [ -z "$SNOWFLAKE_JOBS_COUNT" ]; then
97
+ SNOWFLAKE_JOBS_COUNT=1
98
+ fi
99
+
100
+ if [ -z "$SNOWFLAKE_JOB_INDEX" ]; then
101
+ SNOWFLAKE_JOB_INDEX=0
102
+ fi
103
+
104
+ # Determine if it should be a worker or a head node for batch jobs
105
+ if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 && "$HELPER_EXISTS" = "true" ]]; then
106
+ head_info=$(python3 -m snowflake.runtime.utils.get_instance_ip "$SNOWFLAKE_SERVICE_NAME" --head)
107
+ if [ $? -eq 0 ]; then
108
+ # Parse the output using read
109
+ read head_index head_ip <<< "$head_info"
110
+
111
+ # Use the parsed variables
112
+ echo "Head Instance Index: $head_index"
113
+ echo "Head Instance IP: $head_ip"
114
+
115
+ else
116
+ echo "Error: Failed to get head instance information."
117
+ echo "$head_info" # Print the error message
118
+ exit 1
119
+ fi
120
+
121
+ if [ "$SNOWFLAKE_JOB_INDEX" -ne "$head_index" ]; then
122
+ NODE_TYPE="worker"
123
+ fi
124
+ fi
125
+
81
126
  # Common parameters for both head and worker nodes
82
127
  common_params=(
83
128
  "--node-ip-address=$eth0Ip"
@@ -93,33 +138,94 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
93
138
  "--disable-usage-stats"
94
139
  )
95
140
 
96
- # Additional head-specific parameters
97
- head_params=(
98
- "--head"
99
- "--port=${{RAY_HEAD_GCS_PORT:-12001}}" # Port of Ray (GCS server)
100
- "--ray-client-server-port=${{RAY_HEAD_CLIENT_SERVER_PORT:-10001}}" # Listening port for Ray Client Server
101
- "--dashboard-host=${{NODE_IP_ADDRESS}}" # Host to bind the dashboard server
102
- "--dashboard-grpc-port=${{RAY_HEAD_DASHBOARD_GRPC_PORT:-12002}}" # Dashboard head to listen for grpc on
103
- "--dashboard-port=${{DASHBOARD_PORT}}" # Port to bind the dashboard server for local debugging
104
- "--resources={{\\"node_tag:head\\":1}}" # Resource tag for selecting head as coordinator
105
- )
141
+ if [ "$NODE_TYPE" = "worker" ]; then
142
+ # Use head_ip as head address if it exists
143
+ if [ ! -z "$head_ip" ]; then
144
+ RAY_HEAD_ADDRESS="$head_ip"
145
+ fi
146
+
147
+ # If RAY_HEAD_ADDRESS is still empty, exit with an error
148
+ if [ -z "$RAY_HEAD_ADDRESS" ]; then
149
+ echo "Error: Failed to determine head node address using default instance-index=0"
150
+ exit 1
151
+ fi
152
+
153
+ if [ -z "$SERVICE_NAME" ]; then
154
+ SERVICE_NAME="$SNOWFLAKE_SERVICE_NAME"
155
+ fi
156
+
157
+ if [ -z "$RAY_HEAD_ADDRESS" ] || [ -z "$SERVICE_NAME" ]; then
158
+ echo "Error: RAY_HEAD_ADDRESS and SERVICE_NAME must be set."
159
+ exit 1
160
+ fi
161
+
162
+ # Additional worker-specific parameters
163
+ worker_params=(
164
+ "--address=${{RAY_HEAD_ADDRESS}}:12001" # Connect to head node
165
+ "--resources={{\\"${{SERVICE_NAME}}\\":1, \\"node_tag:worker\\":1}}" # Tag for node identification
166
+ "--object-store-memory=${{shm_size}}"
167
+ )
106
168
 
107
- # Start Ray on the head node
108
- ray start "${{common_params[@]}}" "${{head_params[@]}}" &
109
- ##### End Ray configuration #####
169
+ # Start Ray on a worker node
170
+ ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block
171
+ else
172
+
173
+ # Additional head-specific parameters
174
+ head_params=(
175
+ "--head"
176
+ "--port=${{RAY_HEAD_GCS_PORT:-12001}}" # Port of Ray (GCS server)
177
+ "--ray-client-server-port=${{RAY_HEAD_CLIENT_SERVER_PORT:-10001}}" # Rort for Ray Client Server
178
+ "--dashboard-host=${{NODE_IP_ADDRESS}}" # Host to bind the dashboard server
179
+ "--dashboard-grpc-port=${{RAY_HEAD_DASHBOARD_GRPC_PORT:-12002}}" # Dashboard head to listen for grpc
180
+ "--dashboard-port=${{DASHBOARD_PORT}}" # Port to bind the dashboard server for debugging
181
+ "--resources={{\\"node_tag:head\\":1}}" # Resource tag for selecting head as coordinator
182
+ )
183
+
184
+ # Start Ray on the head node
185
+ ray start "${{common_params[@]}}" "${{head_params[@]}}" -v
186
+ ##### End Ray configuration #####
110
187
 
111
- # TODO: Monitor MLRS and handle process crashes
112
- python -m web.ml_runtime_grpc_server &
188
+ # TODO: Monitor MLRS and handle process crashes
189
+ python -m web.ml_runtime_grpc_server &
113
190
 
114
- # TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
191
+ # TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
115
192
 
116
- # Run user's Python entrypoint
117
- echo Running command: python "$@"
118
- python "$@"
193
+ # Run user's Python entrypoint
194
+ echo Running command: python "$@"
195
+ python "$@"
196
+ fi
119
197
  """
120
198
  ).strip()
121
199
 
122
200
 
201
+ def _resolve_entrypoint(parent: Path, entrypoint: Optional[Path]) -> Path:
202
+ parent = parent.absolute()
203
+ if entrypoint is None:
204
+ if parent.is_file():
205
+ # Infer entrypoint from source
206
+ entrypoint = parent
207
+ else:
208
+ raise ValueError("entrypoint must be provided when source is a directory")
209
+ elif entrypoint.is_absolute():
210
+ # Absolute path - validate it's a subpath of source dir
211
+ if not entrypoint.is_relative_to(parent):
212
+ raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint})")
213
+ else:
214
+ # Relative path
215
+ if (abs_entrypoint := entrypoint.absolute()).is_relative_to(parent) and abs_entrypoint.is_file():
216
+ # Relative to working dir iff path is relative to source dir and exists
217
+ entrypoint = abs_entrypoint
218
+ else:
219
+ # Relative to source dir
220
+ entrypoint = parent.joinpath(entrypoint)
221
+ if not entrypoint.is_file():
222
+ raise FileNotFoundError(
223
+ "Entrypoint not found. Ensure the entrypoint is a valid file and is under"
224
+ f" the source directory (source={parent}, entrypoint={entrypoint})"
225
+ )
226
+ return entrypoint
227
+
228
+
123
229
  class JobPayload:
124
230
  def __init__(
125
231
  self,
@@ -138,23 +244,23 @@ class JobPayload:
138
244
  # since we will generate the file from the serialized callable
139
245
  pass
140
246
  elif isinstance(self.source, Path):
141
- # Validate self.source and self.entrypoint for files
142
- if not self.source.exists():
143
- raise FileNotFoundError(f"{self.source} does not exist")
144
- if self.entrypoint is None:
145
- if self.source.is_file():
146
- self.entrypoint = self.source
147
- else:
148
- raise ValueError("entrypoint must be provided when source is a directory")
149
- if not self.entrypoint.is_file():
150
- # Check if self.entrypoint is a valid relative path
151
- self.entrypoint = self.source.joinpath(self.entrypoint)
152
- if not self.entrypoint.is_file():
153
- raise FileNotFoundError(f"File {self.entrypoint} does not exist")
154
- if not self.entrypoint.is_relative_to(self.source):
155
- raise ValueError(f"{self.entrypoint} must be a subpath of {self.source}")
156
- if self.entrypoint.suffix != ".py":
157
- raise NotImplementedError("Only Python entrypoints are supported currently")
247
+ # Validate source
248
+ source = self.source
249
+ if not source.exists():
250
+ raise FileNotFoundError(f"{source} does not exist")
251
+ source = source.absolute()
252
+
253
+ # Validate entrypoint
254
+ entrypoint = _resolve_entrypoint(source, self.entrypoint)
255
+ if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
256
+ raise ValueError(
257
+ "Unsupported entrypoint type:"
258
+ f" supported={','.join(_SUPPORTED_ENTRYPOINT_EXTENSIONS)} got={entrypoint.suffix}"
259
+ )
260
+
261
+ # Update fields with normalized values
262
+ self.source = source
263
+ self.entrypoint = entrypoint
158
264
  else:
159
265
  raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
160
266
 
@@ -168,12 +274,16 @@ class JobPayload:
168
274
  entrypoint = self.entrypoint or Path(constants.DEFAULT_ENTRYPOINT_PATH)
169
275
 
170
276
  # Create stage if necessary
171
- stage_name = stage_path.parts[0]
172
- session.sql(
173
- f"create stage if not exists {stage_name.lstrip('@')}"
174
- " encryption = ( type = 'SNOWFLAKE_SSE' )"
175
- " comment = 'Created by snowflake.ml.jobs Python API'"
176
- ).collect()
277
+ stage_name = stage_path.parts[0].lstrip("@")
278
+ # Explicitly check if stage exists first since we may not have CREATE STAGE privilege
279
+ try:
280
+ session.sql(f"describe stage {stage_name}").collect()
281
+ except sp_exceptions.SnowparkSQLException:
282
+ session.sql(
283
+ f"create stage if not exists {stage_name}"
284
+ " encryption = ( type = 'SNOWFLAKE_SSE' )"
285
+ " comment = 'Created by snowflake.ml.jobs Python API'"
286
+ ).collect()
177
287
 
178
288
  # Upload payload to stage
179
289
  if not isinstance(source, Path):
@@ -237,7 +347,7 @@ class JobPayload:
237
347
  )
238
348
 
239
349
 
240
- def get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
350
+ def _get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
241
351
  # Unwrap Optional type annotations
242
352
  param_type = param.annotation
243
353
  if get_origin(param_type) is Union and len(get_args(param_type)) == 2 and type(None) in get_args(param_type):
@@ -249,7 +359,7 @@ def get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
249
359
  return cast(Type[object], param_type)
250
360
 
251
361
 
252
- def validate_parameter_type(param_type: Type[object], param_name: str) -> None:
362
+ def _validate_parameter_type(param_type: Type[object], param_name: str) -> None:
253
363
  # Validate param_type is a supported type
254
364
  if param_type not in _SUPPORTED_ARG_TYPES:
255
365
  raise ValueError(
@@ -258,41 +368,60 @@ def validate_parameter_type(param_type: Type[object], param_name: str) -> None:
258
368
  )
259
369
 
260
370
 
261
- def generate_python_code(func: Callable[..., Any], source_code_display: bool = False) -> str:
262
- signature = inspect.signature(func)
263
- if any(
264
- p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
265
- for p in signature.parameters.values()
266
- ):
267
- raise NotImplementedError("Function must not have unpacking arguments (* or **)")
268
-
269
- # Mirrored from Snowpark generate_python_code() function
270
- # https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
371
+ def _generate_source_code_comment(func: Callable[..., Any]) -> str:
372
+ """Generate a comment string containing the source code of a function for readability."""
271
373
  try:
272
- source_code_comment = (
273
- code_generation.generate_source_code(func) if source_code_display else "" # type: ignore[arg-type]
274
- )
374
+ if isinstance(func, functools.partial):
375
+ # Unwrap functools.partial and generate source code comment from the original function
376
+ comment = code_generation.generate_source_code(func.func) # type: ignore[arg-type]
377
+ args = itertools.chain((repr(a) for a in func.args), (f"{k}={v!r}" for k, v in func.keywords.items()))
378
+
379
+ # Update invocation comment to show arguments passed via functools.partial
380
+ comment = comment.replace(
381
+ f"= {func.func.__name__}",
382
+ "= functools.partial({}({}))".format(
383
+ func.func.__name__,
384
+ ", ".join(args),
385
+ ),
386
+ )
387
+ return comment
388
+ else:
389
+ return code_generation.generate_source_code(func) # type: ignore[arg-type]
275
390
  except Exception as exc:
276
391
  error_msg = f"Source code comment could not be generated for {func} due to error {exc}."
277
- source_code_comment = code_generation.comment_source_code(error_msg)
278
-
279
- func_name = "func"
280
- func_code = f"""
281
- {source_code_comment}
392
+ return code_generation.comment_source_code(error_msg)
282
393
 
283
- import pickle
284
- {func_name} = pickle.loads(bytes.fromhex('{cp.dumps(func).hex()}'))
285
- """
286
394
 
395
+ def _serialize_callable(func: Callable[..., Any]) -> bytes:
396
+ try:
397
+ func_bytes: bytes = cp.dumps(func)
398
+ return func_bytes
399
+ except pickle.PicklingError as e:
400
+ if isinstance(func, functools.partial):
401
+ # Try to find which part of the partial isn't serializable for better debuggability
402
+ objects = [
403
+ ("function", func.func),
404
+ *((f"positional arg {i}", a) for i, a in enumerate(func.args)),
405
+ *((f"keyword arg '{k}'", v) for k, v in func.keywords.items()),
406
+ ]
407
+ for name, obj in objects:
408
+ try:
409
+ cp.dumps(obj)
410
+ except pickle.PicklingError:
411
+ raise ValueError(f"Unable to serialize {name}: {obj}") from e
412
+ raise ValueError(f"Unable to serialize function: {func}") from e
413
+
414
+
415
+ def _generate_param_handler_code(signature: inspect.Signature, output_name: str = "kwargs") -> str:
287
416
  # Generate argparse logic for argument handling (type coercion, default values, etc)
288
417
  argparse_code = ["import argparse", "", "parser = argparse.ArgumentParser()"]
289
418
  argparse_postproc = []
290
419
  for name, param in signature.parameters.items():
291
420
  opts = {}
292
421
 
293
- param_type = get_parameter_type(param)
422
+ param_type = _get_parameter_type(param)
294
423
  if param_type is not None:
295
- validate_parameter_type(param_type, name)
424
+ _validate_parameter_type(param_type, name)
296
425
  opts["type"] = param_type.__name__
297
426
 
298
427
  if param.default != inspect.Parameter.empty:
@@ -324,6 +453,37 @@ import pickle
324
453
  )
325
454
  argparse_code.append("args = parser.parse_args()")
326
455
  param_code = "\n".join(argparse_code + argparse_postproc)
456
+ param_code += f"\n{output_name} = vars(args)"
457
+
458
+ return param_code
459
+
460
+
461
+ def generate_python_code(func: Callable[..., Any], source_code_display: bool = False) -> str:
462
+ """Generate an entrypoint script from a Python function."""
463
+ signature = inspect.signature(func)
464
+ if any(
465
+ p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
466
+ for p in signature.parameters.values()
467
+ ):
468
+ raise NotImplementedError("Function must not have unpacking arguments (* or **)")
469
+
470
+ # Mirrored from Snowpark generate_python_code() function
471
+ # https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
472
+ source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
473
+
474
+ func_name = "func"
475
+ func_code = f"""
476
+ {source_code_comment}
477
+
478
+ import pickle
479
+ {func_name} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
480
+ """
481
+
482
+ arg_dict_name = "kwargs"
483
+ if getattr(func, constants.IS_MLJOB_REMOTE_ATTR, None):
484
+ param_code = f"{arg_dict_name} = {{}}"
485
+ else:
486
+ param_code = _generate_param_handler_code(signature, arg_dict_name)
327
487
 
328
488
  return f"""
329
489
  ### Version guard to check compatibility across Python versions ###
@@ -348,5 +508,5 @@ if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor
348
508
  if __name__ == '__main__':
349
509
  {textwrap.indent(param_code, ' ')}
350
510
 
351
- {func_name}(**vars(args))
511
+ {func_name}(**{arg_dict_name})
352
512
  """
@@ -26,19 +26,22 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
26
26
  )
27
27
 
28
28
 
29
- def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
29
+ def _get_image_spec(session: snowpark.Session, compute_pool: str, image_tag: Optional[str] = None) -> types.ImageSpec:
30
30
  # Retrieve compute pool node resources
31
31
  resources = _get_node_resources(session, compute_pool=compute_pool)
32
32
 
33
33
  # Use MLRuntime image
34
34
  image_repo = constants.DEFAULT_IMAGE_REPO
35
35
  image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
36
- image_tag = constants.DEFAULT_IMAGE_TAG
37
36
 
38
37
  # Try to pull latest image tag from server side if possible
39
- query_result = session.sql("SHOW PARAMETERS LIKE 'constants.RUNTIME_BASE_IMAGE_TAG' IN ACCOUNT").collect()
40
- if query_result:
41
- image_tag = query_result[0]["value"]
38
+ if not image_tag:
39
+ query_result = session.sql("SHOW PARAMETERS LIKE 'constants.RUNTIME_BASE_IMAGE_TAG' IN ACCOUNT").collect()
40
+ if query_result:
41
+ image_tag = query_result[0]["value"]
42
+
43
+ if image_tag is None:
44
+ image_tag = constants.DEFAULT_IMAGE_TAG
42
45
 
43
46
  # TODO: Should each instance consume the entire pod?
44
47
  return types.ImageSpec(
@@ -93,6 +96,7 @@ def generate_service_spec(
93
96
  compute_pool: str,
94
97
  payload: types.UploadedPayload,
95
98
  args: Optional[List[str]] = None,
99
+ num_instances: Optional[int] = None,
96
100
  ) -> Dict[str, Any]:
97
101
  """
98
102
  Generate a service specification for a job.
@@ -102,12 +106,21 @@ def generate_service_spec(
102
106
  compute_pool: Compute pool for job execution
103
107
  payload: Uploaded job payload
104
108
  args: Arguments to pass to entrypoint script
109
+ num_instances: Number of instances for multi-node job
105
110
 
106
111
  Returns:
107
112
  Job service specification
108
113
  """
114
+ is_multi_node = num_instances is not None and num_instances > 1
115
+
109
116
  # Set resource requests/limits, including nvidia.com/gpu quantity if applicable
110
- image_spec = _get_image_spec(session, compute_pool)
117
+ if is_multi_node:
118
+ # If the job is of multi-node, we will need a different image which contains
119
+ # module snowflake.runtime.utils.get_instance_ip
120
+ # TODO(SNOW-1961849): Remove the hard-coded image name
121
+ image_spec = _get_image_spec(session, compute_pool, constants.MULTINODE_HEADLESS_IMAGE_TAG)
122
+ else:
123
+ image_spec = _get_image_spec(session, compute_pool)
111
124
  resource_requests: Dict[str, Union[str, int]] = {
112
125
  "cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
113
126
  "memory": f"{image_spec.resource_limits.memory}Gi",
@@ -141,68 +154,88 @@ def generate_service_spec(
141
154
  )
142
155
 
143
156
  # Mount 30% of memory limit as a memory-backed volume
144
- memory_volume_name = "dshm"
145
157
  memory_volume_size = min(
146
158
  ceil(image_spec.resource_limits.memory * constants.MEMORY_VOLUME_SIZE),
147
159
  image_spec.resource_requests.memory,
148
160
  )
149
161
  volume_mounts.append(
150
162
  {
151
- "name": memory_volume_name,
163
+ "name": constants.MEMORY_VOLUME_NAME,
152
164
  "mountPath": "/dev/shm",
153
165
  }
154
166
  )
155
167
  volumes.append(
156
168
  {
157
- "name": memory_volume_name,
169
+ "name": constants.MEMORY_VOLUME_NAME,
158
170
  "source": "memory",
159
171
  "size": f"{memory_volume_size}Gi",
160
172
  }
161
173
  )
162
174
 
163
175
  # Mount payload as volume
164
- stage_mount = PurePath("/opt/app")
165
- stage_volume_name = "stage-volume"
176
+ stage_mount = PurePath(constants.STAGE_VOLUME_MOUNT_PATH)
166
177
  volume_mounts.append(
167
178
  {
168
- "name": stage_volume_name,
179
+ "name": constants.STAGE_VOLUME_NAME,
169
180
  "mountPath": stage_mount.as_posix(),
170
181
  }
171
182
  )
172
183
  volumes.append(
173
184
  {
174
- "name": stage_volume_name,
185
+ "name": constants.STAGE_VOLUME_NAME,
175
186
  "source": payload.stage_path.as_posix(),
176
187
  }
177
188
  )
178
189
 
179
190
  # TODO: Add hooks for endpoints for integration with TensorBoard etc
180
191
 
181
- # Assemble into service specification dict
182
- spec = {
183
- "spec": {
184
- "containers": [
185
- {
186
- "name": constants.DEFAULT_CONTAINER_NAME,
187
- "image": image_spec.full_name,
188
- "command": ["/usr/local/bin/_entrypoint.sh"],
189
- "args": [
190
- stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v for v in payload.entrypoint
191
- ]
192
- + (args or []),
193
- "env": {
194
- constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix(),
195
- },
196
- "volumeMounts": volume_mounts,
197
- "resources": {
198
- "requests": resource_requests,
199
- "limits": resource_limits,
200
- },
192
+ env_vars = {constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix()}
193
+ endpoints = []
194
+
195
+ if is_multi_node:
196
+ # Update environment variables for multi-node job
197
+ env_vars.update(constants.RAY_PORTS)
198
+ env_vars["ENABLE_HEALTH_CHECKS"] = constants.ENABLE_HEALTH_CHECKS
199
+
200
+ # Define Ray endpoints for intra-service instance communication
201
+ ray_endpoints = [
202
+ {"name": "ray-client-server-endpoint", "port": 10001, "protocol": "TCP"},
203
+ {"name": "ray-gcs-endpoint", "port": 12001, "protocol": "TCP"},
204
+ {"name": "ray-dashboard-grpc-endpoint", "port": 12002, "protocol": "TCP"},
205
+ {"name": "ray-object-manager-endpoint", "port": 12011, "protocol": "TCP"},
206
+ {"name": "ray-node-manager-endpoint", "port": 12012, "protocol": "TCP"},
207
+ {"name": "ray-runtime-agent-endpoint", "port": 12013, "protocol": "TCP"},
208
+ {"name": "ray-dashboard-agent-grpc-endpoint", "port": 12014, "protocol": "TCP"},
209
+ {"name": "ephemeral-port-range", "portRange": "32768-60999", "protocol": "TCP"},
210
+ {"name": "ray-worker-port-range", "portRange": "12031-13000", "protocol": "TCP"},
211
+ ]
212
+ endpoints.extend(ray_endpoints)
213
+
214
+ spec_dict = {
215
+ "containers": [
216
+ {
217
+ "name": constants.DEFAULT_CONTAINER_NAME,
218
+ "image": image_spec.full_name,
219
+ "command": ["/usr/local/bin/_entrypoint.sh"],
220
+ "args": [
221
+ (stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
222
+ ]
223
+ + (args or []),
224
+ "env": env_vars,
225
+ "volumeMounts": volume_mounts,
226
+ "resources": {
227
+ "requests": resource_requests,
228
+ "limits": resource_limits,
201
229
  },
202
- ],
203
- "volumes": volumes,
204
- }
230
+ },
231
+ ],
232
+ "volumes": volumes,
205
233
  }
234
+ if endpoints:
235
+ spec_dict["endpoints"] = endpoints
236
+
237
+ # Assemble into service specification dict
238
+ spec = {"spec": spec_dict}
206
239
 
207
240
  return spec
208
241
 
@@ -250,7 +283,10 @@ def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
250
283
 
251
284
 
252
285
  def _merge_lists_of_dicts(
253
- base: List[Dict[str, Any]], patch: List[Dict[str, Any]], merge_key: str = "name", display_name: str = ""
286
+ base: List[Dict[str, Any]],
287
+ patch: List[Dict[str, Any]],
288
+ merge_key: str = "name",
289
+ display_name: str = "",
254
290
  ) -> List[Dict[str, Any]]:
255
291
  """
256
292
  Attempts to merge lists of dicts by matching on a merge key (default "name").
@@ -290,7 +326,11 @@ def _merge_lists_of_dicts(
290
326
 
291
327
  # Apply patch
292
328
  if key in result:
293
- d = merge_patch(result[key], d, display_name=f"{display_name}[{merge_key}={d[merge_key]}]")
329
+ d = merge_patch(
330
+ result[key],
331
+ d,
332
+ display_name=f"{display_name}[{merge_key}={d[merge_key]}]",
333
+ )
294
334
  # TODO: Should we drop the item if the patch result is empty save for the merge key?
295
335
  # Can check `d.keys() <= {merge_key}`
296
336
  result[key] = d