snowflake-ml-python 1.8.3__py3-none-any.whl → 1.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/ml/_internal/platform_capabilities.py +13 -11
- snowflake/ml/_internal/telemetry.py +42 -13
- snowflake/ml/_internal/utils/identifier.py +2 -2
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/jobs/_utils/constants.py +10 -1
- snowflake/ml/jobs/_utils/interop_utils.py +1 -1
- snowflake/ml/jobs/_utils/payload_utils.py +51 -34
- snowflake/ml/jobs/_utils/scripts/constants.py +6 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +86 -3
- snowflake/ml/jobs/_utils/spec_utils.py +8 -6
- snowflake/ml/jobs/decorators.py +13 -3
- snowflake/ml/jobs/job.py +206 -26
- snowflake/ml/jobs/manager.py +78 -34
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/service_ops.py +31 -17
- snowflake/ml/model/_client/service/model_deployment_spec.py +351 -170
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -0
- snowflake/ml/model/_client/sql/model_version.py +1 -1
- snowflake/ml/model/_client/sql/service.py +20 -32
- snowflake/ml/model/_model_composer/model_composer.py +44 -19
- snowflake/ml/model/_packager/model_handlers/_utils.py +32 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +100 -41
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +7 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +16 -7
- snowflake/ml/model/_packager/model_meta/model_meta.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +5 -4
- snowflake/ml/model/_signatures/dmatrix_handler.py +15 -2
- snowflake/ml/model/custom_model.py +17 -4
- snowflake/ml/model/model_signature.py +3 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/explain_visualize.py +424 -0
- snowflake/ml/registry/_manager/model_manager.py +23 -2
- snowflake/ml/registry/registry.py +10 -9
- snowflake/ml/utils/connection_params.py +8 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/METADATA +58 -8
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/RECORD +196 -195
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/decorators.py
CHANGED
@@ -24,8 +24,11 @@ def remote(
|
|
24
24
|
external_access_integrations: Optional[list[str]] = None,
|
25
25
|
query_warehouse: Optional[str] = None,
|
26
26
|
env_vars: Optional[dict[str, str]] = None,
|
27
|
-
|
27
|
+
target_instances: int = 1,
|
28
|
+
min_instances: int = 1,
|
28
29
|
enable_metrics: bool = False,
|
30
|
+
database: Optional[str] = None,
|
31
|
+
schema: Optional[str] = None,
|
29
32
|
session: Optional[snowpark.Session] = None,
|
30
33
|
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
|
31
34
|
"""
|
@@ -38,8 +41,12 @@ def remote(
|
|
38
41
|
external_access_integrations: A list of external access integrations.
|
39
42
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
40
43
|
env_vars: Environment variables to set in container
|
41
|
-
|
44
|
+
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
45
|
+
min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
|
46
|
+
If set, the job will not start until the minimum number of nodes is available.
|
42
47
|
enable_metrics: Whether to enable metrics publishing for the job.
|
48
|
+
database: The database to use for the job.
|
49
|
+
schema: The schema to use for the job.
|
43
50
|
session: The Snowpark session to use. If none specified, uses active session.
|
44
51
|
|
45
52
|
Returns:
|
@@ -65,8 +72,11 @@ def remote(
|
|
65
72
|
external_access_integrations=external_access_integrations,
|
66
73
|
query_warehouse=query_warehouse,
|
67
74
|
env_vars=env_vars,
|
68
|
-
|
75
|
+
target_instances=target_instances,
|
76
|
+
min_instances=min_instances,
|
69
77
|
enable_metrics=enable_metrics,
|
78
|
+
database=database,
|
79
|
+
schema=schema,
|
70
80
|
session=session,
|
71
81
|
)
|
72
82
|
assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
|
snowflake/ml/jobs/job.py
CHANGED
@@ -1,18 +1,25 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
1
3
|
import time
|
4
|
+
from functools import cached_property
|
2
5
|
from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
|
3
6
|
|
4
7
|
import yaml
|
5
8
|
|
6
9
|
from snowflake import snowpark
|
7
10
|
from snowflake.ml._internal import telemetry
|
11
|
+
from snowflake.ml._internal.utils import identifier
|
8
12
|
from snowflake.ml.jobs._utils import constants, interop_utils, types
|
9
|
-
from snowflake.snowpark import context as sp_context
|
13
|
+
from snowflake.snowpark import Row, context as sp_context
|
14
|
+
from snowflake.snowpark.exceptions import SnowparkSQLException
|
10
15
|
|
11
16
|
_PROJECT = "MLJob"
|
12
17
|
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
|
13
18
|
|
14
19
|
T = TypeVar("T")
|
15
20
|
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
16
23
|
|
17
24
|
class MLJob(Generic[T]):
|
18
25
|
def __init__(
|
@@ -28,6 +35,21 @@ class MLJob(Generic[T]):
|
|
28
35
|
self._status: types.JOB_STATUS = "PENDING"
|
29
36
|
self._result: Optional[interop_utils.ExecutionResult] = None
|
30
37
|
|
38
|
+
@cached_property
|
39
|
+
def name(self) -> str:
|
40
|
+
return identifier.parse_schema_level_object_identifier(self.id)[-1]
|
41
|
+
|
42
|
+
@cached_property
|
43
|
+
def target_instances(self) -> int:
|
44
|
+
return _get_target_instances(self._session, self.id)
|
45
|
+
|
46
|
+
@cached_property
|
47
|
+
def min_instances(self) -> int:
|
48
|
+
try:
|
49
|
+
return int(self._container_spec["env"].get(constants.MIN_INSTANCES_ENV_VAR, 1))
|
50
|
+
except TypeError:
|
51
|
+
return 1
|
52
|
+
|
31
53
|
@property
|
32
54
|
def id(self) -> str:
|
33
55
|
"""Get the unique job ID"""
|
@@ -41,6 +63,12 @@ class MLJob(Generic[T]):
|
|
41
63
|
self._status = _get_status(self._session, self.id)
|
42
64
|
return self._status
|
43
65
|
|
66
|
+
@cached_property
|
67
|
+
def _compute_pool(self) -> str:
|
68
|
+
"""Get the job's compute pool name."""
|
69
|
+
row = _get_service_info(self._session, self.id)
|
70
|
+
return cast(str, row["compute_pool"])
|
71
|
+
|
44
72
|
@property
|
45
73
|
def _service_spec(self) -> dict[str, Any]:
|
46
74
|
"""Get the job's service spec."""
|
@@ -67,19 +95,38 @@ class MLJob(Generic[T]):
|
|
67
95
|
"""Get the job's result file location."""
|
68
96
|
result_path = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
|
69
97
|
if result_path is None:
|
70
|
-
raise RuntimeError(f"Job {self.
|
98
|
+
raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
|
71
99
|
return f"{self._stage_path}/{result_path}"
|
72
100
|
|
73
101
|
@overload
|
74
|
-
def get_logs(
|
102
|
+
def get_logs(
|
103
|
+
self,
|
104
|
+
limit: int = -1,
|
105
|
+
instance_id: Optional[int] = None,
|
106
|
+
*,
|
107
|
+
as_list: Literal[True],
|
108
|
+
verbose: bool = constants.DEFAULT_VERBOSE_LOG,
|
109
|
+
) -> list[str]:
|
75
110
|
...
|
76
111
|
|
77
112
|
@overload
|
78
|
-
def get_logs(
|
113
|
+
def get_logs(
|
114
|
+
self,
|
115
|
+
limit: int = -1,
|
116
|
+
instance_id: Optional[int] = None,
|
117
|
+
*,
|
118
|
+
as_list: Literal[False] = False,
|
119
|
+
verbose: bool = constants.DEFAULT_VERBOSE_LOG,
|
120
|
+
) -> str:
|
79
121
|
...
|
80
122
|
|
81
123
|
def get_logs(
|
82
|
-
self,
|
124
|
+
self,
|
125
|
+
limit: int = -1,
|
126
|
+
instance_id: Optional[int] = None,
|
127
|
+
*,
|
128
|
+
as_list: bool = False,
|
129
|
+
verbose: bool = constants.DEFAULT_VERBOSE_LOG,
|
83
130
|
) -> Union[str, list[str]]:
|
84
131
|
"""
|
85
132
|
Return the job's execution logs.
|
@@ -89,17 +136,20 @@ class MLJob(Generic[T]):
|
|
89
136
|
instance_id: Optional instance ID to get logs from a specific instance.
|
90
137
|
If not provided, returns logs from the head node.
|
91
138
|
as_list: If True, returns logs as a list of lines. Otherwise, returns logs as a single string.
|
139
|
+
verbose: Whether to return the full log or just the user log.
|
92
140
|
|
93
141
|
Returns:
|
94
142
|
The job's execution logs.
|
95
143
|
"""
|
96
|
-
logs = _get_logs(self._session, self.id, limit, instance_id)
|
144
|
+
logs = _get_logs(self._session, self.id, limit, instance_id, verbose)
|
97
145
|
assert isinstance(logs, str) # mypy
|
98
146
|
if as_list:
|
99
147
|
return logs.splitlines()
|
100
148
|
return logs
|
101
149
|
|
102
|
-
def show_logs(
|
150
|
+
def show_logs(
|
151
|
+
self, limit: int = -1, instance_id: Optional[int] = None, verbose: bool = constants.DEFAULT_VERBOSE_LOG
|
152
|
+
) -> None:
|
103
153
|
"""
|
104
154
|
Display the job's execution logs.
|
105
155
|
|
@@ -107,8 +157,9 @@ class MLJob(Generic[T]):
|
|
107
157
|
limit: The maximum number of lines to display. Negative values are treated as no limit.
|
108
158
|
instance_id: Optional instance ID to get logs from a specific instance.
|
109
159
|
If not provided, displays logs from the head node.
|
160
|
+
verbose: Whether to return the full log or just the user log.
|
110
161
|
"""
|
111
|
-
print(self.get_logs(limit, instance_id, as_list=False)) # noqa: T201: we need to print here.
|
162
|
+
print(self.get_logs(limit, instance_id, as_list=False, verbose=verbose)) # noqa: T201: we need to print here.
|
112
163
|
|
113
164
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
|
114
165
|
def wait(self, timeout: float = -1) -> types.JOB_STATUS:
|
@@ -126,9 +177,18 @@ class MLJob(Generic[T]):
|
|
126
177
|
"""
|
127
178
|
delay = constants.JOB_POLL_INITIAL_DELAY_SECONDS # Start with 100ms delay
|
128
179
|
start_time = time.monotonic()
|
129
|
-
|
180
|
+
warning_shown = False
|
181
|
+
while (status := self.status) not in TERMINAL_JOB_STATUSES:
|
182
|
+
if status == "PENDING" and not warning_shown:
|
183
|
+
pool_info = _get_compute_pool_info(self._session, self._compute_pool)
|
184
|
+
if (pool_info.max_nodes - pool_info.active_nodes) < self.min_instances:
|
185
|
+
logger.warning(
|
186
|
+
f"Compute pool busy ({pool_info.active_nodes}/{pool_info.max_nodes} nodes in use)."
|
187
|
+
" Job execution may be delayed."
|
188
|
+
)
|
189
|
+
warning_shown = True
|
130
190
|
if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
|
131
|
-
raise TimeoutError(f"Job {self.
|
191
|
+
raise TimeoutError(f"Job {self.name} did not complete within {elapsed} seconds")
|
132
192
|
time.sleep(delay)
|
133
193
|
delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
134
194
|
return self.status
|
@@ -154,11 +214,11 @@ class MLJob(Generic[T]):
|
|
154
214
|
try:
|
155
215
|
self._result = interop_utils.fetch_result(self._session, self._result_path)
|
156
216
|
except Exception as e:
|
157
|
-
raise RuntimeError(f"Failed to retrieve result for job (id={self.
|
217
|
+
raise RuntimeError(f"Failed to retrieve result for job (id={self.name})") from e
|
158
218
|
|
159
219
|
if self._result.success:
|
160
220
|
return cast(T, self._result.result)
|
161
|
-
raise RuntimeError(f"Job execution failed (id={self.
|
221
|
+
raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception
|
162
222
|
|
163
223
|
|
164
224
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
|
@@ -172,19 +232,21 @@ def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[in
|
|
172
232
|
return cast(types.JOB_STATUS, row["status"])
|
173
233
|
raise ValueError(f"Instance {instance_id} not found in job {job_id}")
|
174
234
|
else:
|
175
|
-
|
235
|
+
row = _get_service_info(session, job_id)
|
176
236
|
return cast(types.JOB_STATUS, row["status"])
|
177
237
|
|
178
238
|
|
179
239
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
180
240
|
def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
|
181
241
|
"""Retrieve job execution service spec."""
|
182
|
-
|
242
|
+
row = _get_service_info(session, job_id)
|
183
243
|
return cast(dict[str, Any], yaml.safe_load(row["spec"]))
|
184
244
|
|
185
245
|
|
186
246
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
|
187
|
-
def _get_logs(
|
247
|
+
def _get_logs(
|
248
|
+
session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None, verbose: bool = True
|
249
|
+
) -> str:
|
188
250
|
"""
|
189
251
|
Retrieve the job's execution logs.
|
190
252
|
|
@@ -193,13 +255,20 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
|
|
193
255
|
limit: The maximum number of lines to return. Negative values are treated as no limit.
|
194
256
|
session: The Snowpark session to use. If none specified, uses active session.
|
195
257
|
instance_id: Optional instance ID to get logs from a specific instance.
|
258
|
+
verbose: Whether to return the full log or just the portion between START and END messages.
|
196
259
|
|
197
260
|
Returns:
|
198
261
|
The job's execution logs.
|
262
|
+
|
263
|
+
Raises:
|
264
|
+
RuntimeError: if failed to get head instance_id
|
199
265
|
"""
|
200
266
|
# If instance_id is not specified, try to get the head instance ID
|
201
267
|
if instance_id is None:
|
202
|
-
|
268
|
+
try:
|
269
|
+
instance_id = _get_head_instance_id(session, job_id)
|
270
|
+
except RuntimeError:
|
271
|
+
instance_id = None
|
203
272
|
|
204
273
|
# Assemble params: [job_id, instance_id, container_name, (optional) limit]
|
205
274
|
params: list[Any] = [
|
@@ -209,12 +278,50 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
|
|
209
278
|
]
|
210
279
|
if limit > 0:
|
211
280
|
params.append(limit)
|
281
|
+
try:
|
282
|
+
(row,) = session.sql(
|
283
|
+
f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
|
284
|
+
params=params,
|
285
|
+
).collect()
|
286
|
+
except SnowparkSQLException as e:
|
287
|
+
if "Container Status: PENDING" in e.message:
|
288
|
+
logger.warning("Waiting for container to start. Logs will be shown when available.")
|
289
|
+
return ""
|
290
|
+
else:
|
291
|
+
# event table accepts job name, not fully qualified name
|
292
|
+
# cast is to resolve the type check error
|
293
|
+
db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
|
294
|
+
db = cast(str, db or session.get_current_database())
|
295
|
+
schema = cast(str, schema or session.get_current_schema())
|
296
|
+
logs = _get_service_log_from_event_table(
|
297
|
+
session, db, schema, name, limit, instance_id if instance_id else None
|
298
|
+
)
|
299
|
+
if len(logs) == 0:
|
300
|
+
raise RuntimeError(
|
301
|
+
"No logs were found. Please verify that the database, schema, and job ID are correct."
|
302
|
+
)
|
303
|
+
return os.linesep.join(row[0] for row in logs)
|
304
|
+
|
305
|
+
full_log = str(row[0])
|
306
|
+
|
307
|
+
# If verbose is True, return the complete log
|
308
|
+
if verbose:
|
309
|
+
return full_log
|
310
|
+
|
311
|
+
# Otherwise, extract only the portion between LOG_START_MSG and LOG_END_MSG
|
312
|
+
start_idx = full_log.find(constants.LOG_START_MSG)
|
313
|
+
if start_idx != -1:
|
314
|
+
start_idx += len(constants.LOG_START_MSG)
|
315
|
+
else:
|
316
|
+
# If start message not found, start from the beginning
|
317
|
+
start_idx = 0
|
212
318
|
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
319
|
+
end_idx = full_log.find(constants.LOG_END_MSG, start_idx)
|
320
|
+
if end_idx == -1:
|
321
|
+
# If end message not found, return everything after start
|
322
|
+
end_idx = len(full_log)
|
323
|
+
|
324
|
+
return full_log[start_idx:end_idx].strip()
|
218
325
|
|
219
326
|
|
220
327
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
@@ -223,18 +330,31 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
223
330
|
Retrieve the head instance ID of a job.
|
224
331
|
|
225
332
|
Args:
|
226
|
-
session: The Snowpark session to use.
|
227
|
-
job_id: The job ID.
|
333
|
+
session (Session): The Snowpark session to use.
|
334
|
+
job_id (str): The job ID.
|
228
335
|
|
229
336
|
Returns:
|
230
|
-
The head instance ID of the job
|
337
|
+
Optional[int]: The head instance ID of the job, or None if the head instance has not started yet.
|
338
|
+
|
339
|
+
Raises:
|
340
|
+
RuntimeError: If the instances died or if some instances disappeared.
|
341
|
+
|
231
342
|
"""
|
232
|
-
|
343
|
+
try:
|
344
|
+
rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
345
|
+
except SnowparkSQLException:
|
346
|
+
# service may be deleted
|
347
|
+
raise RuntimeError("Couldn’t retrieve instances")
|
233
348
|
if not rows:
|
234
349
|
return None
|
350
|
+
if _get_target_instances(session, job_id) > len(rows):
|
351
|
+
raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
|
235
352
|
|
236
353
|
# Sort by start_time first, then by instance_id
|
237
|
-
|
354
|
+
try:
|
355
|
+
sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
356
|
+
except TypeError:
|
357
|
+
raise RuntimeError("Job instance information unavailable.")
|
238
358
|
head_instance = sorted_instances[0]
|
239
359
|
if not head_instance["start_time"]:
|
240
360
|
# If head instance hasn't started yet, return None
|
@@ -243,3 +363,63 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
243
363
|
return int(head_instance["instance_id"])
|
244
364
|
except (ValueError, TypeError):
|
245
365
|
return 0
|
366
|
+
|
367
|
+
|
368
|
+
def _get_service_log_from_event_table(
|
369
|
+
session: snowpark.Session, database: str, schema: str, name: str, limit: int, instance_id: Optional[int]
|
370
|
+
) -> list[Row]:
|
371
|
+
params: list[Any] = [
|
372
|
+
database,
|
373
|
+
schema,
|
374
|
+
name,
|
375
|
+
]
|
376
|
+
query = [
|
377
|
+
"SELECT VALUE FROM snowflake.telemetry.events_view",
|
378
|
+
'WHERE RESOURCE_ATTRIBUTES:"snow.database.name" = ?',
|
379
|
+
'AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?',
|
380
|
+
'AND RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
|
381
|
+
]
|
382
|
+
|
383
|
+
if instance_id:
|
384
|
+
query.append('AND RESOURCE_ATTRIBUTES:"snow.service.container.instance" = ?')
|
385
|
+
params.append(instance_id)
|
386
|
+
|
387
|
+
query.append("AND RECORD_TYPE = 'LOG'")
|
388
|
+
# sort by TIMESTAMP; although OBSERVED_TIMESTAMP is for log, it is NONE currently when record_type is log
|
389
|
+
query.append("ORDER BY TIMESTAMP")
|
390
|
+
|
391
|
+
if limit > 0:
|
392
|
+
query.append("LIMIT ?")
|
393
|
+
params.append(limit)
|
394
|
+
|
395
|
+
rows = session.sql(
|
396
|
+
"\n".join(line for line in query if line),
|
397
|
+
params=params,
|
398
|
+
).collect()
|
399
|
+
return rows
|
400
|
+
|
401
|
+
|
402
|
+
def _get_service_info(session: snowpark.Session, job_id: str) -> Row:
|
403
|
+
(row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
404
|
+
return row
|
405
|
+
|
406
|
+
|
407
|
+
def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
|
408
|
+
"""
|
409
|
+
Check if the compute pool has enough available instances.
|
410
|
+
|
411
|
+
Args:
|
412
|
+
session (Session): The Snowpark session to use.
|
413
|
+
compute_pool (str): The name of the compute pool.
|
414
|
+
|
415
|
+
Returns:
|
416
|
+
Row: The compute pool information.
|
417
|
+
"""
|
418
|
+
(pool_info,) = session.sql("SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,)).collect()
|
419
|
+
return pool_info
|
420
|
+
|
421
|
+
|
422
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
423
|
+
def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
|
424
|
+
row = _get_service_info(session, job_id)
|
425
|
+
return int(row["target_instances"]) if row["target_instances"] else 0
|
snowflake/ml/jobs/manager.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
import pathlib
|
3
3
|
import textwrap
|
4
|
-
from typing import Any, Callable, Literal, Optional, TypeVar, Union, overload
|
4
|
+
from typing import Any, Callable, Literal, Optional, TypeVar, Union, cast, overload
|
5
5
|
from uuid import uuid4
|
6
6
|
|
7
7
|
import yaml
|
@@ -52,7 +52,7 @@ def list_jobs(
|
|
52
52
|
query += f" LIMIT {limit}"
|
53
53
|
df = session.sql(query)
|
54
54
|
df = df.select(
|
55
|
-
df['"name"']
|
55
|
+
df['"name"'],
|
56
56
|
df['"owner"'],
|
57
57
|
df['"status"'],
|
58
58
|
df['"created_on"'],
|
@@ -65,16 +65,16 @@ def list_jobs(
|
|
65
65
|
def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob[Any]:
|
66
66
|
"""Retrieve a job service from the backend."""
|
67
67
|
session = session or get_active_session()
|
68
|
-
|
69
68
|
try:
|
70
|
-
|
71
|
-
|
69
|
+
database, schema, job_name = identifier.parse_schema_level_object_identifier(job_id)
|
70
|
+
database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
|
71
|
+
schema = identifier.resolve_identifier(cast(str, schema or session.get_current_schema()))
|
72
72
|
except ValueError as e:
|
73
73
|
raise ValueError(f"Invalid job ID: {job_id}") from e
|
74
74
|
|
75
|
+
job_id = f"{database}.{schema}.{job_name}"
|
75
76
|
try:
|
76
77
|
# Validate that job exists by doing a status check
|
77
|
-
# FIXME: Retrieve return path
|
78
78
|
job = jb.MLJob[Any](job_id, session=session)
|
79
79
|
_ = job.status
|
80
80
|
return job
|
@@ -108,8 +108,11 @@ def submit_file(
|
|
108
108
|
external_access_integrations: Optional[list[str]] = None,
|
109
109
|
query_warehouse: Optional[str] = None,
|
110
110
|
spec_overrides: Optional[dict[str, Any]] = None,
|
111
|
-
|
111
|
+
target_instances: int = 1,
|
112
|
+
min_instances: int = 1,
|
112
113
|
enable_metrics: bool = False,
|
114
|
+
database: Optional[str] = None,
|
115
|
+
schema: Optional[str] = None,
|
113
116
|
session: Optional[snowpark.Session] = None,
|
114
117
|
) -> jb.MLJob[None]:
|
115
118
|
"""
|
@@ -125,8 +128,11 @@ def submit_file(
|
|
125
128
|
external_access_integrations: A list of external access integrations.
|
126
129
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
127
130
|
spec_overrides: Custom service specification overrides to apply.
|
128
|
-
|
131
|
+
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
132
|
+
min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
|
129
133
|
enable_metrics: Whether to enable metrics publishing for the job.
|
134
|
+
database: The database to use.
|
135
|
+
schema: The schema to use.
|
130
136
|
session: The Snowpark session to use. If none specified, uses active session.
|
131
137
|
|
132
138
|
Returns:
|
@@ -142,8 +148,11 @@ def submit_file(
|
|
142
148
|
external_access_integrations=external_access_integrations,
|
143
149
|
query_warehouse=query_warehouse,
|
144
150
|
spec_overrides=spec_overrides,
|
145
|
-
|
151
|
+
target_instances=target_instances,
|
152
|
+
min_instances=min_instances,
|
146
153
|
enable_metrics=enable_metrics,
|
154
|
+
database=database,
|
155
|
+
schema=schema,
|
147
156
|
session=session,
|
148
157
|
)
|
149
158
|
|
@@ -161,8 +170,11 @@ def submit_directory(
|
|
161
170
|
external_access_integrations: Optional[list[str]] = None,
|
162
171
|
query_warehouse: Optional[str] = None,
|
163
172
|
spec_overrides: Optional[dict[str, Any]] = None,
|
164
|
-
|
173
|
+
target_instances: int = 1,
|
174
|
+
min_instances: int = 1,
|
165
175
|
enable_metrics: bool = False,
|
176
|
+
database: Optional[str] = None,
|
177
|
+
schema: Optional[str] = None,
|
166
178
|
session: Optional[snowpark.Session] = None,
|
167
179
|
) -> jb.MLJob[None]:
|
168
180
|
"""
|
@@ -179,8 +191,11 @@ def submit_directory(
|
|
179
191
|
external_access_integrations: A list of external access integrations.
|
180
192
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
181
193
|
spec_overrides: Custom service specification overrides to apply.
|
182
|
-
|
194
|
+
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
195
|
+
min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
|
183
196
|
enable_metrics: Whether to enable metrics publishing for the job.
|
197
|
+
database: The database to use.
|
198
|
+
schema: The schema to use.
|
184
199
|
session: The Snowpark session to use. If none specified, uses active session.
|
185
200
|
|
186
201
|
Returns:
|
@@ -197,8 +212,11 @@ def submit_directory(
|
|
197
212
|
external_access_integrations=external_access_integrations,
|
198
213
|
query_warehouse=query_warehouse,
|
199
214
|
spec_overrides=spec_overrides,
|
200
|
-
|
215
|
+
target_instances=target_instances,
|
216
|
+
min_instances=min_instances,
|
201
217
|
enable_metrics=enable_metrics,
|
218
|
+
database=database,
|
219
|
+
schema=schema,
|
202
220
|
session=session,
|
203
221
|
)
|
204
222
|
|
@@ -216,8 +234,11 @@ def _submit_job(
|
|
216
234
|
external_access_integrations: Optional[list[str]] = None,
|
217
235
|
query_warehouse: Optional[str] = None,
|
218
236
|
spec_overrides: Optional[dict[str, Any]] = None,
|
219
|
-
|
237
|
+
target_instances: int = 1,
|
238
|
+
min_instances: int = 1,
|
220
239
|
enable_metrics: bool = False,
|
240
|
+
database: Optional[str] = None,
|
241
|
+
schema: Optional[str] = None,
|
221
242
|
session: Optional[snowpark.Session] = None,
|
222
243
|
) -> jb.MLJob[None]:
|
223
244
|
...
|
@@ -236,8 +257,11 @@ def _submit_job(
|
|
236
257
|
external_access_integrations: Optional[list[str]] = None,
|
237
258
|
query_warehouse: Optional[str] = None,
|
238
259
|
spec_overrides: Optional[dict[str, Any]] = None,
|
239
|
-
|
260
|
+
target_instances: int = 1,
|
261
|
+
min_instances: int = 1,
|
240
262
|
enable_metrics: bool = False,
|
263
|
+
database: Optional[str] = None,
|
264
|
+
schema: Optional[str] = None,
|
241
265
|
session: Optional[snowpark.Session] = None,
|
242
266
|
) -> jb.MLJob[T]:
|
243
267
|
...
|
@@ -251,7 +275,7 @@ def _submit_job(
|
|
251
275
|
# TODO: Log lengths of args, env_vars, and spec_overrides values
|
252
276
|
"pip_requirements",
|
253
277
|
"external_access_integrations",
|
254
|
-
"
|
278
|
+
"target_instances",
|
255
279
|
"enable_metrics",
|
256
280
|
],
|
257
281
|
)
|
@@ -267,8 +291,11 @@ def _submit_job(
|
|
267
291
|
external_access_integrations: Optional[list[str]] = None,
|
268
292
|
query_warehouse: Optional[str] = None,
|
269
293
|
spec_overrides: Optional[dict[str, Any]] = None,
|
270
|
-
|
294
|
+
target_instances: int = 1,
|
295
|
+
min_instances: int = 1,
|
271
296
|
enable_metrics: bool = False,
|
297
|
+
database: Optional[str] = None,
|
298
|
+
schema: Optional[str] = None,
|
272
299
|
session: Optional[snowpark.Session] = None,
|
273
300
|
) -> jb.MLJob[T]:
|
274
301
|
"""
|
@@ -285,8 +312,11 @@ def _submit_job(
|
|
285
312
|
external_access_integrations: A list of external access integrations.
|
286
313
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
287
314
|
spec_overrides: Custom service specification overrides to apply.
|
288
|
-
|
315
|
+
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
316
|
+
min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
|
289
317
|
enable_metrics: Whether to enable metrics publishing for the job.
|
318
|
+
database: The database to use.
|
319
|
+
schema: The schema to use.
|
290
320
|
session: The Snowpark session to use. If none specified, uses active session.
|
291
321
|
|
292
322
|
Returns:
|
@@ -294,17 +324,27 @@ def _submit_job(
|
|
294
324
|
|
295
325
|
Raises:
|
296
326
|
RuntimeError: If required Snowflake features are not enabled.
|
327
|
+
ValueError: If database or schema value(s) are invalid
|
297
328
|
"""
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
329
|
+
if database and not schema:
|
330
|
+
raise ValueError("Schema must be specified if database is specified.")
|
331
|
+
if target_instances < 1 or min_instances < 1:
|
332
|
+
raise ValueError("target_instances and min_instances must be greater than 0.")
|
333
|
+
if min_instances > target_instances:
|
334
|
+
raise ValueError("min_instances must be less than or equal to target_instances.")
|
303
335
|
|
304
336
|
session = session or get_active_session()
|
305
|
-
|
306
|
-
|
307
|
-
|
337
|
+
|
338
|
+
# Validate database and schema identifiers on client side since
|
339
|
+
# SQL parser for EXECUTE JOB SERVICE seems to struggle with this
|
340
|
+
database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
|
341
|
+
schema = identifier.resolve_identifier(cast(str, schema or session.get_current_schema()))
|
342
|
+
|
343
|
+
job_name = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
|
344
|
+
job_id = f"{database}.{schema}.{job_name}"
|
345
|
+
stage_path_parts = identifier.parse_snowflake_stage_path(stage_name.lstrip("@"))
|
346
|
+
stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
|
347
|
+
stage_path = pathlib.PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_name}")
|
308
348
|
|
309
349
|
# Upload payload
|
310
350
|
uploaded_payload = payload_utils.JobPayload(
|
@@ -319,7 +359,8 @@ def _submit_job(
|
|
319
359
|
compute_pool=compute_pool,
|
320
360
|
payload=uploaded_payload,
|
321
361
|
args=args,
|
322
|
-
|
362
|
+
target_instances=target_instances,
|
363
|
+
min_instances=min_instances,
|
323
364
|
enable_metrics=enable_metrics,
|
324
365
|
)
|
325
366
|
spec_overrides = spec_utils.generate_spec_overrides(
|
@@ -331,31 +372,34 @@ def _submit_job(
|
|
331
372
|
|
332
373
|
# Generate SQL command for job submission
|
333
374
|
query_template = textwrap.dedent(
|
334
|
-
|
375
|
+
"""\
|
335
376
|
EXECUTE JOB SERVICE
|
336
|
-
IN COMPUTE POOL
|
377
|
+
IN COMPUTE POOL IDENTIFIER(?)
|
337
378
|
FROM SPECIFICATION $$
|
338
|
-
{
|
379
|
+
{}
|
339
380
|
$$
|
340
|
-
NAME =
|
381
|
+
NAME = IDENTIFIER(?)
|
341
382
|
ASYNC = TRUE
|
342
383
|
"""
|
343
384
|
)
|
385
|
+
params: list[Any] = [compute_pool, job_id]
|
344
386
|
query = query_template.format(yaml.dump(spec)).splitlines()
|
345
387
|
if external_access_integrations:
|
346
388
|
external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
|
347
389
|
query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
|
348
390
|
query_warehouse = query_warehouse or session.get_current_warehouse()
|
349
391
|
if query_warehouse:
|
350
|
-
query.append(
|
351
|
-
|
352
|
-
|
392
|
+
query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
|
393
|
+
params.append(query_warehouse)
|
394
|
+
if target_instances > 1:
|
395
|
+
query.append("REPLICAS = ?")
|
396
|
+
params.append(target_instances)
|
353
397
|
|
354
398
|
# Submit job
|
355
399
|
query_text = "\n".join(line for line in query if line)
|
356
400
|
|
357
401
|
try:
|
358
|
-
_ = session.sql(query_text).collect()
|
402
|
+
_ = session.sql(query_text, params=params).collect()
|
359
403
|
except SnowparkSQLException as e:
|
360
404
|
if "invalid property 'ASYNC'" in e.message:
|
361
405
|
raise RuntimeError(
|