snowflake-ml-python 1.8.3__py3-none-any.whl → 1.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/ml/_internal/platform_capabilities.py +13 -11
- snowflake/ml/_internal/telemetry.py +42 -13
- snowflake/ml/_internal/utils/identifier.py +2 -2
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/jobs/_utils/constants.py +10 -1
- snowflake/ml/jobs/_utils/interop_utils.py +1 -1
- snowflake/ml/jobs/_utils/payload_utils.py +51 -34
- snowflake/ml/jobs/_utils/scripts/constants.py +6 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +86 -3
- snowflake/ml/jobs/_utils/spec_utils.py +8 -6
- snowflake/ml/jobs/decorators.py +13 -3
- snowflake/ml/jobs/job.py +206 -26
- snowflake/ml/jobs/manager.py +78 -34
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/service_ops.py +31 -17
- snowflake/ml/model/_client/service/model_deployment_spec.py +351 -170
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -0
- snowflake/ml/model/_client/sql/model_version.py +1 -1
- snowflake/ml/model/_client/sql/service.py +20 -32
- snowflake/ml/model/_model_composer/model_composer.py +44 -19
- snowflake/ml/model/_packager/model_handlers/_utils.py +32 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +100 -41
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +7 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +16 -7
- snowflake/ml/model/_packager/model_meta/model_meta.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +5 -4
- snowflake/ml/model/_signatures/dmatrix_handler.py +15 -2
- snowflake/ml/model/custom_model.py +17 -4
- snowflake/ml/model/model_signature.py +3 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/explain_visualize.py +424 -0
- snowflake/ml/registry/_manager/model_manager.py +23 -2
- snowflake/ml/registry/registry.py +10 -9
- snowflake/ml/utils/connection_params.py +8 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/METADATA +58 -8
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/RECORD +196 -195
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/top_level.txt +0 -0
snowflake/cortex/__init__.py
CHANGED
@@ -1,5 +1,10 @@
|
|
1
1
|
from snowflake.cortex._classify_text import ClassifyText, classify_text
|
2
|
-
from snowflake.cortex._complete import
|
2
|
+
from snowflake.cortex._complete import (
|
3
|
+
Complete,
|
4
|
+
CompleteOptions,
|
5
|
+
ConversationMessage,
|
6
|
+
complete,
|
7
|
+
)
|
3
8
|
from snowflake.cortex._embed_text_768 import EmbedText768, embed_text_768
|
4
9
|
from snowflake.cortex._embed_text_1024 import EmbedText1024, embed_text_1024
|
5
10
|
from snowflake.cortex._extract_answer import ExtractAnswer, extract_answer
|
@@ -14,6 +19,7 @@ __all__ = [
|
|
14
19
|
"Complete",
|
15
20
|
"complete",
|
16
21
|
"CompleteOptions",
|
22
|
+
"ConversationMessage",
|
17
23
|
"EmbedText768",
|
18
24
|
"embed_text_768",
|
19
25
|
"EmbedText1024",
|
@@ -11,6 +11,9 @@ from snowflake.snowpark import (
|
|
11
11
|
session as snowpark_session,
|
12
12
|
)
|
13
13
|
|
14
|
+
LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
|
15
|
+
INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC"
|
16
|
+
|
14
17
|
|
15
18
|
class PlatformCapabilities:
|
16
19
|
"""Class that retrieves platform feature values for the currently running server.
|
@@ -18,12 +21,12 @@ class PlatformCapabilities:
|
|
18
21
|
Example usage:
|
19
22
|
```
|
20
23
|
pc = PlatformCapabilities.get_instance(session)
|
21
|
-
if pc.
|
22
|
-
#
|
23
|
-
print("
|
24
|
+
if pc.is_inlined_deployment_spec_enabled():
|
25
|
+
# Inline deployment spec is enabled.
|
26
|
+
print("Inline deployment spec is enabled.")
|
24
27
|
else:
|
25
|
-
#
|
26
|
-
print("
|
28
|
+
# Inline deployment spec is disabled.
|
29
|
+
print("Inline deployment spec is disabled or not supported.")
|
27
30
|
```
|
28
31
|
"""
|
29
32
|
|
@@ -50,9 +53,11 @@ class PlatformCapabilities:
|
|
50
53
|
|
51
54
|
# For contextmanager, we need to have return type Iterator[Never]. However, Never type is introduced only in
|
52
55
|
# Python 3.11. So, we are ignoring the type for this method.
|
56
|
+
_dummy_features: dict[str, Any] = {"dummy": "dummy"}
|
57
|
+
|
53
58
|
@classmethod # type: ignore[arg-type]
|
54
59
|
@contextmanager
|
55
|
-
def mock_features(cls, features: dict[str, Any]) -> None: # type: ignore[misc]
|
60
|
+
def mock_features(cls, features: dict[str, Any] = _dummy_features) -> None: # type: ignore[misc]
|
56
61
|
logging.debug(f"Setting mock features: {features}")
|
57
62
|
cls.set_mock_features(features)
|
58
63
|
try:
|
@@ -61,14 +66,11 @@ class PlatformCapabilities:
|
|
61
66
|
logging.debug(f"Clearing mock features: {features}")
|
62
67
|
cls.clear_mock_features()
|
63
68
|
|
64
|
-
def is_nested_function_enabled(self) -> bool:
|
65
|
-
return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
|
66
|
-
|
67
69
|
def is_inlined_deployment_spec_enabled(self) -> bool:
|
68
|
-
return self._get_bool_feature(
|
70
|
+
return self._get_bool_feature(INLINE_DEPLOYMENT_SPEC_PARAMETER, False)
|
69
71
|
|
70
72
|
def is_live_commit_enabled(self) -> bool:
|
71
|
-
return self._get_bool_feature(
|
73
|
+
return self._get_bool_feature(LIVE_COMMIT_PARAMETER, False)
|
72
74
|
|
73
75
|
@staticmethod
|
74
76
|
def _get_features(session: snowpark_session.Session) -> dict[str, Any]:
|
@@ -4,6 +4,7 @@ import enum
|
|
4
4
|
import functools
|
5
5
|
import inspect
|
6
6
|
import operator
|
7
|
+
import os
|
7
8
|
import sys
|
8
9
|
import time
|
9
10
|
import traceback
|
@@ -13,7 +14,7 @@ from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union, c
|
|
13
14
|
from typing_extensions import ParamSpec
|
14
15
|
|
15
16
|
from snowflake import connector
|
16
|
-
from snowflake.connector import telemetry as connector_telemetry, time_util
|
17
|
+
from snowflake.connector import connect, telemetry as connector_telemetry, time_util
|
17
18
|
from snowflake.ml import version as snowml_version
|
18
19
|
from snowflake.ml._internal import env
|
19
20
|
from snowflake.ml._internal.exceptions import (
|
@@ -37,6 +38,37 @@ _Args = ParamSpec("_Args")
|
|
37
38
|
_ReturnValue = TypeVar("_ReturnValue")
|
38
39
|
|
39
40
|
|
41
|
+
def _get_login_token() -> Union[str, bytes]:
|
42
|
+
with open("/snowflake/session/token") as f:
|
43
|
+
return f.read()
|
44
|
+
|
45
|
+
|
46
|
+
def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
|
47
|
+
conn = None
|
48
|
+
if os.getenv("SNOWFLAKE_HOST") is not None and os.getenv("SNOWFLAKE_ACCOUNT") is not None:
|
49
|
+
try:
|
50
|
+
conn = connect(
|
51
|
+
host=os.getenv("SNOWFLAKE_HOST"),
|
52
|
+
account=os.getenv("SNOWFLAKE_ACCOUNT"),
|
53
|
+
token=_get_login_token(),
|
54
|
+
authenticator="oauth",
|
55
|
+
)
|
56
|
+
except Exception:
|
57
|
+
# Failed to get a new SnowflakeConnection in SPCS. Fall back to using the active session.
|
58
|
+
# This will work in some cases once SPCS enables multiple authentication modes, and users select any auth.
|
59
|
+
pass
|
60
|
+
|
61
|
+
if conn is None:
|
62
|
+
try:
|
63
|
+
active_session = next(iter(session._get_active_sessions()))
|
64
|
+
conn = active_session._conn._conn if active_session.telemetry_enabled else None
|
65
|
+
except snowpark_exceptions.SnowparkSessionException:
|
66
|
+
# Failed to get an active session. No connection available.
|
67
|
+
pass
|
68
|
+
|
69
|
+
return conn
|
70
|
+
|
71
|
+
|
40
72
|
@enum.unique
|
41
73
|
class TelemetryProject(enum.Enum):
|
42
74
|
MLOPS = "MLOps"
|
@@ -378,10 +410,14 @@ def send_custom_usage(
|
|
378
410
|
data: Optional[dict[str, Any]] = None,
|
379
411
|
**kwargs: Any,
|
380
412
|
) -> None:
|
381
|
-
|
382
|
-
|
413
|
+
conn = _get_snowflake_connection()
|
414
|
+
if conn is None:
|
415
|
+
raise ValueError(
|
416
|
+
"""Snowflake connection is required to send custom telemetry. This means there
|
417
|
+
must be at least one active session, or that telemetry is being sent from within an SPCS service."""
|
418
|
+
)
|
383
419
|
|
384
|
-
client = _SourceTelemetryClient(conn=
|
420
|
+
client = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject)
|
385
421
|
common_metrics = client._create_basic_telemetry_data(telemetry_type=telemetry_type)
|
386
422
|
data = {**common_metrics, TelemetryField.KEY_DATA.value: data, **kwargs}
|
387
423
|
client._send(msg=data)
|
@@ -501,7 +537,6 @@ def send_api_usage_telemetry(
|
|
501
537
|
return update_stmt_params_if_snowpark_df(result, statement_params)
|
502
538
|
|
503
539
|
# prioritize `conn_attr_name` over the active session
|
504
|
-
telemetry_enabled = True
|
505
540
|
if conn_attr_name:
|
506
541
|
# raise AttributeError if conn attribute does not exist in `self`
|
507
542
|
conn = operator.attrgetter(conn_attr_name)(args[0])
|
@@ -509,16 +544,10 @@ def send_api_usage_telemetry(
|
|
509
544
|
raise TypeError(
|
510
545
|
f"Expected a conn object of type {' or '.join(_CONNECTION_TYPES.keys())} but got {type(conn)}"
|
511
546
|
)
|
512
|
-
# get an active session
|
513
547
|
else:
|
514
|
-
|
515
|
-
active_session = next(iter(session._get_active_sessions()))
|
516
|
-
conn = active_session._conn._conn
|
517
|
-
telemetry_enabled = active_session.telemetry_enabled
|
518
|
-
except snowpark_exceptions.SnowparkSessionException:
|
519
|
-
conn = None
|
548
|
+
conn = _get_snowflake_connection()
|
520
549
|
|
521
|
-
if conn is None
|
550
|
+
if conn is None:
|
522
551
|
# Telemetry not enabled, just execute without our additional telemetry logic
|
523
552
|
try:
|
524
553
|
return ctx.run(execute_func_with_statement_params)
|
@@ -12,7 +12,7 @@ SF_IDENTIFIER_RE = re.compile(_SF_IDENTIFIER)
|
|
12
12
|
_SF_SCHEMA_LEVEL_OBJECT = (
|
13
13
|
rf"(?:(?:(?P<db>{_SF_IDENTIFIER})\.)?(?P<schema>{_SF_IDENTIFIER})\.)?(?P<object>{_SF_IDENTIFIER})"
|
14
14
|
)
|
15
|
-
_SF_STAGE_PATH = rf"{_SF_SCHEMA_LEVEL_OBJECT}(?P<path
|
15
|
+
_SF_STAGE_PATH = rf"@?{_SF_SCHEMA_LEVEL_OBJECT}(?P<path>/.*)?"
|
16
16
|
_SF_SCHEMA_LEVEL_OBJECT_RE = re.compile(_SF_SCHEMA_LEVEL_OBJECT)
|
17
17
|
_SF_STAGE_PATH_RE = re.compile(_SF_STAGE_PATH)
|
18
18
|
|
@@ -197,7 +197,7 @@ def parse_snowflake_stage_path(
|
|
197
197
|
res.group("db"),
|
198
198
|
res.group("schema"),
|
199
199
|
res.group("object"),
|
200
|
-
res.group("path"),
|
200
|
+
res.group("path") or "",
|
201
201
|
)
|
202
202
|
|
203
203
|
|
@@ -249,7 +249,7 @@ class DataConnector:
|
|
249
249
|
|
250
250
|
# Switch to use Runtime's Data Ingester if running in ML runtime
|
251
251
|
# Fail silently if the data ingester is not found
|
252
|
-
if env.IN_ML_RUNTIME and os.getenv(env.USE_OPTIMIZED_DATA_INGESTOR):
|
252
|
+
if env.IN_ML_RUNTIME and os.getenv(env.USE_OPTIMIZED_DATA_INGESTOR, "").lower() in ("true", "1"):
|
253
253
|
try:
|
254
254
|
from runtime_external_entities import get_ingester_class
|
255
255
|
|
@@ -5,6 +5,7 @@ from snowflake.ml.jobs._utils.types import ComputeResources
|
|
5
5
|
DEFAULT_CONTAINER_NAME = "main"
|
6
6
|
PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
7
7
|
RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
|
8
|
+
MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
|
8
9
|
MEMORY_VOLUME_NAME = "dshm"
|
9
10
|
STAGE_VOLUME_NAME = "stage-volume"
|
10
11
|
STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
@@ -13,7 +14,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
|
13
14
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
14
15
|
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
15
16
|
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
16
|
-
DEFAULT_IMAGE_TAG = "1.
|
17
|
+
DEFAULT_IMAGE_TAG = "1.2.3"
|
17
18
|
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
18
19
|
|
19
20
|
# Percent of container memory to allocate for /dev/shm volume
|
@@ -37,6 +38,7 @@ RAY_PORTS = {
|
|
37
38
|
# Node health check configuration
|
38
39
|
# TODO(SNOW-1937020): Revisit the health check configuration
|
39
40
|
ML_RUNTIME_HEALTH_CHECK_PORT = "5001"
|
41
|
+
ENABLE_HEALTH_CHECKS_ENV_VAR = "ENABLE_HEALTH_CHECKS"
|
40
42
|
ENABLE_HEALTH_CHECKS = "false"
|
41
43
|
|
42
44
|
# Job status polling constants
|
@@ -47,6 +49,13 @@ JOB_POLL_MAX_DELAY_SECONDS = 1
|
|
47
49
|
IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
|
48
50
|
RESULT_PATH_DEFAULT_VALUE = "mljob_result.pkl"
|
49
51
|
|
52
|
+
# Log start and end messages
|
53
|
+
LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
|
54
|
+
LOG_END_MSG = "--------------------------------\nML job finished\n--------------------------------"
|
55
|
+
|
56
|
+
# Default setting for verbose logging in get_log function
|
57
|
+
DEFAULT_VERBOSE_LOG = False
|
58
|
+
|
50
59
|
# Compute pool resource information
|
51
60
|
# TODO: Query Snowflake for resource information instead of relying on this hardcoded
|
52
61
|
# table from https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool
|
@@ -80,7 +80,7 @@ def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult
|
|
80
80
|
# TODO: Check if file exists
|
81
81
|
with session.file.get_stream(result_path) as result_stream:
|
82
82
|
return ExecutionResult.from_dict(pickle.load(result_stream))
|
83
|
-
except (sp_exceptions.SnowparkSQLException, TypeError,
|
83
|
+
except (sp_exceptions.SnowparkSQLException, pickle.UnpicklingError, TypeError, ImportError):
|
84
84
|
# Fall back to JSON result if loading pickled result fails for any reason
|
85
85
|
result_json_path = os.path.splitext(result_path)[0] + ".json"
|
86
86
|
with session.file.get_stream(result_json_path) as result_stream:
|
@@ -9,6 +9,7 @@ from pathlib import Path, PurePath
|
|
9
9
|
from typing import Any, Callable, Optional, Union, cast, get_args, get_origin
|
10
10
|
|
11
11
|
import cloudpickle as cp
|
12
|
+
from packaging import version
|
12
13
|
|
13
14
|
from snowflake import snowpark
|
14
15
|
from snowflake.ml.jobs._utils import constants, types
|
@@ -97,11 +98,23 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
97
98
|
head_info=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --head)
|
98
99
|
if [ $? -eq 0 ]; then
|
99
100
|
# Parse the output using read
|
100
|
-
read head_index head_ip <<< "$head_info"
|
101
|
+
read head_index head_ip head_status<<< "$head_info"
|
102
|
+
|
103
|
+
if [ "$SNOWFLAKE_JOB_INDEX" -ne "$head_index" ]; then
|
104
|
+
NODE_TYPE="worker"
|
105
|
+
echo "{constants.LOG_START_MSG}"
|
106
|
+
fi
|
101
107
|
|
102
108
|
# Use the parsed variables
|
103
109
|
echo "Head Instance Index: $head_index"
|
104
110
|
echo "Head Instance IP: $head_ip"
|
111
|
+
echo "Head Instance Status: $head_status"
|
112
|
+
|
113
|
+
# If the head status is not "READY" or "PENDING", exit early
|
114
|
+
if [ "$head_status" != "READY" ] && [ "$head_status" != "PENDING" ]; then
|
115
|
+
echo "Head instance status is not READY or PENDING. Exiting."
|
116
|
+
exit 0
|
117
|
+
fi
|
105
118
|
|
106
119
|
else
|
107
120
|
echo "Error: Failed to get head instance information."
|
@@ -109,9 +122,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
109
122
|
exit 1
|
110
123
|
fi
|
111
124
|
|
112
|
-
|
113
|
-
NODE_TYPE="worker"
|
114
|
-
fi
|
125
|
+
|
115
126
|
fi
|
116
127
|
|
117
128
|
# Common parameters for both head and worker nodes
|
@@ -160,6 +171,10 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
160
171
|
# Start Ray on a worker node - run in background
|
161
172
|
ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block &
|
162
173
|
|
174
|
+
echo "Worker node started on address $eth0Ip. See more logs in the head node."
|
175
|
+
|
176
|
+
echo "{constants.LOG_END_MSG}"
|
177
|
+
|
163
178
|
# Start the worker shutdown listener in the background
|
164
179
|
echo "Starting worker shutdown listener..."
|
165
180
|
python worker_shutdown_listener.py
|
@@ -181,15 +196,16 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
181
196
|
|
182
197
|
# Start Ray on the head node
|
183
198
|
ray start "${{common_params[@]}}" "${{head_params[@]}}" -v
|
199
|
+
|
184
200
|
##### End Ray configuration #####
|
185
201
|
|
186
202
|
# TODO: Monitor MLRS and handle process crashes
|
187
203
|
python -m web.ml_runtime_grpc_server &
|
188
204
|
|
189
205
|
# TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
|
206
|
+
echo Running command: python "$@"
|
190
207
|
|
191
208
|
# Run user's Python entrypoint
|
192
|
-
echo Running command: python "$@"
|
193
209
|
python "$@"
|
194
210
|
|
195
211
|
# After the user's job completes, signal workers to shut down
|
@@ -278,17 +294,19 @@ class JobPayload:
|
|
278
294
|
stage_path = PurePath(stage_path) if isinstance(stage_path, str) else stage_path
|
279
295
|
source = resolve_source(self.source)
|
280
296
|
entrypoint = resolve_entrypoint(source, self.entrypoint)
|
297
|
+
pip_requirements = self.pip_requirements or []
|
281
298
|
|
282
299
|
# Create stage if necessary
|
283
300
|
stage_name = stage_path.parts[0].lstrip("@")
|
284
301
|
# Explicitly check if stage exists first since we may not have CREATE STAGE privilege
|
285
302
|
try:
|
286
|
-
session.sql(
|
303
|
+
session.sql("describe stage identifier(?)", params=[stage_name]).collect()
|
287
304
|
except sp_exceptions.SnowparkSQLException:
|
288
305
|
session.sql(
|
289
|
-
|
306
|
+
"create stage if not exists identifier(?)"
|
290
307
|
" encryption = ( type = 'SNOWFLAKE_SSE' )"
|
291
|
-
" comment = 'Created by snowflake.ml.jobs Python API'"
|
308
|
+
" comment = 'Created by snowflake.ml.jobs Python API'",
|
309
|
+
params=[stage_name],
|
292
310
|
).collect()
|
293
311
|
|
294
312
|
# Upload payload to stage
|
@@ -301,6 +319,8 @@ class JobPayload:
|
|
301
319
|
overwrite=True,
|
302
320
|
)
|
303
321
|
source = Path(entrypoint.file_path.parent)
|
322
|
+
if not any(r.startswith("cloudpickle") for r in pip_requirements):
|
323
|
+
pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
|
304
324
|
elif source.is_dir():
|
305
325
|
# Manually traverse the directory and upload each file, since Snowflake PUT
|
306
326
|
# can't handle directories. Reduce the number of PUT operations by using
|
@@ -325,10 +345,10 @@ class JobPayload:
|
|
325
345
|
|
326
346
|
# Upload requirements
|
327
347
|
# TODO: Check if payload includes both a requirements.txt file and pip_requirements
|
328
|
-
if
|
348
|
+
if pip_requirements:
|
329
349
|
# Upload requirements.txt to stage
|
330
350
|
session.file.put_stream(
|
331
|
-
io.BytesIO("\n".join(
|
351
|
+
io.BytesIO("\n".join(pip_requirements).encode()),
|
332
352
|
stage_location=stage_path.joinpath("requirements.txt").as_posix(),
|
333
353
|
auto_compress=False,
|
334
354
|
overwrite=True,
|
@@ -495,13 +515,6 @@ def generate_python_code(func: Callable[..., Any], source_code_display: bool = F
|
|
495
515
|
# https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
|
496
516
|
source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
|
497
517
|
|
498
|
-
func_code = f"""
|
499
|
-
{source_code_comment}
|
500
|
-
|
501
|
-
import pickle
|
502
|
-
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
|
503
|
-
"""
|
504
|
-
|
505
518
|
arg_dict_name = "kwargs"
|
506
519
|
if getattr(func, constants.IS_MLJOB_REMOTE_ATTR, None):
|
507
520
|
param_code = f"{arg_dict_name} = {{}}"
|
@@ -509,25 +522,29 @@ import pickle
|
|
509
522
|
param_code = _generate_param_handler_code(signature, arg_dict_name)
|
510
523
|
|
511
524
|
return f"""
|
512
|
-
### Version guard to check compatibility across Python versions ###
|
513
|
-
import os
|
514
525
|
import sys
|
515
|
-
import
|
516
|
-
|
517
|
-
if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
|
518
|
-
warnings.warn(
|
519
|
-
"Python version mismatch: job was created using"
|
520
|
-
" python{sys.version_info.major}.{sys.version_info.minor}"
|
521
|
-
f" but runtime environment uses python{{sys.version_info.major}}.{{sys.version_info.minor}}."
|
522
|
-
" Compatibility across Python versions is not guaranteed and may result in unexpected behavior."
|
523
|
-
" This will be fixed in a future release; for now, please use Python version"
|
524
|
-
f" {{sys.version_info.major}}.{{sys.version_info.minor}}.",
|
525
|
-
RuntimeWarning,
|
526
|
-
stacklevel=0,
|
527
|
-
)
|
528
|
-
### End version guard ###
|
526
|
+
import pickle
|
529
527
|
|
530
|
-
|
528
|
+
try:
|
529
|
+
{textwrap.indent(source_code_comment, ' ')}
|
530
|
+
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
|
531
|
+
except (TypeError, pickle.PickleError):
|
532
|
+
if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
|
533
|
+
raise RuntimeError(
|
534
|
+
"Failed to deserialize function due to Python version mismatch."
|
535
|
+
f" Runtime environment is Python {{sys.version_info.major}}.{{sys.version_info.minor}}"
|
536
|
+
" but function was serialized using Python {sys.version_info.major}.{sys.version_info.minor}."
|
537
|
+
) from None
|
538
|
+
raise
|
539
|
+
except AttributeError as e:
|
540
|
+
if 'cloudpickle' in str(e):
|
541
|
+
import cloudpickle as cp
|
542
|
+
raise RuntimeError(
|
543
|
+
"Failed to deserialize function due to cloudpickle version mismatch."
|
544
|
+
f" Runtime environment uses cloudpickle=={{cp.__version__}}"
|
545
|
+
" but job was serialized using cloudpickle=={cp.__version__}."
|
546
|
+
) from e
|
547
|
+
raise
|
531
548
|
|
532
549
|
if __name__ == '__main__':
|
533
550
|
{textwrap.indent(param_code, ' ')}
|
@@ -2,3 +2,9 @@
|
|
2
2
|
SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
|
3
3
|
SHUTDOWN_ACTOR_NAMESPACE = "default"
|
4
4
|
SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0
|
5
|
+
|
6
|
+
|
7
|
+
# Log start and end messages
|
8
|
+
# Inherited from snowflake.ml.jobs._utils.constants
|
9
|
+
LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
|
10
|
+
LOG_END_MSG = "--------------------------------\nML job finished\n--------------------------------"
|
@@ -29,7 +29,7 @@ def get_self_ip() -> Optional[str]:
|
|
29
29
|
return None
|
30
30
|
|
31
31
|
|
32
|
-
def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
|
32
|
+
def get_first_instance(service_name: str) -> Optional[tuple[str, str, str]]:
|
33
33
|
"""Get the first instance of a batch job based on start time and instance ID.
|
34
34
|
|
35
35
|
Args:
|
@@ -42,7 +42,7 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
|
|
42
42
|
|
43
43
|
session = session_utils.get_session()
|
44
44
|
df = session.sql(f"show service instances in service {service_name}")
|
45
|
-
result = df.select('"instance_id"', '"ip_address"', '"start_time"').collect()
|
45
|
+
result = df.select('"instance_id"', '"ip_address"', '"start_time"', '"status"').collect()
|
46
46
|
|
47
47
|
if not result:
|
48
48
|
return None
|
@@ -57,7 +57,7 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
|
|
57
57
|
ip_address = head_instance["ip_address"]
|
58
58
|
try:
|
59
59
|
socket.inet_aton(ip_address) # Validate IPv4 address
|
60
|
-
return (head_instance["instance_id"], ip_address)
|
60
|
+
return (head_instance["instance_id"], ip_address, head_instance["status"])
|
61
61
|
except OSError:
|
62
62
|
logger.error(f"Error: Invalid IP address format: {ip_address}")
|
63
63
|
return None
|
@@ -110,7 +110,7 @@ def main():
|
|
110
110
|
head_info = get_first_instance(args.service_name)
|
111
111
|
if head_info:
|
112
112
|
# Print to stdout to allow capture but don't use logger
|
113
|
-
sys.stdout.write(
|
113
|
+
sys.stdout.write(" ".join(head_info) + "\n")
|
114
114
|
sys.exit(0)
|
115
115
|
time.sleep(args.retry_interval)
|
116
116
|
# If we get here, we've timed out
|
@@ -2,25 +2,35 @@ import argparse
|
|
2
2
|
import copy
|
3
3
|
import importlib.util
|
4
4
|
import json
|
5
|
+
import logging
|
5
6
|
import os
|
6
7
|
import runpy
|
7
8
|
import sys
|
9
|
+
import time
|
8
10
|
import traceback
|
9
11
|
import warnings
|
10
12
|
from pathlib import Path
|
11
13
|
from typing import Any, Optional
|
12
14
|
|
13
15
|
import cloudpickle
|
16
|
+
from constants import LOG_END_MSG, LOG_START_MSG
|
14
17
|
|
15
18
|
from snowflake.ml.jobs._utils import constants
|
16
19
|
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
17
20
|
from snowflake.snowpark import Session
|
18
21
|
|
22
|
+
# Configure logging
|
23
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
19
26
|
# Fallbacks in case of SnowML version mismatch
|
20
27
|
RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
|
21
|
-
|
22
28
|
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "mljob_result.pkl")
|
23
29
|
|
30
|
+
# Constants for the wait_for_min_instances function
|
31
|
+
CHECK_INTERVAL = 10 # seconds
|
32
|
+
TIMEOUT = 720 # seconds
|
33
|
+
|
24
34
|
|
25
35
|
try:
|
26
36
|
from snowflake.ml.jobs._utils.interop_utils import ExecutionResult
|
@@ -59,7 +69,67 @@ class SimpleJSONEncoder(json.JSONEncoder):
|
|
59
69
|
try:
|
60
70
|
return super().default(obj)
|
61
71
|
except TypeError:
|
62
|
-
return
|
72
|
+
return f"Unserializable object: {repr(obj)}"
|
73
|
+
|
74
|
+
|
75
|
+
def get_active_node_count() -> int:
|
76
|
+
"""
|
77
|
+
Count the number of active nodes in the Ray cluster.
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
int: Total count of active nodes
|
81
|
+
"""
|
82
|
+
import ray
|
83
|
+
|
84
|
+
if not ray.is_initialized():
|
85
|
+
ray.init(address="auto", ignore_reinit_error=True, log_to_driver=False)
|
86
|
+
try:
|
87
|
+
nodes = [node for node in ray.nodes() if node.get("Alive")]
|
88
|
+
total_active = len(nodes)
|
89
|
+
|
90
|
+
logger.info(f"Active nodes: {total_active}")
|
91
|
+
return total_active
|
92
|
+
except Exception as e:
|
93
|
+
logger.warning(f"Error getting active node count: {e}")
|
94
|
+
return 0
|
95
|
+
|
96
|
+
|
97
|
+
def wait_for_min_instances(min_instances: int) -> None:
|
98
|
+
"""
|
99
|
+
Wait until the specified minimum number of instances are available in the Ray cluster.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
min_instances: Minimum number of instances required
|
103
|
+
|
104
|
+
Raises:
|
105
|
+
TimeoutError: If failed to connect to Ray or if minimum instances are not available within timeout
|
106
|
+
"""
|
107
|
+
if min_instances <= 1:
|
108
|
+
logger.debug("Minimum instances is 1 or less, no need to wait for additional instances")
|
109
|
+
return
|
110
|
+
|
111
|
+
start_time = time.time()
|
112
|
+
timeout = os.getenv("JOB_MIN_INSTANCES_TIMEOUT", TIMEOUT)
|
113
|
+
check_interval = os.getenv("JOB_MIN_INSTANCES_CHECK_INTERVAL", CHECK_INTERVAL)
|
114
|
+
logger.debug(f"Waiting for at least {min_instances} instances to be ready (timeout: {timeout}s)")
|
115
|
+
|
116
|
+
while time.time() - start_time < timeout:
|
117
|
+
total_nodes = get_active_node_count()
|
118
|
+
|
119
|
+
if total_nodes >= min_instances:
|
120
|
+
elapsed = time.time() - start_time
|
121
|
+
logger.info(f"Minimum instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
|
122
|
+
return
|
123
|
+
|
124
|
+
logger.debug(
|
125
|
+
f"Waiting for instances: {total_nodes}/{min_instances} available "
|
126
|
+
f"(elapsed: {time.time() - start_time:.1f}s)"
|
127
|
+
)
|
128
|
+
time.sleep(check_interval)
|
129
|
+
|
130
|
+
raise TimeoutError(
|
131
|
+
f"Timed out after {timeout}s waiting for {min_instances} instances, only {get_active_node_count()} available"
|
132
|
+
)
|
63
133
|
|
64
134
|
|
65
135
|
def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = None) -> Any:
|
@@ -86,6 +156,7 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
86
156
|
session = Session.builder.configs(SnowflakeLoginOptions()).create() # noqa: F841
|
87
157
|
|
88
158
|
try:
|
159
|
+
|
89
160
|
if main_func:
|
90
161
|
# Use importlib for scripts with a main function defined
|
91
162
|
module_name = Path(script_path).stem
|
@@ -126,9 +197,21 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
126
197
|
Raises:
|
127
198
|
Exception: Re-raises any exception caught during script execution.
|
128
199
|
"""
|
129
|
-
# Run the script with the specified arguments
|
130
200
|
try:
|
201
|
+
# Wait for minimum required instances if specified
|
202
|
+
min_instances_str = os.environ.get("JOB_MIN_INSTANCES", 1)
|
203
|
+
if min_instances_str and int(min_instances_str) > 1:
|
204
|
+
wait_for_min_instances(int(min_instances_str))
|
205
|
+
|
206
|
+
# Log start marker for user script execution
|
207
|
+
print(LOG_START_MSG) # noqa: T201
|
208
|
+
|
209
|
+
# Run the script with the specified arguments
|
131
210
|
result = run_script(script_path, *script_args, main_func=script_main_func)
|
211
|
+
|
212
|
+
# Log end marker for user script execution
|
213
|
+
print(LOG_END_MSG) # noqa: T201
|
214
|
+
|
132
215
|
result_obj = ExecutionResult(result=result)
|
133
216
|
return result_obj
|
134
217
|
except Exception as e:
|
@@ -11,7 +11,7 @@ from snowflake.ml.jobs._utils import constants, types
|
|
11
11
|
def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
|
12
12
|
"""Extract resource information for the specified compute pool"""
|
13
13
|
# Get the instance family
|
14
|
-
rows = session.sql(
|
14
|
+
rows = session.sql("show compute pools like ?", params=[compute_pool]).collect()
|
15
15
|
if not rows:
|
16
16
|
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
17
17
|
instance_family: str = rows[0]["instance_family"]
|
@@ -85,7 +85,8 @@ def generate_service_spec(
|
|
85
85
|
compute_pool: str,
|
86
86
|
payload: types.UploadedPayload,
|
87
87
|
args: Optional[list[str]] = None,
|
88
|
-
|
88
|
+
target_instances: int = 1,
|
89
|
+
min_instances: int = 1,
|
89
90
|
enable_metrics: bool = False,
|
90
91
|
) -> dict[str, Any]:
|
91
92
|
"""
|
@@ -96,13 +97,13 @@ def generate_service_spec(
|
|
96
97
|
compute_pool: Compute pool for job execution
|
97
98
|
payload: Uploaded job payload
|
98
99
|
args: Arguments to pass to entrypoint script
|
99
|
-
|
100
|
+
target_instances: Number of instances for multi-node job
|
100
101
|
enable_metrics: Enable platform metrics for the job
|
102
|
+
min_instances: Minimum number of instances required to start the job
|
101
103
|
|
102
104
|
Returns:
|
103
105
|
Job service specification
|
104
106
|
"""
|
105
|
-
is_multi_node = num_instances is not None and num_instances > 1
|
106
107
|
image_spec = _get_image_spec(session, compute_pool)
|
107
108
|
|
108
109
|
# Set resource requests/limits, including nvidia.com/gpu quantity if applicable
|
@@ -180,10 +181,11 @@ def generate_service_spec(
|
|
180
181
|
}
|
181
182
|
endpoints = []
|
182
183
|
|
183
|
-
if
|
184
|
+
if target_instances > 1:
|
184
185
|
# Update environment variables for multi-node job
|
185
186
|
env_vars.update(constants.RAY_PORTS)
|
186
|
-
env_vars[
|
187
|
+
env_vars[constants.ENABLE_HEALTH_CHECKS_ENV_VAR] = constants.ENABLE_HEALTH_CHECKS
|
188
|
+
env_vars[constants.MIN_INSTANCES_ENV_VAR] = str(min_instances)
|
187
189
|
|
188
190
|
# Define Ray endpoints for intra-service instance communication
|
189
191
|
ray_endpoints = [
|