snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. snowflake/cortex/_complete.py +44 -10
  2. snowflake/ml/_internal/platform_capabilities.py +39 -3
  3. snowflake/ml/data/data_connector.py +25 -0
  4. snowflake/ml/dataset/dataset_reader.py +5 -1
  5. snowflake/ml/jobs/_utils/constants.py +2 -4
  6. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  7. snowflake/ml/jobs/_utils/payload_utils.py +81 -47
  8. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  9. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  10. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +178 -0
  11. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  12. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  13. snowflake/ml/jobs/_utils/spec_utils.py +5 -8
  14. snowflake/ml/jobs/_utils/types.py +6 -0
  15. snowflake/ml/jobs/decorators.py +3 -3
  16. snowflake/ml/jobs/job.py +145 -23
  17. snowflake/ml/jobs/manager.py +62 -10
  18. snowflake/ml/model/_client/ops/service_ops.py +42 -35
  19. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -4
  20. snowflake/ml/model/_client/sql/service.py +9 -5
  21. snowflake/ml/model/_model_composer/model_composer.py +29 -11
  22. snowflake/ml/model/_packager/model_env/model_env.py +8 -2
  23. snowflake/ml/model/_packager/model_meta/model_meta.py +6 -1
  24. snowflake/ml/model/_packager/model_packager.py +2 -0
  25. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  26. snowflake/ml/model/type_hints.py +2 -0
  27. snowflake/ml/registry/_manager/model_manager.py +20 -1
  28. snowflake/ml/registry/registry.py +5 -1
  29. snowflake/ml/version.py +1 -1
  30. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/METADATA +35 -4
  31. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/RECORD +34 -28
  32. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/WHEEL +0 -0
  33. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/licenses/LICENSE.txt +0 -0
  34. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,442 @@
