snowflake-ml-python 1.14.0__py3-none-any.whl → 1.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +13 -7
- snowflake/ml/_internal/utils/connection_params.py +5 -3
- snowflake/ml/_internal/utils/jwt_generator.py +3 -2
- snowflake/ml/_internal/utils/mixins.py +24 -9
- snowflake/ml/_internal/utils/temp_file_utils.py +1 -2
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +16 -3
- snowflake/ml/experiment/_entities/__init__.py +2 -1
- snowflake/ml/experiment/_entities/run.py +0 -15
- snowflake/ml/experiment/_entities/run_metadata.py +3 -51
- snowflake/ml/experiment/experiment_tracking.py +71 -27
- snowflake/ml/jobs/_utils/spec_utils.py +49 -11
- snowflake/ml/jobs/manager.py +20 -0
- snowflake/ml/model/__init__.py +12 -2
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -4
- snowflake/ml/model/_client/model/inference_engine_utils.py +55 -0
- snowflake/ml/model/_client/model/model_version_impl.py +30 -62
- snowflake/ml/model/_client/ops/service_ops.py +68 -7
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/sql/service.py +29 -2
- snowflake/ml/model/_client/sql/stage.py +8 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +25 -2
- snowflake/ml/model/_packager/model_env/model_env.py +26 -16
- snowflake/ml/model/_packager/model_handlers/_utils.py +4 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -5
- snowflake/ml/model/_packager/model_packager.py +4 -3
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_signatures/utils.py +0 -21
- snowflake/ml/model/models/huggingface_pipeline.py +56 -21
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/model/volatility.py +34 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +2 -1
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +29 -2
- snowflake/ml/registry/registry.py +15 -0
- snowflake/ml/utils/authentication.py +16 -0
- snowflake/ml/utils/connection_params.py +5 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/METADATA +81 -36
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/RECORD +193 -191
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
+
import re
|
|
3
4
|
import sys
|
|
4
5
|
from math import ceil
|
|
5
6
|
from pathlib import PurePath
|
|
@@ -10,6 +11,8 @@ from snowflake.ml._internal.utils import snowflake_env
|
|
|
10
11
|
from snowflake.ml.jobs._utils import constants, feature_flags, query_helper, types
|
|
11
12
|
from snowflake.ml.jobs._utils.runtime_env_utils import RuntimeEnvironmentsDict
|
|
12
13
|
|
|
14
|
+
_OCI_TAG_REGEX = re.compile("^[a-zA-Z0-9._-]{1,128}$")
|
|
15
|
+
|
|
13
16
|
|
|
14
17
|
def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
|
|
15
18
|
"""Extract resource information for the specified compute pool"""
|
|
@@ -56,22 +59,55 @@ def _get_runtime_image(session: snowpark.Session, target_hardware: Literal["CPU"
|
|
|
56
59
|
return selected_runtime.runtime_container_image if selected_runtime else None
|
|
57
60
|
|
|
58
61
|
|
|
59
|
-
def
|
|
62
|
+
def _check_image_tag_valid(tag: Optional[str]) -> bool:
|
|
63
|
+
if tag is None:
|
|
64
|
+
return False
|
|
65
|
+
|
|
66
|
+
return _OCI_TAG_REGEX.fullmatch(tag) is not None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _get_image_spec(
|
|
70
|
+
session: snowpark.Session, compute_pool: str, runtime_environment: Optional[str] = None
|
|
71
|
+
) -> types.ImageSpec:
|
|
72
|
+
"""
|
|
73
|
+
Resolve image specification (container image and resources) for the job.
|
|
74
|
+
|
|
75
|
+
Behavior:
|
|
76
|
+
- If `runtime_environment` is empty or the feature flag is disabled, use the
|
|
77
|
+
default image tag and image name.
|
|
78
|
+
- If `runtime_environment` is a valid image tag, use that tag with the default
|
|
79
|
+
repository/name.
|
|
80
|
+
- If `runtime_environment` is a full image URL, use it directly.
|
|
81
|
+
- If the feature flag is enabled and `runtime_environment` is not provided,
|
|
82
|
+
select an ML Runtime image matching the local Python major.minor
|
|
83
|
+
- When multiple inputs are provided, `runtime_environment` takes priority.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
session: Snowflake session.
|
|
87
|
+
compute_pool: Compute pool used to infer CPU/GPU resources.
|
|
88
|
+
runtime_environment: Optional image tag or full image URL to override.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Image spec including container image and resource requests/limits.
|
|
92
|
+
"""
|
|
60
93
|
# Retrieve compute pool node resources
|
|
61
94
|
resources = _get_node_resources(session, compute_pool=compute_pool)
|
|
95
|
+
hardware = "GPU" if resources.gpu > 0 else "CPU"
|
|
96
|
+
image_tag = _get_runtime_image_tag()
|
|
97
|
+
image_repo = constants.DEFAULT_IMAGE_REPO
|
|
98
|
+
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
|
62
99
|
|
|
63
100
|
# Use MLRuntime image
|
|
64
|
-
hardware = "GPU" if resources.gpu > 0 else "CPU"
|
|
65
101
|
container_image = None
|
|
66
|
-
if
|
|
102
|
+
if runtime_environment:
|
|
103
|
+
if _check_image_tag_valid(runtime_environment):
|
|
104
|
+
image_tag = runtime_environment
|
|
105
|
+
else:
|
|
106
|
+
container_image = runtime_environment
|
|
107
|
+
elif feature_flags.FeatureFlags.ENABLE_IMAGE_VERSION_ENV_VAR.is_enabled():
|
|
67
108
|
container_image = _get_runtime_image(session, hardware) # type: ignore[arg-type]
|
|
68
109
|
|
|
69
|
-
|
|
70
|
-
image_repo = constants.DEFAULT_IMAGE_REPO
|
|
71
|
-
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
|
72
|
-
image_tag = _get_runtime_image_tag()
|
|
73
|
-
container_image = f"{image_repo}/{image_name}:{image_tag}"
|
|
74
|
-
|
|
110
|
+
container_image = container_image or f"{image_repo}/{image_name}:{image_tag}"
|
|
75
111
|
# TODO: Should each instance consume the entire pod?
|
|
76
112
|
return types.ImageSpec(
|
|
77
113
|
resource_requests=resources,
|
|
@@ -127,6 +163,7 @@ def generate_service_spec(
|
|
|
127
163
|
target_instances: int = 1,
|
|
128
164
|
min_instances: int = 1,
|
|
129
165
|
enable_metrics: bool = False,
|
|
166
|
+
runtime_environment: Optional[str] = None,
|
|
130
167
|
) -> dict[str, Any]:
|
|
131
168
|
"""
|
|
132
169
|
Generate a service specification for a job.
|
|
@@ -139,11 +176,12 @@ def generate_service_spec(
|
|
|
139
176
|
target_instances: Number of instances for multi-node job
|
|
140
177
|
enable_metrics: Enable platform metrics for the job
|
|
141
178
|
min_instances: Minimum number of instances required to start the job
|
|
179
|
+
runtime_environment: The runtime image to use. Only support image tag or full image URL.
|
|
142
180
|
|
|
143
181
|
Returns:
|
|
144
182
|
Job service specification
|
|
145
183
|
"""
|
|
146
|
-
image_spec = _get_image_spec(session, compute_pool)
|
|
184
|
+
image_spec = _get_image_spec(session, compute_pool, runtime_environment)
|
|
147
185
|
|
|
148
186
|
# Set resource requests/limits, including nvidia.com/gpu quantity if applicable
|
|
149
187
|
resource_requests: dict[str, Union[str, int]] = {
|
|
@@ -317,7 +355,7 @@ def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
|
|
|
317
355
|
Returns:
|
|
318
356
|
The patched object.
|
|
319
357
|
"""
|
|
320
|
-
if
|
|
358
|
+
if type(base) is not type(patch):
|
|
321
359
|
if base is not None:
|
|
322
360
|
logging.warning(f"Type mismatch while merging {display_name} (base={type(base)}, patch={type(patch)})")
|
|
323
361
|
return patch
|
snowflake/ml/jobs/manager.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import pathlib
|
|
4
|
+
import sys
|
|
4
5
|
import textwrap
|
|
5
6
|
from pathlib import PurePath
|
|
6
7
|
from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
|
|
@@ -344,6 +345,9 @@ def submit_from_stage(
|
|
|
344
345
|
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
|
345
346
|
spec_overrides (dict): A dictionary of overrides for the service spec.
|
|
346
347
|
imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job.
|
|
348
|
+
runtime_environment (str): The runtime image to use. Only support image tag or full image URL,
|
|
349
|
+
e.g. "1.7.1" or "image_repo/image_name:image_tag". When it refers to a full image URL,
|
|
350
|
+
it should contain image repository, image name and image tag.
|
|
347
351
|
|
|
348
352
|
Returns:
|
|
349
353
|
An object representing the submitted job.
|
|
@@ -409,6 +413,7 @@ def _submit_job(
|
|
|
409
413
|
"min_instances",
|
|
410
414
|
"enable_metrics",
|
|
411
415
|
"query_warehouse",
|
|
416
|
+
"runtime_environment",
|
|
412
417
|
],
|
|
413
418
|
)
|
|
414
419
|
def _submit_job(
|
|
@@ -459,6 +464,9 @@ def _submit_job(
|
|
|
459
464
|
)
|
|
460
465
|
imports = kwargs.pop("additional_payloads")
|
|
461
466
|
|
|
467
|
+
if "runtime_environment" in kwargs:
|
|
468
|
+
logger.warning("'runtime_environment' is in private preview since 1.15.0, do not use it in production.")
|
|
469
|
+
|
|
462
470
|
# Use kwargs for less common optional parameters
|
|
463
471
|
database = kwargs.pop("database", None)
|
|
464
472
|
schema = kwargs.pop("schema", None)
|
|
@@ -470,6 +478,7 @@ def _submit_job(
|
|
|
470
478
|
enable_metrics = kwargs.pop("enable_metrics", True)
|
|
471
479
|
query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
|
|
472
480
|
imports = kwargs.pop("imports", None) or imports
|
|
481
|
+
runtime_environment = kwargs.pop("runtime_environment", None)
|
|
473
482
|
|
|
474
483
|
# Warn if there are unknown kwargs
|
|
475
484
|
if kwargs:
|
|
@@ -544,6 +553,7 @@ def _submit_job(
|
|
|
544
553
|
min_instances=min_instances,
|
|
545
554
|
enable_metrics=enable_metrics,
|
|
546
555
|
use_async=True,
|
|
556
|
+
runtime_environment=runtime_environment,
|
|
547
557
|
)
|
|
548
558
|
|
|
549
559
|
# Fall back to v1
|
|
@@ -556,6 +566,7 @@ def _submit_job(
|
|
|
556
566
|
target_instances=target_instances,
|
|
557
567
|
min_instances=min_instances,
|
|
558
568
|
enable_metrics=enable_metrics,
|
|
569
|
+
runtime_environment=runtime_environment,
|
|
559
570
|
)
|
|
560
571
|
|
|
561
572
|
# Generate spec overrides
|
|
@@ -639,6 +650,7 @@ def _do_submit_job_v2(
|
|
|
639
650
|
min_instances: int = 1,
|
|
640
651
|
enable_metrics: bool = True,
|
|
641
652
|
use_async: bool = True,
|
|
653
|
+
runtime_environment: Optional[str] = None,
|
|
642
654
|
) -> jb.MLJob[Any]:
|
|
643
655
|
"""
|
|
644
656
|
Generate the SQL query for job submission.
|
|
@@ -657,6 +669,7 @@ def _do_submit_job_v2(
|
|
|
657
669
|
min_instances: Minimum number of instances required to start the job.
|
|
658
670
|
enable_metrics: Whether to enable platform metrics for the job.
|
|
659
671
|
use_async: Whether to run the job asynchronously.
|
|
672
|
+
runtime_environment: image tag or full image URL to use for the job.
|
|
660
673
|
|
|
661
674
|
Returns:
|
|
662
675
|
The job object.
|
|
@@ -672,6 +685,13 @@ def _do_submit_job_v2(
|
|
|
672
685
|
"ENABLE_METRICS": enable_metrics,
|
|
673
686
|
"SPEC_OVERRIDES": spec_overrides,
|
|
674
687
|
}
|
|
688
|
+
# for the image tag or full image URL, we use that directly
|
|
689
|
+
if runtime_environment:
|
|
690
|
+
spec_options["RUNTIME"] = runtime_environment
|
|
691
|
+
elif feature_flags.FeatureFlags.ENABLE_IMAGE_VERSION_ENV_VAR.is_enabled():
|
|
692
|
+
# when feature flag is enabled, we get the local python version and wrap it in a dict
|
|
693
|
+
# in system function, we can know whether it is python version or image tag or full image URL through the format
|
|
694
|
+
spec_options["RUNTIME"] = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
|
|
675
695
|
job_options = {
|
|
676
696
|
"EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
|
|
677
697
|
"QUERY_WAREHOUSE": query_warehouse,
|
snowflake/ml/model/__init__.py
CHANGED
|
@@ -1,10 +1,20 @@
|
|
|
1
1
|
from snowflake.ml.model._client.model.batch_inference_specs import (
|
|
2
|
-
InputSpec,
|
|
3
2
|
JobSpec,
|
|
4
3
|
OutputSpec,
|
|
4
|
+
SaveMode,
|
|
5
5
|
)
|
|
6
6
|
from snowflake.ml.model._client.model.model_impl import Model
|
|
7
7
|
from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
|
|
8
8
|
from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
|
|
9
|
+
from snowflake.ml.model.volatility import Volatility
|
|
9
10
|
|
|
10
|
-
__all__ = [
|
|
11
|
+
__all__ = [
|
|
12
|
+
"Model",
|
|
13
|
+
"ModelVersion",
|
|
14
|
+
"ExportMode",
|
|
15
|
+
"HuggingFacePipelineModel",
|
|
16
|
+
"JobSpec",
|
|
17
|
+
"OutputSpec",
|
|
18
|
+
"SaveMode",
|
|
19
|
+
"Volatility",
|
|
20
|
+
]
|
|
@@ -1,14 +1,26 @@
|
|
|
1
|
-
from
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Optional
|
|
2
3
|
|
|
3
4
|
from pydantic import BaseModel
|
|
4
5
|
|
|
5
6
|
|
|
6
|
-
class
|
|
7
|
-
|
|
7
|
+
class SaveMode(str, Enum):
|
|
8
|
+
"""Save mode options for batch inference output.
|
|
9
|
+
|
|
10
|
+
Determines the behavior when files already exist in the output location.
|
|
11
|
+
|
|
12
|
+
OVERWRITE: Remove existing files and write new results.
|
|
13
|
+
|
|
14
|
+
ERROR: Raise an error if files already exist in the output location.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
OVERWRITE = "overwrite"
|
|
18
|
+
ERROR = "error"
|
|
8
19
|
|
|
9
20
|
|
|
10
21
|
class OutputSpec(BaseModel):
|
|
11
22
|
stage_location: str
|
|
23
|
+
mode: SaveMode = SaveMode.ERROR
|
|
12
24
|
|
|
13
25
|
|
|
14
26
|
class JobSpec(BaseModel):
|
|
@@ -16,10 +28,10 @@ class JobSpec(BaseModel):
|
|
|
16
28
|
job_name: Optional[str] = None
|
|
17
29
|
num_workers: Optional[int] = None
|
|
18
30
|
function_name: Optional[str] = None
|
|
19
|
-
gpu: Optional[Union[str, int]] = None
|
|
20
31
|
force_rebuild: bool = False
|
|
21
32
|
max_batch_rows: int = 1024
|
|
22
33
|
warehouse: Optional[str] = None
|
|
23
34
|
cpu_requests: Optional[str] = None
|
|
24
35
|
memory_requests: Optional[str] = None
|
|
36
|
+
gpu_requests: Optional[str] = None
|
|
25
37
|
replicas: Optional[int] = None
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
|
+
|
|
3
|
+
from snowflake.ml.model._client.ops import service_ops
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _get_inference_engine_args(
|
|
7
|
+
experimental_options: Optional[dict[str, Any]],
|
|
8
|
+
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
9
|
+
|
|
10
|
+
if not experimental_options:
|
|
11
|
+
return None
|
|
12
|
+
|
|
13
|
+
if "inference_engine" not in experimental_options:
|
|
14
|
+
raise ValueError("inference_engine is required in experimental_options")
|
|
15
|
+
|
|
16
|
+
return service_ops.InferenceEngineArgs(
|
|
17
|
+
inference_engine=experimental_options["inference_engine"],
|
|
18
|
+
inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _enrich_inference_engine_args(
|
|
23
|
+
inference_engine_args: service_ops.InferenceEngineArgs,
|
|
24
|
+
gpu_requests: Optional[Union[str, int]] = None,
|
|
25
|
+
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
26
|
+
"""Enrich inference engine args with model path and tensor parallelism settings.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
inference_engine_args: The original inference engine args
|
|
30
|
+
gpu_requests: The number of GPUs requested
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Enriched inference engine args
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
ValueError: Invalid gpu_requests
|
|
37
|
+
"""
|
|
38
|
+
if inference_engine_args.inference_engine_args_override is None:
|
|
39
|
+
inference_engine_args.inference_engine_args_override = []
|
|
40
|
+
|
|
41
|
+
gpu_count = None
|
|
42
|
+
|
|
43
|
+
# Set tensor-parallelism if gpu_requests is specified
|
|
44
|
+
if gpu_requests is not None:
|
|
45
|
+
# assert gpu_requests is a string or an integer before casting to int
|
|
46
|
+
try:
|
|
47
|
+
gpu_count = int(gpu_requests)
|
|
48
|
+
if gpu_count > 0:
|
|
49
|
+
inference_engine_args.inference_engine_args_override.append(f"--tensor-parallel-size={gpu_count}")
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError(f"GPU count must be greater than 0, got {gpu_count}")
|
|
52
|
+
except ValueError:
|
|
53
|
+
raise ValueError(f"Invalid gpu_requests: {gpu_requests} with type {type(gpu_requests).__name__}")
|
|
54
|
+
|
|
55
|
+
return inference_engine_args
|
|
@@ -12,7 +12,10 @@ from snowflake.ml._internal import telemetry
|
|
|
12
12
|
from snowflake.ml._internal.utils import sql_identifier
|
|
13
13
|
from snowflake.ml.lineage import lineage_node
|
|
14
14
|
from snowflake.ml.model import task, type_hints
|
|
15
|
-
from snowflake.ml.model._client.model import
|
|
15
|
+
from snowflake.ml.model._client.model import (
|
|
16
|
+
batch_inference_specs,
|
|
17
|
+
inference_engine_utils,
|
|
18
|
+
)
|
|
16
19
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
|
17
20
|
from snowflake.ml.model._model_composer import model_composer
|
|
18
21
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
@@ -22,6 +25,7 @@ from snowflake.snowpark import Session, async_job, dataframe
|
|
|
22
25
|
_TELEMETRY_PROJECT = "MLOps"
|
|
23
26
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
|
24
27
|
_BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
|
|
28
|
+
_BATCH_INFERENCE_TEMPORARY_FOLDER = "_temporary"
|
|
25
29
|
|
|
26
30
|
|
|
27
31
|
class ExportMode(enum.Enum):
|
|
@@ -547,13 +551,15 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
547
551
|
subproject=_TELEMETRY_SUBPROJECT,
|
|
548
552
|
func_params_to_log=[
|
|
549
553
|
"compute_pool",
|
|
554
|
+
"output_spec",
|
|
555
|
+
"job_spec",
|
|
550
556
|
],
|
|
551
557
|
)
|
|
552
558
|
def _run_batch(
|
|
553
559
|
self,
|
|
554
560
|
*,
|
|
555
561
|
compute_pool: str,
|
|
556
|
-
input_spec:
|
|
562
|
+
input_spec: dataframe.DataFrame,
|
|
557
563
|
output_spec: batch_inference_specs.OutputSpec,
|
|
558
564
|
job_spec: Optional[batch_inference_specs.JobSpec] = None,
|
|
559
565
|
) -> jobs.MLJob[Any]:
|
|
@@ -569,6 +575,20 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
569
575
|
if warehouse is None:
|
|
570
576
|
raise ValueError("Warehouse is not set. Please set the warehouse field in the JobSpec.")
|
|
571
577
|
|
|
578
|
+
# use a temporary folder in the output stage to store the intermediate output from the dataframe
|
|
579
|
+
output_stage_location = output_spec.stage_location
|
|
580
|
+
if not output_stage_location.endswith("/"):
|
|
581
|
+
output_stage_location += "/"
|
|
582
|
+
input_stage_location = f"{output_stage_location}{_BATCH_INFERENCE_TEMPORARY_FOLDER}/"
|
|
583
|
+
|
|
584
|
+
self._service_ops._enforce_save_mode(output_spec.mode, output_stage_location)
|
|
585
|
+
|
|
586
|
+
try:
|
|
587
|
+
input_spec.write.copy_into_location(location=input_stage_location, file_format_type="parquet", header=True)
|
|
588
|
+
# todo: be specific about the type of errors to provide better error messages.
|
|
589
|
+
except Exception as e:
|
|
590
|
+
raise RuntimeError(f"Failed to process input_spec: {e}")
|
|
591
|
+
|
|
572
592
|
if job_spec.job_name is None:
|
|
573
593
|
# Same as the MLJob ID generation logic with a different prefix
|
|
574
594
|
job_name = f"{_BATCH_INFERENCE_JOB_ID_PREFIX}{str(uuid.uuid4()).replace('-', '_').upper()}"
|
|
@@ -589,12 +609,13 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
589
609
|
warehouse=sql_identifier.SqlIdentifier(warehouse),
|
|
590
610
|
cpu_requests=job_spec.cpu_requests,
|
|
591
611
|
memory_requests=job_spec.memory_requests,
|
|
612
|
+
gpu_requests=job_spec.gpu_requests,
|
|
592
613
|
job_name=job_name,
|
|
593
614
|
replicas=job_spec.replicas,
|
|
594
615
|
# input and output
|
|
595
|
-
input_stage_location=
|
|
616
|
+
input_stage_location=input_stage_location,
|
|
596
617
|
input_file_pattern="*",
|
|
597
|
-
output_stage_location=
|
|
618
|
+
output_stage_location=output_stage_location,
|
|
598
619
|
completion_filename="_SUCCESS",
|
|
599
620
|
# misc
|
|
600
621
|
statement_params=statement_params,
|
|
@@ -768,60 +789,6 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
768
789
|
version_name=sql_identifier.SqlIdentifier(version),
|
|
769
790
|
)
|
|
770
791
|
|
|
771
|
-
def _get_inference_engine_args(
|
|
772
|
-
self, experimental_options: Optional[dict[str, Any]]
|
|
773
|
-
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
774
|
-
|
|
775
|
-
if not experimental_options:
|
|
776
|
-
return None
|
|
777
|
-
|
|
778
|
-
if "inference_engine" not in experimental_options:
|
|
779
|
-
raise ValueError("inference_engine is required in experimental_options")
|
|
780
|
-
|
|
781
|
-
return service_ops.InferenceEngineArgs(
|
|
782
|
-
inference_engine=experimental_options["inference_engine"],
|
|
783
|
-
inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
|
|
784
|
-
)
|
|
785
|
-
|
|
786
|
-
def _enrich_inference_engine_args(
|
|
787
|
-
self,
|
|
788
|
-
inference_engine_args: service_ops.InferenceEngineArgs,
|
|
789
|
-
gpu_requests: Optional[Union[str, int]] = None,
|
|
790
|
-
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
791
|
-
"""Enrich inference engine args with tensor parallelism settings.
|
|
792
|
-
|
|
793
|
-
Args:
|
|
794
|
-
inference_engine_args: The original inference engine args
|
|
795
|
-
gpu_requests: The number of GPUs requested
|
|
796
|
-
|
|
797
|
-
Returns:
|
|
798
|
-
Enriched inference engine args
|
|
799
|
-
|
|
800
|
-
Raises:
|
|
801
|
-
ValueError: Invalid gpu_requests
|
|
802
|
-
"""
|
|
803
|
-
if inference_engine_args.inference_engine_args_override is None:
|
|
804
|
-
inference_engine_args.inference_engine_args_override = []
|
|
805
|
-
|
|
806
|
-
gpu_count = None
|
|
807
|
-
|
|
808
|
-
# Set tensor-parallelism if gpu_requests is specified
|
|
809
|
-
if gpu_requests is not None:
|
|
810
|
-
# assert gpu_requests is a string or an integer before casting to int
|
|
811
|
-
if isinstance(gpu_requests, str) or isinstance(gpu_requests, int):
|
|
812
|
-
try:
|
|
813
|
-
gpu_count = int(gpu_requests)
|
|
814
|
-
except ValueError:
|
|
815
|
-
raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
|
|
816
|
-
|
|
817
|
-
if gpu_count is not None:
|
|
818
|
-
if gpu_count > 0:
|
|
819
|
-
inference_engine_args.inference_engine_args_override.append(f"--tensor-parallel-size={gpu_count}")
|
|
820
|
-
else:
|
|
821
|
-
raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
|
|
822
|
-
|
|
823
|
-
return inference_engine_args
|
|
824
|
-
|
|
825
792
|
def _check_huggingface_text_generation_model(
|
|
826
793
|
self,
|
|
827
794
|
statement_params: Optional[dict[str, Any]] = None,
|
|
@@ -1101,13 +1068,14 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1101
1068
|
if experimental_options:
|
|
1102
1069
|
self._check_huggingface_text_generation_model(statement_params)
|
|
1103
1070
|
|
|
1104
|
-
inference_engine_args
|
|
1105
|
-
experimental_options
|
|
1106
|
-
)
|
|
1071
|
+
inference_engine_args = inference_engine_utils._get_inference_engine_args(experimental_options)
|
|
1107
1072
|
|
|
1108
1073
|
# Enrich inference engine args if inference engine is specified
|
|
1109
1074
|
if inference_engine_args is not None:
|
|
1110
|
-
inference_engine_args =
|
|
1075
|
+
inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
|
|
1076
|
+
inference_engine_args,
|
|
1077
|
+
gpu_requests,
|
|
1078
|
+
)
|
|
1111
1079
|
|
|
1112
1080
|
from snowflake.ml.model import event_handler
|
|
1113
1081
|
from snowflake.snowpark import exceptions
|
|
@@ -7,6 +7,7 @@ import re
|
|
|
7
7
|
import tempfile
|
|
8
8
|
import threading
|
|
9
9
|
import time
|
|
10
|
+
import warnings
|
|
10
11
|
from typing import Any, Optional, Union, cast
|
|
11
12
|
|
|
12
13
|
from snowflake import snowpark
|
|
@@ -14,6 +15,7 @@ from snowflake.ml import jobs
|
|
|
14
15
|
from snowflake.ml._internal import file_utils, platform_capabilities as pc
|
|
15
16
|
from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
|
|
16
17
|
from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
|
|
18
|
+
from snowflake.ml.model._client.model import batch_inference_specs
|
|
17
19
|
from snowflake.ml.model._client.service import model_deployment_spec
|
|
18
20
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
|
19
21
|
from snowflake.snowpark import async_job, exceptions, row, session
|
|
@@ -155,16 +157,17 @@ class ServiceOperator:
|
|
|
155
157
|
database_name=database_name,
|
|
156
158
|
schema_name=schema_name,
|
|
157
159
|
)
|
|
158
|
-
|
|
160
|
+
self._stage_client = stage_sql.StageSQLClient(
|
|
161
|
+
session,
|
|
162
|
+
database_name=database_name,
|
|
163
|
+
schema_name=schema_name,
|
|
164
|
+
)
|
|
165
|
+
self._use_inlined_deployment_spec = pc.PlatformCapabilities.get_instance().is_inlined_deployment_spec_enabled()
|
|
166
|
+
if self._use_inlined_deployment_spec:
|
|
159
167
|
self._workspace = None
|
|
160
168
|
self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec()
|
|
161
169
|
else:
|
|
162
170
|
self._workspace = tempfile.TemporaryDirectory()
|
|
163
|
-
self._stage_client = stage_sql.StageSQLClient(
|
|
164
|
-
session,
|
|
165
|
-
database_name=database_name,
|
|
166
|
-
schema_name=schema_name,
|
|
167
|
-
)
|
|
168
171
|
self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
|
|
169
172
|
workspace_path=pathlib.Path(self._workspace.name)
|
|
170
173
|
)
|
|
@@ -264,7 +267,14 @@ class ServiceOperator:
|
|
|
264
267
|
self._model_deployment_spec.add_hf_logger_spec(
|
|
265
268
|
hf_model_name=hf_model_args.hf_model_name,
|
|
266
269
|
hf_task=hf_model_args.hf_task,
|
|
267
|
-
hf_token=
|
|
270
|
+
hf_token=(
|
|
271
|
+
# when using inlined deployment spec, we need to use QMARK_RESERVED_TOKEN
|
|
272
|
+
# to avoid revealing the token while calling the SYSTEM$DEPLOY_MODEL function
|
|
273
|
+
# noop if using file-based deployment spec or token is not provided
|
|
274
|
+
service_sql.QMARK_RESERVED_TOKEN
|
|
275
|
+
if hf_model_args.hf_token and self._use_inlined_deployment_spec
|
|
276
|
+
else hf_model_args.hf_token
|
|
277
|
+
),
|
|
268
278
|
hf_tokenizer=hf_model_args.hf_tokenizer,
|
|
269
279
|
hf_revision=hf_model_args.hf_revision,
|
|
270
280
|
hf_trust_remote_code=hf_model_args.hf_trust_remote_code,
|
|
@@ -320,6 +330,14 @@ class ServiceOperator:
|
|
|
320
330
|
model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
|
|
321
331
|
),
|
|
322
332
|
model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
|
|
333
|
+
query_params=(
|
|
334
|
+
# when using inlined deployment spec, we need to add the token to the query params
|
|
335
|
+
# to avoid revealing the token while calling the SYSTEM$DEPLOY_MODEL function
|
|
336
|
+
# noop if using file-based deployment spec or token is not provided
|
|
337
|
+
[hf_model_args.hf_token]
|
|
338
|
+
if (self._use_inlined_deployment_spec and hf_model_args and hf_model_args.hf_token)
|
|
339
|
+
else []
|
|
340
|
+
),
|
|
323
341
|
statement_params=statement_params,
|
|
324
342
|
)
|
|
325
343
|
|
|
@@ -635,6 +653,47 @@ class ServiceOperator:
|
|
|
635
653
|
else:
|
|
636
654
|
module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.")
|
|
637
655
|
|
|
656
|
+
def _enforce_save_mode(self, output_mode: batch_inference_specs.SaveMode, output_stage_location: str) -> None:
|
|
657
|
+
"""Enforce the save mode for the output stage location.
|
|
658
|
+
|
|
659
|
+
Args:
|
|
660
|
+
output_mode: The output mode
|
|
661
|
+
output_stage_location: The output stage location to check/clean.
|
|
662
|
+
|
|
663
|
+
Raises:
|
|
664
|
+
FileExistsError: When ERROR mode is specified and files exist in the output location.
|
|
665
|
+
RuntimeError: When operations fail (checking files or removing files).
|
|
666
|
+
ValueError: When an invalid SaveMode is specified.
|
|
667
|
+
"""
|
|
668
|
+
list_results = self._stage_client.list_stage(output_stage_location)
|
|
669
|
+
|
|
670
|
+
if output_mode == batch_inference_specs.SaveMode.ERROR:
|
|
671
|
+
if len(list_results) > 0:
|
|
672
|
+
raise FileExistsError(
|
|
673
|
+
f"Output stage location '{output_stage_location}' is not empty. "
|
|
674
|
+
f"Found {len(list_results)} existing files. When using ERROR mode, the output location "
|
|
675
|
+
f"must be empty. Please clear the existing files or use OVERWRITE mode."
|
|
676
|
+
)
|
|
677
|
+
elif output_mode == batch_inference_specs.SaveMode.OVERWRITE:
|
|
678
|
+
if len(list_results) > 0:
|
|
679
|
+
warnings.warn(
|
|
680
|
+
f"Output stage location '{output_stage_location}' is not empty. "
|
|
681
|
+
f"Found {len(list_results)} existing files. OVERWRITE mode will remove all existing files "
|
|
682
|
+
f"in the output location before running the batch inference job.",
|
|
683
|
+
stacklevel=2,
|
|
684
|
+
)
|
|
685
|
+
try:
|
|
686
|
+
self._session.sql(f"REMOVE {output_stage_location}").collect()
|
|
687
|
+
except Exception as e:
|
|
688
|
+
raise RuntimeError(
|
|
689
|
+
f"OVERWRITE was specified. However, failed to remove existing files in output stage "
|
|
690
|
+
f"{output_stage_location}: {e}. Please clear up the existing files manually and retry "
|
|
691
|
+
f"the operation."
|
|
692
|
+
)
|
|
693
|
+
else:
|
|
694
|
+
valid_modes = list(batch_inference_specs.SaveMode)
|
|
695
|
+
raise ValueError(f"Invalid SaveMode: {output_mode}. Must be one of {valid_modes}")
|
|
696
|
+
|
|
638
697
|
def _stream_service_logs(
|
|
639
698
|
self,
|
|
640
699
|
async_job: snowpark.AsyncJob,
|
|
@@ -911,6 +970,7 @@ class ServiceOperator:
|
|
|
911
970
|
max_batch_rows: Optional[int],
|
|
912
971
|
cpu_requests: Optional[str],
|
|
913
972
|
memory_requests: Optional[str],
|
|
973
|
+
gpu_requests: Optional[str],
|
|
914
974
|
replicas: Optional[int],
|
|
915
975
|
statement_params: Optional[dict[str, Any]] = None,
|
|
916
976
|
) -> jobs.MLJob[Any]:
|
|
@@ -945,6 +1005,7 @@ class ServiceOperator:
|
|
|
945
1005
|
warehouse=warehouse,
|
|
946
1006
|
cpu=cpu_requests,
|
|
947
1007
|
memory=memory_requests,
|
|
1008
|
+
gpu=gpu_requests,
|
|
948
1009
|
replicas=replicas,
|
|
949
1010
|
)
|
|
950
1011
|
|
|
@@ -204,7 +204,7 @@ class ModelDeploymentSpec:
|
|
|
204
204
|
job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
205
205
|
cpu: Optional[str] = None,
|
|
206
206
|
memory: Optional[str] = None,
|
|
207
|
-
gpu: Optional[
|
|
207
|
+
gpu: Optional[str] = None,
|
|
208
208
|
num_workers: Optional[int] = None,
|
|
209
209
|
max_batch_rows: Optional[int] = None,
|
|
210
210
|
replicas: Optional[int] = None,
|