snowflake-ml-python 1.8.5__py3-none-any.whl → 1.8.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +6 -9
- snowflake/ml/_internal/utils/connection_params.py +196 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +3 -2
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/payload_utils.py +83 -35
- snowflake/ml/jobs/_utils/scripts/constants.py +19 -3
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -26
- snowflake/ml/jobs/_utils/spec_utils.py +23 -1
- snowflake/ml/jobs/_utils/stage_utils.py +119 -0
- snowflake/ml/jobs/_utils/types.py +5 -1
- snowflake/ml/jobs/decorators.py +6 -7
- snowflake/ml/jobs/job.py +24 -9
- snowflake/ml/jobs/manager.py +102 -19
- snowflake/ml/model/_client/model/model_impl.py +58 -0
- snowflake/ml/model/_client/model/model_version_impl.py +90 -0
- snowflake/ml/model/_client/ops/model_ops.py +6 -3
- snowflake/ml/model/_client/ops/service_ops.py +19 -4
- snowflake/ml/model/_client/sql/service.py +68 -20
- snowflake/ml/model/_client/sql/stage.py +5 -2
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
- snowflake/ml/model/_signatures/core.py +24 -0
- snowflake/ml/monitoring/explain_visualize.py +2 -2
- snowflake/ml/monitoring/model_monitor.py +0 -4
- snowflake/ml/registry/registry.py +34 -14
- snowflake/ml/utils/connection_params.py +1 -1
- snowflake/ml/utils/html_utils.py +263 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/METADATA +14 -5
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/RECORD +33 -30
- snowflake/ml/monitoring/model_monitor_version.py +0 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.8.6.dist-info}/top_level.txt +0 -0
@@ -411,16 +411,13 @@ def send_custom_usage(
|
|
411
411
|
**kwargs: Any,
|
412
412
|
) -> None:
|
413
413
|
conn = _get_snowflake_connection()
|
414
|
-
if conn is None:
|
415
|
-
raise ValueError(
|
416
|
-
"""Snowflake connection is required to send custom telemetry. This means there
|
417
|
-
must be at least one active session, or that telemetry is being sent from within an SPCS service."""
|
418
|
-
)
|
419
414
|
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
415
|
+
# Send telemetry if Snowflake connection is available.
|
416
|
+
if conn is not None:
|
417
|
+
client = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject)
|
418
|
+
common_metrics = client._create_basic_telemetry_data(telemetry_type=telemetry_type)
|
419
|
+
data = {**common_metrics, TelemetryField.KEY_DATA.value: data, **kwargs}
|
420
|
+
client._send(msg=data)
|
424
421
|
|
425
422
|
|
426
423
|
def send_api_usage_telemetry(
|
@@ -0,0 +1,196 @@
|
|
1
|
+
import configparser
|
2
|
+
import os
|
3
|
+
from typing import Optional, Union
|
4
|
+
|
5
|
+
from absl import logging
|
6
|
+
from cryptography.hazmat import backends
|
7
|
+
from cryptography.hazmat.primitives import serialization
|
8
|
+
|
9
|
+
_DEFAULT_CONNECTION_FILE = "~/.snowsql/config"
|
10
|
+
|
11
|
+
|
12
|
+
def _read_token(token_file: str = "") -> str:
|
13
|
+
"""
|
14
|
+
Reads token from environment or file provided.
|
15
|
+
|
16
|
+
First tries to read the token from environment variable
|
17
|
+
(`SNOWFLAKE_TOKEN`) followed by the token file.
|
18
|
+
Both the options are tried out in SnowServices.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
token_file: File from which token needs to be read. Optional.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
the token.
|
25
|
+
"""
|
26
|
+
token = os.getenv("SNOWFLAKE_TOKEN", "")
|
27
|
+
if token:
|
28
|
+
return token
|
29
|
+
if token_file and os.path.exists(token_file):
|
30
|
+
with open(token_file) as f:
|
31
|
+
token = f.read()
|
32
|
+
return token
|
33
|
+
|
34
|
+
|
35
|
+
_ENCRYPTED_PKCS8_PK_HEADER = b"-----BEGIN ENCRYPTED PRIVATE KEY-----"
|
36
|
+
_UNENCRYPTED_PKCS8_PK_HEADER = b"-----BEGIN PRIVATE KEY-----"
|
37
|
+
|
38
|
+
|
39
|
+
def _load_pem_to_der(private_key_path: str) -> bytes:
|
40
|
+
"""Given a private key file path (in PEM format), decode key data into DER format."""
|
41
|
+
with open(private_key_path, "rb") as f:
|
42
|
+
private_key_pem = f.read()
|
43
|
+
private_key_passphrase: Optional[str] = os.getenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", None)
|
44
|
+
|
45
|
+
# Only PKCS#8 format key will be accepted. However, openssl
|
46
|
+
# transparently handle PKCS#8 and PKCS#1 format (by some fallback
|
47
|
+
# logic) and their is no function to distinguish between them. By
|
48
|
+
# reading openssl source code, apparently they also relies on header
|
49
|
+
# to determine if give bytes is PKCS#8 format or not
|
50
|
+
if not private_key_pem.startswith(_ENCRYPTED_PKCS8_PK_HEADER) and not private_key_pem.startswith(
|
51
|
+
_UNENCRYPTED_PKCS8_PK_HEADER
|
52
|
+
):
|
53
|
+
raise Exception("Private key provided is not in PKCS#8 format. Please use correct format.")
|
54
|
+
|
55
|
+
if private_key_pem.startswith(_ENCRYPTED_PKCS8_PK_HEADER) and private_key_passphrase is None:
|
56
|
+
raise Exception(
|
57
|
+
"Private key is encrypted but passphrase could not be found. "
|
58
|
+
"Please set SNOWFLAKE_PRIVATE_KEY_PASSPHRASE env variable."
|
59
|
+
)
|
60
|
+
|
61
|
+
if private_key_pem.startswith(_UNENCRYPTED_PKCS8_PK_HEADER):
|
62
|
+
private_key_passphrase = None
|
63
|
+
|
64
|
+
private_key = serialization.load_pem_private_key(
|
65
|
+
private_key_pem,
|
66
|
+
str.encode(private_key_passphrase) if private_key_passphrase is not None else private_key_passphrase,
|
67
|
+
backends.default_backend(),
|
68
|
+
)
|
69
|
+
|
70
|
+
return private_key.private_bytes(
|
71
|
+
encoding=serialization.Encoding.DER,
|
72
|
+
format=serialization.PrivateFormat.PKCS8,
|
73
|
+
encryption_algorithm=serialization.NoEncryption(),
|
74
|
+
)
|
75
|
+
|
76
|
+
|
77
|
+
def _connection_properties_from_env() -> dict[str, str]:
|
78
|
+
"""Returns a dict with all possible login related env variables."""
|
79
|
+
sf_conn_prop = {
|
80
|
+
# Mandatory fields
|
81
|
+
"account": os.environ["SNOWFLAKE_ACCOUNT"],
|
82
|
+
"database": os.environ["SNOWFLAKE_DATABASE"],
|
83
|
+
# With a default value
|
84
|
+
"token_file": os.getenv("SNOWFLAKE_TOKEN_FILE", "/snowflake/session/token"),
|
85
|
+
"ssl": os.getenv("SNOWFLAKE_SSL", "on"),
|
86
|
+
"protocol": os.getenv("SNOWFLAKE_PROTOCOL", "https"),
|
87
|
+
}
|
88
|
+
# With empty default value
|
89
|
+
for key, env_var in {
|
90
|
+
"user": "SNOWFLAKE_USER",
|
91
|
+
"authenticator": "SNOWFLAKE_AUTHENTICATOR",
|
92
|
+
"password": "SNOWFLAKE_PASSWORD",
|
93
|
+
"host": "SNOWFLAKE_HOST",
|
94
|
+
"port": "SNOWFLAKE_PORT",
|
95
|
+
"schema": "SNOWFLAKE_SCHEMA",
|
96
|
+
"warehouse": "SNOWFLAKE_WAREHOUSE",
|
97
|
+
"private_key_path": "SNOWFLAKE_PRIVATE_KEY_PATH",
|
98
|
+
}.items():
|
99
|
+
value = os.getenv(env_var, "")
|
100
|
+
if value:
|
101
|
+
sf_conn_prop[key] = value
|
102
|
+
return sf_conn_prop
|
103
|
+
|
104
|
+
|
105
|
+
def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -> dict[str, str]:
|
106
|
+
"""Loads the dictionary from snowsql config file."""
|
107
|
+
snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE)
|
108
|
+
if not os.path.exists(snowsql_config_file):
|
109
|
+
logging.error(f"Connection name given but snowsql config file is not found at: {snowsql_config_file}")
|
110
|
+
raise Exception("Snowflake SnowSQL config not found.")
|
111
|
+
|
112
|
+
config = configparser.ConfigParser(inline_comment_prefixes="#")
|
113
|
+
|
114
|
+
snowflake_connection_name = os.getenv("SNOWFLAKE_CONNECTION_NAME")
|
115
|
+
if snowflake_connection_name is not None:
|
116
|
+
connection_name = snowflake_connection_name
|
117
|
+
|
118
|
+
if connection_name:
|
119
|
+
if not connection_name.startswith("connections."):
|
120
|
+
connection_name = "connections." + connection_name
|
121
|
+
else:
|
122
|
+
# See https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
|
123
|
+
connection_name = "connections"
|
124
|
+
|
125
|
+
logging.info(f"Reading {snowsql_config_file} for connection parameters defined as {connection_name}")
|
126
|
+
config.read(snowsql_config_file)
|
127
|
+
conn_params = dict(config[connection_name])
|
128
|
+
# Remap names to appropriate args in Python Connector API
|
129
|
+
# Note: "dbname" should become "database"
|
130
|
+
conn_params = {k.replace("name", ""): v.strip('"') for k, v in conn_params.items()}
|
131
|
+
if "db" in conn_params:
|
132
|
+
conn_params["database"] = conn_params["db"]
|
133
|
+
del conn_params["db"]
|
134
|
+
return conn_params
|
135
|
+
|
136
|
+
|
137
|
+
def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) -> dict[str, Union[str, bytes]]:
|
138
|
+
"""Returns a dict that can be used directly into snowflake python connector or Snowpark session config.
|
139
|
+
|
140
|
+
NOTE: Token/Auth information is sideloaded in all cases above, if provided in following order:
|
141
|
+
1. If SNOWFLAKE_TOKEN is defined in the environment, it will be used.
|
142
|
+
2. If SNOWFLAKE_TOKEN_FILE is defined in the environment and file matching the value found, content of the file
|
143
|
+
will be used.
|
144
|
+
|
145
|
+
If token is found, username, password will be reset and 'authenticator' will be set to 'oauth'.
|
146
|
+
|
147
|
+
Python Connector:
|
148
|
+
>> ctx = snowflake.connector.connect(**(SnowflakeLoginOptions()))
|
149
|
+
|
150
|
+
Snowpark Session:
|
151
|
+
>> session = Session.builder.configs(SnowflakeLoginOptions()).create()
|
152
|
+
|
153
|
+
Usage Note:
|
154
|
+
Ideally one should have a snowsql config file. Read more here:
|
155
|
+
https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
|
156
|
+
|
157
|
+
If snowsql config file does not exist, it tries auth from env variables.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
connection_name: Name of the connection to look for inside the config file. If environment variable
|
161
|
+
SNOWFLAKE_CONNECTION_NAME is provided, it will override the input connection_name.
|
162
|
+
login_file: If provided, this is used as config file instead of default one (_DEFAULT_CONNECTION_FILE).
|
163
|
+
|
164
|
+
Returns:
|
165
|
+
A dict with connection parameters.
|
166
|
+
|
167
|
+
Raises:
|
168
|
+
Exception: if none of config file and environment variable are present.
|
169
|
+
"""
|
170
|
+
conn_prop: dict[str, Union[str, bytes]] = {}
|
171
|
+
login_file = login_file or os.path.expanduser(_DEFAULT_CONNECTION_FILE)
|
172
|
+
# If login file exists, use this exclusively.
|
173
|
+
if os.path.exists(login_file):
|
174
|
+
conn_prop = {**(_load_from_snowsql_config_file(connection_name, login_file))}
|
175
|
+
else:
|
176
|
+
# If environment exists for SNOWFLAKE_ACCOUNT, assume everything
|
177
|
+
# comes from environment. Mixing it not allowed.
|
178
|
+
account = os.getenv("SNOWFLAKE_ACCOUNT", "")
|
179
|
+
if account:
|
180
|
+
conn_prop = {**_connection_properties_from_env()}
|
181
|
+
else:
|
182
|
+
raise Exception("Snowflake credential is neither set in env nor a login file was provided.")
|
183
|
+
|
184
|
+
# Token, if specified, is always side-loaded in all cases.
|
185
|
+
token = _read_token(str(conn_prop["token_file"]) if "token_file" in conn_prop else "")
|
186
|
+
if token:
|
187
|
+
conn_prop["token"] = token
|
188
|
+
if "authenticator" not in conn_prop or conn_prop["authenticator"]:
|
189
|
+
conn_prop["authenticator"] = "oauth"
|
190
|
+
elif "private_key_path" in conn_prop and "private_key" not in conn_prop:
|
191
|
+
conn_prop["private_key"] = _load_pem_to_der(str(conn_prop["private_key_path"]))
|
192
|
+
|
193
|
+
if "ssl" in conn_prop and conn_prop["ssl"].lower() == "off":
|
194
|
+
conn_prop["protocol"] = "http"
|
195
|
+
|
196
|
+
return conn_prop
|
snowflake/ml/jobs/__init__.py
CHANGED
@@ -6,6 +6,7 @@ DEFAULT_CONTAINER_NAME = "main"
|
|
6
6
|
PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
7
7
|
RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
|
8
8
|
MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
|
9
|
+
RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG"
|
9
10
|
MEMORY_VOLUME_NAME = "dshm"
|
10
11
|
STAGE_VOLUME_NAME = "stage-volume"
|
11
12
|
STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
@@ -14,7 +15,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
|
14
15
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
15
16
|
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
16
17
|
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
17
|
-
DEFAULT_IMAGE_TAG = "1.2
|
18
|
+
DEFAULT_IMAGE_TAG = "1.4.2"
|
18
19
|
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
19
20
|
|
20
21
|
# Percent of container memory to allocate for /dev/shm volume
|
@@ -43,7 +44,7 @@ ENABLE_HEALTH_CHECKS = "false"
|
|
43
44
|
|
44
45
|
# Job status polling constants
|
45
46
|
JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
|
46
|
-
JOB_POLL_MAX_DELAY_SECONDS =
|
47
|
+
JOB_POLL_MAX_DELAY_SECONDS = 30
|
47
48
|
|
48
49
|
# Magic attributes
|
49
50
|
IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
|
@@ -0,0 +1,43 @@
|
|
1
|
+
import inspect
|
2
|
+
from typing import Any, Callable, Optional
|
3
|
+
|
4
|
+
from snowflake import snowpark
|
5
|
+
from snowflake.snowpark import context as sp_context
|
6
|
+
|
7
|
+
|
8
|
+
class FunctionPayload:
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
func: Callable[..., Any],
|
12
|
+
session: Optional[snowpark.Session] = None,
|
13
|
+
session_argument: str = "",
|
14
|
+
*args: Any,
|
15
|
+
**kwargs: Any
|
16
|
+
) -> None:
|
17
|
+
self.function = func
|
18
|
+
self.args = args
|
19
|
+
self.kwargs = kwargs
|
20
|
+
self._session = session
|
21
|
+
self._session_argument = session_argument
|
22
|
+
|
23
|
+
@property
|
24
|
+
def session(self) -> Optional[snowpark.Session]:
|
25
|
+
return self._session
|
26
|
+
|
27
|
+
def __getstate__(self) -> dict[str, Any]:
|
28
|
+
"""Customize pickling to exclude session."""
|
29
|
+
state = self.__dict__.copy()
|
30
|
+
state["_session"] = None
|
31
|
+
return state
|
32
|
+
|
33
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
34
|
+
"""Restore session from context during unpickling."""
|
35
|
+
self.__dict__.update(state)
|
36
|
+
self._session = sp_context.get_active_session()
|
37
|
+
|
38
|
+
def __call__(self) -> Any:
|
39
|
+
sig = inspect.signature(self.function)
|
40
|
+
bound = sig.bind_partial(*self.args, **self.kwargs)
|
41
|
+
bound.arguments[self._session_argument] = self._session
|
42
|
+
|
43
|
+
return self.function(*bound.args, **bound.kwargs)
|
@@ -12,10 +12,17 @@ import cloudpickle as cp
|
|
12
12
|
from packaging import version
|
13
13
|
|
14
14
|
from snowflake import snowpark
|
15
|
-
from snowflake.ml.jobs._utils import
|
15
|
+
from snowflake.ml.jobs._utils import (
|
16
|
+
constants,
|
17
|
+
function_payload_utils,
|
18
|
+
stage_utils,
|
19
|
+
types,
|
20
|
+
)
|
16
21
|
from snowflake.snowpark import exceptions as sp_exceptions
|
17
22
|
from snowflake.snowpark._internal import code_generation
|
18
23
|
|
24
|
+
cp.register_pickle_by_value(function_payload_utils)
|
25
|
+
|
19
26
|
_SUPPORTED_ARG_TYPES = {str, int, float}
|
20
27
|
_SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
|
21
28
|
_ENTRYPOINT_FUNC_NAME = "func"
|
@@ -217,20 +224,23 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
217
224
|
).strip()
|
218
225
|
|
219
226
|
|
220
|
-
def resolve_source(
|
227
|
+
def resolve_source(
|
228
|
+
source: Union[Path, stage_utils.StagePath, Callable[..., Any]]
|
229
|
+
) -> Union[Path, stage_utils.StagePath, Callable[..., Any]]:
|
221
230
|
if callable(source):
|
222
231
|
return source
|
223
|
-
elif isinstance(source, Path):
|
224
|
-
# Validate source
|
225
|
-
source = source
|
232
|
+
elif isinstance(source, (Path, stage_utils.StagePath)):
|
226
233
|
if not source.exists():
|
227
234
|
raise FileNotFoundError(f"{source} does not exist")
|
228
235
|
return source.absolute()
|
229
236
|
else:
|
230
|
-
raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
|
237
|
+
raise ValueError("Unsupported source type. Source must be a stage, file, directory, or callable.")
|
231
238
|
|
232
239
|
|
233
|
-
def resolve_entrypoint(
|
240
|
+
def resolve_entrypoint(
|
241
|
+
source: Union[Path, stage_utils.StagePath, Callable[..., Any]],
|
242
|
+
entrypoint: Optional[Union[stage_utils.StagePath, Path]],
|
243
|
+
) -> types.PayloadEntrypoint:
|
234
244
|
if callable(source):
|
235
245
|
# Entrypoint is generated for callable payloads
|
236
246
|
return types.PayloadEntrypoint(
|
@@ -245,11 +255,11 @@ def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Opti
|
|
245
255
|
# Infer entrypoint from source
|
246
256
|
entrypoint = parent
|
247
257
|
else:
|
248
|
-
raise ValueError("
|
258
|
+
raise ValueError("Entrypoint must be provided when source is a directory")
|
249
259
|
elif entrypoint.is_absolute():
|
250
260
|
# Absolute path - validate it's a subpath of source dir
|
251
261
|
if not entrypoint.is_relative_to(parent):
|
252
|
-
raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint}
|
262
|
+
raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint}")
|
253
263
|
else:
|
254
264
|
# Relative path
|
255
265
|
if (abs_entrypoint := entrypoint.absolute()).is_relative_to(parent) and abs_entrypoint.is_file():
|
@@ -265,6 +275,7 @@ def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Opti
|
|
265
275
|
"Entrypoint not found. Ensure the entrypoint is a valid file and is under"
|
266
276
|
f" the source directory (source={parent}, entrypoint={entrypoint})"
|
267
277
|
)
|
278
|
+
|
268
279
|
if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
|
269
280
|
raise ValueError(
|
270
281
|
"Unsupported entrypoint type:"
|
@@ -285,8 +296,9 @@ class JobPayload:
|
|
285
296
|
*,
|
286
297
|
pip_requirements: Optional[list[str]] = None,
|
287
298
|
) -> None:
|
288
|
-
|
289
|
-
self.
|
299
|
+
# for stage path like snow://domain....., Path(path) will remove duplicate /, it will become snow:/ domain...
|
300
|
+
self.source = stage_utils.identify_stage_path(source) if isinstance(source, str) else source
|
301
|
+
self.entrypoint = stage_utils.identify_stage_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
|
290
302
|
self.pip_requirements = pip_requirements
|
291
303
|
|
292
304
|
def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
|
@@ -310,7 +322,7 @@ class JobPayload:
|
|
310
322
|
).collect()
|
311
323
|
|
312
324
|
# Upload payload to stage
|
313
|
-
if not isinstance(source, Path):
|
325
|
+
if not isinstance(source, (Path, stage_utils.StagePath)):
|
314
326
|
source_code = generate_python_code(source, source_code_display=True)
|
315
327
|
_ = session.file.put_stream(
|
316
328
|
io.BytesIO(source_code.encode()),
|
@@ -321,27 +333,38 @@ class JobPayload:
|
|
321
333
|
source = Path(entrypoint.file_path.parent)
|
322
334
|
if not any(r.startswith("cloudpickle") for r in pip_requirements):
|
323
335
|
pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
|
324
|
-
|
325
|
-
|
326
|
-
#
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
}
|
336
|
+
|
337
|
+
elif isinstance(source, stage_utils.StagePath):
|
338
|
+
# copy payload to stage
|
339
|
+
if source == entrypoint.file_path:
|
340
|
+
source = source.parent
|
341
|
+
source_path = source.as_posix() + "/"
|
342
|
+
session.sql(f"copy files into {stage_path}/ from {source_path}").collect()
|
343
|
+
|
344
|
+
elif isinstance(source, Path):
|
345
|
+
if source.is_dir():
|
346
|
+
# Manually traverse the directory and upload each file, since Snowflake PUT
|
347
|
+
# can't handle directories. Reduce the number of PUT operations by using
|
348
|
+
# wildcard patterns to batch upload files with the same extension.
|
349
|
+
for path in {
|
350
|
+
p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
|
351
|
+
for p in source.resolve().rglob("*")
|
352
|
+
if p.is_file()
|
353
|
+
}:
|
354
|
+
session.file.put(
|
355
|
+
str(path),
|
356
|
+
stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
|
357
|
+
overwrite=True,
|
358
|
+
auto_compress=False,
|
359
|
+
)
|
360
|
+
else:
|
331
361
|
session.file.put(
|
332
|
-
str(
|
333
|
-
stage_path.
|
362
|
+
str(source.resolve()),
|
363
|
+
stage_path.as_posix(),
|
334
364
|
overwrite=True,
|
335
365
|
auto_compress=False,
|
336
366
|
)
|
337
|
-
|
338
|
-
session.file.put(
|
339
|
-
str(source.resolve()),
|
340
|
-
stage_path.as_posix(),
|
341
|
-
overwrite=True,
|
342
|
-
auto_compress=False,
|
343
|
-
)
|
344
|
-
source = source.parent
|
367
|
+
source = source.parent
|
345
368
|
|
346
369
|
# Upload requirements
|
347
370
|
# TODO: Check if payload includes both a requirements.txt file and pip_requirements
|
@@ -502,9 +525,15 @@ def _generate_param_handler_code(signature: inspect.Signature, output_name: str
|
|
502
525
|
return param_code
|
503
526
|
|
504
527
|
|
505
|
-
def generate_python_code(
|
528
|
+
def generate_python_code(payload: Callable[..., Any], source_code_display: bool = False) -> str:
|
506
529
|
"""Generate an entrypoint script from a Python function."""
|
507
|
-
|
530
|
+
|
531
|
+
if isinstance(payload, function_payload_utils.FunctionPayload):
|
532
|
+
function = payload.function
|
533
|
+
else:
|
534
|
+
function = payload
|
535
|
+
|
536
|
+
signature = inspect.signature(function)
|
508
537
|
if any(
|
509
538
|
p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
510
539
|
for p in signature.parameters.values()
|
@@ -513,21 +542,20 @@ def generate_python_code(func: Callable[..., Any], source_code_display: bool = F
|
|
513
542
|
|
514
543
|
# Mirrored from Snowpark generate_python_code() function
|
515
544
|
# https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
|
516
|
-
source_code_comment = _generate_source_code_comment(
|
545
|
+
source_code_comment = _generate_source_code_comment(function) if source_code_display else ""
|
517
546
|
|
518
547
|
arg_dict_name = "kwargs"
|
519
|
-
if
|
548
|
+
if isinstance(payload, function_payload_utils.FunctionPayload):
|
520
549
|
param_code = f"{arg_dict_name} = {{}}"
|
521
550
|
else:
|
522
551
|
param_code = _generate_param_handler_code(signature, arg_dict_name)
|
523
|
-
|
524
552
|
return f"""
|
525
553
|
import sys
|
526
554
|
import pickle
|
527
555
|
|
528
556
|
try:
|
529
557
|
{textwrap.indent(source_code_comment, ' ')}
|
530
|
-
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(
|
558
|
+
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(payload).hex()}'))
|
531
559
|
except (TypeError, pickle.PickleError):
|
532
560
|
if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
|
533
561
|
raise RuntimeError(
|
@@ -551,3 +579,23 @@ if __name__ == '__main__':
|
|
551
579
|
|
552
580
|
__return__ = {_ENTRYPOINT_FUNC_NAME}(**{arg_dict_name})
|
553
581
|
"""
|
582
|
+
|
583
|
+
|
584
|
+
def create_function_payload(
|
585
|
+
func: Callable[..., Any], *args: Any, **kwargs: Any
|
586
|
+
) -> function_payload_utils.FunctionPayload:
|
587
|
+
signature = inspect.signature(func)
|
588
|
+
bound = signature.bind(*args, **kwargs)
|
589
|
+
bound.apply_defaults()
|
590
|
+
session_argument = ""
|
591
|
+
session = None
|
592
|
+
for name, val in list(bound.arguments.items()):
|
593
|
+
if isinstance(val, snowpark.Session):
|
594
|
+
if session:
|
595
|
+
raise TypeError(f"Expected only one Session-type argument, but got both {session_argument} and {name}.")
|
596
|
+
session = val
|
597
|
+
session_argument = name
|
598
|
+
del bound.arguments[name]
|
599
|
+
payload = function_payload_utils.FunctionPayload(func, session, session_argument, *bound.args, **bound.kwargs)
|
600
|
+
|
601
|
+
return payload
|
@@ -1,10 +1,26 @@
|
|
1
|
+
from snowflake.ml.jobs._utils import constants as mljob_constants
|
2
|
+
|
1
3
|
# Constants defining the shutdown signal actor configuration.
|
2
4
|
SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
|
3
5
|
SHUTDOWN_ACTOR_NAMESPACE = "default"
|
4
6
|
SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0
|
5
7
|
|
6
8
|
|
9
|
+
# The followings are Inherited from snowflake.ml.jobs._utils.constants
|
10
|
+
# We need to copy them here since snowml package on the server side does
|
11
|
+
# not have the latest version of the code
|
12
|
+
|
7
13
|
# Log start and end messages
|
8
|
-
|
9
|
-
|
10
|
-
|
14
|
+
LOG_START_MSG = getattr(
|
15
|
+
mljob_constants,
|
16
|
+
"LOG_START_MSG",
|
17
|
+
"--------------------------------\nML job started\n--------------------------------",
|
18
|
+
)
|
19
|
+
LOG_END_MSG = getattr(
|
20
|
+
mljob_constants,
|
21
|
+
"LOG_END_MSG",
|
22
|
+
"--------------------------------\nML job finished\n--------------------------------",
|
23
|
+
)
|
24
|
+
|
25
|
+
# min_instances environment variable name
|
26
|
+
MIN_INSTANCES_ENV_VAR = getattr(mljob_constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
|
@@ -13,7 +13,7 @@ from pathlib import Path
|
|
13
13
|
from typing import Any, Optional
|
14
14
|
|
15
15
|
import cloudpickle
|
16
|
-
from constants import LOG_END_MSG, LOG_START_MSG
|
16
|
+
from constants import LOG_END_MSG, LOG_START_MSG, MIN_INSTANCES_ENV_VAR
|
17
17
|
|
18
18
|
from snowflake.ml.jobs._utils import constants
|
19
19
|
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
@@ -72,28 +72,6 @@ class SimpleJSONEncoder(json.JSONEncoder):
|
|
72
72
|
return f"Unserializable object: {repr(obj)}"
|
73
73
|
|
74
74
|
|
75
|
-
def get_active_node_count() -> int:
|
76
|
-
"""
|
77
|
-
Count the number of active nodes in the Ray cluster.
|
78
|
-
|
79
|
-
Returns:
|
80
|
-
int: Total count of active nodes
|
81
|
-
"""
|
82
|
-
import ray
|
83
|
-
|
84
|
-
if not ray.is_initialized():
|
85
|
-
ray.init(address="auto", ignore_reinit_error=True, log_to_driver=False)
|
86
|
-
try:
|
87
|
-
nodes = [node for node in ray.nodes() if node.get("Alive")]
|
88
|
-
total_active = len(nodes)
|
89
|
-
|
90
|
-
logger.info(f"Active nodes: {total_active}")
|
91
|
-
return total_active
|
92
|
-
except Exception as e:
|
93
|
-
logger.warning(f"Error getting active node count: {e}")
|
94
|
-
return 0
|
95
|
-
|
96
|
-
|
97
75
|
def wait_for_min_instances(min_instances: int) -> None:
|
98
76
|
"""
|
99
77
|
Wait until the specified minimum number of instances are available in the Ray cluster.
|
@@ -108,13 +86,16 @@ def wait_for_min_instances(min_instances: int) -> None:
|
|
108
86
|
logger.debug("Minimum instances is 1 or less, no need to wait for additional instances")
|
109
87
|
return
|
110
88
|
|
89
|
+
# mljob_launcher runs inside the CR where mlruntime libraries are available, so we can import common_util directly
|
90
|
+
from common_utils import common_util as mlrs_util
|
91
|
+
|
111
92
|
start_time = time.time()
|
112
93
|
timeout = os.getenv("JOB_MIN_INSTANCES_TIMEOUT", TIMEOUT)
|
113
94
|
check_interval = os.getenv("JOB_MIN_INSTANCES_CHECK_INTERVAL", CHECK_INTERVAL)
|
114
95
|
logger.debug(f"Waiting for at least {min_instances} instances to be ready (timeout: {timeout}s)")
|
115
96
|
|
116
97
|
while time.time() - start_time < timeout:
|
117
|
-
total_nodes =
|
98
|
+
total_nodes = mlrs_util.get_num_ray_nodes()
|
118
99
|
|
119
100
|
if total_nodes >= min_instances:
|
120
101
|
elapsed = time.time() - start_time
|
@@ -128,7 +109,8 @@ def wait_for_min_instances(min_instances: int) -> None:
|
|
128
109
|
time.sleep(check_interval)
|
129
110
|
|
130
111
|
raise TimeoutError(
|
131
|
-
f"Timed out after {timeout}s waiting for {min_instances} instances, only
|
112
|
+
f"Timed out after {timeout}s waiting for {min_instances} instances, only "
|
113
|
+
f"{mlrs_util.get_num_ray_nodes()} available"
|
132
114
|
)
|
133
115
|
|
134
116
|
|
@@ -199,7 +181,7 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
199
181
|
"""
|
200
182
|
try:
|
201
183
|
# Wait for minimum required instances if specified
|
202
|
-
min_instances_str = os.environ.get("
|
184
|
+
min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
|
203
185
|
if min_instances_str and int(min_instances_str) > 1:
|
204
186
|
wait_for_min_instances(int(min_instances_str))
|
205
187
|
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import logging
|
2
|
+
import os
|
2
3
|
from math import ceil
|
3
4
|
from pathlib import PurePath
|
4
5
|
from typing import Any, Optional, Union
|
@@ -30,7 +31,7 @@ def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.Image
|
|
30
31
|
# Use MLRuntime image
|
31
32
|
image_repo = constants.DEFAULT_IMAGE_REPO
|
32
33
|
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
33
|
-
image_tag =
|
34
|
+
image_tag = _get_runtime_image_tag()
|
34
35
|
|
35
36
|
# TODO: Should each instance consume the entire pod?
|
36
37
|
return types.ImageSpec(
|
@@ -346,3 +347,24 @@ def _merge_lists_of_dicts(
|
|
346
347
|
result[key] = d
|
347
348
|
|
348
349
|
return list(result.values())
|
350
|
+
|
351
|
+
|
352
|
+
def _get_runtime_image_tag() -> str:
|
353
|
+
"""
|
354
|
+
Detect runtime image tag from container environment.
|
355
|
+
|
356
|
+
Checks in order:
|
357
|
+
1. Environment variable MLRS_CONTAINER_IMAGE_TAG
|
358
|
+
2. Falls back to hardcoded default
|
359
|
+
|
360
|
+
Returns:
|
361
|
+
str: The runtime image tag to use for job containers
|
362
|
+
"""
|
363
|
+
env_tag = os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR)
|
364
|
+
if env_tag:
|
365
|
+
logging.debug(f"Using runtime image tag from environment: {env_tag}")
|
366
|
+
return env_tag
|
367
|
+
|
368
|
+
# Fall back to default
|
369
|
+
logging.debug(f"Using default runtime image tag: {constants.DEFAULT_IMAGE_TAG}")
|
370
|
+
return constants.DEFAULT_IMAGE_TAG
|