1
+ import builtins
2
+ import functools
3
+ import importlib
4
+ import json
5
+ import os
6
+ import pickle
7
+ import re
8
+ import sys
9
+ import traceback
10
+ from collections import namedtuple
11
+ from dataclasses import dataclass
12
+ from types import TracebackType
13
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, cast
14
+
15
+ from snowflake import snowpark
16
+ from snowflake.snowpark import exceptions as sp_exceptions
17
+
18
+ _TRACEBACK_ENTRY_PATTERN = re.compile(
19
+ r'File "(?P<filename>[^"]+)", line (?P<lineno>\d+), in (?P<name>[^\n]+)(?:\n(?!^\s*File)^\s*(?P<line>[^\n]+))?\n',
20
+ flags=re.MULTILINE,
21
+ )
22
+ _REMOTE_ERROR_ATTR_NAME = "_remote_error"
23
+
24
+ RemoteError = namedtuple("RemoteError", ["exc_type", "exc_msg", "exc_tb"])
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class ExecutionResult:
29
+ result: Any = None
30
+ exception: Optional[BaseException] = None
31
+
32
+ @property
33
+ def success(self) -> bool:
34
+ return self.exception is None
35
+
36
+ def to_dict(self) -> Dict[str, Any]:
37
+ """Return the serializable dictionary."""
38
+ if isinstance(self.exception, BaseException):
39
+ exc_type = type(self.exception)
40
+ return {
41
+ "success": False,
42
+ "exc_type": f"{exc_type.__module__}.{exc_type.__name__}",
43
+ "exc_value": self.exception,
44
+ "exc_tb": "".join(traceback.format_tb(self.exception.__traceback__)),
45
+ }
46
+ return {
47
+ "success": True,
48
+ "result_type": type(self.result).__qualname__,
49
+ "result": self.result,
50
+ }
51
+
52
+ @classmethod
53
+ def from_dict(cls, result_dict: Dict[str, Any]) -> "ExecutionResult":
54
+ if not isinstance(result_dict.get("success"), bool):
55
+ raise ValueError("Invalid result dictionary")
56
+
57
+ if result_dict["success"]:
58
+ # Load successful result
59
+ return cls(result=result_dict.get("result"))
60
+
61
+ # Load exception
62
+ exc_type = result_dict.get("exc_type", "RuntimeError")
63
+ exc_value = result_dict.get("exc_value", "Unknown error")
64
+ exc_tb = result_dict.get("exc_tb", "")
65
+ return cls(exception=load_exception(exc_type, exc_value, exc_tb))
66
+
67
+
68
+ def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult:
69
+ """
70
+ Fetch the serialized result from the specified path.
71
+
72
+ Args:
73
+ session: Snowpark Session to use for file operations.
74
+ result_path: The path to the serialized result file.
75
+
76
+ Returns:
77
+ A dictionary containing the execution result if available, None otherwise.
78
+ """
79
+ try:
80
+ # TODO: Check if file exists
81
+ with session.file.get_stream(result_path) as result_stream:
82
+ return ExecutionResult.from_dict(pickle.load(result_stream))
83
+ except (sp_exceptions.SnowparkSQLException, TypeError, pickle.UnpicklingError):
84
+ # Fall back to JSON result if loading pickled result fails for any reason
85
+ result_json_path = os.path.splitext(result_path)[0] + ".json"
86
+ with session.file.get_stream(result_json_path) as result_stream:
87
+ return ExecutionResult.from_dict(json.load(result_stream))
88
+
89
+
90
+ def load_exception(exc_type_name: str, exc_value: Union[Exception, str], exc_tb: str) -> Exception:
91
+ """
92
+ Create an exception with a string-formatted traceback.
93
+
94
+ When this exception is raised and not caught, it will display the original traceback.
95
+ When caught, it behaves like a regular exception without showing the traceback.
96
+
97
+ Args:
98
+ exc_type_name: Name of the exception type (e.g., 'ValueError', 'RuntimeError')
99
+ exc_value: The deserialized exception value or exception string (i.e. message)
100
+ exc_tb: String representation of the traceback
101
+
102
+ Returns:
103
+ An exception object with the original traceback information
104
+
105
+ # noqa: DAR401
106
+ """
107
+ if isinstance(exc_value, Exception):
108
+ exception = exc_value
109
+ else:
110
+ # Try to load the original exception type if possible
111
+ try:
112
+ # First check built-in exceptions
113
+ exc_type = getattr(builtins, exc_type_name, None)
114
+ if exc_type is None and "." in exc_type_name:
115
+ # Try to import from module path if it's a qualified name
116
+ module_path, class_name = exc_type_name.rsplit(".", 1)
117
+ module = importlib.import_module(module_path)
118
+ exc_type = getattr(module, class_name)
119
+ if exc_type is None or not issubclass(exc_type, Exception):
120
+ raise TypeError(f"{exc_type_name} is not a known exception type")
121
+ # Create the exception instance
122
+ exception = exc_type(exc_value)
123
+ except (ImportError, AttributeError, TypeError):
124
+ # Fall back to a generic exception
125
+ exception = RuntimeError(
126
+ f"Exception deserialization failed, original exception: {exc_type_name}: {exc_value}"
127
+ )
128
+
129
+ # Attach the traceback information to the exception
130
+ return _attach_remote_error_info(exception, exc_type_name, str(exc_value), exc_tb)
131
+
132
+
133
+ def _attach_remote_error_info(ex: Exception, exc_type: str, exc_msg: str, traceback_str: str) -> Exception:
134
+ """
135
+ Attach a string-formatted traceback to an exception.
136
+
137
+ When the exception is raised and not caught, it will display the original traceback.
138
+ When caught, it behaves like a regular exception without showing the traceback.
139
+
140
+ Args:
141
+ ex: The exception object to modify
142
+ exc_type: The original exception type name
143
+ exc_msg: The original exception message
144
+ traceback_str: String representation of the traceback
145
+
146
+ Returns:
147
+ An exception object with the original traceback information
148
+ """
149
+ # Store the traceback information
150
+ exc_type = exc_type.rsplit(".", 1)[-1] # Remove module path
151
+ setattr(ex, _REMOTE_ERROR_ATTR_NAME, RemoteError(exc_type=exc_type, exc_msg=exc_msg, exc_tb=traceback_str))
152
+ return ex
153
+
154
+
155
+ def _retrieve_remote_error_info(ex: Optional[BaseException]) -> Optional[RemoteError]:
156
+ """
157
+ Retrieve the string-formatted traceback from an exception if it exists.
158
+
159
+ Args:
160
+ ex: The exception to retrieve the traceback from
161
+
162
+ Returns:
163
+ The remote error tuple if it exists, None otherwise
164
+ """
165
+ if not ex:
166
+ return None
167
+ return getattr(ex, _REMOTE_ERROR_ATTR_NAME, None)
168
+
169
+
170
+ # ###############################################################################
171
+ # ------------------------------- !!! NOTE !!! -------------------------------- #
172
+ # ###############################################################################
173
+ # Job execution results (including uncaught exceptions) are serialized to file(s)
174
+ # in mljob_launcher.py. When the job is executed remotely, the serialized results
175
+ # are fetched and deserialized in the local environment. If the result contains
176
+ # an exception the original traceback is reconstructed and displayed to the user.
177
+ #
178
+ # It's currently impossible to recreate the original traceback object, so the
179
+ # following overrides are necessary to attach and display the deserialized
180
+ # traceback during exception handling.
181
+ #
182
+ # The following code implements the necessary overrides including sys.excepthook
183
+ # modifications and IPython traceback formatting. The hooks are applied on init
184
+ # and will be active for the duration of the process. The hooks are designed to
185
+ # self-uninstall in the event of an error in case of future compatibility issues.
186
+ # ###############################################################################
187
+
188
+
189
+ def _revert_func_wrapper(
190
+ patched_func: Callable[..., Any],
191
+ original_func: Callable[..., Any],
192
+ uninstall_func: Callable[[], None],
193
+ ) -> Callable[..., Any]:
194
+ """
195
+ Create a wrapper function that uninstalls the original function if an error occurs during execution.
196
+
197
+ This wrapper provides a fallback mechanism where if the patched function fails, it will:
198
+ 1. Uninstall the patched function using the provided uninstall_func, reverting back to using the original function
199
+ 2. Re-execute the current call using the original (unpatched) function with the same arguments
200
+
201
+ Args:
202
+ patched_func: The patched function to call.
203
+ original_func: The original function to call if patched_func fails.
204
+ uninstall_func: The function to call to uninstall the patched function.
205
+
206
+ Returns:
207
+ A wrapped function that calls patched_func and uninstalls on failure.
208
+ """
209
+
210
+ @functools.wraps(patched_func)
211
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
212
+ try:
213
+ return patched_func(*args, **kwargs)
214
+ except Exception:
215
+ # Uninstall and revert to original on failure
216
+ uninstall_func()
217
+ return original_func(*args, **kwargs)
218
+
219
+ return wrapped
220
+
221
+
222
+ def _install_sys_excepthook() -> None:
223
+ """
224
+ Install a custom sys.excepthook to handle remote exception tracebacks.
225
+
226
+ sys.excepthook is the global hook that Python calls when an unhandled exception occurs.
227
+ By default it prints the exception type, message and traceback to stderr.
228
+
229
+ We override sys.excepthook to intercept exceptions that contain our special RemoteError
230
+ attribute. These exceptions come from deserialized remote execution results and contain
231
+ the original traceback information from where they occurred.
232
+
233
+ When such an exception is detected, we format and display the original remote traceback
234
+ instead of the local one, which provides better debugging context by showing where the
235
+ error actually happened during remote execution.
236
+
237
+ The custom hook maintains proper exception chaining for both __cause__ (from raise from)
238
+ and __context__ (from implicit exception chaining).
239
+ """
240
+ # Attach the custom excepthook for standard Python scripts if not already attached
241
+ if not hasattr(sys, "_original_excepthook"):
242
+ original_excepthook = sys.excepthook
243
+
244
+ def custom_excepthook(
245
+ exc_type: Type[BaseException],
246
+ exc_value: BaseException,
247
+ exc_tb: Optional[TracebackType],
248
+ *,
249
+ seen_exc_ids: Optional[Set[int]] = None,
250
+ ) -> None:
251
+ if seen_exc_ids is None:
252
+ seen_exc_ids = set()
253
+ seen_exc_ids.add(id(exc_value))
254
+
255
+ cause = getattr(exc_value, "__cause__", None)
256
+ context = getattr(exc_value, "__context__", None)
257
+ if cause:
258
+ # Handle cause-chained exceptions
259
+ custom_excepthook(type(cause), cause, cause.__traceback__, seen_exc_ids=seen_exc_ids)
260
+ print( # noqa: T201
261
+ "\nThe above exception was the direct cause of the following exception:\n", file=sys.stderr
262
+ )
263
+ elif context and not getattr(exc_value, "__suppress_context__", False):
264
+ # Handle context-chained exceptions
265
+ # Only process context if it's different from cause to avoid double printing
266
+ custom_excepthook(type(context), context, context.__traceback__, seen_exc_ids=seen_exc_ids)
267
+ print( # noqa: T201
268
+ "\nDuring handling of the above exception, another exception occurred:\n", file=sys.stderr
269
+ )
270
+
271
+ if (remote_err := _retrieve_remote_error_info(exc_value)) and isinstance(remote_err, RemoteError):
272
+ # Display stored traceback for deserialized exceptions
273
+ print("Traceback (from remote execution):", file=sys.stderr) # noqa: T201
274
+ print(remote_err.exc_tb, end="", file=sys.stderr) # noqa: T201
275
+ print(f"{remote_err.exc_type}: {remote_err.exc_msg}", file=sys.stderr) # noqa: T201
276
+ else:
277
+ # Fall back to the original excepthook
278
+ traceback.print_exception(exc_type, exc_value, exc_tb, file=sys.stderr, chain=False)
279
+
280
+ sys._original_excepthook = original_excepthook # type: ignore[attr-defined]
281
+ sys.excepthook = _revert_func_wrapper(custom_excepthook, original_excepthook, _uninstall_sys_excepthook)
282
+
283
+
284
+ def _uninstall_sys_excepthook() -> None:
285
+ """
286
+ Restore the original excepthook for the current process.
287
+
288
+ This is useful when we want to revert to the default behavior after installing a custom excepthook.
289
+ """
290
+ if hasattr(sys, "_original_excepthook"):
291
+ sys.excepthook = sys._original_excepthook
292
+ del sys._original_excepthook
293
+
294
+
295
+ def _install_ipython_hook() -> bool:
296
+ """Install IPython-specific exception handling hook to improve remote error reporting.
297
+
298
+ This function enhances IPython's error formatting capabilities by intercepting and customizing
299
+ how remote execution errors are displayed. It modifies two key IPython traceback formatters:
300
+
301
+ 1. VerboseTB.format_exception_as_a_whole: Customizes the full traceback formatting for remote
302
+ errors by:
303
+ - Adding a "(from remote execution)" header instead of "(most recent call last)"
304
+ - Properly formatting the remote traceback entries
305
+ - Maintaining original behavior for non-remote errors
306
+
307
+ 2. ListTB.structured_traceback: Modifies the structured traceback output by:
308
+ - Parsing and formatting remote tracebacks appropriately
309
+ - Adding remote execution context to the output
310
+ - Preserving original functionality for local errors
311
+
312
+ The modifications are needed because IPython's default error handling doesn't properly display
313
+ remote execution errors that occur in Snowpark/Snowflake operations. The custom formatters
314
+ ensure that error messages from remote executions are properly captured, formatted and displayed
315
+ with the correct context and traceback information.
316
+
317
+ Returns:
318
+ bool: True if IPython hooks were successfully installed, False if IPython is not available
319
+ or not in an IPython environment.
320
+
321
+ Note:
322
+ This function maintains the ability to revert changes through _uninstall_ipython_hook by
323
+ storing original implementations before applying modifications.
324
+ """
325
+ try:
326
+ from IPython.core.getipython import get_ipython
327
+ from IPython.core.ultratb import ListTB, VerboseTB
328
+
329
+ if get_ipython() is None:
330
+ return False
331
+ except ImportError:
332
+ return False
333
+
334
+ def parse_traceback_str(traceback_str: str) -> List[Tuple[str, int, str, str]]:
335
+ return [
336
+ (m.group("filename"), int(m.group("lineno")), m.group("name"), m.group("line"))
337
+ for m in re.finditer(_TRACEBACK_ENTRY_PATTERN, traceback_str)
338
+ ]
339
+
340
+ if not hasattr(VerboseTB, "_original_format_exception_as_a_whole"):
341
+ original_format_exception_as_a_whole = VerboseTB.format_exception_as_a_whole
342
+
343
+ def custom_format_exception_as_a_whole(
344
+ self: VerboseTB,
345
+ etype: Type[BaseException],
346
+ evalue: Optional[BaseException],
347
+ etb: Optional[TracebackType],
348
+ number_of_lines_of_context: int,
349
+ tb_offset: Optional[int],
350
+ **kwargs: Any,
351
+ ) -> List[List[str]]:
352
+ if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError):
353
+ # Implementation forked from IPython.core.ultratb.VerboseTB.format_exception_as_a_whole
354
+ head = self.prepare_header(remote_err.exc_type, long_version=False).replace(
355
+ "(most recent call last)",
356
+ "(from remote execution)",
357
+ )
358
+
359
+ frames = ListTB._format_list(
360
+ self,
361
+ parse_traceback_str(remote_err.exc_tb),
362
+ )
363
+ formatted_exception = self.format_exception(remote_err.exc_type, remote_err.exc_msg)
364
+
365
+ return [[head] + frames + formatted_exception]
366
+ return original_format_exception_as_a_whole( # type: ignore[no-any-return]
367
+ self,
368
+ etype=etype,
369
+ evalue=evalue,
370
+ etb=etb,
371
+ number_of_lines_of_context=number_of_lines_of_context,
372
+ tb_offset=tb_offset,
373
+ **kwargs,
374
+ )
375
+
376
+ VerboseTB._original_format_exception_as_a_whole = original_format_exception_as_a_whole
377
+ VerboseTB.format_exception_as_a_whole = _revert_func_wrapper(
378
+ custom_format_exception_as_a_whole, original_format_exception_as_a_whole, _uninstall_ipython_hook
379
+ )
380
+
381
+ if not hasattr(ListTB, "_original_structured_traceback"):
382
+ original_structured_traceback = ListTB.structured_traceback
383
+
384
+ def structured_traceback(
385
+ self: ListTB,
386
+ etype: type,
387
+ evalue: Optional[BaseException],
388
+ etb: Optional[TracebackType],
389
+ tb_offset: Optional[int] = None,
390
+ **kwargs: Any,
391
+ ) -> List[str]:
392
+ if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError):
393
+ tb_list = [
394
+ (m.group("filename"), m.group("lineno"), m.group("name"), m.group("line"))
395
+ for m in re.finditer(_TRACEBACK_ENTRY_PATTERN, remote_err.exc_tb or "")
396
+ ]
397
+ out_list = original_structured_traceback(self, etype, evalue, tb_list, tb_offset, **kwargs)
398
+ if out_list:
399
+ out_list[0] = out_list[0].replace(
400
+ "(most recent call last)",
401
+ "(from remote execution)",
402
+ )
403
+ return cast(List[str], out_list)
404
+ return original_structured_traceback( # type: ignore[no-any-return]
405
+ self, etype, evalue, etb, tb_offset, **kwargs
406
+ )
407
+
408
+ ListTB._original_structured_traceback = original_structured_traceback
409
+ ListTB.structured_traceback = _revert_func_wrapper(
410
+ structured_traceback, original_structured_traceback, _uninstall_ipython_hook
411
+ )
412
+
413
+ return True
414
+
415
+
416
+ def _uninstall_ipython_hook() -> None:
417
+ """
418
+ Restore the original IPython traceback formatting if it was modified.
419
+
420
+ This is useful when we want to revert to the default behavior after installing a custom hook.
421
+ """
422
+ try:
423
+ from IPython.core.ultratb import ListTB, VerboseTB
424
+
425
+ if hasattr(VerboseTB, "_original_format_exception_as_a_whole"):
426
+ VerboseTB.format_exception_as_a_whole = VerboseTB._original_format_exception_as_a_whole
427
+ del VerboseTB._original_format_exception_as_a_whole
428
+
429
+ if hasattr(ListTB, "_original_structured_traceback"):
430
+ ListTB.structured_traceback = ListTB._original_structured_traceback
431
+ del ListTB._original_structured_traceback
432
+ except ImportError:
433
+ pass
434
+
435
+
436
+ def install_exception_display_hooks() -> None:
437
+ if not _install_ipython_hook():
438
+ _install_sys_excepthook()
439
+
440
+
441
+ # ------ Install the custom traceback hooks by default ------ #
442
+ install_exception_display_hooks()
@@ -27,6 +27,7 @@ from snowflake.snowpark._internal import code_generation
27
27
 
