snowflake-ml-python 1.15.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
- snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
- snowflake/ml/_internal/platform_capabilities.py +4 -0
- snowflake/ml/_internal/utils/mixins.py +24 -9
- snowflake/ml/experiment/experiment_tracking.py +63 -19
- snowflake/ml/jobs/__init__.py +4 -0
- snowflake/ml/jobs/_interop/__init__.py +0 -0
- snowflake/ml/jobs/_interop/data_utils.py +124 -0
- snowflake/ml/jobs/_interop/dto_schema.py +95 -0
- snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
- snowflake/ml/jobs/_interop/legacy.py +225 -0
- snowflake/ml/jobs/_interop/protocols.py +471 -0
- snowflake/ml/jobs/_interop/results.py +51 -0
- snowflake/ml/jobs/_interop/utils.py +144 -0
- snowflake/ml/jobs/_utils/constants.py +4 -1
- snowflake/ml/jobs/_utils/feature_flags.py +37 -5
- snowflake/ml/jobs/_utils/payload_utils.py +1 -1
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
- snowflake/ml/jobs/_utils/spec_utils.py +50 -11
- snowflake/ml/jobs/_utils/types.py +10 -0
- snowflake/ml/jobs/job.py +168 -36
- snowflake/ml/jobs/manager.py +54 -36
- snowflake/ml/model/__init__.py +16 -2
- snowflake/ml/model/_client/model/batch_inference_specs.py +18 -2
- snowflake/ml/model/_client/model/model_version_impl.py +44 -7
- snowflake/ml/model/_client/ops/model_ops.py +4 -0
- snowflake/ml/model/_client/ops/service_ops.py +50 -5
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/sql/model_version.py +3 -1
- snowflake/ml/model/_client/sql/stage.py +8 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +32 -4
- snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
- snowflake/ml/model/_packager/model_env/model_env.py +48 -21
- snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/model/volatility.py +34 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +1 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +27 -0
- snowflake/ml/registry/registry.py +15 -0
- snowflake/ml/utils/authentication.py +16 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/METADATA +65 -5
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/RECORD +201 -192
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py
CHANGED
|
@@ -12,12 +12,19 @@ from snowflake import snowpark
|
|
|
12
12
|
from snowflake.ml._internal import telemetry
|
|
13
13
|
from snowflake.ml._internal.utils import identifier
|
|
14
14
|
from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
|
|
15
|
-
from snowflake.ml.jobs.
|
|
15
|
+
from snowflake.ml.jobs._interop import results as interop_result, utils as interop_utils
|
|
16
|
+
from snowflake.ml.jobs._utils import (
|
|
17
|
+
constants,
|
|
18
|
+
payload_utils,
|
|
19
|
+
query_helper,
|
|
20
|
+
stage_utils,
|
|
21
|
+
types,
|
|
22
|
+
)
|
|
16
23
|
from snowflake.snowpark import Row, context as sp_context
|
|
17
24
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
18
25
|
|
|
19
26
|
_PROJECT = "MLJob"
|
|
20
|
-
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR"}
|
|
27
|
+
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR", "DELETED"}
|
|
21
28
|
|
|
22
29
|
T = TypeVar("T")
|
|
23
30
|
|
|
@@ -36,7 +43,12 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
36
43
|
self._session = session or sp_context.get_active_session()
|
|
37
44
|
|
|
38
45
|
self._status: types.JOB_STATUS = "PENDING"
|
|
39
|
-
self._result: Optional[
|
|
46
|
+
self._result: Optional[interop_result.ExecutionResult] = None
|
|
47
|
+
|
|
48
|
+
@cached_property
|
|
49
|
+
def _service_info(self) -> types.ServiceInfo:
|
|
50
|
+
"""Get the job's service info."""
|
|
51
|
+
return _resolve_service_info(self.id, self._session)
|
|
40
52
|
|
|
41
53
|
@cached_property
|
|
42
54
|
def name(self) -> str:
|
|
@@ -44,7 +56,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
44
56
|
|
|
45
57
|
@cached_property
|
|
46
58
|
def target_instances(self) -> int:
|
|
47
|
-
return
|
|
59
|
+
return self._service_info.target_instances
|
|
48
60
|
|
|
49
61
|
@cached_property
|
|
50
62
|
def min_instances(self) -> int:
|
|
@@ -69,8 +81,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
69
81
|
@cached_property
|
|
70
82
|
def _compute_pool(self) -> str:
|
|
71
83
|
"""Get the job's compute pool name."""
|
|
72
|
-
|
|
73
|
-
return cast(str, row["compute_pool"])
|
|
84
|
+
return self._service_info.compute_pool
|
|
74
85
|
|
|
75
86
|
@property
|
|
76
87
|
def _service_spec(self) -> dict[str, Any]:
|
|
@@ -82,7 +93,13 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
82
93
|
@property
|
|
83
94
|
def _container_spec(self) -> dict[str, Any]:
|
|
84
95
|
"""Get the job's main container spec."""
|
|
85
|
-
|
|
96
|
+
try:
|
|
97
|
+
containers = self._service_spec["spec"]["containers"]
|
|
98
|
+
except SnowparkSQLException as e:
|
|
99
|
+
if e.sql_error_code == 2003:
|
|
100
|
+
# If the job is deleted, the service spec is not available
|
|
101
|
+
return {}
|
|
102
|
+
raise
|
|
86
103
|
if len(containers) == 1:
|
|
87
104
|
return cast(dict[str, Any], containers[0])
|
|
88
105
|
try:
|
|
@@ -105,22 +122,28 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
105
122
|
if result_path_str is None:
|
|
106
123
|
raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
|
|
107
124
|
|
|
108
|
-
|
|
109
|
-
result_path = Path(result_path_str)
|
|
110
|
-
if not result_path.is_absolute():
|
|
111
|
-
return f"{self._stage_path}/{result_path.as_posix()}"
|
|
125
|
+
return self._transform_path(result_path_str)
|
|
112
126
|
|
|
113
|
-
|
|
127
|
+
def _transform_path(self, path_str: str) -> str:
|
|
128
|
+
"""Transform a local path within the container to a stage path."""
|
|
129
|
+
path = payload_utils.resolve_path(path_str)
|
|
130
|
+
if isinstance(path, stage_utils.StagePath):
|
|
131
|
+
# Stage paths need no transformation
|
|
132
|
+
return path.as_posix()
|
|
133
|
+
if not path.is_absolute():
|
|
134
|
+
# Assume relative paths are relative to stage mount path
|
|
135
|
+
return f"{self._stage_path}/{path.as_posix()}"
|
|
136
|
+
|
|
137
|
+
# If result path is absolute, rebase it onto the stage mount path
|
|
138
|
+
# TODO: Rather than matching by name, use the longest mount path which matches
|
|
114
139
|
volume_mounts = self._container_spec["volumeMounts"]
|
|
115
140
|
stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
|
|
116
141
|
stage_mount = Path(stage_mount_str)
|
|
117
142
|
try:
|
|
118
|
-
relative_path =
|
|
143
|
+
relative_path = path.relative_to(stage_mount)
|
|
119
144
|
return f"{self._stage_path}/{relative_path.as_posix()}"
|
|
120
145
|
except ValueError:
|
|
121
|
-
raise ValueError(
|
|
122
|
-
f"Result path {result_path} is absolute, but should be relative to stage mount {stage_mount}"
|
|
123
|
-
)
|
|
146
|
+
raise ValueError(f"Result path {path} is absolute, but should be relative to stage mount {stage_mount}")
|
|
124
147
|
|
|
125
148
|
@overload
|
|
126
149
|
def get_logs(
|
|
@@ -165,7 +188,14 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
165
188
|
Returns:
|
|
166
189
|
The job's execution logs.
|
|
167
190
|
"""
|
|
168
|
-
logs = _get_logs(
|
|
191
|
+
logs = _get_logs(
|
|
192
|
+
self._session,
|
|
193
|
+
self.id,
|
|
194
|
+
limit,
|
|
195
|
+
instance_id,
|
|
196
|
+
self._container_spec["name"] if "name" in self._container_spec else constants.DEFAULT_CONTAINER_NAME,
|
|
197
|
+
verbose,
|
|
198
|
+
)
|
|
169
199
|
assert isinstance(logs, str) # mypy
|
|
170
200
|
if as_list:
|
|
171
201
|
return logs.splitlines()
|
|
@@ -218,7 +248,6 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
218
248
|
delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
|
219
249
|
return self.status
|
|
220
250
|
|
|
221
|
-
@snowpark._internal.utils.private_preview(version="1.8.2")
|
|
222
251
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
|
|
223
252
|
def result(self, timeout: float = -1) -> T:
|
|
224
253
|
"""
|
|
@@ -237,13 +266,13 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
237
266
|
if self._result is None:
|
|
238
267
|
self.wait(timeout)
|
|
239
268
|
try:
|
|
240
|
-
self._result = interop_utils.
|
|
269
|
+
self._result = interop_utils.load_result(
|
|
270
|
+
self._result_path, session=self._session, path_transform=self._transform_path
|
|
271
|
+
)
|
|
241
272
|
except Exception as e:
|
|
242
|
-
raise RuntimeError(f"Failed to retrieve result for job
|
|
273
|
+
raise RuntimeError(f"Failed to retrieve result for job, error: {e!r}") from e
|
|
243
274
|
|
|
244
|
-
|
|
245
|
-
return cast(T, self._result.result)
|
|
246
|
-
raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception
|
|
275
|
+
return cast(T, self._result.get_value())
|
|
247
276
|
|
|
248
277
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
|
249
278
|
def cancel(self) -> None:
|
|
@@ -256,22 +285,28 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
256
285
|
self._session.sql(f"CALL {self.id}!spcs_cancel_job()").collect()
|
|
257
286
|
logger.debug(f"Cancellation requested for job {self.id}")
|
|
258
287
|
except SnowparkSQLException as e:
|
|
259
|
-
raise RuntimeError(f"Failed to cancel job
|
|
288
|
+
raise RuntimeError(f"Failed to cancel job, error: {e!r}") from e
|
|
260
289
|
|
|
261
290
|
|
|
262
291
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
|
|
263
292
|
def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
|
|
264
293
|
"""Retrieve job or job instance execution status."""
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
294
|
+
try:
|
|
295
|
+
if instance_id is not None:
|
|
296
|
+
# Get specific instance status
|
|
297
|
+
rows = query_helper.run_query(session, "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,))
|
|
298
|
+
for row in rows:
|
|
299
|
+
if row["instance_id"] == str(instance_id):
|
|
300
|
+
return cast(types.JOB_STATUS, row["status"])
|
|
301
|
+
raise ValueError(f"Instance {instance_id} not found in job {job_id}")
|
|
302
|
+
else:
|
|
303
|
+
row = _get_service_info(session, job_id)
|
|
304
|
+
return cast(types.JOB_STATUS, row["status"])
|
|
305
|
+
except SnowparkSQLException as e:
|
|
306
|
+
if e.sql_error_code == 2003:
|
|
307
|
+
row = _get_service_info_spcs(session, job_id)
|
|
308
|
+
return cast(types.JOB_STATUS, row["STATUS"])
|
|
309
|
+
raise
|
|
275
310
|
|
|
276
311
|
|
|
277
312
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
|
@@ -542,8 +577,21 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
|
|
|
542
577
|
|
|
543
578
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
|
544
579
|
def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
|
|
545
|
-
|
|
546
|
-
|
|
580
|
+
try:
|
|
581
|
+
row = _get_service_info(session, job_id)
|
|
582
|
+
return int(row["target_instances"])
|
|
583
|
+
except SnowparkSQLException as e:
|
|
584
|
+
if e.sql_error_code == 2003:
|
|
585
|
+
row = _get_service_info_spcs(session, job_id)
|
|
586
|
+
try:
|
|
587
|
+
params = json.loads(row["PARAMETERS"])
|
|
588
|
+
if isinstance(params, dict):
|
|
589
|
+
return int(params.get("REPLICAS", 1))
|
|
590
|
+
else:
|
|
591
|
+
return 1
|
|
592
|
+
except (json.JSONDecodeError, ValueError):
|
|
593
|
+
return 1
|
|
594
|
+
raise
|
|
547
595
|
|
|
548
596
|
|
|
549
597
|
def _get_logs_spcs(
|
|
@@ -581,3 +629,87 @@ def _get_logs_spcs(
|
|
|
581
629
|
query.append(f" LIMIT {limit};")
|
|
582
630
|
rows = session.sql("\n".join(query)).collect()
|
|
583
631
|
return rows
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def _get_service_info_spcs(session: snowpark.Session, job_id: str) -> Any:
|
|
635
|
+
"""
|
|
636
|
+
Retrieve the service info from the SPCS interface.
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
session (Session): The Snowpark session to use.
|
|
640
|
+
job_id (str): The job ID.
|
|
641
|
+
|
|
642
|
+
Returns:
|
|
643
|
+
Any: The service info.
|
|
644
|
+
|
|
645
|
+
Raises:
|
|
646
|
+
SnowparkSQLException: If the job does not exist or is too old to retrieve.
|
|
647
|
+
"""
|
|
648
|
+
db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
|
|
649
|
+
db = db or session.get_current_database()
|
|
650
|
+
schema = schema or session.get_current_schema()
|
|
651
|
+
rows = query_helper.run_query(
|
|
652
|
+
session,
|
|
653
|
+
"""
|
|
654
|
+
select DATABASE_NAME, SCHEMA_NAME, NAME, STATUS, COMPUTE_POOL_NAME, PARAMETERS
|
|
655
|
+
from table(snowflake.spcs.get_job_history())
|
|
656
|
+
where database_name = ? and schema_name = ? and name = ?
|
|
657
|
+
""",
|
|
658
|
+
params=(db, schema, name),
|
|
659
|
+
)
|
|
660
|
+
if rows:
|
|
661
|
+
return rows[0]
|
|
662
|
+
else:
|
|
663
|
+
raise SnowparkSQLException(f"Job {job_id} does not exist or could not be retrieved", sql_error_code=2003)
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def _resolve_service_info(id: str, session: snowpark.Session) -> types.ServiceInfo:
|
|
667
|
+
try:
|
|
668
|
+
row = _get_service_info(session, id)
|
|
669
|
+
except SnowparkSQLException as e:
|
|
670
|
+
if e.sql_error_code == 2003:
|
|
671
|
+
row = _get_service_info_spcs(session, id)
|
|
672
|
+
else:
|
|
673
|
+
raise
|
|
674
|
+
if not row:
|
|
675
|
+
raise SnowparkSQLException(f"Job {id} does not exist or could not be retrieved", sql_error_code=2003)
|
|
676
|
+
|
|
677
|
+
if "compute_pool" in row:
|
|
678
|
+
compute_pool = row["compute_pool"]
|
|
679
|
+
elif "COMPUTE_POOL_NAME" in row:
|
|
680
|
+
compute_pool = row["COMPUTE_POOL_NAME"]
|
|
681
|
+
else:
|
|
682
|
+
raise ValueError(f"compute_pool not found in row: {row}")
|
|
683
|
+
|
|
684
|
+
if "status" in row:
|
|
685
|
+
status = row["status"]
|
|
686
|
+
elif "STATUS" in row:
|
|
687
|
+
status = row["STATUS"]
|
|
688
|
+
else:
|
|
689
|
+
raise ValueError(f"status not found in row: {row}")
|
|
690
|
+
# Normalize target_instances
|
|
691
|
+
target_instances: int
|
|
692
|
+
if "target_instances" in row and row["target_instances"] is not None:
|
|
693
|
+
try:
|
|
694
|
+
target_instances = int(row["target_instances"])
|
|
695
|
+
except (ValueError, TypeError):
|
|
696
|
+
target_instances = 1
|
|
697
|
+
elif "PARAMETERS" in row and row["PARAMETERS"]:
|
|
698
|
+
try:
|
|
699
|
+
params = json.loads(row["PARAMETERS"])
|
|
700
|
+
target_instances = int(params.get("REPLICAS", 1)) if isinstance(params, dict) else 1
|
|
701
|
+
except (json.JSONDecodeError, ValueError, TypeError):
|
|
702
|
+
target_instances = 1
|
|
703
|
+
else:
|
|
704
|
+
target_instances = 1
|
|
705
|
+
|
|
706
|
+
database_name = row["database_name"] if "database_name" in row else row["DATABASE_NAME"]
|
|
707
|
+
schema_name = row["schema_name"] if "schema_name" in row else row["SCHEMA_NAME"]
|
|
708
|
+
|
|
709
|
+
return types.ServiceInfo(
|
|
710
|
+
database_name=database_name,
|
|
711
|
+
schema_name=schema_name,
|
|
712
|
+
status=cast(types.JOB_STATUS, status),
|
|
713
|
+
compute_pool=cast(str, compute_pool),
|
|
714
|
+
target_instances=target_instances,
|
|
715
|
+
)
|
snowflake/ml/jobs/manager.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import pathlib
|
|
4
|
+
import sys
|
|
4
5
|
import textwrap
|
|
5
6
|
from pathlib import PurePath
|
|
6
7
|
from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
|
|
@@ -20,6 +21,7 @@ from snowflake.ml.jobs._utils import (
|
|
|
20
21
|
spec_utils,
|
|
21
22
|
types,
|
|
22
23
|
)
|
|
24
|
+
from snowflake.snowpark._internal import utils as sp_utils
|
|
23
25
|
from snowflake.snowpark.context import get_active_session
|
|
24
26
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
25
27
|
from snowflake.snowpark.functions import coalesce, col, lit, when
|
|
@@ -178,8 +180,10 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
|
178
180
|
_ = job._service_spec
|
|
179
181
|
return job
|
|
180
182
|
except SnowparkSQLException as e:
|
|
181
|
-
if
|
|
182
|
-
|
|
183
|
+
if e.sql_error_code == 2003:
|
|
184
|
+
job = jb.MLJob[Any](job_id, session=session)
|
|
185
|
+
_ = job.status
|
|
186
|
+
return job
|
|
183
187
|
raise
|
|
184
188
|
|
|
185
189
|
|
|
@@ -344,6 +348,9 @@ def submit_from_stage(
|
|
|
344
348
|
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
|
345
349
|
spec_overrides (dict): A dictionary of overrides for the service spec.
|
|
346
350
|
imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job.
|
|
351
|
+
runtime_environment (str): The runtime image to use. Only support image tag or full image URL,
|
|
352
|
+
e.g. "1.7.1" or "image_repo/image_name:image_tag". When it refers to a full image URL,
|
|
353
|
+
it should contain image repository, image name and image tag.
|
|
347
354
|
|
|
348
355
|
Returns:
|
|
349
356
|
An object representing the submitted job.
|
|
@@ -409,6 +416,7 @@ def _submit_job(
|
|
|
409
416
|
"min_instances",
|
|
410
417
|
"enable_metrics",
|
|
411
418
|
"query_warehouse",
|
|
419
|
+
"runtime_environment",
|
|
412
420
|
],
|
|
413
421
|
)
|
|
414
422
|
def _submit_job(
|
|
@@ -441,7 +449,7 @@ def _submit_job(
|
|
|
441
449
|
Raises:
|
|
442
450
|
ValueError: If database or schema value(s) are invalid
|
|
443
451
|
RuntimeError: If schema is not specified in session context or job submission
|
|
444
|
-
|
|
452
|
+
SnowparkSQLException: if failed to upload payload
|
|
445
453
|
"""
|
|
446
454
|
session = _ensure_session(session)
|
|
447
455
|
|
|
@@ -459,6 +467,9 @@ def _submit_job(
|
|
|
459
467
|
)
|
|
460
468
|
imports = kwargs.pop("additional_payloads")
|
|
461
469
|
|
|
470
|
+
if "runtime_environment" in kwargs:
|
|
471
|
+
logger.warning("'runtime_environment' is in private preview since 1.15.0, do not use it in production.")
|
|
472
|
+
|
|
462
473
|
# Use kwargs for less common optional parameters
|
|
463
474
|
database = kwargs.pop("database", None)
|
|
464
475
|
schema = kwargs.pop("schema", None)
|
|
@@ -470,6 +481,7 @@ def _submit_job(
|
|
|
470
481
|
enable_metrics = kwargs.pop("enable_metrics", True)
|
|
471
482
|
query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
|
|
472
483
|
imports = kwargs.pop("imports", None) or imports
|
|
484
|
+
runtime_environment = kwargs.pop("runtime_environment", None)
|
|
473
485
|
|
|
474
486
|
# Warn if there are unknown kwargs
|
|
475
487
|
if kwargs:
|
|
@@ -503,48 +515,44 @@ def _submit_job(
|
|
|
503
515
|
uploaded_payload = payload_utils.JobPayload(
|
|
504
516
|
source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=imports
|
|
505
517
|
).upload(session, stage_path)
|
|
506
|
-
except
|
|
518
|
+
except SnowparkSQLException as e:
|
|
507
519
|
if e.sql_error_code == 90106:
|
|
508
520
|
raise RuntimeError(
|
|
509
521
|
"Please specify a schema, either in the session context or as a parameter in the job submission"
|
|
510
522
|
)
|
|
511
523
|
raise
|
|
512
524
|
|
|
513
|
-
|
|
514
|
-
if target_instances > 1:
|
|
515
|
-
default_spec_overrides = {
|
|
516
|
-
"spec": {
|
|
517
|
-
"endpoints": [
|
|
518
|
-
{"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"},
|
|
519
|
-
]
|
|
520
|
-
},
|
|
521
|
-
}
|
|
522
|
-
if spec_overrides:
|
|
523
|
-
spec_overrides = spec_utils.merge_patch(
|
|
524
|
-
default_spec_overrides, spec_overrides, display_name="spec_overrides"
|
|
525
|
-
)
|
|
526
|
-
else:
|
|
527
|
-
spec_overrides = default_spec_overrides
|
|
528
|
-
|
|
529
|
-
if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled():
|
|
525
|
+
if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled(default=True):
|
|
530
526
|
# Add default env vars (extracted from spec_utils.generate_service_spec)
|
|
531
527
|
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
|
|
532
528
|
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
529
|
+
try:
|
|
530
|
+
return _do_submit_job_v2(
|
|
531
|
+
session=session,
|
|
532
|
+
payload=uploaded_payload,
|
|
533
|
+
args=args,
|
|
534
|
+
env_vars=combined_env_vars,
|
|
535
|
+
spec_overrides=spec_overrides,
|
|
536
|
+
compute_pool=compute_pool,
|
|
537
|
+
job_id=job_id,
|
|
538
|
+
external_access_integrations=external_access_integrations,
|
|
539
|
+
query_warehouse=query_warehouse,
|
|
540
|
+
target_instances=target_instances,
|
|
541
|
+
min_instances=min_instances,
|
|
542
|
+
enable_metrics=enable_metrics,
|
|
543
|
+
use_async=True,
|
|
544
|
+
runtime_environment=runtime_environment,
|
|
545
|
+
)
|
|
546
|
+
except SnowparkSQLException as e:
|
|
547
|
+
if not (e.sql_error_code == 90237 and sp_utils.is_in_stored_procedure()): # type: ignore[no-untyped-call]
|
|
548
|
+
raise
|
|
549
|
+
# SNOW-2390287: SYSTEM$EXECUTE_ML_JOB() is erroneously blocked in owner's rights
|
|
550
|
+
# stored procedures. This will be fixed in an upcoming release.
|
|
551
|
+
logger.warning(
|
|
552
|
+
"Job submission using V2 failed with error {}. Falling back to V1.".format(
|
|
553
|
+
str(e).split("\n", 1)[0],
|
|
554
|
+
)
|
|
555
|
+
)
|
|
548
556
|
|
|
549
557
|
# Fall back to v1
|
|
550
558
|
# Generate service spec
|
|
@@ -556,6 +564,7 @@ def _submit_job(
|
|
|
556
564
|
target_instances=target_instances,
|
|
557
565
|
min_instances=min_instances,
|
|
558
566
|
enable_metrics=enable_metrics,
|
|
567
|
+
runtime_environment=runtime_environment,
|
|
559
568
|
)
|
|
560
569
|
|
|
561
570
|
# Generate spec overrides
|
|
@@ -639,6 +648,7 @@ def _do_submit_job_v2(
|
|
|
639
648
|
min_instances: int = 1,
|
|
640
649
|
enable_metrics: bool = True,
|
|
641
650
|
use_async: bool = True,
|
|
651
|
+
runtime_environment: Optional[str] = None,
|
|
642
652
|
) -> jb.MLJob[Any]:
|
|
643
653
|
"""
|
|
644
654
|
Generate the SQL query for job submission.
|
|
@@ -657,6 +667,7 @@ def _do_submit_job_v2(
|
|
|
657
667
|
min_instances: Minimum number of instances required to start the job.
|
|
658
668
|
enable_metrics: Whether to enable platform metrics for the job.
|
|
659
669
|
use_async: Whether to run the job asynchronously.
|
|
670
|
+
runtime_environment: image tag or full image URL to use for the job.
|
|
660
671
|
|
|
661
672
|
Returns:
|
|
662
673
|
The job object.
|
|
@@ -672,6 +683,13 @@ def _do_submit_job_v2(
|
|
|
672
683
|
"ENABLE_METRICS": enable_metrics,
|
|
673
684
|
"SPEC_OVERRIDES": spec_overrides,
|
|
674
685
|
}
|
|
686
|
+
# for the image tag or full image URL, we use that directly
|
|
687
|
+
if runtime_environment:
|
|
688
|
+
spec_options["RUNTIME"] = runtime_environment
|
|
689
|
+
elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
|
|
690
|
+
# when feature flag is enabled, we get the local python version and wrap it in a dict
|
|
691
|
+
# in system function, we can know whether it is python version or image tag or full image URL through the format
|
|
692
|
+
spec_options["RUNTIME"] = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
|
|
675
693
|
job_options = {
|
|
676
694
|
"EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
|
|
677
695
|
"QUERY_WAREHOUSE": query_warehouse,
|
snowflake/ml/model/__init__.py
CHANGED
|
@@ -1,6 +1,20 @@
|
|
|
1
|
-
from snowflake.ml.model._client.model.batch_inference_specs import
|
|
1
|
+
from snowflake.ml.model._client.model.batch_inference_specs import (
|
|
2
|
+
JobSpec,
|
|
3
|
+
OutputSpec,
|
|
4
|
+
SaveMode,
|
|
5
|
+
)
|
|
2
6
|
from snowflake.ml.model._client.model.model_impl import Model
|
|
3
7
|
from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
|
|
4
8
|
from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
|
|
9
|
+
from snowflake.ml.model.volatility import Volatility
|
|
5
10
|
|
|
6
|
-
__all__ = [
|
|
11
|
+
__all__ = [
|
|
12
|
+
"Model",
|
|
13
|
+
"ModelVersion",
|
|
14
|
+
"ExportMode",
|
|
15
|
+
"HuggingFacePipelineModel",
|
|
16
|
+
"JobSpec",
|
|
17
|
+
"OutputSpec",
|
|
18
|
+
"SaveMode",
|
|
19
|
+
"Volatility",
|
|
20
|
+
]
|
|
@@ -1,10 +1,26 @@
|
|
|
1
|
-
from
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Optional
|
|
2
3
|
|
|
3
4
|
from pydantic import BaseModel
|
|
4
5
|
|
|
5
6
|
|
|
7
|
+
class SaveMode(str, Enum):
|
|
8
|
+
"""Save mode options for batch inference output.
|
|
9
|
+
|
|
10
|
+
Determines the behavior when files already exist in the output location.
|
|
11
|
+
|
|
12
|
+
OVERWRITE: Remove existing files and write new results.
|
|
13
|
+
|
|
14
|
+
ERROR: Raise an error if files already exist in the output location.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
OVERWRITE = "overwrite"
|
|
18
|
+
ERROR = "error"
|
|
19
|
+
|
|
20
|
+
|
|
6
21
|
class OutputSpec(BaseModel):
|
|
7
22
|
stage_location: str
|
|
23
|
+
mode: SaveMode = SaveMode.ERROR
|
|
8
24
|
|
|
9
25
|
|
|
10
26
|
class JobSpec(BaseModel):
|
|
@@ -12,10 +28,10 @@ class JobSpec(BaseModel):
|
|
|
12
28
|
job_name: Optional[str] = None
|
|
13
29
|
num_workers: Optional[int] = None
|
|
14
30
|
function_name: Optional[str] = None
|
|
15
|
-
gpu: Optional[Union[str, int]] = None
|
|
16
31
|
force_rebuild: bool = False
|
|
17
32
|
max_batch_rows: int = 1024
|
|
18
33
|
warehouse: Optional[str] = None
|
|
19
34
|
cpu_requests: Optional[str] = None
|
|
20
35
|
memory_requests: Optional[str] = None
|
|
36
|
+
gpu_requests: Optional[str] = None
|
|
21
37
|
replicas: Optional[int] = None
|
|
@@ -19,7 +19,9 @@ from snowflake.ml.model._client.model import (
|
|
|
19
19
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
|
20
20
|
from snowflake.ml.model._model_composer import model_composer
|
|
21
21
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
22
|
+
from snowflake.ml.model._model_composer.model_method import utils as model_method_utils
|
|
22
23
|
from snowflake.ml.model._packager.model_handlers import snowmlmodel
|
|
24
|
+
from snowflake.ml.model._packager.model_meta import model_meta_schema
|
|
23
25
|
from snowflake.snowpark import Session, async_job, dataframe
|
|
24
26
|
|
|
25
27
|
_TELEMETRY_PROJECT = "MLOps"
|
|
@@ -41,6 +43,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
41
43
|
_model_name: sql_identifier.SqlIdentifier
|
|
42
44
|
_version_name: sql_identifier.SqlIdentifier
|
|
43
45
|
_functions: list[model_manifest_schema.ModelFunctionInfo]
|
|
46
|
+
_model_spec: Optional[model_meta_schema.ModelMetadataDict]
|
|
44
47
|
|
|
45
48
|
def __init__(self) -> None:
|
|
46
49
|
raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
|
|
@@ -150,6 +153,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
150
153
|
self._model_name = model_name
|
|
151
154
|
self._version_name = version_name
|
|
152
155
|
self._functions = self._get_functions()
|
|
156
|
+
self._model_spec = None
|
|
153
157
|
super(cls, cls).__init__(
|
|
154
158
|
self,
|
|
155
159
|
session=model_ops._session,
|
|
@@ -437,6 +441,26 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
437
441
|
"""
|
|
438
442
|
return self._functions
|
|
439
443
|
|
|
444
|
+
def _get_model_spec(self, statement_params: Optional[dict[str, Any]] = None) -> model_meta_schema.ModelMetadataDict:
|
|
445
|
+
"""Fetch and cache the model spec for this model version.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
statement_params: Optional dictionary of statement parameters to include
|
|
449
|
+
in the SQL command to fetch the model spec.
|
|
450
|
+
|
|
451
|
+
Returns:
|
|
452
|
+
The model spec as a dictionary for this model version.
|
|
453
|
+
"""
|
|
454
|
+
if self._model_spec is None:
|
|
455
|
+
self._model_spec = self._model_ops._fetch_model_spec(
|
|
456
|
+
database_name=None,
|
|
457
|
+
schema_name=None,
|
|
458
|
+
model_name=self._model_name,
|
|
459
|
+
version_name=self._version_name,
|
|
460
|
+
statement_params=statement_params,
|
|
461
|
+
)
|
|
462
|
+
return self._model_spec
|
|
463
|
+
|
|
440
464
|
@overload
|
|
441
465
|
def run(
|
|
442
466
|
self,
|
|
@@ -531,6 +555,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
531
555
|
statement_params=statement_params,
|
|
532
556
|
)
|
|
533
557
|
else:
|
|
558
|
+
explain_case_sensitive = self._determine_explain_case_sensitivity(target_function_info, statement_params)
|
|
559
|
+
|
|
534
560
|
return self._model_ops.invoke_method(
|
|
535
561
|
method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
|
|
536
562
|
method_function_type=target_function_info["target_method_function_type"],
|
|
@@ -544,13 +570,27 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
544
570
|
partition_column=partition_column,
|
|
545
571
|
statement_params=statement_params,
|
|
546
572
|
is_partitioned=target_function_info["is_partitioned"],
|
|
573
|
+
explain_case_sensitive=explain_case_sensitive,
|
|
547
574
|
)
|
|
548
575
|
|
|
576
|
+
def _determine_explain_case_sensitivity(
|
|
577
|
+
self,
|
|
578
|
+
target_function_info: model_manifest_schema.ModelFunctionInfo,
|
|
579
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
580
|
+
) -> bool:
|
|
581
|
+
model_spec = self._get_model_spec(statement_params)
|
|
582
|
+
method_options = model_spec.get("method_options", {})
|
|
583
|
+
return model_method_utils.determine_explain_case_sensitive_from_method_options(
|
|
584
|
+
method_options, target_function_info["name"]
|
|
585
|
+
)
|
|
586
|
+
|
|
549
587
|
@telemetry.send_api_usage_telemetry(
|
|
550
588
|
project=_TELEMETRY_PROJECT,
|
|
551
589
|
subproject=_TELEMETRY_SUBPROJECT,
|
|
552
590
|
func_params_to_log=[
|
|
553
591
|
"compute_pool",
|
|
592
|
+
"output_spec",
|
|
593
|
+
"job_spec",
|
|
554
594
|
],
|
|
555
595
|
)
|
|
556
596
|
def _run_batch(
|
|
@@ -579,6 +619,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
579
619
|
output_stage_location += "/"
|
|
580
620
|
input_stage_location = f"{output_stage_location}{_BATCH_INFERENCE_TEMPORARY_FOLDER}/"
|
|
581
621
|
|
|
622
|
+
self._service_ops._enforce_save_mode(output_spec.mode, output_stage_location)
|
|
623
|
+
|
|
582
624
|
try:
|
|
583
625
|
input_spec.write.copy_into_location(location=input_stage_location, file_format_type="parquet", header=True)
|
|
584
626
|
# todo: be specific about the type of errors to provide better error messages.
|
|
@@ -605,6 +647,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
605
647
|
warehouse=sql_identifier.SqlIdentifier(warehouse),
|
|
606
648
|
cpu_requests=job_spec.cpu_requests,
|
|
607
649
|
memory_requests=job_spec.memory_requests,
|
|
650
|
+
gpu_requests=job_spec.gpu_requests,
|
|
608
651
|
job_name=job_name,
|
|
609
652
|
replicas=job_spec.replicas,
|
|
610
653
|
# input and output
|
|
@@ -798,13 +841,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
798
841
|
ValueError: If the model is not a HuggingFace text-generation model.
|
|
799
842
|
"""
|
|
800
843
|
# Fetch model spec
|
|
801
|
-
model_spec = self.
|
|
802
|
-
database_name=None,
|
|
803
|
-
schema_name=None,
|
|
804
|
-
model_name=self._model_name,
|
|
805
|
-
version_name=self._version_name,
|
|
806
|
-
statement_params=statement_params,
|
|
807
|
-
)
|
|
844
|
+
model_spec = self._get_model_spec(statement_params)
|
|
808
845
|
|
|
809
846
|
# Check if model_type is huggingface_pipeline
|
|
810
847
|
model_type = model_spec.get("model_type")
|