snowflake-ml-python 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +44 -3
- snowflake/ml/_internal/platform_capabilities.py +52 -2
- snowflake/ml/_internal/type_utils.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +54 -42
- snowflake/ml/_internal/utils/service_logger.py +105 -3
- snowflake/ml/data/_internal/arrow_ingestor.py +15 -2
- snowflake/ml/data/data_connector.py +13 -2
- snowflake/ml/data/data_ingestor.py +8 -0
- snowflake/ml/data/torch_utils.py +1 -1
- snowflake/ml/dataset/dataset.py +2 -1
- snowflake/ml/dataset/dataset_reader.py +14 -4
- snowflake/ml/experiment/__init__.py +3 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
- snowflake/ml/experiment/_entities/__init__.py +4 -0
- snowflake/ml/experiment/_entities/experiment.py +10 -0
- snowflake/ml/experiment/_entities/run.py +62 -0
- snowflake/ml/experiment/_entities/run_metadata.py +68 -0
- snowflake/ml/experiment/_experiment_info.py +63 -0
- snowflake/ml/experiment/callback.py +121 -0
- snowflake/ml/experiment/experiment_tracking.py +319 -0
- snowflake/ml/jobs/_utils/constants.py +15 -4
- snowflake/ml/jobs/_utils/payload_utils.py +156 -54
- snowflake/ml/jobs/_utils/query_helper.py +16 -5
- snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +130 -23
- snowflake/ml/jobs/_utils/spec_utils.py +23 -8
- snowflake/ml/jobs/_utils/stage_utils.py +30 -14
- snowflake/ml/jobs/_utils/types.py +64 -4
- snowflake/ml/jobs/job.py +70 -75
- snowflake/ml/jobs/manager.py +59 -31
- snowflake/ml/lineage/lineage_node.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +16 -4
- snowflake/ml/model/_client/ops/service_ops.py +336 -137
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -1
- snowflake/ml/model/_client/sql/service.py +1 -38
- snowflake/ml/model/_model_composer/model_composer.py +6 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +17 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
- snowflake/ml/model/_signatures/pandas_handler.py +3 -0
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/event_handler.py +117 -0
- snowflake/ml/model/model_signature.py +11 -9
- snowflake/ml/model/models/huggingface_pipeline.py +170 -1
- snowflake/ml/modeling/framework/base.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +14 -14
- snowflake/ml/modeling/metrics/correlation.py +19 -8
- snowflake/ml/modeling/metrics/ranking.py +6 -6
- snowflake/ml/modeling/metrics/regression.py +9 -9
- snowflake/ml/monitoring/explain_visualize.py +12 -5
- snowflake/ml/registry/_manager/model_manager.py +32 -15
- snowflake/ml/registry/registry.py +48 -80
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +107 -5
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +62 -52
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,8 @@ import functools
|
|
|
2
2
|
import inspect
|
|
3
3
|
import io
|
|
4
4
|
import itertools
|
|
5
|
+
import keyword
|
|
6
|
+
import logging
|
|
5
7
|
import pickle
|
|
6
8
|
import sys
|
|
7
9
|
import textwrap
|
|
@@ -12,17 +14,21 @@ import cloudpickle as cp
|
|
|
12
14
|
from packaging import version
|
|
13
15
|
|
|
14
16
|
from snowflake import snowpark
|
|
15
|
-
from snowflake.connector import errors
|
|
16
17
|
from snowflake.ml.jobs._utils import (
|
|
17
18
|
constants,
|
|
18
19
|
function_payload_utils,
|
|
20
|
+
query_helper,
|
|
19
21
|
stage_utils,
|
|
20
22
|
types,
|
|
21
23
|
)
|
|
24
|
+
from snowflake.snowpark import exceptions as sp_exceptions
|
|
22
25
|
from snowflake.snowpark._internal import code_generation
|
|
23
26
|
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
24
29
|
cp.register_pickle_by_value(function_payload_utils)
|
|
25
30
|
|
|
31
|
+
|
|
26
32
|
_SUPPORTED_ARG_TYPES = {str, int, float}
|
|
27
33
|
_SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
|
|
28
34
|
_ENTRYPOINT_FUNC_NAME = "func"
|
|
@@ -31,6 +37,9 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
31
37
|
f"""
|
|
32
38
|
#!/bin/bash
|
|
33
39
|
|
|
40
|
+
##### Get system scripts directory #####
|
|
41
|
+
SYSTEM_DIR=$(cd "$(dirname "$0")" && pwd)
|
|
42
|
+
|
|
34
43
|
##### Perform common set up steps #####
|
|
35
44
|
set -e # exit if a command fails
|
|
36
45
|
|
|
@@ -74,12 +83,14 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
74
83
|
|
|
75
84
|
# Check if the local get_instance_ip.py script exists
|
|
76
85
|
HELPER_EXISTS=$(
|
|
77
|
-
[ -f "get_instance_ip.py" ] && echo "true" || echo "false"
|
|
86
|
+
[ -f "${{SYSTEM_DIR}}/get_instance_ip.py" ] && echo "true" || echo "false"
|
|
78
87
|
)
|
|
79
88
|
|
|
89
|
+
|
|
80
90
|
# Configure IP address and logging directory
|
|
81
91
|
if [ "$HELPER_EXISTS" = "true" ]; then
|
|
82
|
-
eth0Ip=$(python3 get_instance_ip.py
|
|
92
|
+
eth0Ip=$(python3 "${{SYSTEM_DIR}}/get_instance_ip.py" \
|
|
93
|
+
"$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
|
|
83
94
|
else
|
|
84
95
|
eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
|
|
85
96
|
fi
|
|
@@ -102,7 +113,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
102
113
|
|
|
103
114
|
# Determine if it should be a worker or a head node for batch jobs
|
|
104
115
|
if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 && "$HELPER_EXISTS" = "true" ]]; then
|
|
105
|
-
head_info=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --head)
|
|
116
|
+
head_info=$(python3 "${{SYSTEM_DIR}}/get_instance_ip.py" "$SNOWFLAKE_SERVICE_NAME" --head)
|
|
106
117
|
if [ $? -eq 0 ]; then
|
|
107
118
|
# Parse the output using read
|
|
108
119
|
read head_index head_ip head_status<<< "$head_info"
|
|
@@ -184,7 +195,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
184
195
|
|
|
185
196
|
# Start the worker shutdown listener in the background
|
|
186
197
|
echo "Starting worker shutdown listener..."
|
|
187
|
-
python worker_shutdown_listener.py
|
|
198
|
+
python "${{SYSTEM_DIR}}/worker_shutdown_listener.py"
|
|
188
199
|
WORKER_EXIT_CODE=$?
|
|
189
200
|
|
|
190
201
|
echo "Worker shutdown listener exited with code $WORKER_EXIT_CODE"
|
|
@@ -217,19 +228,59 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
|
217
228
|
|
|
218
229
|
# After the user's job completes, signal workers to shut down
|
|
219
230
|
echo "User job completed. Signaling workers to shut down..."
|
|
220
|
-
python signal_workers.py --wait-time 15
|
|
231
|
+
python "${{SYSTEM_DIR}}/signal_workers.py" --wait-time 15
|
|
221
232
|
echo "Head node job completed. Exiting."
|
|
222
233
|
fi
|
|
223
234
|
"""
|
|
224
235
|
).strip()
|
|
225
236
|
|
|
226
237
|
|
|
238
|
+
def resolve_path(path: str) -> types.PayloadPath:
|
|
239
|
+
try:
|
|
240
|
+
stage_path = stage_utils.StagePath(path)
|
|
241
|
+
except ValueError:
|
|
242
|
+
return Path(path)
|
|
243
|
+
return stage_path
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_specs: types.PayloadSpec) -> None:
|
|
247
|
+
for source_path, remote_relative_path in payload_specs:
|
|
248
|
+
payload_stage_path = stage_path.joinpath(remote_relative_path) if remote_relative_path else stage_path
|
|
249
|
+
if isinstance(source_path, stage_utils.StagePath):
|
|
250
|
+
# only copy files into one stage directory from another stage directory, not from stage file
|
|
251
|
+
# due to incomplete of StagePath functionality
|
|
252
|
+
session.sql(f"copy files into {payload_stage_path.as_posix()}/ from {source_path.as_posix()}/").collect()
|
|
253
|
+
elif isinstance(source_path, Path):
|
|
254
|
+
if source_path.is_dir():
|
|
255
|
+
# Manually traverse the directory and upload each file, since Snowflake PUT
|
|
256
|
+
# can't handle directories. Reduce the number of PUT operations by using
|
|
257
|
+
# wildcard patterns to batch upload files with the same extension.
|
|
258
|
+
for path in {
|
|
259
|
+
p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
|
|
260
|
+
for p in source_path.resolve().rglob("*")
|
|
261
|
+
if p.is_file()
|
|
262
|
+
}:
|
|
263
|
+
session.file.put(
|
|
264
|
+
str(path),
|
|
265
|
+
payload_stage_path.joinpath(path.parent.relative_to(source_path)).as_posix(),
|
|
266
|
+
overwrite=True,
|
|
267
|
+
auto_compress=False,
|
|
268
|
+
)
|
|
269
|
+
else:
|
|
270
|
+
session.file.put(
|
|
271
|
+
str(source_path.resolve()),
|
|
272
|
+
payload_stage_path.as_posix(),
|
|
273
|
+
overwrite=True,
|
|
274
|
+
auto_compress=False,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
|
|
227
278
|
def resolve_source(
|
|
228
|
-
source: Union[
|
|
229
|
-
) -> Union[
|
|
279
|
+
source: Union[types.PayloadPath, Callable[..., Any]]
|
|
280
|
+
) -> Union[types.PayloadPath, Callable[..., Any]]:
|
|
230
281
|
if callable(source):
|
|
231
282
|
return source
|
|
232
|
-
elif isinstance(source,
|
|
283
|
+
elif isinstance(source, types.PayloadPath):
|
|
233
284
|
if not source.exists():
|
|
234
285
|
raise FileNotFoundError(f"{source} does not exist")
|
|
235
286
|
return source.absolute()
|
|
@@ -238,8 +289,8 @@ def resolve_source(
|
|
|
238
289
|
|
|
239
290
|
|
|
240
291
|
def resolve_entrypoint(
|
|
241
|
-
source: Union[
|
|
242
|
-
entrypoint: Optional[
|
|
292
|
+
source: Union[types.PayloadPath, Callable[..., Any]],
|
|
293
|
+
entrypoint: Optional[types.PayloadPath],
|
|
243
294
|
) -> types.PayloadEntrypoint:
|
|
244
295
|
if callable(source):
|
|
245
296
|
# Entrypoint is generated for callable payloads
|
|
@@ -288,6 +339,73 @@ def resolve_entrypoint(
|
|
|
288
339
|
)
|
|
289
340
|
|
|
290
341
|
|
|
342
|
+
def resolve_additional_payloads(
|
|
343
|
+
additional_payloads: Optional[list[Union[str, tuple[str, str]]]]
|
|
344
|
+
) -> list[types.PayloadSpec]:
|
|
345
|
+
"""
|
|
346
|
+
Determine how to stage local packages so that imports continue to work.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
additional_payloads: A list of directory paths, each optionally paired with a dot-separated
|
|
350
|
+
import path
|
|
351
|
+
e.g. [("proj/src/utils", "src.utils"), "proj/src/helper"]
|
|
352
|
+
if there is no import path, the last part of path will be considered as import path
|
|
353
|
+
e.g. the import path of "proj/src/helper" is "helper"
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
A list of payloadSpec for additional payloads.
|
|
357
|
+
|
|
358
|
+
Raises:
|
|
359
|
+
FileNotFoundError: If any specified package path does not exist.
|
|
360
|
+
ValueError: If the format of local_packages is invalid.
|
|
361
|
+
|
|
362
|
+
"""
|
|
363
|
+
if not additional_payloads:
|
|
364
|
+
return []
|
|
365
|
+
|
|
366
|
+
logger.warning(
|
|
367
|
+
"When providing a stage path as an additional payload, "
|
|
368
|
+
"please ensure it points to a directory. "
|
|
369
|
+
"Files are not currently supported."
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
additional_payloads_paths = []
|
|
373
|
+
for pkg in additional_payloads:
|
|
374
|
+
if isinstance(pkg, str):
|
|
375
|
+
source_path = resolve_path(pkg).absolute()
|
|
376
|
+
module_path = source_path.name
|
|
377
|
+
elif isinstance(pkg, tuple):
|
|
378
|
+
try:
|
|
379
|
+
source_path_str, module_path = pkg
|
|
380
|
+
except ValueError:
|
|
381
|
+
raise ValueError(
|
|
382
|
+
f"Invalid format in `additional_payloads`. "
|
|
383
|
+
f"Expected a tuple of (source_path, module_path). Got {pkg}"
|
|
384
|
+
)
|
|
385
|
+
source_path = resolve_path(source_path_str).absolute()
|
|
386
|
+
else:
|
|
387
|
+
raise ValueError("the format of additional payload is not correct")
|
|
388
|
+
|
|
389
|
+
if not source_path.exists():
|
|
390
|
+
raise FileNotFoundError(f"{source_path} does not exist")
|
|
391
|
+
|
|
392
|
+
if isinstance(source_path, Path):
|
|
393
|
+
if source_path.is_file():
|
|
394
|
+
raise ValueError(f"file is not supported for additional payloads: {source_path}")
|
|
395
|
+
|
|
396
|
+
module_parts = module_path.split(".")
|
|
397
|
+
for part in module_parts:
|
|
398
|
+
if not part.isidentifier() or keyword.iskeyword(part):
|
|
399
|
+
raise ValueError(
|
|
400
|
+
f"Invalid module import path '{module_path}'. "
|
|
401
|
+
f"'{part}' is not a valid Python identifier or is a keyword."
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
dest_path = PurePath(*module_parts)
|
|
405
|
+
additional_payloads_paths.append(types.PayloadSpec(source_path, dest_path))
|
|
406
|
+
return additional_payloads_paths
|
|
407
|
+
|
|
408
|
+
|
|
291
409
|
class JobPayload:
|
|
292
410
|
def __init__(
|
|
293
411
|
self,
|
|
@@ -295,11 +413,13 @@ class JobPayload:
|
|
|
295
413
|
entrypoint: Optional[Union[str, Path]] = None,
|
|
296
414
|
*,
|
|
297
415
|
pip_requirements: Optional[list[str]] = None,
|
|
416
|
+
additional_payloads: Optional[list[Union[str, tuple[str, str]]]] = None,
|
|
298
417
|
) -> None:
|
|
299
418
|
# for stage path like snow://domain....., Path(path) will remove duplicate /, it will become snow:/ domain...
|
|
300
|
-
self.source =
|
|
301
|
-
self.entrypoint =
|
|
419
|
+
self.source = resolve_path(source) if isinstance(source, str) else source
|
|
420
|
+
self.entrypoint = resolve_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
|
|
302
421
|
self.pip_requirements = pip_requirements
|
|
422
|
+
self.additional_payloads = additional_payloads
|
|
303
423
|
|
|
304
424
|
def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
|
|
305
425
|
# Prepare local variables
|
|
@@ -307,27 +427,29 @@ class JobPayload:
|
|
|
307
427
|
source = resolve_source(self.source)
|
|
308
428
|
entrypoint = resolve_entrypoint(source, self.entrypoint)
|
|
309
429
|
pip_requirements = self.pip_requirements or []
|
|
430
|
+
additional_payload_specs = resolve_additional_payloads(self.additional_payloads)
|
|
310
431
|
|
|
311
432
|
# Create stage if necessary
|
|
312
433
|
stage_name = stage_path.parts[0].lstrip("@")
|
|
313
434
|
# Explicitly check if stage exists first since we may not have CREATE STAGE privilege
|
|
314
435
|
try:
|
|
315
|
-
|
|
316
|
-
except
|
|
317
|
-
|
|
436
|
+
query_helper.run_query(session, "describe stage identifier(?)", params=[stage_name])
|
|
437
|
+
except sp_exceptions.SnowparkSQLException:
|
|
438
|
+
query_helper.run_query(
|
|
439
|
+
session,
|
|
318
440
|
"create stage if not exists identifier(?)"
|
|
319
441
|
" encryption = ( type = 'SNOWFLAKE_SSE' )"
|
|
320
442
|
" comment = 'Created by snowflake.ml.jobs Python API'",
|
|
321
443
|
params=[stage_name],
|
|
322
|
-
_force_qmark_paramstyle=True,
|
|
323
444
|
)
|
|
324
445
|
|
|
325
|
-
# Upload payload to stage
|
|
326
|
-
|
|
446
|
+
# Upload payload to stage - organize into app/ subdirectory
|
|
447
|
+
app_stage_path = stage_path.joinpath(constants.APP_STAGE_SUBPATH)
|
|
448
|
+
if not isinstance(source, types.PayloadPath):
|
|
327
449
|
source_code = generate_python_code(source, source_code_display=True)
|
|
328
450
|
_ = session.file.put_stream(
|
|
329
451
|
io.BytesIO(source_code.encode()),
|
|
330
|
-
stage_location=
|
|
452
|
+
stage_location=app_stage_path.joinpath(entrypoint.file_path).as_posix(),
|
|
331
453
|
auto_compress=False,
|
|
332
454
|
overwrite=True,
|
|
333
455
|
)
|
|
@@ -339,68 +461,48 @@ class JobPayload:
|
|
|
339
461
|
# copy payload to stage
|
|
340
462
|
if source == entrypoint.file_path:
|
|
341
463
|
source = source.parent
|
|
342
|
-
|
|
343
|
-
session.sql(f"copy files into {stage_path}/ from {source_path}").collect()
|
|
464
|
+
upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
|
|
344
465
|
|
|
345
466
|
elif isinstance(source, Path):
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
# can't handle directories. Reduce the number of PUT operations by using
|
|
349
|
-
# wildcard patterns to batch upload files with the same extension.
|
|
350
|
-
for path in {
|
|
351
|
-
p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
|
|
352
|
-
for p in source.resolve().rglob("*")
|
|
353
|
-
if p.is_file()
|
|
354
|
-
}:
|
|
355
|
-
session.file.put(
|
|
356
|
-
str(path),
|
|
357
|
-
stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
|
|
358
|
-
overwrite=True,
|
|
359
|
-
auto_compress=False,
|
|
360
|
-
)
|
|
361
|
-
else:
|
|
362
|
-
session.file.put(
|
|
363
|
-
str(source.resolve()),
|
|
364
|
-
stage_path.as_posix(),
|
|
365
|
-
overwrite=True,
|
|
366
|
-
auto_compress=False,
|
|
367
|
-
)
|
|
467
|
+
upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
|
|
468
|
+
if source.is_file():
|
|
368
469
|
source = source.parent
|
|
369
470
|
|
|
370
|
-
|
|
471
|
+
upload_payloads(session, app_stage_path, *additional_payload_specs)
|
|
472
|
+
|
|
473
|
+
# Upload requirements to app/ directory
|
|
371
474
|
# TODO: Check if payload includes both a requirements.txt file and pip_requirements
|
|
372
475
|
if pip_requirements:
|
|
373
476
|
# Upload requirements.txt to stage
|
|
374
477
|
session.file.put_stream(
|
|
375
478
|
io.BytesIO("\n".join(pip_requirements).encode()),
|
|
376
|
-
stage_location=
|
|
479
|
+
stage_location=app_stage_path.joinpath("requirements.txt").as_posix(),
|
|
377
480
|
auto_compress=False,
|
|
378
481
|
overwrite=True,
|
|
379
482
|
)
|
|
380
483
|
|
|
381
|
-
# Upload startup script
|
|
484
|
+
# Upload startup script to system/ directory within payload
|
|
485
|
+
system_stage_path = stage_path.joinpath(constants.SYSTEM_STAGE_SUBPATH)
|
|
382
486
|
# TODO: Make sure payload does not include file with same name
|
|
383
487
|
session.file.put_stream(
|
|
384
488
|
io.BytesIO(_STARTUP_SCRIPT_CODE.encode()),
|
|
385
|
-
stage_location=
|
|
489
|
+
stage_location=system_stage_path.joinpath(_STARTUP_SCRIPT_PATH).as_posix(),
|
|
386
490
|
auto_compress=False,
|
|
387
491
|
overwrite=False, # FIXME
|
|
388
492
|
)
|
|
389
493
|
|
|
390
|
-
# Upload system scripts
|
|
391
494
|
scripts_dir = Path(__file__).parent.joinpath("scripts")
|
|
392
495
|
for script_file in scripts_dir.glob("*"):
|
|
393
496
|
if script_file.is_file():
|
|
394
497
|
session.file.put(
|
|
395
498
|
script_file.as_posix(),
|
|
396
|
-
|
|
499
|
+
system_stage_path.as_posix(),
|
|
397
500
|
overwrite=True,
|
|
398
501
|
auto_compress=False,
|
|
399
502
|
)
|
|
400
|
-
|
|
401
503
|
python_entrypoint: list[Union[str, PurePath]] = [
|
|
402
|
-
PurePath("mljob_launcher.py"),
|
|
403
|
-
entrypoint.file_path.relative_to(source),
|
|
504
|
+
PurePath(f"{constants.SYSTEM_MOUNT_PATH}/mljob_launcher.py"),
|
|
505
|
+
PurePath(f"{constants.APP_MOUNT_PATH}/{entrypoint.file_path.relative_to(source).as_posix()}"),
|
|
404
506
|
]
|
|
405
507
|
if entrypoint.main_func:
|
|
406
508
|
python_entrypoint += ["--script_main_func", entrypoint.main_func]
|
|
@@ -409,7 +511,7 @@ class JobPayload:
|
|
|
409
511
|
stage_path=stage_path,
|
|
410
512
|
entrypoint=[
|
|
411
513
|
"bash",
|
|
412
|
-
_STARTUP_SCRIPT_PATH,
|
|
514
|
+
f"{constants.SYSTEM_MOUNT_PATH}/{_STARTUP_SCRIPT_PATH}",
|
|
413
515
|
*python_entrypoint,
|
|
414
516
|
],
|
|
415
517
|
)
|
|
@@ -1,9 +1,20 @@
|
|
|
1
|
+
from typing import Any, Optional, Sequence
|
|
2
|
+
|
|
1
3
|
from snowflake import snowpark
|
|
4
|
+
from snowflake.snowpark import Row
|
|
5
|
+
from snowflake.snowpark._internal import utils
|
|
6
|
+
from snowflake.snowpark._internal.analyzer import snowflake_plan
|
|
2
7
|
|
|
3
8
|
|
|
4
|
-
def
|
|
9
|
+
def result_set_to_rows(session: snowpark.Session, result: dict[str, Any]) -> list[Row]:
|
|
5
10
|
metadata = session._conn._cursor.description
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
11
|
+
result_set = result["data"]
|
|
12
|
+
return utils.result_set_to_rows(result_set, metadata)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
|
|
16
|
+
def run_query(session: snowpark.Session, query_text: str, params: Optional[Sequence[Any]] = None) -> list[Row]:
|
|
17
|
+
result = session._conn.run_query(query=query_text, params=params, _force_qmark_paramstyle=True)
|
|
18
|
+
if not isinstance(result, dict) or "data" not in result:
|
|
19
|
+
raise ValueError(f"Unprocessable result: {result}")
|
|
20
|
+
return result_set_to_rows(session, result)
|
|
@@ -1,26 +1,4 @@
|
|
|
1
|
-
from snowflake.ml.jobs._utils import constants as mljob_constants
|
|
2
|
-
|
|
3
1
|
# Constants defining the shutdown signal actor configuration.
|
|
4
2
|
SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
|
|
5
3
|
SHUTDOWN_ACTOR_NAMESPACE = "default"
|
|
6
4
|
SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
# The followings are Inherited from snowflake.ml.jobs._utils.constants
|
|
10
|
-
# We need to copy them here since snowml package on the server side does
|
|
11
|
-
# not have the latest version of the code
|
|
12
|
-
|
|
13
|
-
# Log start and end messages
|
|
14
|
-
LOG_START_MSG = getattr(
|
|
15
|
-
mljob_constants,
|
|
16
|
-
"LOG_START_MSG",
|
|
17
|
-
"--------------------------------\nML job started\n--------------------------------",
|
|
18
|
-
)
|
|
19
|
-
LOG_END_MSG = getattr(
|
|
20
|
-
mljob_constants,
|
|
21
|
-
"LOG_END_MSG",
|
|
22
|
-
"--------------------------------\nML job finished\n--------------------------------",
|
|
23
|
-
)
|
|
24
|
-
|
|
25
|
-
# min_instances environment variable name
|
|
26
|
-
MIN_INSTANCES_ENV_VAR = getattr(mljob_constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
|
|
@@ -3,6 +3,7 @@ import copy
|
|
|
3
3
|
import importlib.util
|
|
4
4
|
import json
|
|
5
5
|
import logging
|
|
6
|
+
import math
|
|
6
7
|
import os
|
|
7
8
|
import runpy
|
|
8
9
|
import sys
|
|
@@ -13,23 +14,48 @@ from pathlib import Path
|
|
|
13
14
|
from typing import Any, Optional
|
|
14
15
|
|
|
15
16
|
import cloudpickle
|
|
16
|
-
from constants import LOG_END_MSG, LOG_START_MSG, MIN_INSTANCES_ENV_VAR
|
|
17
17
|
|
|
18
18
|
from snowflake.ml.jobs._utils import constants
|
|
19
|
-
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
|
20
19
|
from snowflake.snowpark import Session
|
|
21
20
|
|
|
21
|
+
try:
|
|
22
|
+
from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
|
|
23
|
+
except ImportError:
|
|
24
|
+
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
|
25
|
+
|
|
22
26
|
# Configure logging
|
|
23
27
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
24
28
|
logger = logging.getLogger(__name__)
|
|
25
29
|
|
|
30
|
+
|
|
31
|
+
# The followings are Inherited from snowflake.ml.jobs._utils.constants
|
|
32
|
+
# We need to copy them here since snowml package on the server side does
|
|
33
|
+
# not have the latest version of the code
|
|
34
|
+
# Log start and end messages
|
|
35
|
+
LOG_START_MSG = getattr(
|
|
36
|
+
constants,
|
|
37
|
+
"LOG_START_MSG",
|
|
38
|
+
"--------------------------------\nML job started\n--------------------------------",
|
|
39
|
+
)
|
|
40
|
+
LOG_END_MSG = getattr(
|
|
41
|
+
constants,
|
|
42
|
+
"LOG_END_MSG",
|
|
43
|
+
"--------------------------------\nML job finished\n--------------------------------",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# min_instances environment variable name
|
|
47
|
+
MIN_INSTANCES_ENV_VAR = getattr(constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
|
|
48
|
+
TARGET_INSTANCES_ENV_VAR = getattr(constants, "TARGET_INSTANCES_ENV_VAR", "SNOWFLAKE_JOBS_COUNT")
|
|
49
|
+
|
|
26
50
|
# Fallbacks in case of SnowML version mismatch
|
|
27
51
|
RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
|
|
28
|
-
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "mljob_result.pkl")
|
|
52
|
+
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "/mnt/job_stage/output/mljob_result.pkl")
|
|
53
|
+
PAYLOAD_DIR_ENV_VAR = getattr(constants, "PAYLOAD_DIR_ENV_VAR", "MLRS_PAYLOAD_DIR")
|
|
29
54
|
|
|
30
|
-
# Constants for the
|
|
31
|
-
|
|
32
|
-
TIMEOUT = 720 # seconds
|
|
55
|
+
# Constants for the wait_for_instances function
|
|
56
|
+
MIN_WAIT_TIME = float(os.getenv("MLRS_INSTANCES_MIN_WAIT") or -1) # seconds
|
|
57
|
+
TIMEOUT = float(os.getenv("MLRS_INSTANCES_TIMEOUT") or 720) # seconds
|
|
58
|
+
CHECK_INTERVAL = float(os.getenv("MLRS_INSTANCES_CHECK_INTERVAL") or 10) # seconds
|
|
33
59
|
|
|
34
60
|
|
|
35
61
|
try:
|
|
@@ -72,45 +98,108 @@ class SimpleJSONEncoder(json.JSONEncoder):
|
|
|
72
98
|
return f"Unserializable object: {repr(obj)}"
|
|
73
99
|
|
|
74
100
|
|
|
75
|
-
def
|
|
101
|
+
def wait_for_instances(
|
|
102
|
+
min_instances: int,
|
|
103
|
+
target_instances: int,
|
|
104
|
+
*,
|
|
105
|
+
min_wait_time: float = -1, # seconds
|
|
106
|
+
timeout: float = 720, # seconds
|
|
107
|
+
check_interval: float = 10, # seconds
|
|
108
|
+
) -> None:
|
|
76
109
|
"""
|
|
77
110
|
Wait until the specified minimum number of instances are available in the Ray cluster.
|
|
78
111
|
|
|
79
112
|
Args:
|
|
80
113
|
min_instances: Minimum number of instances required
|
|
114
|
+
target_instances: Target number of instances to wait for
|
|
115
|
+
min_wait_time: Minimum time to wait for target_instances to be available.
|
|
116
|
+
If less than 0, automatically set based on target_instances.
|
|
117
|
+
timeout: Maximum time to wait for min_instances to be available before raising a TimeoutError.
|
|
118
|
+
check_interval: Maximum time to wait between checks (uses exponential backoff).
|
|
119
|
+
|
|
120
|
+
Examples:
|
|
121
|
+
Scenario 1 - Ideal case (target met quickly):
|
|
122
|
+
wait_for_instances(min_instances=2, target_instances=4, min_wait_time=5, timeout=60)
|
|
123
|
+
If 4 instances are available after 1 second, the function returns without further waiting (target met).
|
|
124
|
+
|
|
125
|
+
Scenario 2 - Min instances met, target not reached:
|
|
126
|
+
wait_for_instances(min_instances=2, target_instances=4, min_wait_time=10, timeout=60)
|
|
127
|
+
If only 3 instances are available after 10 seconds, the function returns (min requirement satisfied).
|
|
128
|
+
|
|
129
|
+
Scenario 3 - Min instances met early, but min_wait_time not elapsed:
|
|
130
|
+
wait_for_instances(min_instances=2, target_instances=4, min_wait_time=30, timeout=60)
|
|
131
|
+
If 2 instances are available after 5 seconds, function continues waiting for target_instances
|
|
132
|
+
until either 4 instances are found or 30 seconds have elapsed.
|
|
133
|
+
|
|
134
|
+
Scenario 4 - Timeout scenario:
|
|
135
|
+
wait_for_instances(min_instances=3, target_instances=5, min_wait_time=10, timeout=30)
|
|
136
|
+
If only 2 instances are available after 30 seconds, TimeoutError is raised.
|
|
137
|
+
|
|
138
|
+
Scenario 5 - Single instance job (early return):
|
|
139
|
+
wait_for_instances(min_instances=1, target_instances=1, min_wait_time=5, timeout=60)
|
|
140
|
+
The function returns without waiting because target_instances <= 1.
|
|
81
141
|
|
|
82
142
|
Raises:
|
|
143
|
+
ValueError: If arguments are invalid
|
|
83
144
|
TimeoutError: If failed to connect to Ray or if minimum instances are not available within timeout
|
|
84
145
|
"""
|
|
85
|
-
if min_instances
|
|
86
|
-
|
|
146
|
+
if min_instances > target_instances:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f"Minimum instances ({min_instances}) cannot be greater than target instances ({target_instances})"
|
|
149
|
+
)
|
|
150
|
+
if timeout < 0:
|
|
151
|
+
raise ValueError("Timeout must be greater than 0")
|
|
152
|
+
if check_interval < 0:
|
|
153
|
+
raise ValueError("Check interval must be greater than 0")
|
|
154
|
+
|
|
155
|
+
if target_instances <= 1:
|
|
156
|
+
logger.debug("Target instances is 1 or less, no need to wait for additional instances")
|
|
87
157
|
return
|
|
88
158
|
|
|
159
|
+
if min_wait_time < 0:
|
|
160
|
+
# Automatically set min_wait_time based on the number of target instances
|
|
161
|
+
# Using min_wait_time = 3 * log2(target_instances) as a starting point:
|
|
162
|
+
# target_instances = 1 => min_wait_time = 0
|
|
163
|
+
# target_instances = 2 => min_wait_time = 3
|
|
164
|
+
# target_instances = 4 => min_wait_time = 6
|
|
165
|
+
# target_instances = 8 => min_wait_time = 9
|
|
166
|
+
# target_instances = 32 => min_wait_time = 15
|
|
167
|
+
# target_instances = 50 => min_wait_time = 16.9
|
|
168
|
+
# target_instances = 100 => min_wait_time = 19.9
|
|
169
|
+
min_wait_time = min(3 * math.log2(target_instances), timeout / 10) # Clamp to timeout / 10
|
|
170
|
+
|
|
89
171
|
# mljob_launcher runs inside the CR where mlruntime libraries are available, so we can import common_util directly
|
|
90
172
|
from common_utils import common_util as mlrs_util
|
|
91
173
|
|
|
92
174
|
start_time = time.time()
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
175
|
+
current_interval = max(min(1, check_interval), 0.1) # Default 1s, minimum 0.1s
|
|
176
|
+
logger.debug(
|
|
177
|
+
"Waiting for instances to be ready "
|
|
178
|
+
"(min_instances={}, target_instances={}, timeout={}s, max_check_interval={}s)".format(
|
|
179
|
+
min_instances, target_instances, timeout, check_interval
|
|
180
|
+
)
|
|
181
|
+
)
|
|
96
182
|
|
|
97
|
-
while time.time() - start_time < timeout:
|
|
183
|
+
while (elapsed := time.time() - start_time) < timeout:
|
|
98
184
|
total_nodes = mlrs_util.get_num_ray_nodes()
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
185
|
+
if total_nodes >= target_instances:
|
|
186
|
+
# Best case scenario: target_instances are already available
|
|
187
|
+
logger.info(f"Target instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
|
|
188
|
+
return
|
|
189
|
+
elif total_nodes >= min_instances and elapsed >= min_wait_time:
|
|
190
|
+
# Second best case scenario: target_instances not met within min_wait_time, but min_instances met
|
|
102
191
|
logger.info(f"Minimum instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
|
|
103
192
|
return
|
|
104
193
|
|
|
105
194
|
logger.debug(
|
|
106
|
-
f"Waiting for instances: {total_nodes}
|
|
107
|
-
f"
|
|
195
|
+
f"Waiting for instances: current_instances={total_nodes}, min_instances={min_instances}, "
|
|
196
|
+
f"target_instances={target_instances}, elapsed={elapsed:.1f}s, next check in {current_interval:.1f}s"
|
|
108
197
|
)
|
|
109
|
-
time.sleep(
|
|
198
|
+
time.sleep(current_interval)
|
|
199
|
+
current_interval = min(current_interval * 2, check_interval) # Exponential backoff
|
|
110
200
|
|
|
111
201
|
raise TimeoutError(
|
|
112
|
-
f"Timed out after {timeout}s waiting for {min_instances} instances, only "
|
|
113
|
-
f"{mlrs_util.get_num_ray_nodes()} available"
|
|
202
|
+
f"Timed out after {timeout}s waiting for {min_instances} instances, only " f"{total_nodes} available"
|
|
114
203
|
)
|
|
115
204
|
|
|
116
205
|
|
|
@@ -133,6 +222,13 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
133
222
|
original_argv = sys.argv
|
|
134
223
|
sys.argv = [script_path, *script_args]
|
|
135
224
|
|
|
225
|
+
# Ensure payload directory is in sys.path for module imports
|
|
226
|
+
# This is needed because mljob_launcher.py is now in /mnt/job_stage/system
|
|
227
|
+
# but user scripts are in the payload directory and may import from each other
|
|
228
|
+
payload_dir = os.environ.get(PAYLOAD_DIR_ENV_VAR)
|
|
229
|
+
if payload_dir and payload_dir not in sys.path:
|
|
230
|
+
sys.path.insert(0, payload_dir)
|
|
231
|
+
|
|
136
232
|
# Create a Snowpark session before running the script
|
|
137
233
|
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
|
|
138
234
|
session = Session.builder.configs(SnowflakeLoginOptions()).create() # noqa: F841
|
|
@@ -179,11 +275,22 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
179
275
|
Raises:
|
|
180
276
|
Exception: Re-raises any exception caught during script execution.
|
|
181
277
|
"""
|
|
278
|
+
# Ensure the output directory exists before trying to write result files.
|
|
279
|
+
output_dir = os.path.dirname(JOB_RESULT_PATH)
|
|
280
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
281
|
+
|
|
182
282
|
try:
|
|
183
283
|
# Wait for minimum required instances if specified
|
|
184
284
|
min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
|
|
185
|
-
|
|
186
|
-
|
|
285
|
+
target_instances_str = os.environ.get(TARGET_INSTANCES_ENV_VAR) or min_instances_str
|
|
286
|
+
if target_instances_str and int(target_instances_str) > 1:
|
|
287
|
+
wait_for_instances(
|
|
288
|
+
int(min_instances_str),
|
|
289
|
+
int(target_instances_str),
|
|
290
|
+
min_wait_time=MIN_WAIT_TIME,
|
|
291
|
+
timeout=TIMEOUT,
|
|
292
|
+
check_interval=CHECK_INTERVAL,
|
|
293
|
+
)
|
|
187
294
|
|
|
188
295
|
# Log start marker for user script execution
|
|
189
296
|
print(LOG_START_MSG) # noqa: T201
|