28
28
  _SUPPORTED_ARG_TYPES = {str, int, float}
29
29
  _SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
30
+ _ENTRYPOINT_FUNC_NAME = "func"
30
31
  _STARTUP_SCRIPT_PATH = PurePath("startup.sh")
31
32
  _STARTUP_SCRIPT_CODE = textwrap.dedent(
32
33
  f"""
@@ -73,14 +74,14 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
73
74
  ##### Ray configuration #####
74
75
  shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
75
76
 
76
- # Check if the instance ip retrieval module exists, which is a prerequisite for multi node jobs
77
+ # Check if the local get_instance_ip.py script exists
77
78
  HELPER_EXISTS=$(
78
- python3 -c "import snowflake.runtime.utils.get_instance_ip" 2>/dev/null && echo "true" || echo "false"
79
+ [ -f "get_instance_ip.py" ] && echo "true" || echo "false"
79
80
  )
80
81
 
81
82
  # Configure IP address and logging directory
82
83
  if [ "$HELPER_EXISTS" = "true" ]; then
83
- eth0Ip=$(python3 -m snowflake.runtime.utils.get_instance_ip "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
84
+ eth0Ip=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
84
85
  else
85
86
  eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
86
87
  fi
@@ -103,7 +104,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
103
104
 
104
105
  # Determine if it should be a worker or a head node for batch jobs
105
106
  if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 && "$HELPER_EXISTS" = "true" ]]; then
106
- head_info=$(python3 -m snowflake.runtime.utils.get_instance_ip "$SNOWFLAKE_SERVICE_NAME" --head)
107
+ head_info=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --head)
107
108
  if [ $? -eq 0 ]; then
108
109
  # Parse the output using read
109
110
  read head_index head_ip <<< "$head_info"
@@ -166,10 +167,17 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
166
167
  "--object-store-memory=${{shm_size}}"
167
168
  )
168
169
 
169
- # Start Ray on a worker node
170
- ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block
171
- else
170
+ # Start Ray on a worker node - run in background
171
+ ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block &
172
+
173
+ # Start the worker shutdown listener in the background
174
+ echo "Starting worker shutdown listener..."
175
+ python worker_shutdown_listener.py
176
+ WORKER_EXIT_CODE=$?
172
177
 
178
+ echo "Worker shutdown listener exited with code $WORKER_EXIT_CODE"
179
+ exit $WORKER_EXIT_CODE
180
+ else
173
181
  # Additional head-specific parameters
174
182
  head_params=(
175
183
  "--head"
@@ -193,13 +201,39 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
193
201
  # Run user's Python entrypoint
194
202
  echo Running command: python "$@"
195
203
  python "$@"
204
+
205
+ # After the user's job completes, signal workers to shut down
206
+ echo "User job completed. Signaling workers to shut down..."
207
+ python signal_workers.py --wait-time 15
208
+ echo "Head node job completed. Exiting."
196
209
  fi
197
210
  """
198
211
  ).strip()
