snowflake-ml-python 1.8.5__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +6 -9
- snowflake/ml/_internal/utils/connection_params.py +196 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +61 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +3 -2
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +89 -40
- snowflake/ml/jobs/_utils/query_helper.py +9 -0
- snowflake/ml/jobs/_utils/scripts/constants.py +19 -3
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -26
- snowflake/ml/jobs/_utils/spec_utils.py +29 -5
- snowflake/ml/jobs/_utils/stage_utils.py +119 -0
- snowflake/ml/jobs/_utils/types.py +5 -1
- snowflake/ml/jobs/decorators.py +20 -28
- snowflake/ml/jobs/job.py +197 -61
- snowflake/ml/jobs/manager.py +253 -121
- snowflake/ml/model/_client/model/model_impl.py +58 -0
- snowflake/ml/model/_client/model/model_version_impl.py +90 -0
- snowflake/ml/model/_client/ops/model_ops.py +18 -6
- snowflake/ml/model/_client/ops/service_ops.py +23 -6
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
- snowflake/ml/model/_client/sql/service.py +68 -20
- snowflake/ml/model/_client/sql/stage.py +5 -2
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_signatures/core.py +24 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/monitoring/explain_visualize.py +2 -2
- snowflake/ml/monitoring/model_monitor.py +0 -4
- snowflake/ml/registry/_manager/model_manager.py +30 -15
- snowflake/ml/registry/registry.py +144 -47
- snowflake/ml/utils/connection_params.py +1 -1
- snowflake/ml/utils/html_utils.py +263 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +64 -19
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +48 -41
- snowflake/ml/monitoring/model_monitor_version.py +0 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
@@ -12,10 +12,17 @@ import cloudpickle as cp
|
|
12
12
|
from packaging import version
|
13
13
|
|
14
14
|
from snowflake import snowpark
|
15
|
-
from snowflake.
|
16
|
-
from snowflake.
|
15
|
+
from snowflake.connector import errors
|
16
|
+
from snowflake.ml.jobs._utils import (
|
17
|
+
constants,
|
18
|
+
function_payload_utils,
|
19
|
+
stage_utils,
|
20
|
+
types,
|
21
|
+
)
|
17
22
|
from snowflake.snowpark._internal import code_generation
|
18
23
|
|
24
|
+
cp.register_pickle_by_value(function_payload_utils)
|
25
|
+
|
19
26
|
_SUPPORTED_ARG_TYPES = {str, int, float}
|
20
27
|
_SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
|
21
28
|
_ENTRYPOINT_FUNC_NAME = "func"
|
@@ -217,20 +224,23 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
217
224
|
).strip()
|
218
225
|
|
219
226
|
|
220
|
-
def resolve_source(
|
227
|
+
def resolve_source(
|
228
|
+
source: Union[Path, stage_utils.StagePath, Callable[..., Any]]
|
229
|
+
) -> Union[Path, stage_utils.StagePath, Callable[..., Any]]:
|
221
230
|
if callable(source):
|
222
231
|
return source
|
223
|
-
elif isinstance(source, Path):
|
224
|
-
# Validate source
|
225
|
-
source = source
|
232
|
+
elif isinstance(source, (Path, stage_utils.StagePath)):
|
226
233
|
if not source.exists():
|
227
234
|
raise FileNotFoundError(f"{source} does not exist")
|
228
235
|
return source.absolute()
|
229
236
|
else:
|
230
|
-
raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
|
237
|
+
raise ValueError("Unsupported source type. Source must be a stage, file, directory, or callable.")
|
231
238
|
|
232
239
|
|
233
|
-
def resolve_entrypoint(
|
240
|
+
def resolve_entrypoint(
|
241
|
+
source: Union[Path, stage_utils.StagePath, Callable[..., Any]],
|
242
|
+
entrypoint: Optional[Union[stage_utils.StagePath, Path]],
|
243
|
+
) -> types.PayloadEntrypoint:
|
234
244
|
if callable(source):
|
235
245
|
# Entrypoint is generated for callable payloads
|
236
246
|
return types.PayloadEntrypoint(
|
@@ -245,11 +255,11 @@ def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Opti
|
|
245
255
|
# Infer entrypoint from source
|
246
256
|
entrypoint = parent
|
247
257
|
else:
|
248
|
-
raise ValueError("
|
258
|
+
raise ValueError("Entrypoint must be provided when source is a directory")
|
249
259
|
elif entrypoint.is_absolute():
|
250
260
|
# Absolute path - validate it's a subpath of source dir
|
251
261
|
if not entrypoint.is_relative_to(parent):
|
252
|
-
raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint}
|
262
|
+
raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint}")
|
253
263
|
else:
|
254
264
|
# Relative path
|
255
265
|
if (abs_entrypoint := entrypoint.absolute()).is_relative_to(parent) and abs_entrypoint.is_file():
|
@@ -265,6 +275,7 @@ def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Opti
|
|
265
275
|
"Entrypoint not found. Ensure the entrypoint is a valid file and is under"
|
266
276
|
f" the source directory (source={parent}, entrypoint={entrypoint})"
|
267
277
|
)
|
278
|
+
|
268
279
|
if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
|
269
280
|
raise ValueError(
|
270
281
|
"Unsupported entrypoint type:"
|
@@ -285,8 +296,9 @@ class JobPayload:
|
|
285
296
|
*,
|
286
297
|
pip_requirements: Optional[list[str]] = None,
|
287
298
|
) -> None:
|
288
|
-
|
289
|
-
self.
|
299
|
+
# for stage path like snow://domain....., Path(path) will remove duplicate /, it will become snow:/ domain...
|
300
|
+
self.source = stage_utils.identify_stage_path(source) if isinstance(source, str) else source
|
301
|
+
self.entrypoint = stage_utils.identify_stage_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
|
290
302
|
self.pip_requirements = pip_requirements
|
291
303
|
|
292
304
|
def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
|
@@ -300,17 +312,18 @@ class JobPayload:
|
|
300
312
|
stage_name = stage_path.parts[0].lstrip("@")
|
301
313
|
# Explicitly check if stage exists first since we may not have CREATE STAGE privilege
|
302
314
|
try:
|
303
|
-
session.
|
304
|
-
except
|
305
|
-
session.
|
315
|
+
session._conn.run_query("describe stage identifier(?)", params=[stage_name], _force_qmark_paramstyle=True)
|
316
|
+
except errors.ProgrammingError:
|
317
|
+
session._conn.run_query(
|
306
318
|
"create stage if not exists identifier(?)"
|
307
319
|
" encryption = ( type = 'SNOWFLAKE_SSE' )"
|
308
320
|
" comment = 'Created by snowflake.ml.jobs Python API'",
|
309
321
|
params=[stage_name],
|
310
|
-
|
322
|
+
_force_qmark_paramstyle=True,
|
323
|
+
)
|
311
324
|
|
312
325
|
# Upload payload to stage
|
313
|
-
if not isinstance(source, Path):
|
326
|
+
if not isinstance(source, (Path, stage_utils.StagePath)):
|
314
327
|
source_code = generate_python_code(source, source_code_display=True)
|
315
328
|
_ = session.file.put_stream(
|
316
329
|
io.BytesIO(source_code.encode()),
|
@@ -321,27 +334,38 @@ class JobPayload:
|
|
321
334
|
source = Path(entrypoint.file_path.parent)
|
322
335
|
if not any(r.startswith("cloudpickle") for r in pip_requirements):
|
323
336
|
pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
|
324
|
-
|
325
|
-
|
326
|
-
#
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
}
|
337
|
+
|
338
|
+
elif isinstance(source, stage_utils.StagePath):
|
339
|
+
# copy payload to stage
|
340
|
+
if source == entrypoint.file_path:
|
341
|
+
source = source.parent
|
342
|
+
source_path = source.as_posix() + "/"
|
343
|
+
session.sql(f"copy files into {stage_path}/ from {source_path}").collect()
|
344
|
+
|
345
|
+
elif isinstance(source, Path):
|
346
|
+
if source.is_dir():
|
347
|
+
# Manually traverse the directory and upload each file, since Snowflake PUT
|
348
|
+
# can't handle directories. Reduce the number of PUT operations by using
|
349
|
+
# wildcard patterns to batch upload files with the same extension.
|
350
|
+
for path in {
|
351
|
+
p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
|
352
|
+
for p in source.resolve().rglob("*")
|
353
|
+
if p.is_file()
|
354
|
+
}:
|
355
|
+
session.file.put(
|
356
|
+
str(path),
|
357
|
+
stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
|
358
|
+
overwrite=True,
|
359
|
+
auto_compress=False,
|
360
|
+
)
|
361
|
+
else:
|
331
362
|
session.file.put(
|
332
|
-
str(
|
333
|
-
stage_path.
|
363
|
+
str(source.resolve()),
|
364
|
+
stage_path.as_posix(),
|
334
365
|
overwrite=True,
|
335
366
|
auto_compress=False,
|
336
367
|
)
|
337
|
-
|
338
|
-
session.file.put(
|
339
|
-
str(source.resolve()),
|
340
|
-
stage_path.as_posix(),
|
341
|
-
overwrite=True,
|
342
|
-
auto_compress=False,
|
343
|
-
)
|
344
|
-
source = source.parent
|
368
|
+
source = source.parent
|
345
369
|
|
346
370
|
# Upload requirements
|
347
371
|
# TODO: Check if payload includes both a requirements.txt file and pip_requirements
|
@@ -502,9 +526,15 @@ def _generate_param_handler_code(signature: inspect.Signature, output_name: str
|
|
502
526
|
return param_code
|
503
527
|
|
504
528
|
|
505
|
-
def generate_python_code(
|
529
|
+
def generate_python_code(payload: Callable[..., Any], source_code_display: bool = False) -> str:
|
506
530
|
"""Generate an entrypoint script from a Python function."""
|
507
|
-
|
531
|
+
|
532
|
+
if isinstance(payload, function_payload_utils.FunctionPayload):
|
533
|
+
function = payload.function
|
534
|
+
else:
|
535
|
+
function = payload
|
536
|
+
|
537
|
+
signature = inspect.signature(function)
|
508
538
|
if any(
|
509
539
|
p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
510
540
|
for p in signature.parameters.values()
|
@@ -513,21 +543,20 @@ def generate_python_code(func: Callable[..., Any], source_code_display: bool = F
|
|
513
543
|
|
514
544
|
# Mirrored from Snowpark generate_python_code() function
|
515
545
|
# https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
|
516
|
-
source_code_comment = _generate_source_code_comment(
|
546
|
+
source_code_comment = _generate_source_code_comment(function) if source_code_display else ""
|
517
547
|
|
518
548
|
arg_dict_name = "kwargs"
|
519
|
-
if
|
549
|
+
if isinstance(payload, function_payload_utils.FunctionPayload):
|
520
550
|
param_code = f"{arg_dict_name} = {{}}"
|
521
551
|
else:
|
522
552
|
param_code = _generate_param_handler_code(signature, arg_dict_name)
|
523
|
-
|
524
553
|
return f"""
|
525
554
|
import sys
|
526
555
|
import pickle
|
527
556
|
|
528
557
|
try:
|
529
558
|
{textwrap.indent(source_code_comment, ' ')}
|
530
|
-
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(
|
559
|
+
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(payload).hex()}'))
|
531
560
|
except (TypeError, pickle.PickleError):
|
532
561
|
if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
|
533
562
|
raise RuntimeError(
|
@@ -551,3 +580,23 @@ if __name__ == '__main__':
|
|
551
580
|
|
552
581
|
__return__ = {_ENTRYPOINT_FUNC_NAME}(**{arg_dict_name})
|
553
582
|
"""
|
583
|
+
|
584
|
+
|
585
|
+
def create_function_payload(
|
586
|
+
func: Callable[..., Any], *args: Any, **kwargs: Any
|
587
|
+
) -> function_payload_utils.FunctionPayload:
|
588
|
+
signature = inspect.signature(func)
|
589
|
+
bound = signature.bind(*args, **kwargs)
|
590
|
+
bound.apply_defaults()
|
591
|
+
session_argument = ""
|
592
|
+
session = None
|
593
|
+
for name, val in list(bound.arguments.items()):
|
594
|
+
if isinstance(val, snowpark.Session):
|
595
|
+
if session:
|
596
|
+
raise TypeError(f"Expected only one Session-type argument, but got both {session_argument} and {name}.")
|
597
|
+
session = val
|
598
|
+
session_argument = name
|
599
|
+
del bound.arguments[name]
|
600
|
+
payload = function_payload_utils.FunctionPayload(func, session, session_argument, *bound.args, **bound.kwargs)
|
601
|
+
|
602
|
+
return payload
|
@@ -0,0 +1,9 @@
|
|
1
|
+
from snowflake import snowpark
|
2
|
+
|
3
|
+
|
4
|
+
def get_attribute_map(session: snowpark.Session, requested_attributes: dict[str, int]) -> dict[str, int]:
|
5
|
+
metadata = session._conn._cursor.description
|
6
|
+
for index in range(len(metadata)):
|
7
|
+
if metadata[index].name in requested_attributes.keys():
|
8
|
+
requested_attributes[metadata[index].name] = index
|
9
|
+
return requested_attributes
|
@@ -1,10 +1,26 @@
|
|
1
|
+
from snowflake.ml.jobs._utils import constants as mljob_constants
|
2
|
+
|
1
3
|
# Constants defining the shutdown signal actor configuration.
|
2
4
|
SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
|
3
5
|
SHUTDOWN_ACTOR_NAMESPACE = "default"
|
4
6
|
SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0
|
5
7
|
|
6
8
|
|
9
|
+
# The followings are Inherited from snowflake.ml.jobs._utils.constants
|
10
|
+
# We need to copy them here since snowml package on the server side does
|
11
|
+
# not have the latest version of the code
|
12
|
+
|
7
13
|
# Log start and end messages
|
8
|
-
|
9
|
-
|
10
|
-
|
14
|
+
LOG_START_MSG = getattr(
|
15
|
+
mljob_constants,
|
16
|
+
"LOG_START_MSG",
|
17
|
+
"--------------------------------\nML job started\n--------------------------------",
|
18
|
+
)
|
19
|
+
LOG_END_MSG = getattr(
|
20
|
+
mljob_constants,
|
21
|
+
"LOG_END_MSG",
|
22
|
+
"--------------------------------\nML job finished\n--------------------------------",
|
23
|
+
)
|
24
|
+
|
25
|
+
# min_instances environment variable name
|
26
|
+
MIN_INSTANCES_ENV_VAR = getattr(mljob_constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
|
@@ -13,7 +13,7 @@ from pathlib import Path
|
|
13
13
|
from typing import Any, Optional
|
14
14
|
|
15
15
|
import cloudpickle
|
16
|
-
from constants import LOG_END_MSG, LOG_START_MSG
|
16
|
+
from constants import LOG_END_MSG, LOG_START_MSG, MIN_INSTANCES_ENV_VAR
|
17
17
|
|
18
18
|
from snowflake.ml.jobs._utils import constants
|
19
19
|
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
@@ -72,28 +72,6 @@ class SimpleJSONEncoder(json.JSONEncoder):
|
|
72
72
|
return f"Unserializable object: {repr(obj)}"
|
73
73
|
|
74
74
|
|
75
|
-
def get_active_node_count() -> int:
|
76
|
-
"""
|
77
|
-
Count the number of active nodes in the Ray cluster.
|
78
|
-
|
79
|
-
Returns:
|
80
|
-
int: Total count of active nodes
|
81
|
-
"""
|
82
|
-
import ray
|
83
|
-
|
84
|
-
if not ray.is_initialized():
|
85
|
-
ray.init(address="auto", ignore_reinit_error=True, log_to_driver=False)
|
86
|
-
try:
|
87
|
-
nodes = [node for node in ray.nodes() if node.get("Alive")]
|
88
|
-
total_active = len(nodes)
|
89
|
-
|
90
|
-
logger.info(f"Active nodes: {total_active}")
|
91
|
-
return total_active
|
92
|
-
except Exception as e:
|
93
|
-
logger.warning(f"Error getting active node count: {e}")
|
94
|
-
return 0
|
95
|
-
|
96
|
-
|
97
75
|
def wait_for_min_instances(min_instances: int) -> None:
|
98
76
|
"""
|
99
77
|
Wait until the specified minimum number of instances are available in the Ray cluster.
|
@@ -108,13 +86,16 @@ def wait_for_min_instances(min_instances: int) -> None:
|
|
108
86
|
logger.debug("Minimum instances is 1 or less, no need to wait for additional instances")
|
109
87
|
return
|
110
88
|
|
89
|
+
# mljob_launcher runs inside the CR where mlruntime libraries are available, so we can import common_util directly
|
90
|
+
from common_utils import common_util as mlrs_util
|
91
|
+
|
111
92
|
start_time = time.time()
|
112
93
|
timeout = os.getenv("JOB_MIN_INSTANCES_TIMEOUT", TIMEOUT)
|
113
94
|
check_interval = os.getenv("JOB_MIN_INSTANCES_CHECK_INTERVAL", CHECK_INTERVAL)
|
114
95
|
logger.debug(f"Waiting for at least {min_instances} instances to be ready (timeout: {timeout}s)")
|
115
96
|
|
116
97
|
while time.time() - start_time < timeout:
|
117
|
-
total_nodes =
|
98
|
+
total_nodes = mlrs_util.get_num_ray_nodes()
|
118
99
|
|
119
100
|
if total_nodes >= min_instances:
|
120
101
|
elapsed = time.time() - start_time
|
@@ -128,7 +109,8 @@ def wait_for_min_instances(min_instances: int) -> None:
|
|
128
109
|
time.sleep(check_interval)
|
129
110
|
|
130
111
|
raise TimeoutError(
|
131
|
-
f"Timed out after {timeout}s waiting for {min_instances} instances, only
|
112
|
+
f"Timed out after {timeout}s waiting for {min_instances} instances, only "
|
113
|
+
f"{mlrs_util.get_num_ray_nodes()} available"
|
132
114
|
)
|
133
115
|
|
134
116
|
|
@@ -199,7 +181,7 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
199
181
|
"""
|
200
182
|
try:
|
201
183
|
# Wait for minimum required instances if specified
|
202
|
-
min_instances_str = os.environ.get("
|
184
|
+
min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
|
203
185
|
if min_instances_str and int(min_instances_str) > 1:
|
204
186
|
wait_for_min_instances(int(min_instances_str))
|
205
187
|
|
@@ -1,20 +1,23 @@
|
|
1
1
|
import logging
|
2
|
+
import os
|
2
3
|
from math import ceil
|
3
4
|
from pathlib import PurePath
|
4
5
|
from typing import Any, Optional, Union
|
5
6
|
|
6
7
|
from snowflake import snowpark
|
7
8
|
from snowflake.ml._internal.utils import snowflake_env
|
8
|
-
from snowflake.ml.jobs._utils import constants, types
|
9
|
+
from snowflake.ml.jobs._utils import constants, query_helper, types
|
9
10
|
|
10
11
|
|
11
12
|
def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
|
12
13
|
"""Extract resource information for the specified compute pool"""
|
13
14
|
# Get the instance family
|
14
|
-
rows = session.
|
15
|
-
if not rows:
|
15
|
+
rows = session._conn.run_query("show compute pools like ?", params=[compute_pool], _force_qmark_paramstyle=True)
|
16
|
+
if not rows or not isinstance(rows, dict) or not rows.get("data"):
|
16
17
|
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
17
|
-
|
18
|
+
requested_attributes = query_helper.get_attribute_map(session, {"instance_family": 4})
|
19
|
+
compute_pool_info = rows["data"]
|
20
|
+
instance_family: str = compute_pool_info[0][requested_attributes["instance_family"]]
|
18
21
|
cloud = snowflake_env.get_current_cloud(session, default=snowflake_env.SnowflakeCloudType.AWS)
|
19
22
|
|
20
23
|
return (
|
@@ -30,7 +33,7 @@ def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.Image
|
|
30
33
|
# Use MLRuntime image
|
31
34
|
image_repo = constants.DEFAULT_IMAGE_REPO
|
32
35
|
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
33
|
-
image_tag =
|
36
|
+
image_tag = _get_runtime_image_tag()
|
34
37
|
|
35
38
|
# TODO: Should each instance consume the entire pod?
|
36
39
|
return types.ImageSpec(
|
@@ -346,3 +349,24 @@ def _merge_lists_of_dicts(
|
|
346
349
|
result[key] = d
|
347
350
|
|
348
351
|
return list(result.values())
|
352
|
+
|
353
|
+
|
354
|
+
def _get_runtime_image_tag() -> str:
|
355
|
+
"""
|
356
|
+
Detect runtime image tag from container environment.
|
357
|
+
|
358
|
+
Checks in order:
|
359
|
+
1. Environment variable MLRS_CONTAINER_IMAGE_TAG
|
360
|
+
2. Falls back to hardcoded default
|
361
|
+
|
362
|
+
Returns:
|
363
|
+
str: The runtime image tag to use for job containers
|
364
|
+
"""
|
365
|
+
env_tag = os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR)
|
366
|
+
if env_tag:
|
367
|
+
logging.debug(f"Using runtime image tag from environment: {env_tag}")
|
368
|
+
return env_tag
|
369
|
+
|
370
|
+
# Fall back to default
|
371
|
+
logging.debug(f"Using default runtime image tag: {constants.DEFAULT_IMAGE_TAG}")
|
372
|
+
return constants.DEFAULT_IMAGE_TAG
|
@@ -0,0 +1,119 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
from os import PathLike
|
4
|
+
from pathlib import Path, PurePath
|
5
|
+
from typing import Union
|
6
|
+
|
7
|
+
from snowflake.ml._internal.utils import identifier
|
8
|
+
|
9
|
+
PROTOCOL_NAME = "snow"
|
10
|
+
_SNOWURL_PATH_RE = re.compile(
|
11
|
+
rf"^(?:(?:{PROTOCOL_NAME}://)?"
|
12
|
+
r"(?<!@)(?P<domain>\w+)/"
|
13
|
+
rf"(?P<name>(?:{identifier._SF_IDENTIFIER}\.){{,2}}{identifier._SF_IDENTIFIER})/)?"
|
14
|
+
r"(?P<path>versions(?:/(?P<version>[^/]+)(?:/(?P<relpath>.*))?)?)$"
|
15
|
+
)
|
16
|
+
|
17
|
+
_STAGEF_PATH_RE = re.compile(r"^@(?P<stage>~|%?\w+)(?:/(?P<relpath>[\w\-./]*))?$")
|
18
|
+
|
19
|
+
|
20
|
+
class StagePath:
|
21
|
+
def __init__(self, path: str) -> None:
|
22
|
+
stage_match = _SNOWURL_PATH_RE.fullmatch(path) or _STAGEF_PATH_RE.fullmatch(path)
|
23
|
+
if not stage_match:
|
24
|
+
raise ValueError(f"{path} is not a valid stage path")
|
25
|
+
path = path.strip()
|
26
|
+
self._raw_path = path
|
27
|
+
relpath = stage_match.group("relpath")
|
28
|
+
start, _ = stage_match.span("relpath")
|
29
|
+
self._root = self._raw_path[0:start].rstrip("/") if relpath else self._raw_path.rstrip("/")
|
30
|
+
self._path = Path(relpath or "")
|
31
|
+
|
32
|
+
@property
|
33
|
+
def parent(self) -> "StagePath":
|
34
|
+
if self._path.parent == Path(""):
|
35
|
+
return StagePath(self._root)
|
36
|
+
else:
|
37
|
+
return StagePath(f"{self._root}/{self._path.parent}")
|
38
|
+
|
39
|
+
@property
|
40
|
+
def root(self) -> str:
|
41
|
+
return self._root
|
42
|
+
|
43
|
+
@property
|
44
|
+
def suffix(self) -> str:
|
45
|
+
return self._path.suffix
|
46
|
+
|
47
|
+
def _compose_path(self, path: Path) -> str:
|
48
|
+
# in pathlib, Path("") = "."
|
49
|
+
if path == Path(""):
|
50
|
+
return self.root
|
51
|
+
else:
|
52
|
+
return f"{self.root}/{path}"
|
53
|
+
|
54
|
+
def is_relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> bool:
|
55
|
+
stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
|
56
|
+
if stage_path.root == self.root:
|
57
|
+
return self._path.is_relative_to(stage_path._path)
|
58
|
+
else:
|
59
|
+
return False
|
60
|
+
|
61
|
+
def relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> PurePath:
|
62
|
+
stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
|
63
|
+
if self.root == stage_path.root:
|
64
|
+
return self._path.relative_to(stage_path._path)
|
65
|
+
raise ValueError(f"{self._raw_path} does not start with {stage_path._raw_path}")
|
66
|
+
|
67
|
+
def absolute(self) -> "StagePath":
|
68
|
+
return self
|
69
|
+
|
70
|
+
def as_posix(self) -> str:
|
71
|
+
return self._compose_path(self._path)
|
72
|
+
|
73
|
+
# TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
|
74
|
+
def exists(self) -> bool:
|
75
|
+
return True
|
76
|
+
|
77
|
+
# TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
|
78
|
+
def is_file(self) -> bool:
|
79
|
+
return True
|
80
|
+
|
81
|
+
# TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
|
82
|
+
def is_dir(self) -> bool:
|
83
|
+
return True
|
84
|
+
|
85
|
+
def is_absolute(self) -> bool:
|
86
|
+
return True
|
87
|
+
|
88
|
+
def __str__(self) -> str:
|
89
|
+
return self.as_posix()
|
90
|
+
|
91
|
+
def __eq__(self, other: object) -> bool:
|
92
|
+
if not isinstance(other, StagePath):
|
93
|
+
raise NotImplementedError
|
94
|
+
return bool(self.root == other.root and self._path == other._path)
|
95
|
+
|
96
|
+
def __fspath__(self) -> str:
|
97
|
+
return self._compose_path(self._path)
|
98
|
+
|
99
|
+
def joinpath(self, *args: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
|
100
|
+
path = self
|
101
|
+
for arg in args:
|
102
|
+
path = path._make_child(arg)
|
103
|
+
return path
|
104
|
+
|
105
|
+
def _make_child(self, path: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
|
106
|
+
stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
|
107
|
+
if self.root == stage_path.root:
|
108
|
+
child_path = self._path.joinpath(stage_path._path)
|
109
|
+
return StagePath(self._compose_path(child_path))
|
110
|
+
else:
|
111
|
+
return stage_path
|
112
|
+
|
113
|
+
|
114
|
+
def identify_stage_path(path: str) -> Union[StagePath, Path]:
|
115
|
+
try:
|
116
|
+
stage_path = StagePath(path)
|
117
|
+
except ValueError:
|
118
|
+
return Path(path)
|
119
|
+
return stage_path
|
@@ -2,18 +2,22 @@ from dataclasses import dataclass
|
|
2
2
|
from pathlib import PurePath
|
3
3
|
from typing import Literal, Optional, Union
|
4
4
|
|
5
|
+
from snowflake.ml.jobs._utils import stage_utils
|
6
|
+
|
5
7
|
JOB_STATUS = Literal[
|
6
8
|
"PENDING",
|
7
9
|
"RUNNING",
|
8
10
|
"FAILED",
|
9
11
|
"DONE",
|
12
|
+
"CANCELLING",
|
13
|
+
"CANCELLED",
|
10
14
|
"INTERNAL_ERROR",
|
11
15
|
]
|
12
16
|
|
13
17
|
|
14
18
|
@dataclass(frozen=True)
|
15
19
|
class PayloadEntrypoint:
|
16
|
-
file_path: PurePath
|
20
|
+
file_path: Union[PurePath, stage_utils.StagePath]
|
17
21
|
main_func: Optional[str]
|
18
22
|
|
19
23
|
|
snowflake/ml/jobs/decorators.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1
1
|
import copy
|
2
2
|
import functools
|
3
|
-
from typing import Callable, Optional, TypeVar
|
3
|
+
from typing import Any, Callable, Optional, TypeVar
|
4
4
|
|
5
5
|
from typing_extensions import ParamSpec
|
6
6
|
|
7
7
|
from snowflake import snowpark
|
8
8
|
from snowflake.ml._internal import telemetry
|
9
9
|
from snowflake.ml.jobs import job as jb, manager as jm
|
10
|
-
from snowflake.ml.jobs._utils import
|
10
|
+
from snowflake.ml.jobs._utils import payload_utils
|
11
11
|
|
12
12
|
_PROJECT = "MLJob"
|
13
13
|
|
@@ -20,16 +20,11 @@ def remote(
|
|
20
20
|
compute_pool: str,
|
21
21
|
*,
|
22
22
|
stage_name: str,
|
23
|
+
target_instances: int = 1,
|
23
24
|
pip_requirements: Optional[list[str]] = None,
|
24
25
|
external_access_integrations: Optional[list[str]] = None,
|
25
|
-
query_warehouse: Optional[str] = None,
|
26
|
-
env_vars: Optional[dict[str, str]] = None,
|
27
|
-
target_instances: int = 1,
|
28
|
-
min_instances: int = 1,
|
29
|
-
enable_metrics: bool = False,
|
30
|
-
database: Optional[str] = None,
|
31
|
-
schema: Optional[str] = None,
|
32
26
|
session: Optional[snowpark.Session] = None,
|
27
|
+
**kwargs: Any,
|
33
28
|
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
|
34
29
|
"""
|
35
30
|
Submit a job to the compute pool.
|
@@ -37,17 +32,20 @@ def remote(
|
|
37
32
|
Args:
|
38
33
|
compute_pool: The compute pool to use for the job.
|
39
34
|
stage_name: The name of the stage where the job payload will be uploaded.
|
35
|
+
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
40
36
|
pip_requirements: A list of pip requirements for the job.
|
41
37
|
external_access_integrations: A list of external access integrations.
|
42
|
-
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
43
|
-
env_vars: Environment variables to set in container
|
44
|
-
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
45
|
-
min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
|
46
|
-
If set, the job will not start until the minimum number of nodes is available.
|
47
|
-
enable_metrics: Whether to enable metrics publishing for the job.
|
48
|
-
database: The database to use for the job.
|
49
|
-
schema: The schema to use for the job.
|
50
38
|
session: The Snowpark session to use. If none specified, uses active session.
|
39
|
+
kwargs: Additional keyword arguments. Supported arguments:
|
40
|
+
database (str): The database to use for the job.
|
41
|
+
schema (str): The schema to use for the job.
|
42
|
+
min_instances (int): The minimum number of nodes required to start the job.
|
43
|
+
If none specified, defaults to target_instances. If set, the job
|
44
|
+
will not start until the minimum number of nodes is available.
|
45
|
+
env_vars (dict): Environment variables to set in container.
|
46
|
+
enable_metrics (bool): Whether to enable metrics publishing for the job.
|
47
|
+
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
48
|
+
spec_overrides (dict): A dictionary of overrides for the service spec.
|
51
49
|
|
52
50
|
Returns:
|
53
51
|
Decorator that dispatches invocations of the decorated function as remote jobs.
|
@@ -61,23 +59,17 @@ def remote(
|
|
61
59
|
wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
|
62
60
|
|
63
61
|
@functools.wraps(func)
|
64
|
-
def wrapper(*
|
65
|
-
payload =
|
66
|
-
setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
|
62
|
+
def wrapper(*_args: _Args.args, **_kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
|
63
|
+
payload = payload_utils.create_function_payload(func, *_args, **_kwargs)
|
67
64
|
job = jm._submit_job(
|
68
65
|
source=payload,
|
69
66
|
stage_name=stage_name,
|
70
67
|
compute_pool=compute_pool,
|
68
|
+
target_instances=target_instances,
|
71
69
|
pip_requirements=pip_requirements,
|
72
70
|
external_access_integrations=external_access_integrations,
|
73
|
-
|
74
|
-
|
75
|
-
target_instances=target_instances,
|
76
|
-
min_instances=min_instances,
|
77
|
-
enable_metrics=enable_metrics,
|
78
|
-
database=database,
|
79
|
-
schema=schema,
|
80
|
-
session=session,
|
71
|
+
session=payload.session or session,
|
72
|
+
**kwargs,
|
81
73
|
)
|
82
74
|
assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
|
83
75
|
return job
|