snowflake-ml-python 1.8.0__py3-none-any.whl → 1.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +44 -10
- snowflake/ml/_internal/platform_capabilities.py +39 -3
- snowflake/ml/data/data_connector.py +25 -0
- snowflake/ml/dataset/dataset_reader.py +5 -1
- snowflake/ml/jobs/_utils/constants.py +3 -5
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +81 -47
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +178 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +27 -8
- snowflake/ml/jobs/_utils/types.py +6 -0
- snowflake/ml/jobs/decorators.py +10 -6
- snowflake/ml/jobs/job.py +145 -23
- snowflake/ml/jobs/manager.py +79 -12
- snowflake/ml/model/_client/ops/model_ops.py +6 -3
- snowflake/ml/model/_client/ops/service_ops.py +57 -39
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -4
- snowflake/ml/model/_client/sql/service.py +11 -5
- snowflake/ml/model/_model_composer/model_composer.py +29 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +1 -2
- snowflake/ml/model/_packager/model_env/model_env.py +8 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -4
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +6 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_packager.py +2 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +5 -1
- snowflake/ml/registry/_manager/model_manager.py +20 -1
- snowflake/ml/registry/registry.py +46 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/METADATA +55 -4
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/RECORD +40 -34
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,442 @@
|
|
1
|
+
import builtins
|
2
|
+
import functools
|
3
|
+
import importlib
|
4
|
+
import json
|
5
|
+
import os
|
6
|
+
import pickle
|
7
|
+
import re
|
8
|
+
import sys
|
9
|
+
import traceback
|
10
|
+
from collections import namedtuple
|
11
|
+
from dataclasses import dataclass
|
12
|
+
from types import TracebackType
|
13
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, cast
|
14
|
+
|
15
|
+
from snowflake import snowpark
|
16
|
+
from snowflake.snowpark import exceptions as sp_exceptions
|
17
|
+
|
18
|
+
_TRACEBACK_ENTRY_PATTERN = re.compile(
|
19
|
+
r'File "(?P<filename>[^"]+)", line (?P<lineno>\d+), in (?P<name>[^\n]+)(?:\n(?!^\s*File)^\s*(?P<line>[^\n]+))?\n',
|
20
|
+
flags=re.MULTILINE,
|
21
|
+
)
|
22
|
+
_REMOTE_ERROR_ATTR_NAME = "_remote_error"
|
23
|
+
|
24
|
+
RemoteError = namedtuple("RemoteError", ["exc_type", "exc_msg", "exc_tb"])
|
25
|
+
|
26
|
+
|
27
|
+
@dataclass(frozen=True)
|
28
|
+
class ExecutionResult:
|
29
|
+
result: Any = None
|
30
|
+
exception: Optional[BaseException] = None
|
31
|
+
|
32
|
+
@property
|
33
|
+
def success(self) -> bool:
|
34
|
+
return self.exception is None
|
35
|
+
|
36
|
+
def to_dict(self) -> Dict[str, Any]:
|
37
|
+
"""Return the serializable dictionary."""
|
38
|
+
if isinstance(self.exception, BaseException):
|
39
|
+
exc_type = type(self.exception)
|
40
|
+
return {
|
41
|
+
"success": False,
|
42
|
+
"exc_type": f"{exc_type.__module__}.{exc_type.__name__}",
|
43
|
+
"exc_value": self.exception,
|
44
|
+
"exc_tb": "".join(traceback.format_tb(self.exception.__traceback__)),
|
45
|
+
}
|
46
|
+
return {
|
47
|
+
"success": True,
|
48
|
+
"result_type": type(self.result).__qualname__,
|
49
|
+
"result": self.result,
|
50
|
+
}
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def from_dict(cls, result_dict: Dict[str, Any]) -> "ExecutionResult":
|
54
|
+
if not isinstance(result_dict.get("success"), bool):
|
55
|
+
raise ValueError("Invalid result dictionary")
|
56
|
+
|
57
|
+
if result_dict["success"]:
|
58
|
+
# Load successful result
|
59
|
+
return cls(result=result_dict.get("result"))
|
60
|
+
|
61
|
+
# Load exception
|
62
|
+
exc_type = result_dict.get("exc_type", "RuntimeError")
|
63
|
+
exc_value = result_dict.get("exc_value", "Unknown error")
|
64
|
+
exc_tb = result_dict.get("exc_tb", "")
|
65
|
+
return cls(exception=load_exception(exc_type, exc_value, exc_tb))
|
66
|
+
|
67
|
+
|
68
|
+
def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult:
|
69
|
+
"""
|
70
|
+
Fetch the serialized result from the specified path.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
session: Snowpark Session to use for file operations.
|
74
|
+
result_path: The path to the serialized result file.
|
75
|
+
|
76
|
+
Returns:
|
77
|
+
A dictionary containing the execution result if available, None otherwise.
|
78
|
+
"""
|
79
|
+
try:
|
80
|
+
# TODO: Check if file exists
|
81
|
+
with session.file.get_stream(result_path) as result_stream:
|
82
|
+
return ExecutionResult.from_dict(pickle.load(result_stream))
|
83
|
+
except (sp_exceptions.SnowparkSQLException, TypeError, pickle.UnpicklingError):
|
84
|
+
# Fall back to JSON result if loading pickled result fails for any reason
|
85
|
+
result_json_path = os.path.splitext(result_path)[0] + ".json"
|
86
|
+
with session.file.get_stream(result_json_path) as result_stream:
|
87
|
+
return ExecutionResult.from_dict(json.load(result_stream))
|
88
|
+
|
89
|
+
|
90
|
+
def load_exception(exc_type_name: str, exc_value: Union[Exception, str], exc_tb: str) -> Exception:
|
91
|
+
"""
|
92
|
+
Create an exception with a string-formatted traceback.
|
93
|
+
|
94
|
+
When this exception is raised and not caught, it will display the original traceback.
|
95
|
+
When caught, it behaves like a regular exception without showing the traceback.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
exc_type_name: Name of the exception type (e.g., 'ValueError', 'RuntimeError')
|
99
|
+
exc_value: The deserialized exception value or exception string (i.e. message)
|
100
|
+
exc_tb: String representation of the traceback
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
An exception object with the original traceback information
|
104
|
+
|
105
|
+
# noqa: DAR401
|
106
|
+
"""
|
107
|
+
if isinstance(exc_value, Exception):
|
108
|
+
exception = exc_value
|
109
|
+
else:
|
110
|
+
# Try to load the original exception type if possible
|
111
|
+
try:
|
112
|
+
# First check built-in exceptions
|
113
|
+
exc_type = getattr(builtins, exc_type_name, None)
|
114
|
+
if exc_type is None and "." in exc_type_name:
|
115
|
+
# Try to import from module path if it's a qualified name
|
116
|
+
module_path, class_name = exc_type_name.rsplit(".", 1)
|
117
|
+
module = importlib.import_module(module_path)
|
118
|
+
exc_type = getattr(module, class_name)
|
119
|
+
if exc_type is None or not issubclass(exc_type, Exception):
|
120
|
+
raise TypeError(f"{exc_type_name} is not a known exception type")
|
121
|
+
# Create the exception instance
|
122
|
+
exception = exc_type(exc_value)
|
123
|
+
except (ImportError, AttributeError, TypeError):
|
124
|
+
# Fall back to a generic exception
|
125
|
+
exception = RuntimeError(
|
126
|
+
f"Exception deserialization failed, original exception: {exc_type_name}: {exc_value}"
|
127
|
+
)
|
128
|
+
|
129
|
+
# Attach the traceback information to the exception
|
130
|
+
return _attach_remote_error_info(exception, exc_type_name, str(exc_value), exc_tb)
|
131
|
+
|
132
|
+
|
133
|
+
def _attach_remote_error_info(ex: Exception, exc_type: str, exc_msg: str, traceback_str: str) -> Exception:
|
134
|
+
"""
|
135
|
+
Attach a string-formatted traceback to an exception.
|
136
|
+
|
137
|
+
When the exception is raised and not caught, it will display the original traceback.
|
138
|
+
When caught, it behaves like a regular exception without showing the traceback.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
ex: The exception object to modify
|
142
|
+
exc_type: The original exception type name
|
143
|
+
exc_msg: The original exception message
|
144
|
+
traceback_str: String representation of the traceback
|
145
|
+
|
146
|
+
Returns:
|
147
|
+
An exception object with the original traceback information
|
148
|
+
"""
|
149
|
+
# Store the traceback information
|
150
|
+
exc_type = exc_type.rsplit(".", 1)[-1] # Remove module path
|
151
|
+
setattr(ex, _REMOTE_ERROR_ATTR_NAME, RemoteError(exc_type=exc_type, exc_msg=exc_msg, exc_tb=traceback_str))
|
152
|
+
return ex
|
153
|
+
|
154
|
+
|
155
|
+
def _retrieve_remote_error_info(ex: Optional[BaseException]) -> Optional[RemoteError]:
|
156
|
+
"""
|
157
|
+
Retrieve the string-formatted traceback from an exception if it exists.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
ex: The exception to retrieve the traceback from
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
The remote error tuple if it exists, None otherwise
|
164
|
+
"""
|
165
|
+
if not ex:
|
166
|
+
return None
|
167
|
+
return getattr(ex, _REMOTE_ERROR_ATTR_NAME, None)
|
168
|
+
|
169
|
+
|
170
|
+
# ###############################################################################
|
171
|
+
# ------------------------------- !!! NOTE !!! -------------------------------- #
|
172
|
+
# ###############################################################################
|
173
|
+
# Job execution results (including uncaught exceptions) are serialized to file(s)
|
174
|
+
# in mljob_launcher.py. When the job is executed remotely, the serialized results
|
175
|
+
# are fetched and deserialized in the local environment. If the result contains
|
176
|
+
# an exception the original traceback is reconstructed and displayed to the user.
|
177
|
+
#
|
178
|
+
# It's currently impossible to recreate the original traceback object, so the
|
179
|
+
# following overrides are necessary to attach and display the deserialized
|
180
|
+
# traceback during exception handling.
|
181
|
+
#
|
182
|
+
# The following code implements the necessary overrides including sys.excepthook
|
183
|
+
# modifications and IPython traceback formatting. The hooks are applied on init
|
184
|
+
# and will be active for the duration of the process. The hooks are designed to
|
185
|
+
# self-uninstall in the event of an error in case of future compatibility issues.
|
186
|
+
# ###############################################################################
|
187
|
+
|
188
|
+
|
189
|
+
def _revert_func_wrapper(
|
190
|
+
patched_func: Callable[..., Any],
|
191
|
+
original_func: Callable[..., Any],
|
192
|
+
uninstall_func: Callable[[], None],
|
193
|
+
) -> Callable[..., Any]:
|
194
|
+
"""
|
195
|
+
Create a wrapper function that uninstalls the original function if an error occurs during execution.
|
196
|
+
|
197
|
+
This wrapper provides a fallback mechanism where if the patched function fails, it will:
|
198
|
+
1. Uninstall the patched function using the provided uninstall_func, reverting back to using the original function
|
199
|
+
2. Re-execute the current call using the original (unpatched) function with the same arguments
|
200
|
+
|
201
|
+
Args:
|
202
|
+
patched_func: The patched function to call.
|
203
|
+
original_func: The original function to call if patched_func fails.
|
204
|
+
uninstall_func: The function to call to uninstall the patched function.
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
A wrapped function that calls patched_func and uninstalls on failure.
|
208
|
+
"""
|
209
|
+
|
210
|
+
@functools.wraps(patched_func)
|
211
|
+
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
212
|
+
try:
|
213
|
+
return patched_func(*args, **kwargs)
|
214
|
+
except Exception:
|
215
|
+
# Uninstall and revert to original on failure
|
216
|
+
uninstall_func()
|
217
|
+
return original_func(*args, **kwargs)
|
218
|
+
|
219
|
+
return wrapped
|
220
|
+
|
221
|
+
|
222
|
+
def _install_sys_excepthook() -> None:
|
223
|
+
"""
|
224
|
+
Install a custom sys.excepthook to handle remote exception tracebacks.
|
225
|
+
|
226
|
+
sys.excepthook is the global hook that Python calls when an unhandled exception occurs.
|
227
|
+
By default it prints the exception type, message and traceback to stderr.
|
228
|
+
|
229
|
+
We override sys.excepthook to intercept exceptions that contain our special RemoteError
|
230
|
+
attribute. These exceptions come from deserialized remote execution results and contain
|
231
|
+
the original traceback information from where they occurred.
|
232
|
+
|
233
|
+
When such an exception is detected, we format and display the original remote traceback
|
234
|
+
instead of the local one, which provides better debugging context by showing where the
|
235
|
+
error actually happened during remote execution.
|
236
|
+
|
237
|
+
The custom hook maintains proper exception chaining for both __cause__ (from raise from)
|
238
|
+
and __context__ (from implicit exception chaining).
|
239
|
+
"""
|
240
|
+
# Attach the custom excepthook for standard Python scripts if not already attached
|
241
|
+
if not hasattr(sys, "_original_excepthook"):
|
242
|
+
original_excepthook = sys.excepthook
|
243
|
+
|
244
|
+
def custom_excepthook(
|
245
|
+
exc_type: Type[BaseException],
|
246
|
+
exc_value: BaseException,
|
247
|
+
exc_tb: Optional[TracebackType],
|
248
|
+
*,
|
249
|
+
seen_exc_ids: Optional[Set[int]] = None,
|
250
|
+
) -> None:
|
251
|
+
if seen_exc_ids is None:
|
252
|
+
seen_exc_ids = set()
|
253
|
+
seen_exc_ids.add(id(exc_value))
|
254
|
+
|
255
|
+
cause = getattr(exc_value, "__cause__", None)
|
256
|
+
context = getattr(exc_value, "__context__", None)
|
257
|
+
if cause:
|
258
|
+
# Handle cause-chained exceptions
|
259
|
+
custom_excepthook(type(cause), cause, cause.__traceback__, seen_exc_ids=seen_exc_ids)
|
260
|
+
print( # noqa: T201
|
261
|
+
"\nThe above exception was the direct cause of the following exception:\n", file=sys.stderr
|
262
|
+
)
|
263
|
+
elif context and not getattr(exc_value, "__suppress_context__", False):
|
264
|
+
# Handle context-chained exceptions
|
265
|
+
# Only process context if it's different from cause to avoid double printing
|
266
|
+
custom_excepthook(type(context), context, context.__traceback__, seen_exc_ids=seen_exc_ids)
|
267
|
+
print( # noqa: T201
|
268
|
+
"\nDuring handling of the above exception, another exception occurred:\n", file=sys.stderr
|
269
|
+
)
|
270
|
+
|
271
|
+
if (remote_err := _retrieve_remote_error_info(exc_value)) and isinstance(remote_err, RemoteError):
|
272
|
+
# Display stored traceback for deserialized exceptions
|
273
|
+
print("Traceback (from remote execution):", file=sys.stderr) # noqa: T201
|
274
|
+
print(remote_err.exc_tb, end="", file=sys.stderr) # noqa: T201
|
275
|
+
print(f"{remote_err.exc_type}: {remote_err.exc_msg}", file=sys.stderr) # noqa: T201
|
276
|
+
else:
|
277
|
+
# Fall back to the original excepthook
|
278
|
+
traceback.print_exception(exc_type, exc_value, exc_tb, file=sys.stderr, chain=False)
|
279
|
+
|
280
|
+
sys._original_excepthook = original_excepthook # type: ignore[attr-defined]
|
281
|
+
sys.excepthook = _revert_func_wrapper(custom_excepthook, original_excepthook, _uninstall_sys_excepthook)
|
282
|
+
|
283
|
+
|
284
|
+
def _uninstall_sys_excepthook() -> None:
|
285
|
+
"""
|
286
|
+
Restore the original excepthook for the current process.
|
287
|
+
|
288
|
+
This is useful when we want to revert to the default behavior after installing a custom excepthook.
|
289
|
+
"""
|
290
|
+
if hasattr(sys, "_original_excepthook"):
|
291
|
+
sys.excepthook = sys._original_excepthook
|
292
|
+
del sys._original_excepthook
|
293
|
+
|
294
|
+
|
295
|
+
def _install_ipython_hook() -> bool:
|
296
|
+
"""Install IPython-specific exception handling hook to improve remote error reporting.
|
297
|
+
|
298
|
+
This function enhances IPython's error formatting capabilities by intercepting and customizing
|
299
|
+
how remote execution errors are displayed. It modifies two key IPython traceback formatters:
|
300
|
+
|
301
|
+
1. VerboseTB.format_exception_as_a_whole: Customizes the full traceback formatting for remote
|
302
|
+
errors by:
|
303
|
+
- Adding a "(from remote execution)" header instead of "(most recent call last)"
|
304
|
+
- Properly formatting the remote traceback entries
|
305
|
+
- Maintaining original behavior for non-remote errors
|
306
|
+
|
307
|
+
2. ListTB.structured_traceback: Modifies the structured traceback output by:
|
308
|
+
- Parsing and formatting remote tracebacks appropriately
|
309
|
+
- Adding remote execution context to the output
|
310
|
+
- Preserving original functionality for local errors
|
311
|
+
|
312
|
+
The modifications are needed because IPython's default error handling doesn't properly display
|
313
|
+
remote execution errors that occur in Snowpark/Snowflake operations. The custom formatters
|
314
|
+
ensure that error messages from remote executions are properly captured, formatted and displayed
|
315
|
+
with the correct context and traceback information.
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
bool: True if IPython hooks were successfully installed, False if IPython is not available
|
319
|
+
or not in an IPython environment.
|
320
|
+
|
321
|
+
Note:
|
322
|
+
This function maintains the ability to revert changes through _uninstall_ipython_hook by
|
323
|
+
storing original implementations before applying modifications.
|
324
|
+
"""
|
325
|
+
try:
|
326
|
+
from IPython.core.getipython import get_ipython
|
327
|
+
from IPython.core.ultratb import ListTB, VerboseTB
|
328
|
+
|
329
|
+
if get_ipython() is None:
|
330
|
+
return False
|
331
|
+
except ImportError:
|
332
|
+
return False
|
333
|
+
|
334
|
+
def parse_traceback_str(traceback_str: str) -> List[Tuple[str, int, str, str]]:
|
335
|
+
return [
|
336
|
+
(m.group("filename"), int(m.group("lineno")), m.group("name"), m.group("line"))
|
337
|
+
for m in re.finditer(_TRACEBACK_ENTRY_PATTERN, traceback_str)
|
338
|
+
]
|
339
|
+
|
340
|
+
if not hasattr(VerboseTB, "_original_format_exception_as_a_whole"):
|
341
|
+
original_format_exception_as_a_whole = VerboseTB.format_exception_as_a_whole
|
342
|
+
|
343
|
+
def custom_format_exception_as_a_whole(
|
344
|
+
self: VerboseTB,
|
345
|
+
etype: Type[BaseException],
|
346
|
+
evalue: Optional[BaseException],
|
347
|
+
etb: Optional[TracebackType],
|
348
|
+
number_of_lines_of_context: int,
|
349
|
+
tb_offset: Optional[int],
|
350
|
+
**kwargs: Any,
|
351
|
+
) -> List[List[str]]:
|
352
|
+
if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError):
|
353
|
+
# Implementation forked from IPython.core.ultratb.VerboseTB.format_exception_as_a_whole
|
354
|
+
head = self.prepare_header(remote_err.exc_type, long_version=False).replace(
|
355
|
+
"(most recent call last)",
|
356
|
+
"(from remote execution)",
|
357
|
+
)
|
358
|
+
|
359
|
+
frames = ListTB._format_list(
|
360
|
+
self,
|
361
|
+
parse_traceback_str(remote_err.exc_tb),
|
362
|
+
)
|
363
|
+
formatted_exception = self.format_exception(remote_err.exc_type, remote_err.exc_msg)
|
364
|
+
|
365
|
+
return [[head] + frames + formatted_exception]
|
366
|
+
return original_format_exception_as_a_whole( # type: ignore[no-any-return]
|
367
|
+
self,
|
368
|
+
etype=etype,
|
369
|
+
evalue=evalue,
|
370
|
+
etb=etb,
|
371
|
+
number_of_lines_of_context=number_of_lines_of_context,
|
372
|
+
tb_offset=tb_offset,
|
373
|
+
**kwargs,
|
374
|
+
)
|
375
|
+
|
376
|
+
VerboseTB._original_format_exception_as_a_whole = original_format_exception_as_a_whole
|
377
|
+
VerboseTB.format_exception_as_a_whole = _revert_func_wrapper(
|
378
|
+
custom_format_exception_as_a_whole, original_format_exception_as_a_whole, _uninstall_ipython_hook
|
379
|
+
)
|
380
|
+
|
381
|
+
if not hasattr(ListTB, "_original_structured_traceback"):
|
382
|
+
original_structured_traceback = ListTB.structured_traceback
|
383
|
+
|
384
|
+
def structured_traceback(
|
385
|
+
self: ListTB,
|
386
|
+
etype: type,
|
387
|
+
evalue: Optional[BaseException],
|
388
|
+
etb: Optional[TracebackType],
|
389
|
+
tb_offset: Optional[int] = None,
|
390
|
+
**kwargs: Any,
|
391
|
+
) -> List[str]:
|
392
|
+
if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError):
|
393
|
+
tb_list = [
|
394
|
+
(m.group("filename"), m.group("lineno"), m.group("name"), m.group("line"))
|
395
|
+
for m in re.finditer(_TRACEBACK_ENTRY_PATTERN, remote_err.exc_tb or "")
|
396
|
+
]
|
397
|
+
out_list = original_structured_traceback(self, etype, evalue, tb_list, tb_offset, **kwargs)
|
398
|
+
if out_list:
|
399
|
+
out_list[0] = out_list[0].replace(
|
400
|
+
"(most recent call last)",
|
401
|
+
"(from remote execution)",
|
402
|
+
)
|
403
|
+
return cast(List[str], out_list)
|
404
|
+
return original_structured_traceback( # type: ignore[no-any-return]
|
405
|
+
self, etype, evalue, etb, tb_offset, **kwargs
|
406
|
+
)
|
407
|
+
|
408
|
+
ListTB._original_structured_traceback = original_structured_traceback
|
409
|
+
ListTB.structured_traceback = _revert_func_wrapper(
|
410
|
+
structured_traceback, original_structured_traceback, _uninstall_ipython_hook
|
411
|
+
)
|
412
|
+
|
413
|
+
return True
|
414
|
+
|
415
|
+
|
416
|
+
def _uninstall_ipython_hook() -> None:
|
417
|
+
"""
|
418
|
+
Restore the original IPython traceback formatting if it was modified.
|
419
|
+
|
420
|
+
This is useful when we want to revert to the default behavior after installing a custom hook.
|
421
|
+
"""
|
422
|
+
try:
|
423
|
+
from IPython.core.ultratb import ListTB, VerboseTB
|
424
|
+
|
425
|
+
if hasattr(VerboseTB, "_original_format_exception_as_a_whole"):
|
426
|
+
VerboseTB.format_exception_as_a_whole = VerboseTB._original_format_exception_as_a_whole
|
427
|
+
del VerboseTB._original_format_exception_as_a_whole
|
428
|
+
|
429
|
+
if hasattr(ListTB, "_original_structured_traceback"):
|
430
|
+
ListTB.structured_traceback = ListTB._original_structured_traceback
|
431
|
+
del ListTB._original_structured_traceback
|
432
|
+
except ImportError:
|
433
|
+
pass
|
434
|
+
|
435
|
+
|
436
|
+
def install_exception_display_hooks() -> None:
|
437
|
+
if not _install_ipython_hook():
|
438
|
+
_install_sys_excepthook()
|
439
|
+
|
440
|
+
|
441
|
+
# ------ Install the custom traceback hooks by default ------ #
|
442
|
+
install_exception_display_hooks()
|
@@ -27,6 +27,7 @@ from snowflake.snowpark._internal import code_generation
|
|
27
27
|
|
28
28
|
_SUPPORTED_ARG_TYPES = {str, int, float}
|
29
29
|
_SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
|
30
|
+
_ENTRYPOINT_FUNC_NAME = "func"
|
30
31
|
_STARTUP_SCRIPT_PATH = PurePath("startup.sh")
|
31
32
|
_STARTUP_SCRIPT_CODE = textwrap.dedent(
|
32
33
|
f"""
|
@@ -73,14 +74,14 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
73
74
|
##### Ray configuration #####
|
74
75
|
shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
|
75
76
|
|
76
|
-
# Check if the
|
77
|
+
# Check if the local get_instance_ip.py script exists
|
77
78
|
HELPER_EXISTS=$(
|
78
|
-
|
79
|
+
[ -f "get_instance_ip.py" ] && echo "true" || echo "false"
|
79
80
|
)
|
80
81
|
|
81
82
|
# Configure IP address and logging directory
|
82
83
|
if [ "$HELPER_EXISTS" = "true" ]; then
|
83
|
-
eth0Ip=$(python3
|
84
|
+
eth0Ip=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
|
84
85
|
else
|
85
86
|
eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
|
86
87
|
fi
|
@@ -103,7 +104,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
103
104
|
|
104
105
|
# Determine if it should be a worker or a head node for batch jobs
|
105
106
|
if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 && "$HELPER_EXISTS" = "true" ]]; then
|
106
|
-
head_info=$(python3
|
107
|
+
head_info=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --head)
|
107
108
|
if [ $? -eq 0 ]; then
|
108
109
|
# Parse the output using read
|
109
110
|
read head_index head_ip <<< "$head_info"
|
@@ -166,10 +167,17 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
166
167
|
"--object-store-memory=${{shm_size}}"
|
167
168
|
)
|
168
169
|
|
169
|
-
# Start Ray on a worker node
|
170
|
-
ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block
|
171
|
-
|
170
|
+
# Start Ray on a worker node - run in background
|
171
|
+
ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block &
|
172
|
+
|
173
|
+
# Start the worker shutdown listener in the background
|
174
|
+
echo "Starting worker shutdown listener..."
|
175
|
+
python worker_shutdown_listener.py
|
176
|
+
WORKER_EXIT_CODE=$?
|
172
177
|
|
178
|
+
echo "Worker shutdown listener exited with code $WORKER_EXIT_CODE"
|
179
|
+
exit $WORKER_EXIT_CODE
|
180
|
+
else
|
173
181
|
# Additional head-specific parameters
|
174
182
|
head_params=(
|
175
183
|
"--head"
|
@@ -193,13 +201,39 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
193
201
|
# Run user's Python entrypoint
|
194
202
|
echo Running command: python "$@"
|
195
203
|
python "$@"
|
204
|
+
|
205
|
+
# After the user's job completes, signal workers to shut down
|
206
|
+
echo "User job completed. Signaling workers to shut down..."
|
207
|
+
python signal_workers.py --wait-time 15
|
208
|
+
echo "Head node job completed. Exiting."
|
196
209
|
fi
|
197
210
|
"""
|
198
211
|
).strip()
|
199
212
|
|
200
213
|
|
201
|
-
def
|
202
|
-
|
214
|
+
def resolve_source(source: Union[Path, Callable[..., Any]]) -> Union[Path, Callable[..., Any]]:
|
215
|
+
if callable(source):
|
216
|
+
return source
|
217
|
+
elif isinstance(source, Path):
|
218
|
+
# Validate source
|
219
|
+
source = source
|
220
|
+
if not source.exists():
|
221
|
+
raise FileNotFoundError(f"{source} does not exist")
|
222
|
+
return source.absolute()
|
223
|
+
else:
|
224
|
+
raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
|
225
|
+
|
226
|
+
|
227
|
+
def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Optional[Path]) -> types.PayloadEntrypoint:
|
228
|
+
if callable(source):
|
229
|
+
# Entrypoint is generated for callable payloads
|
230
|
+
return types.PayloadEntrypoint(
|
231
|
+
file_path=entrypoint or Path(constants.DEFAULT_ENTRYPOINT_PATH),
|
232
|
+
main_func=_ENTRYPOINT_FUNC_NAME,
|
233
|
+
)
|
234
|
+
|
235
|
+
# Resolve entrypoint path for file-based payloads
|
236
|
+
parent = source.absolute()
|
203
237
|
if entrypoint is None:
|
204
238
|
if parent.is_file():
|
205
239
|
# Infer entrypoint from source
|
@@ -218,12 +252,23 @@ def _resolve_entrypoint(parent: Path, entrypoint: Optional[Path]) -> Path:
|
|
218
252
|
else:
|
219
253
|
# Relative to source dir
|
220
254
|
entrypoint = parent.joinpath(entrypoint)
|
255
|
+
|
256
|
+
# Validate resolved entrypoint file
|
221
257
|
if not entrypoint.is_file():
|
222
258
|
raise FileNotFoundError(
|
223
259
|
"Entrypoint not found. Ensure the entrypoint is a valid file and is under"
|
224
260
|
f" the source directory (source={parent}, entrypoint={entrypoint})"
|
225
261
|
)
|
226
|
-
|
262
|
+
if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
|
263
|
+
raise ValueError(
|
264
|
+
"Unsupported entrypoint type:"
|
265
|
+
f" supported={','.join(_SUPPORTED_ENTRYPOINT_EXTENSIONS)} got={entrypoint.suffix}"
|
266
|
+
)
|
267
|
+
|
268
|
+
return types.PayloadEntrypoint(
|
269
|
+
file_path=entrypoint, # entrypoint is an absolute path at this point
|
270
|
+
main_func=None,
|
271
|
+
)
|
227
272
|
|
228
273
|
|
229
274
|
class JobPayload:
|
@@ -238,40 +283,11 @@ class JobPayload:
|
|
238
283
|
self.entrypoint = Path(entrypoint) if isinstance(entrypoint, str) else entrypoint
|
239
284
|
self.pip_requirements = pip_requirements
|
240
285
|
|
241
|
-
def validate(self) -> None:
|
242
|
-
if callable(self.source):
|
243
|
-
# Any entrypoint value is OK for callable payloads (including None aka default)
|
244
|
-
# since we will generate the file from the serialized callable
|
245
|
-
pass
|
246
|
-
elif isinstance(self.source, Path):
|
247
|
-
# Validate source
|
248
|
-
source = self.source
|
249
|
-
if not source.exists():
|
250
|
-
raise FileNotFoundError(f"{source} does not exist")
|
251
|
-
source = source.absolute()
|
252
|
-
|
253
|
-
# Validate entrypoint
|
254
|
-
entrypoint = _resolve_entrypoint(source, self.entrypoint)
|
255
|
-
if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
|
256
|
-
raise ValueError(
|
257
|
-
"Unsupported entrypoint type:"
|
258
|
-
f" supported={','.join(_SUPPORTED_ENTRYPOINT_EXTENSIONS)} got={entrypoint.suffix}"
|
259
|
-
)
|
260
|
-
|
261
|
-
# Update fields with normalized values
|
262
|
-
self.source = source
|
263
|
-
self.entrypoint = entrypoint
|
264
|
-
else:
|
265
|
-
raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
|
266
|
-
|
267
286
|
def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
|
268
|
-
# Validate payload
|
269
|
-
self.validate()
|
270
|
-
|
271
287
|
# Prepare local variables
|
272
288
|
stage_path = PurePath(stage_path) if isinstance(stage_path, str) else stage_path
|
273
|
-
source = self.source
|
274
|
-
entrypoint = self.entrypoint
|
289
|
+
source = resolve_source(self.source)
|
290
|
+
entrypoint = resolve_entrypoint(source, self.entrypoint)
|
275
291
|
|
276
292
|
# Create stage if necessary
|
277
293
|
stage_name = stage_path.parts[0].lstrip("@")
|
@@ -290,11 +306,11 @@ class JobPayload:
|
|
290
306
|
source_code = generate_python_code(source, source_code_display=True)
|
291
307
|
_ = session.file.put_stream(
|
292
308
|
io.BytesIO(source_code.encode()),
|
293
|
-
stage_location=stage_path.joinpath(entrypoint).as_posix(),
|
309
|
+
stage_location=stage_path.joinpath(entrypoint.file_path).as_posix(),
|
294
310
|
auto_compress=False,
|
295
311
|
overwrite=True,
|
296
312
|
)
|
297
|
-
source = entrypoint.parent
|
313
|
+
source = Path(entrypoint.file_path.parent)
|
298
314
|
elif source.is_dir():
|
299
315
|
# Manually traverse the directory and upload each file, since Snowflake PUT
|
300
316
|
# can't handle directories. Reduce the number of PUT operations by using
|
@@ -337,12 +353,30 @@ class JobPayload:
|
|
337
353
|
overwrite=False, # FIXME
|
338
354
|
)
|
339
355
|
|
356
|
+
# Upload system scripts
|
357
|
+
scripts_dir = Path(__file__).parent.joinpath("scripts")
|
358
|
+
for script_file in scripts_dir.glob("*"):
|
359
|
+
if script_file.is_file():
|
360
|
+
session.file.put(
|
361
|
+
script_file.as_posix(),
|
362
|
+
stage_path.as_posix(),
|
363
|
+
overwrite=True,
|
364
|
+
auto_compress=False,
|
365
|
+
)
|
366
|
+
|
367
|
+
python_entrypoint: List[Union[str, PurePath]] = [
|
368
|
+
PurePath("mljob_launcher.py"),
|
369
|
+
entrypoint.file_path.relative_to(source),
|
370
|
+
]
|
371
|
+
if entrypoint.main_func:
|
372
|
+
python_entrypoint += ["--script_main_func", entrypoint.main_func]
|
373
|
+
|
340
374
|
return types.UploadedPayload(
|
341
375
|
stage_path=stage_path,
|
342
376
|
entrypoint=[
|
343
377
|
"bash",
|
344
378
|
_STARTUP_SCRIPT_PATH,
|
345
|
-
|
379
|
+
*python_entrypoint,
|
346
380
|
],
|
347
381
|
)
|
348
382
|
|
@@ -471,12 +505,11 @@ def generate_python_code(func: Callable[..., Any], source_code_display: bool = F
|
|
471
505
|
# https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
|
472
506
|
source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
|
473
507
|
|
474
|
-
func_name = "func"
|
475
508
|
func_code = f"""
|
476
509
|
{source_code_comment}
|
477
510
|
|
478
511
|
import pickle
|
479
|
-
{
|
512
|
+
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
|
480
513
|
"""
|
481
514
|
|
482
515
|
arg_dict_name = "kwargs"
|
@@ -487,6 +520,7 @@ import pickle
|
|
487
520
|
|
488
521
|
return f"""
|
489
522
|
### Version guard to check compatibility across Python versions ###
|
523
|
+
import os
|
490
524
|
import sys
|
491
525
|
import warnings
|
492
526
|
|
@@ -508,5 +542,5 @@ if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor
|
|
508
542
|
if __name__ == '__main__':
|
509
543
|
{textwrap.indent(param_code, ' ')}
|
510
544
|
|
511
|
-
{
|
545
|
+
__return__ = {_ENTRYPOINT_FUNC_NAME}(**{arg_dict_name})
|
512
546
|
"""
|