199
212
 
200
213
 
201
- def _resolve_entrypoint(parent: Path, entrypoint: Optional[Path]) -> Path:
202
- parent = parent.absolute()
214
+ def resolve_source(source: Union[Path, Callable[..., Any]]) -> Union[Path, Callable[..., Any]]:
215
+ if callable(source):
216
+ return source
217
+ elif isinstance(source, Path):
218
+ # Validate source
219
+ source = source
220
+ if not source.exists():
221
+ raise FileNotFoundError(f"{source} does not exist")
222
+ return source.absolute()
223
+ else:
224
+ raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
225
+
226
+
227
+ def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Optional[Path]) -> types.PayloadEntrypoint:
228
+ if callable(source):
229
+ # Entrypoint is generated for callable payloads
230
+ return types.PayloadEntrypoint(
231
+ file_path=entrypoint or Path(constants.DEFAULT_ENTRYPOINT_PATH),
232
+ main_func=_ENTRYPOINT_FUNC_NAME,
233
+ )
234
+
235
+ # Resolve entrypoint path for file-based payloads
236
+ parent = source.absolute()
203
237
  if entrypoint is None:
204
238
  if parent.is_file():
205
239
  # Infer entrypoint from source
@@ -218,12 +252,23 @@ def _resolve_entrypoint(parent: Path, entrypoint: Optional[Path]) -> Path:
218
252
  else:
