snowflake-ml-python 1.10.0__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +3 -2
- snowflake/ml/_internal/utils/service_logger.py +26 -1
- snowflake/ml/experiment/_client/artifact.py +76 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
- snowflake/ml/experiment/callback/keras.py +63 -0
- snowflake/ml/experiment/callback/lightgbm.py +5 -1
- snowflake/ml/experiment/callback/xgboost.py +5 -1
- snowflake/ml/experiment/experiment_tracking.py +89 -4
- snowflake/ml/feature_store/feature_store.py +1150 -131
- snowflake/ml/feature_store/feature_view.py +122 -0
- snowflake/ml/jobs/_utils/__init__.py +0 -0
- snowflake/ml/jobs/_utils/constants.py +9 -14
- snowflake/ml/jobs/_utils/feature_flags.py +16 -0
- snowflake/ml/jobs/_utils/payload_utils.py +61 -19
- snowflake/ml/jobs/_utils/query_helper.py +5 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +15 -7
- snowflake/ml/jobs/_utils/spec_utils.py +44 -13
- snowflake/ml/jobs/_utils/stage_utils.py +22 -9
- snowflake/ml/jobs/_utils/types.py +7 -8
- snowflake/ml/jobs/job.py +34 -18
- snowflake/ml/jobs/manager.py +107 -24
- snowflake/ml/model/__init__.py +6 -1
- snowflake/ml/model/_client/model/batch_inference_specs.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +225 -73
- snowflake/ml/model/_client/ops/service_ops.py +128 -174
- snowflake/ml/model/_client/service/model_deployment_spec.py +123 -64
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -9
- snowflake/ml/model/_model_composer/model_composer.py +1 -70
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +207 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +4 -2
- snowflake/ml/model/inference_engine.py +5 -0
- snowflake/ml/model/models/huggingface_pipeline.py +4 -3
- snowflake/ml/model/openai_signatures.py +57 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
- snowflake/ml/monitoring/model_monitor.py +26 -0
- snowflake/ml/registry/_manager/model_manager.py +7 -35
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +194 -5
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/METADATA +87 -7
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/RECORD +205 -197
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
+
import sys
|
|
3
4
|
from math import ceil
|
|
4
5
|
from pathlib import PurePath
|
|
5
|
-
from typing import Any, Optional, Union
|
|
6
|
+
from typing import Any, Literal, Optional, Union
|
|
6
7
|
|
|
7
8
|
from snowflake import snowpark
|
|
8
9
|
from snowflake.ml._internal.utils import snowflake_env
|
|
9
|
-
from snowflake.ml.jobs._utils import constants, query_helper, types
|
|
10
|
+
from snowflake.ml.jobs._utils import constants, feature_flags, query_helper, types
|
|
11
|
+
from snowflake.ml.jobs._utils.runtime_env_utils import RuntimeEnvironmentsDict
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
|
|
@@ -28,22 +30,53 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
|
|
|
28
30
|
)
|
|
29
31
|
|
|
30
32
|
|
|
33
|
+
def _get_runtime_image(session: snowpark.Session, target_hardware: Literal["CPU", "GPU"]) -> Optional[str]:
|
|
34
|
+
rows = query_helper.run_query(session, "CALL SYSTEM$NOTEBOOKS_FIND_LABELED_RUNTIMES()")
|
|
35
|
+
if not rows:
|
|
36
|
+
return None
|
|
37
|
+
try:
|
|
38
|
+
runtime_envs = RuntimeEnvironmentsDict.model_validate_json(rows[0][0])
|
|
39
|
+
spcs_container_runtimes = runtime_envs.get_spcs_container_runtimes()
|
|
40
|
+
except Exception as e:
|
|
41
|
+
logging.warning(f"Failed to parse runtime image name from {rows[0][0]}, error: {e}")
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
selected_runtime = next(
|
|
45
|
+
(
|
|
46
|
+
runtime
|
|
47
|
+
for runtime in spcs_container_runtimes
|
|
48
|
+
if (
|
|
49
|
+
runtime.hardware_type.lower() == target_hardware.lower()
|
|
50
|
+
and runtime.python_version.major == sys.version_info.major
|
|
51
|
+
and runtime.python_version.minor == sys.version_info.minor
|
|
52
|
+
)
|
|
53
|
+
),
|
|
54
|
+
None,
|
|
55
|
+
)
|
|
56
|
+
return selected_runtime.runtime_container_image if selected_runtime else None
|
|
57
|
+
|
|
58
|
+
|
|
31
59
|
def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
|
|
32
60
|
# Retrieve compute pool node resources
|
|
33
61
|
resources = _get_node_resources(session, compute_pool=compute_pool)
|
|
34
62
|
|
|
35
63
|
# Use MLRuntime image
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
64
|
+
hardware = "GPU" if resources.gpu > 0 else "CPU"
|
|
65
|
+
container_image = None
|
|
66
|
+
if feature_flags.FeatureFlags.ENABLE_IMAGE_VERSION_ENV_VAR.is_enabled():
|
|
67
|
+
container_image = _get_runtime_image(session, hardware) # type: ignore[arg-type]
|
|
68
|
+
|
|
69
|
+
if not container_image:
|
|
70
|
+
image_repo = constants.DEFAULT_IMAGE_REPO
|
|
71
|
+
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
|
72
|
+
image_tag = _get_runtime_image_tag()
|
|
73
|
+
container_image = f"{image_repo}/{image_name}:{image_tag}"
|
|
39
74
|
|
|
40
75
|
# TODO: Should each instance consume the entire pod?
|
|
41
76
|
return types.ImageSpec(
|
|
42
|
-
repo=image_repo,
|
|
43
|
-
image_name=image_name,
|
|
44
|
-
image_tag=image_tag,
|
|
45
77
|
resource_requests=resources,
|
|
46
78
|
resource_limits=resources,
|
|
79
|
+
container_image=container_image,
|
|
47
80
|
)
|
|
48
81
|
|
|
49
82
|
|
|
@@ -65,6 +98,7 @@ def generate_spec_overrides(
|
|
|
65
98
|
container_spec: dict[str, Any] = {
|
|
66
99
|
"name": constants.DEFAULT_CONTAINER_NAME,
|
|
67
100
|
}
|
|
101
|
+
|
|
68
102
|
if environment_vars:
|
|
69
103
|
# TODO: Validate environment variables
|
|
70
104
|
container_spec["env"] = environment_vars
|
|
@@ -180,10 +214,7 @@ def generate_service_spec(
|
|
|
180
214
|
|
|
181
215
|
# TODO: Add hooks for endpoints for integration with TensorBoard etc
|
|
182
216
|
|
|
183
|
-
env_vars =
|
|
184
|
-
constants.PAYLOAD_DIR_ENV_VAR: constants.APP_MOUNT_PATH,
|
|
185
|
-
constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
|
|
186
|
-
}
|
|
217
|
+
env_vars = payload.env_vars
|
|
187
218
|
endpoints: list[dict[str, Any]] = []
|
|
188
219
|
|
|
189
220
|
if target_instances > 1:
|
|
@@ -220,7 +251,7 @@ def generate_service_spec(
|
|
|
220
251
|
"containers": [
|
|
221
252
|
{
|
|
222
253
|
"name": constants.DEFAULT_CONTAINER_NAME,
|
|
223
|
-
"image": image_spec.
|
|
254
|
+
"image": image_spec.container_image,
|
|
224
255
|
"command": ["/usr/local/bin/_entrypoint.sh"],
|
|
225
256
|
"args": [
|
|
226
257
|
(stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
|
|
@@ -121,15 +121,28 @@ class StagePath:
|
|
|
121
121
|
return self._compose_path(self._path)
|
|
122
122
|
|
|
123
123
|
def joinpath(self, *args: Union[str, PathLike[str]]) -> "StagePath":
|
|
124
|
+
"""
|
|
125
|
+
Joins the given path arguments to the current path,
|
|
126
|
+
mimicking the behavior of pathlib.Path.joinpath.
|
|
127
|
+
If the argument is a stage path (i.e., an absolute path),
|
|
128
|
+
it overrides the current path and is returned as the final path.
|
|
129
|
+
If the argument is a normal path, it is joined with the current relative path
|
|
130
|
+
using self._path.joinpath(arg).
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
*args: Path components to join.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
A new StagePath with the joined path.
|
|
137
|
+
|
|
138
|
+
Raises:
|
|
139
|
+
NotImplementedError: the argument is a stage path.
|
|
140
|
+
"""
|
|
124
141
|
path = self
|
|
125
142
|
for arg in args:
|
|
126
|
-
|
|
143
|
+
if isinstance(arg, StagePath):
|
|
144
|
+
raise NotImplementedError
|
|
145
|
+
else:
|
|
146
|
+
# the arg might be an absolute path, so we need to remove the leading '/'
|
|
147
|
+
path = StagePath(f"{path.root}/{path._path.joinpath(arg).as_posix().lstrip('/')}")
|
|
127
148
|
return path
|
|
128
|
-
|
|
129
|
-
def _make_child(self, path: Union[str, PathLike[str]]) -> "StagePath":
|
|
130
|
-
stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
|
|
131
|
-
if self.root == stage_path.root:
|
|
132
|
-
child_path = self._path.joinpath(stage_path._path)
|
|
133
|
-
return StagePath(self._compose_path(child_path))
|
|
134
|
-
else:
|
|
135
|
-
return stage_path
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from dataclasses import dataclass
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
3
|
from pathlib import PurePath
|
|
4
4
|
from typing import Iterator, Literal, Optional, Protocol, Union, runtime_checkable
|
|
5
5
|
|
|
@@ -30,6 +30,10 @@ class PayloadPath(Protocol):
|
|
|
30
30
|
def parent(self) -> "PayloadPath":
|
|
31
31
|
...
|
|
32
32
|
|
|
33
|
+
@property
|
|
34
|
+
def root(self) -> str:
|
|
35
|
+
...
|
|
36
|
+
|
|
33
37
|
def exists(self) -> bool:
|
|
34
38
|
...
|
|
35
39
|
|
|
@@ -86,6 +90,7 @@ class UploadedPayload:
|
|
|
86
90
|
# TODO: Include manifest of payload files for validation
|
|
87
91
|
stage_path: PurePath
|
|
88
92
|
entrypoint: list[Union[str, PurePath]]
|
|
93
|
+
env_vars: dict[str, str] = field(default_factory=dict)
|
|
89
94
|
|
|
90
95
|
|
|
91
96
|
@dataclass(frozen=True)
|
|
@@ -98,12 +103,6 @@ class ComputeResources:
|
|
|
98
103
|
|
|
99
104
|
@dataclass(frozen=True)
|
|
100
105
|
class ImageSpec:
|
|
101
|
-
repo: str
|
|
102
|
-
image_name: str
|
|
103
|
-
image_tag: str
|
|
104
106
|
resource_requests: ComputeResources
|
|
105
107
|
resource_limits: ComputeResources
|
|
106
|
-
|
|
107
|
-
@property
|
|
108
|
-
def full_name(self) -> str:
|
|
109
|
-
return f"{self.repo}/{self.image_name}:{self.image_tag}"
|
|
108
|
+
container_image: str
|
snowflake/ml/jobs/job.py
CHANGED
|
@@ -99,21 +99,23 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
99
99
|
result_path_str = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
|
|
100
100
|
if result_path_str is None:
|
|
101
101
|
raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
|
|
102
|
-
volume_mounts = self._container_spec["volumeMounts"]
|
|
103
|
-
stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
|
|
104
102
|
|
|
103
|
+
# If result path is relative, it is relative to the stage mount path
|
|
105
104
|
result_path = Path(result_path_str)
|
|
105
|
+
if not result_path.is_absolute():
|
|
106
|
+
return f"{self._stage_path}/{result_path.as_posix()}"
|
|
107
|
+
|
|
108
|
+
# If result path is absolute, it is relative to the stage mount path
|
|
109
|
+
volume_mounts = self._container_spec["volumeMounts"]
|
|
110
|
+
stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
|
|
106
111
|
stage_mount = Path(stage_mount_str)
|
|
107
112
|
try:
|
|
108
113
|
relative_path = result_path.relative_to(stage_mount)
|
|
114
|
+
return f"{self._stage_path}/{relative_path.as_posix()}"
|
|
109
115
|
except ValueError:
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
)
|
|
114
|
-
relative_path = result_path
|
|
115
|
-
|
|
116
|
-
return f"{self._stage_path}/{relative_path.as_posix()}"
|
|
116
|
+
raise ValueError(
|
|
117
|
+
f"Result path {result_path} is absolute, but should be relative to stage mount {stage_mount}"
|
|
118
|
+
)
|
|
117
119
|
|
|
118
120
|
@overload
|
|
119
121
|
def get_logs(
|
|
@@ -199,7 +201,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
199
201
|
elapsed = time.monotonic() - start_time
|
|
200
202
|
if elapsed >= timeout >= 0:
|
|
201
203
|
raise TimeoutError(f"Job {self.name} did not complete within {timeout} seconds")
|
|
202
|
-
elif status == "PENDING" and not warning_shown and elapsed >=
|
|
204
|
+
elif status == "PENDING" and not warning_shown and elapsed >= 5: # Only show warning after 5s
|
|
203
205
|
pool_info = _get_compute_pool_info(self._session, self._compute_pool)
|
|
204
206
|
if (pool_info.max_nodes - pool_info.active_nodes) < self.min_instances:
|
|
205
207
|
logger.warning(
|
|
@@ -419,15 +421,29 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
|
419
421
|
if not rows:
|
|
420
422
|
return None
|
|
421
423
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
+
# we have already integrated with first_instance startup policy,
|
|
425
|
+
# the instance 0 is guaranteed to be the head instance
|
|
426
|
+
head_instance = next(
|
|
427
|
+
(
|
|
428
|
+
row
|
|
429
|
+
for row in rows
|
|
430
|
+
if "instance_id" in row and row["instance_id"] is not None and int(row["instance_id"]) == 0
|
|
431
|
+
),
|
|
432
|
+
None,
|
|
433
|
+
)
|
|
434
|
+
# fallback to find the first instance if the instance 0 is not found
|
|
435
|
+
if not head_instance:
|
|
436
|
+
if target_instances > len(rows):
|
|
437
|
+
raise RuntimeError(
|
|
438
|
+
f"Couldn’t retrieve head instance due to missing instances. {target_instances} > {len(rows)}"
|
|
439
|
+
)
|
|
440
|
+
# Sort by start_time first, then by instance_id
|
|
441
|
+
try:
|
|
442
|
+
sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
|
443
|
+
except TypeError:
|
|
444
|
+
raise RuntimeError("Job instance information unavailable.")
|
|
445
|
+
head_instance = sorted_instances[0]
|
|
424
446
|
|
|
425
|
-
# Sort by start_time first, then by instance_id
|
|
426
|
-
try:
|
|
427
|
-
sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
|
428
|
-
except TypeError:
|
|
429
|
-
raise RuntimeError("Job instance information unavailable.")
|
|
430
|
-
head_instance = sorted_instances[0]
|
|
431
447
|
if not head_instance["start_time"]:
|
|
432
448
|
# If head instance hasn't started yet, return None
|
|
433
449
|
return None
|
snowflake/ml/jobs/manager.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import logging
|
|
2
3
|
import pathlib
|
|
3
4
|
import textwrap
|
|
5
|
+
from pathlib import PurePath
|
|
4
6
|
from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
|
|
5
7
|
from uuid import uuid4
|
|
6
8
|
|
|
@@ -11,7 +13,13 @@ from snowflake import snowpark
|
|
|
11
13
|
from snowflake.ml._internal import telemetry
|
|
12
14
|
from snowflake.ml._internal.utils import identifier
|
|
13
15
|
from snowflake.ml.jobs import job as jb
|
|
14
|
-
from snowflake.ml.jobs._utils import
|
|
16
|
+
from snowflake.ml.jobs._utils import (
|
|
17
|
+
feature_flags,
|
|
18
|
+
payload_utils,
|
|
19
|
+
query_helper,
|
|
20
|
+
spec_utils,
|
|
21
|
+
types,
|
|
22
|
+
)
|
|
15
23
|
from snowflake.snowpark.context import get_active_session
|
|
16
24
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
17
25
|
from snowflake.snowpark.functions import coalesce, col, lit, when
|
|
@@ -426,7 +434,6 @@ def _submit_job(
|
|
|
426
434
|
|
|
427
435
|
Raises:
|
|
428
436
|
ValueError: If database or schema value(s) are invalid
|
|
429
|
-
SnowparkSQLException: If there is an error submitting the job.
|
|
430
437
|
"""
|
|
431
438
|
session = session or get_active_session()
|
|
432
439
|
|
|
@@ -446,7 +453,7 @@ def _submit_job(
|
|
|
446
453
|
env_vars = kwargs.pop("env_vars", None)
|
|
447
454
|
spec_overrides = kwargs.pop("spec_overrides", None)
|
|
448
455
|
enable_metrics = kwargs.pop("enable_metrics", True)
|
|
449
|
-
query_warehouse = kwargs.pop("query_warehouse",
|
|
456
|
+
query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
|
|
450
457
|
additional_payloads = kwargs.pop("additional_payloads", None)
|
|
451
458
|
|
|
452
459
|
if additional_payloads:
|
|
@@ -484,6 +491,27 @@ def _submit_job(
|
|
|
484
491
|
source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=additional_payloads
|
|
485
492
|
).upload(session, stage_path)
|
|
486
493
|
|
|
494
|
+
if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled():
|
|
495
|
+
# Add default env vars (extracted from spec_utils.generate_service_spec)
|
|
496
|
+
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
|
|
497
|
+
|
|
498
|
+
return _do_submit_job_v2(
|
|
499
|
+
session=session,
|
|
500
|
+
payload=uploaded_payload,
|
|
501
|
+
args=args,
|
|
502
|
+
env_vars=combined_env_vars,
|
|
503
|
+
spec_overrides=spec_overrides,
|
|
504
|
+
compute_pool=compute_pool,
|
|
505
|
+
job_id=job_id,
|
|
506
|
+
external_access_integrations=external_access_integrations,
|
|
507
|
+
query_warehouse=query_warehouse,
|
|
508
|
+
target_instances=target_instances,
|
|
509
|
+
min_instances=min_instances,
|
|
510
|
+
enable_metrics=enable_metrics,
|
|
511
|
+
use_async=True,
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
# Fall back to v1
|
|
487
515
|
# Generate service spec
|
|
488
516
|
spec = spec_utils.generate_service_spec(
|
|
489
517
|
session,
|
|
@@ -494,6 +522,8 @@ def _submit_job(
|
|
|
494
522
|
min_instances=min_instances,
|
|
495
523
|
enable_metrics=enable_metrics,
|
|
496
524
|
)
|
|
525
|
+
|
|
526
|
+
# Generate spec overrides
|
|
497
527
|
spec_overrides = spec_utils.generate_spec_overrides(
|
|
498
528
|
environment_vars=env_vars,
|
|
499
529
|
custom_overrides=spec_overrides,
|
|
@@ -501,37 +531,25 @@ def _submit_job(
|
|
|
501
531
|
if spec_overrides:
|
|
502
532
|
spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
|
|
503
533
|
|
|
504
|
-
|
|
505
|
-
spec, external_access_integrations, query_warehouse, target_instances,
|
|
534
|
+
return _do_submit_job_v1(
|
|
535
|
+
session, spec, external_access_integrations, query_warehouse, target_instances, compute_pool, job_id
|
|
506
536
|
)
|
|
507
|
-
try:
|
|
508
|
-
_ = query_helper.run_query(session, query_text, params=params)
|
|
509
|
-
except SnowparkSQLException as e:
|
|
510
|
-
if "Invalid spec: unknown option 'resourceManagement' for 'spec'." in e.message:
|
|
511
|
-
logger.warning("Dropping 'resourceManagement' from spec because control policy is not enabled.")
|
|
512
|
-
spec["spec"].pop("resourceManagement", None)
|
|
513
|
-
query_text, params = _generate_submission_query(
|
|
514
|
-
spec, external_access_integrations, query_warehouse, target_instances, session, compute_pool, job_id
|
|
515
|
-
)
|
|
516
|
-
_ = query_helper.run_query(session, query_text, params=params)
|
|
517
|
-
else:
|
|
518
|
-
raise
|
|
519
|
-
return get_job(job_id, session=session)
|
|
520
537
|
|
|
521
538
|
|
|
522
|
-
def
|
|
539
|
+
def _do_submit_job_v1(
|
|
540
|
+
session: snowpark.Session,
|
|
523
541
|
spec: dict[str, Any],
|
|
524
542
|
external_access_integrations: list[str],
|
|
525
543
|
query_warehouse: Optional[str],
|
|
526
544
|
target_instances: int,
|
|
527
|
-
session: snowpark.Session,
|
|
528
545
|
compute_pool: str,
|
|
529
546
|
job_id: str,
|
|
530
|
-
) ->
|
|
547
|
+
) -> jb.MLJob[Any]:
|
|
531
548
|
"""
|
|
532
549
|
Generate the SQL query for job submission.
|
|
533
550
|
|
|
534
551
|
Args:
|
|
552
|
+
session: The Snowpark session to use.
|
|
535
553
|
spec: The service spec for the job.
|
|
536
554
|
external_access_integrations: The external access integrations for the job.
|
|
537
555
|
query_warehouse: The query warehouse for the job.
|
|
@@ -541,7 +559,7 @@ def _generate_submission_query(
|
|
|
541
559
|
job_id: The ID of the job.
|
|
542
560
|
|
|
543
561
|
Returns:
|
|
544
|
-
|
|
562
|
+
The job object.
|
|
545
563
|
"""
|
|
546
564
|
query_template = textwrap.dedent(
|
|
547
565
|
"""\
|
|
@@ -559,12 +577,77 @@ def _generate_submission_query(
|
|
|
559
577
|
if external_access_integrations:
|
|
560
578
|
external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
|
|
561
579
|
query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
|
|
562
|
-
query_warehouse = query_warehouse or session.get_current_warehouse()
|
|
563
580
|
if query_warehouse:
|
|
564
581
|
query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
|
|
565
582
|
params.append(query_warehouse)
|
|
566
583
|
if target_instances > 1:
|
|
567
584
|
query.append("REPLICAS = ?")
|
|
568
585
|
params.append(target_instances)
|
|
586
|
+
|
|
569
587
|
query_text = "\n".join(line for line in query if line)
|
|
570
|
-
|
|
588
|
+
_ = query_helper.run_query(session, query_text, params=params)
|
|
589
|
+
|
|
590
|
+
return get_job(job_id, session=session)
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def _do_submit_job_v2(
|
|
594
|
+
session: snowpark.Session,
|
|
595
|
+
payload: types.UploadedPayload,
|
|
596
|
+
args: Optional[list[str]],
|
|
597
|
+
env_vars: dict[str, str],
|
|
598
|
+
spec_overrides: dict[str, Any],
|
|
599
|
+
compute_pool: str,
|
|
600
|
+
job_id: Optional[str] = None,
|
|
601
|
+
external_access_integrations: Optional[list[str]] = None,
|
|
602
|
+
query_warehouse: Optional[str] = None,
|
|
603
|
+
target_instances: int = 1,
|
|
604
|
+
min_instances: int = 1,
|
|
605
|
+
enable_metrics: bool = True,
|
|
606
|
+
use_async: bool = True,
|
|
607
|
+
) -> jb.MLJob[Any]:
|
|
608
|
+
"""
|
|
609
|
+
Generate the SQL query for job submission.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
session: The Snowpark session to use.
|
|
613
|
+
payload: The uploaded job payload.
|
|
614
|
+
args: Arguments to pass to the entrypoint script.
|
|
615
|
+
env_vars: Environment variables to set in the job container.
|
|
616
|
+
spec_overrides: Custom service specification overrides.
|
|
617
|
+
compute_pool: The compute pool to use for job execution.
|
|
618
|
+
job_id: The ID of the job.
|
|
619
|
+
external_access_integrations: Optional list of external access integrations.
|
|
620
|
+
query_warehouse: Optional query warehouse to use.
|
|
621
|
+
target_instances: Number of instances for multi-node job.
|
|
622
|
+
min_instances: Minimum number of instances required to start the job.
|
|
623
|
+
enable_metrics: Whether to enable platform metrics for the job.
|
|
624
|
+
use_async: Whether to run the job asynchronously.
|
|
625
|
+
|
|
626
|
+
Returns:
|
|
627
|
+
The job object.
|
|
628
|
+
"""
|
|
629
|
+
args = [
|
|
630
|
+
(payload.stage_path.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
|
|
631
|
+
] + (args or [])
|
|
632
|
+
spec_options = {
|
|
633
|
+
"STAGE_PATH": payload.stage_path.as_posix(),
|
|
634
|
+
"ENTRYPOINT": ["/usr/local/bin/_entrypoint.sh"],
|
|
635
|
+
"ARGS": args,
|
|
636
|
+
"ENV_VARS": env_vars,
|
|
637
|
+
"ENABLE_METRICS": enable_metrics,
|
|
638
|
+
"SPEC_OVERRIDES": spec_overrides,
|
|
639
|
+
}
|
|
640
|
+
job_options = {
|
|
641
|
+
"EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
|
|
642
|
+
"QUERY_WAREHOUSE": query_warehouse,
|
|
643
|
+
"TARGET_INSTANCES": target_instances,
|
|
644
|
+
"MIN_INSTANCES": min_instances,
|
|
645
|
+
"ASYNC": use_async,
|
|
646
|
+
}
|
|
647
|
+
job_options = {k: v for k, v in job_options.items() if v is not None}
|
|
648
|
+
|
|
649
|
+
query_template = "CALL SYSTEM$EXECUTE_ML_JOB(?, ?, ?, ?)"
|
|
650
|
+
params = [job_id, compute_pool, json.dumps(spec_options), json.dumps(job_options)]
|
|
651
|
+
actual_job_id = query_helper.run_query(session, query_template, params=params)[0][0]
|
|
652
|
+
|
|
653
|
+
return get_job(actual_job_id, session=session)
|
snowflake/ml/model/__init__.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
|
1
|
+
from snowflake.ml.model._client.model.batch_inference_specs import (
|
|
2
|
+
InputSpec,
|
|
3
|
+
JobSpec,
|
|
4
|
+
OutputSpec,
|
|
5
|
+
)
|
|
1
6
|
from snowflake.ml.model._client.model.model_impl import Model
|
|
2
7
|
from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
|
|
3
8
|
from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
|
|
4
9
|
|
|
5
|
-
__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel"]
|
|
10
|
+
__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "InputSpec", "JobSpec", "OutputSpec"]
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class InputSpec(BaseModel):
|
|
7
|
+
input_stage_location: str
|
|
8
|
+
input_file_pattern: str = "*"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OutputSpec(BaseModel):
|
|
12
|
+
output_stage_location: str
|
|
13
|
+
output_file_prefix: Optional[str] = None
|
|
14
|
+
completion_filename: str = "_SUCCESS"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class JobSpec(BaseModel):
|
|
18
|
+
image_repo: Optional[str] = None
|
|
19
|
+
job_name: Optional[str] = None
|
|
20
|
+
num_workers: Optional[int] = None
|
|
21
|
+
function_name: Optional[str] = None
|
|
22
|
+
gpu: Optional[Union[str, int]] = None
|
|
23
|
+
force_rebuild: bool = False
|
|
24
|
+
max_batch_rows: int = 1024
|
|
25
|
+
warehouse: Optional[str] = None
|
|
26
|
+
cpu_requests: Optional[str] = None
|
|
27
|
+
memory_requests: Optional[str] = None
|