219
253
  # Relative to source dir
220
254
  entrypoint = parent.joinpath(entrypoint)
255
+
256
+ # Validate resolved entrypoint file
221
257
  if not entrypoint.is_file():
222
258
  raise FileNotFoundError(
223
259
  "Entrypoint not found. Ensure the entrypoint is a valid file and is under"
224
260
  f" the source directory (source={parent}, entrypoint={entrypoint})"
225
261
  )
226
- return entrypoint
262
+ if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
263
+ raise ValueError(
264
+ "Unsupported entrypoint type:"
265
+ f" supported={','.join(_SUPPORTED_ENTRYPOINT_EXTENSIONS)} got={entrypoint.suffix}"
266
+ )
267
+
268
+ return types.PayloadEntrypoint(
269
+ file_path=entrypoint, # entrypoint is an absolute path at this point
270
+ main_func=None,
271
+ )
227
272
 
228
273
 
229
274
  class JobPayload:
@@ -238,40 +283,11 @@ class JobPayload:
238
283
  self.entrypoint = Path(entrypoint) if isinstance(entrypoint, str) else entrypoint
239
284
  self.pip_requirements = pip_requirements
240
285
 
241
- def validate(self) -> None:
242
- if callable(self.source):
243
- # Any entrypoint value is OK for callable payloads (including None aka default)
244
- # since we will generate the file from the serialized callable
245
- pass
246
- elif isinstance(self.source, Path):
247
- # Validate source
248
- source = self.source
249
- if not source.exists():
250
- raise FileNotFoundError(f"{source} does not exist")
251
- source = source.absolute()
252
-
253
- # Validate entrypoint
254
- entrypoint = _resolve_entrypoint(source, self.entrypoint)
255
- if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
256
- raise ValueError(
257
- "Unsupported entrypoint type:"
258
- f" supported={','.join(_SUPPORTED_ENTRYPOINT_EXTENSIONS)} got={entrypoint.suffix}"
259
- )
260
-
261
- # Update fields with normalized values
262
- self.source = source
263
- self.entrypoint = entrypoint
264
- else:
265
- raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
266
-
267
286
  def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
268
- # Validate payload
269
- self.validate()
270
-
271
287
  # Prepare local variables
272
288
  stage_path = PurePath(stage_path) if isinstance(stage_path, str) else stage_path
273
- source = self.source
274
- entrypoint = self.entrypoint or Path(constants.DEFAULT_ENTRYPOINT_PATH)
289
+ source = resolve_source(self.source)
290
+ entrypoint = resolve_entrypoint(source, self.entrypoint)
275
291
 
276
292
  # Create stage if necessary
277
293
  stage_name = stage_path.parts[0].lstrip("@")
@@ -290,11 +306,11 @@ class JobPayload:
290
306
  source_code = generate_python_code(source, source_code_display=True)
291
307
  _ = session.file.put_stream(
292
308
  io.BytesIO(source_code.encode()),
293
- stage_location=stage_path.joinpath(entrypoint).as_posix(),
309
+ stage_location=stage_path.joinpath(entrypoint.file_path).as_posix(),
294
310
  auto_compress=False,
295
311
  overwrite=True,
296
312
  )
297
- source = entrypoint.parent
313
+ source = Path(entrypoint.file_path.parent)
298
314
  elif source.is_dir():
299
315
  # Manually traverse the directory and upload each file, since Snowflake PUT
300
316
  # can't handle directories. Reduce the number of PUT operations by using
@@ -337,12 +353,30 @@ class JobPayload:
337
353
  overwrite=False, # FIXME
338
354
  )
339
355
 
356
+ # Upload system scripts
357
+ scripts_dir = Path(__file__).parent.joinpath("scripts")
358
+ for script_file in scripts_dir.glob("*"):
359
+ if script_file.is_file():
360
+ session.file.put(
361
+ script_file.as_posix(),
362
+ stage_path.as_posix(),
363
+ overwrite=True,
364
+ auto_compress=False,
365
+ )
366
+
367
+ python_entrypoint: List[Union[str, PurePath]] = [
368
+ PurePath("mljob_launcher.py"),
369
+ entrypoint.file_path.relative_to(source),
370
+ ]
371
+ if entrypoint.main_func:
372
+ python_entrypoint += ["--script_main_func", entrypoint.main_func]
373
+
340
374
  return types.UploadedPayload(
341
375
  stage_path=stage_path,
342
376
  entrypoint=[
343
377
  "bash",
344
378
  _STARTUP_SCRIPT_PATH,
345
- entrypoint.relative_to(source),
379
+ *python_entrypoint,
346
380
  ],
347
381
  )
348
382
 
@@ -471,12 +505,11 @@ def generate_python_code(func: Callable[..., Any], source_code_display: bool = F
471
505
  # https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
472
506
  source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
473
507
 
474
- func_name = "func"
475
508
  func_code = f"""
476
509
  {source_code_comment}
477
510
 
478
511
  import pickle
479
- {func_name} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
512
+ {_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
480
513
  """
481
514
 
482
515
  arg_dict_name = "kwargs"
@@ -487,6 +520,7 @@ import pickle
487
520
 
488
521
  return f"""
489
522
  ### Version guard to check compatibility across Python versions ###
523
+ import os
490
524
  import sys
491
525
  import warnings
492
526
 
@@ -508,5 +542,5 @@ if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor
508
542
  if __name__ == '__main__':
509
543
  {textwrap.indent(param_code, ' ')}
510
544
 
511
- {func_name}(**{arg_dict_name})
545
+ __return__ = {_ENTRYPOINT_FUNC_NAME}(**{arg_dict_name})
512
546
  """
@@ -0,0 +1,4 @@
1
+ # Constants defining the shutdown signal actor configuration.
2
+ SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
3
+ SHUTDOWN_ACTOR_NAMESPACE = "default"
4
+ SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0