snowflake-ml-python 1.7.3__py3-none-any.whl → 1.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +19 -0
- snowflake/ml/_internal/platform_capabilities.py +87 -0
- snowflake/ml/dataset/dataset.py +0 -1
- snowflake/ml/fileset/fileset.py +6 -0
- snowflake/ml/jobs/__init__.py +21 -0
- snowflake/ml/jobs/_utils/constants.py +51 -0
- snowflake/ml/jobs/_utils/payload_utils.py +352 -0
- snowflake/ml/jobs/_utils/spec_utils.py +298 -0
- snowflake/ml/jobs/_utils/types.py +39 -0
- snowflake/ml/jobs/decorators.py +91 -0
- snowflake/ml/jobs/job.py +113 -0
- snowflake/ml/jobs/manager.py +298 -0
- snowflake/ml/model/_client/ops/model_ops.py +11 -2
- snowflake/ml/model/_client/ops/service_ops.py +1 -11
- snowflake/ml/model/_client/sql/service.py +13 -6
- snowflake/ml/model/_packager/model_handlers/_utils.py +12 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_signatures/base_handler.py +1 -2
- snowflake/ml/model/_signatures/builtins_handler.py +2 -2
- snowflake/ml/model/_signatures/numpy_handler.py +6 -7
- snowflake/ml/model/_signatures/pandas_handler.py +2 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -5
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
- snowflake/ml/model/model_signature.py +17 -4
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
- snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
- snowflake/ml/modeling/cluster/birch.py +6 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
- snowflake/ml/modeling/cluster/dbscan.py +6 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
- snowflake/ml/modeling/cluster/k_means.py +6 -3
- snowflake/ml/modeling/cluster/mean_shift.py +6 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
- snowflake/ml/modeling/cluster/optics.py +6 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
- snowflake/ml/modeling/compose/column_transformer.py +6 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
- snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
- snowflake/ml/modeling/covariance/oas.py +6 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/pca.py +6 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
- snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
- snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
- snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
- snowflake/ml/modeling/impute/knn_imputer.py +6 -3
- snowflake/ml/modeling/impute/missing_indicator.py +6 -3
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/lars.py +6 -3
- snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
- snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/perceptron.py +6 -3
- snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ridge.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
- snowflake/ml/modeling/manifold/isomap.py +6 -3
- snowflake/ml/modeling/manifold/mds.py +6 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
- snowflake/ml/modeling/manifold/tsne.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
- snowflake/ml/modeling/pipeline/pipeline.py +10 -2
- snowflake/ml/modeling/preprocessing/polynomial_features.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
- snowflake/ml/modeling/svm/linear_svc.py +6 -3
- snowflake/ml/modeling/svm/linear_svr.py +6 -3
- snowflake/ml/modeling/svm/nu_svc.py +6 -3
- snowflake/ml/modeling/svm/nu_svr.py +6 -3
- snowflake/ml/modeling/svm/svc.py +6 -3
- snowflake/ml/modeling/svm/svr.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_classifier.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +6 -3
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +6 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/METADATA +29 -14
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/RECORD +187 -178
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,298 @@
|
|
1
|
+
import logging
|
2
|
+
from math import ceil
|
3
|
+
from pathlib import PurePath
|
4
|
+
from typing import Any, Dict, List, Optional, Union
|
5
|
+
|
6
|
+
from snowflake import snowpark
|
7
|
+
from snowflake.ml._internal.utils import snowflake_env
|
8
|
+
from snowflake.ml.jobs._utils import constants, types
|
9
|
+
|
10
|
+
|
11
|
+
def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
|
12
|
+
"""Extract resource information for the specified compute pool"""
|
13
|
+
# Get the instance family
|
14
|
+
rows = session.sql(f"show compute pools like '{compute_pool}'").collect()
|
15
|
+
if not rows:
|
16
|
+
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
17
|
+
instance_family: str = rows[0]["instance_family"]
|
18
|
+
|
19
|
+
# Get the cloud we're using (AWS, Azure, etc)
|
20
|
+
region = snowflake_env.get_regions(session)[snowflake_env.get_current_region_id(session)]
|
21
|
+
cloud = region["cloud"]
|
22
|
+
|
23
|
+
return (
|
24
|
+
constants.COMMON_INSTANCE_FAMILIES.get(instance_family)
|
25
|
+
or constants.CLOUD_INSTANCE_FAMILIES[cloud][instance_family]
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
|
30
|
+
# Retrieve compute pool node resources
|
31
|
+
resources = _get_node_resources(session, compute_pool=compute_pool)
|
32
|
+
|
33
|
+
# Use MLRuntime image
|
34
|
+
image_repo = constants.DEFAULT_IMAGE_REPO
|
35
|
+
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
36
|
+
image_tag = constants.DEFAULT_IMAGE_TAG
|
37
|
+
|
38
|
+
# Try to pull latest image tag from server side if possible
|
39
|
+
query_result = session.sql("SHOW PARAMETERS LIKE 'constants.RUNTIME_BASE_IMAGE_TAG' IN ACCOUNT").collect()
|
40
|
+
if query_result:
|
41
|
+
image_tag = query_result[0]["value"]
|
42
|
+
|
43
|
+
# TODO: Should each instance consume the entire pod?
|
44
|
+
return types.ImageSpec(
|
45
|
+
repo=image_repo,
|
46
|
+
image_name=image_name,
|
47
|
+
image_tag=image_tag,
|
48
|
+
resource_requests=resources,
|
49
|
+
resource_limits=resources,
|
50
|
+
)
|
51
|
+
|
52
|
+
|
53
|
+
def generate_spec_overrides(
|
54
|
+
environment_vars: Optional[Dict[str, str]] = None,
|
55
|
+
custom_overrides: Optional[Dict[str, Any]] = None,
|
56
|
+
) -> Dict[str, Any]:
|
57
|
+
"""
|
58
|
+
Generate a dictionary of service specification overrides.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
environment_vars: Environment variables to set in primary container
|
62
|
+
custom_overrides: Custom service specification overrides
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
Resulting service specifiation patch dict. Empty if no overrides were supplied.
|
66
|
+
"""
|
67
|
+
# Generate container level overrides
|
68
|
+
container_spec: Dict[str, Any] = {
|
69
|
+
"name": constants.DEFAULT_CONTAINER_NAME,
|
70
|
+
}
|
71
|
+
if environment_vars:
|
72
|
+
# TODO: Validate environment variables
|
73
|
+
container_spec["env"] = environment_vars
|
74
|
+
|
75
|
+
# Build container override spec only if any overrides were supplied
|
76
|
+
spec = {}
|
77
|
+
if len(container_spec) > 1:
|
78
|
+
spec = {
|
79
|
+
"spec": {
|
80
|
+
"containers": [container_spec],
|
81
|
+
}
|
82
|
+
}
|
83
|
+
|
84
|
+
# Apply custom overrides
|
85
|
+
if custom_overrides:
|
86
|
+
spec = merge_patch(spec, custom_overrides, display_name="custom_overrides")
|
87
|
+
|
88
|
+
return spec
|
89
|
+
|
90
|
+
|
91
|
+
def generate_service_spec(
|
92
|
+
session: snowpark.Session,
|
93
|
+
compute_pool: str,
|
94
|
+
payload: types.UploadedPayload,
|
95
|
+
args: Optional[List[str]] = None,
|
96
|
+
) -> Dict[str, Any]:
|
97
|
+
"""
|
98
|
+
Generate a service specification for a job.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
session: Snowflake session
|
102
|
+
compute_pool: Compute pool for job execution
|
103
|
+
payload: Uploaded job payload
|
104
|
+
args: Arguments to pass to entrypoint script
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
Job service specification
|
108
|
+
"""
|
109
|
+
# Set resource requests/limits, including nvidia.com/gpu quantity if applicable
|
110
|
+
image_spec = _get_image_spec(session, compute_pool)
|
111
|
+
resource_requests: Dict[str, Union[str, int]] = {
|
112
|
+
"cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
|
113
|
+
"memory": f"{image_spec.resource_limits.memory}Gi",
|
114
|
+
}
|
115
|
+
resource_limits: Dict[str, Union[str, int]] = {
|
116
|
+
"cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
|
117
|
+
"memory": f"{image_spec.resource_limits.memory}Gi",
|
118
|
+
}
|
119
|
+
if image_spec.resource_limits.gpu > 0:
|
120
|
+
resource_requests["nvidia.com/gpu"] = image_spec.resource_requests.gpu
|
121
|
+
resource_limits["nvidia.com/gpu"] = image_spec.resource_limits.gpu
|
122
|
+
|
123
|
+
# Add local volumes for ephemeral logs and artifacts
|
124
|
+
volumes: List[Dict[str, str]] = []
|
125
|
+
volume_mounts: List[Dict[str, str]] = []
|
126
|
+
for volume_name, mount_path in [
|
127
|
+
("system-logs", "/var/log/managedservices/system/mlrs"),
|
128
|
+
("user-logs", "/var/log/managedservices/user/mlrs"),
|
129
|
+
]:
|
130
|
+
volume_mounts.append(
|
131
|
+
{
|
132
|
+
"name": volume_name,
|
133
|
+
"mountPath": mount_path,
|
134
|
+
}
|
135
|
+
)
|
136
|
+
volumes.append(
|
137
|
+
{
|
138
|
+
"name": volume_name,
|
139
|
+
"source": "local",
|
140
|
+
}
|
141
|
+
)
|
142
|
+
|
143
|
+
# Mount 30% of memory limit as a memory-backed volume
|
144
|
+
memory_volume_name = "dshm"
|
145
|
+
memory_volume_size = min(
|
146
|
+
ceil(image_spec.resource_limits.memory * constants.MEMORY_VOLUME_SIZE),
|
147
|
+
image_spec.resource_requests.memory,
|
148
|
+
)
|
149
|
+
volume_mounts.append(
|
150
|
+
{
|
151
|
+
"name": memory_volume_name,
|
152
|
+
"mountPath": "/dev/shm",
|
153
|
+
}
|
154
|
+
)
|
155
|
+
volumes.append(
|
156
|
+
{
|
157
|
+
"name": memory_volume_name,
|
158
|
+
"source": "memory",
|
159
|
+
"size": f"{memory_volume_size}Gi",
|
160
|
+
}
|
161
|
+
)
|
162
|
+
|
163
|
+
# Mount payload as volume
|
164
|
+
stage_mount = PurePath("/opt/app")
|
165
|
+
stage_volume_name = "stage-volume"
|
166
|
+
volume_mounts.append(
|
167
|
+
{
|
168
|
+
"name": stage_volume_name,
|
169
|
+
"mountPath": stage_mount.as_posix(),
|
170
|
+
}
|
171
|
+
)
|
172
|
+
volumes.append(
|
173
|
+
{
|
174
|
+
"name": stage_volume_name,
|
175
|
+
"source": payload.stage_path.as_posix(),
|
176
|
+
}
|
177
|
+
)
|
178
|
+
|
179
|
+
# TODO: Add hooks for endpoints for integration with TensorBoard etc
|
180
|
+
|
181
|
+
# Assemble into service specification dict
|
182
|
+
spec = {
|
183
|
+
"spec": {
|
184
|
+
"containers": [
|
185
|
+
{
|
186
|
+
"name": constants.DEFAULT_CONTAINER_NAME,
|
187
|
+
"image": image_spec.full_name,
|
188
|
+
"command": ["/usr/local/bin/_entrypoint.sh"],
|
189
|
+
"args": [
|
190
|
+
stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v for v in payload.entrypoint
|
191
|
+
]
|
192
|
+
+ (args or []),
|
193
|
+
"env": {
|
194
|
+
constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix(),
|
195
|
+
},
|
196
|
+
"volumeMounts": volume_mounts,
|
197
|
+
"resources": {
|
198
|
+
"requests": resource_requests,
|
199
|
+
"limits": resource_limits,
|
200
|
+
},
|
201
|
+
},
|
202
|
+
],
|
203
|
+
"volumes": volumes,
|
204
|
+
}
|
205
|
+
}
|
206
|
+
|
207
|
+
return spec
|
208
|
+
|
209
|
+
|
210
|
+
def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
|
211
|
+
"""
|
212
|
+
Implements a modified RFC7386 JSON Merge Patch
|
213
|
+
https://datatracker.ietf.org/doc/html/rfc7386
|
214
|
+
|
215
|
+
Behavior differs from the RFC in the following ways:
|
216
|
+
1. Empty nested dictionaries resulting from the patch are treated as None and are pruned
|
217
|
+
2. Attempts to merge lists of dicts using a merge key (default "name").
|
218
|
+
See _merge_lists_of_dicts for details on list merge behavior.
|
219
|
+
|
220
|
+
Args:
|
221
|
+
base: The base object to patch.
|
222
|
+
patch: The patch object.
|
223
|
+
display_name: The name of the patch object for logging purposes.
|
224
|
+
|
225
|
+
Returns:
|
226
|
+
The patched object.
|
227
|
+
"""
|
228
|
+
if not type(base) is type(patch):
|
229
|
+
if base is not None:
|
230
|
+
logging.warning(f"Type mismatch while merging {display_name} (base={type(base)}, patch={type(patch)})")
|
231
|
+
return patch
|
232
|
+
elif isinstance(patch, list) and all(isinstance(v, dict) for v in base + patch):
|
233
|
+
# TODO: Should we prune empty lists?
|
234
|
+
return _merge_lists_of_dicts(base, patch, display_name=display_name)
|
235
|
+
elif not isinstance(patch, dict) or len(patch) == 0:
|
236
|
+
return patch
|
237
|
+
|
238
|
+
result = dict(base) # Shallow copy
|
239
|
+
for key, value in patch.items():
|
240
|
+
if value is None:
|
241
|
+
result.pop(key, None)
|
242
|
+
else:
|
243
|
+
merge_result = merge_patch(result.get(key, None), value, display_name=f"{display_name}.{key}")
|
244
|
+
if isinstance(merge_result, dict) and len(merge_result) == 0:
|
245
|
+
result.pop(key, None)
|
246
|
+
else:
|
247
|
+
result[key] = merge_result
|
248
|
+
|
249
|
+
return result
|
250
|
+
|
251
|
+
|
252
|
+
def _merge_lists_of_dicts(
|
253
|
+
base: List[Dict[str, Any]], patch: List[Dict[str, Any]], merge_key: str = "name", display_name: str = ""
|
254
|
+
) -> List[Dict[str, Any]]:
|
255
|
+
"""
|
256
|
+
Attempts to merge lists of dicts by matching on a merge key (default "name").
|
257
|
+
- If the merge key is missing, the behavior falls back to overwriting the list.
|
258
|
+
- If the merge key is present, the behavior is to match the list elements based on the
|
259
|
+
merge key and preserving any unmatched elements from the base list.
|
260
|
+
- Matched entries may be dropped in the following way(s):
|
261
|
+
1. The matching patch entry has a None key entry, e.g. { "name": "foo", None: None }.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
base: The base list of dicts.
|
265
|
+
patch: The patch list of dicts.
|
266
|
+
merge_key: The key to use for merging.
|
267
|
+
display_name: The name of the patch object for logging purposes.
|
268
|
+
|
269
|
+
Returns:
|
270
|
+
The merged list of dicts if merging successful, else returns the patch list.
|
271
|
+
"""
|
272
|
+
if any(merge_key not in d for d in base + patch):
|
273
|
+
logging.warning(f"Missing merge key {merge_key} in {display_name}. Falling back to overwrite behavior.")
|
274
|
+
return patch
|
275
|
+
|
276
|
+
# Build mapping of merge key values to list elements for the base list
|
277
|
+
result = {d[merge_key]: d for d in base}
|
278
|
+
if len(result) != len(base):
|
279
|
+
logging.warning(f"Duplicate merge key {merge_key} in {display_name}. Falling back to overwrite behavior.")
|
280
|
+
return patch
|
281
|
+
|
282
|
+
# Apply patches
|
283
|
+
for d in patch:
|
284
|
+
key = d[merge_key]
|
285
|
+
|
286
|
+
# Removal case 1: `None` key in patch entry
|
287
|
+
if None in d:
|
288
|
+
result.pop(key, None)
|
289
|
+
continue
|
290
|
+
|
291
|
+
# Apply patch
|
292
|
+
if key in result:
|
293
|
+
d = merge_patch(result[key], d, display_name=f"{display_name}[{merge_key}={d[merge_key]}]")
|
294
|
+
# TODO: Should we drop the item if the patch result is empty save for the merge key?
|
295
|
+
# Can check `d.keys() <= {merge_key}`
|
296
|
+
result[key] = d
|
297
|
+
|
298
|
+
return list(result.values())
|
@@ -0,0 +1,39 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from pathlib import PurePath
|
3
|
+
from typing import List, Literal, Optional, Union
|
4
|
+
|
5
|
+
JOB_STATUS = Literal[
|
6
|
+
"PENDING",
|
7
|
+
"RUNNING",
|
8
|
+
"FAILED",
|
9
|
+
"DONE",
|
10
|
+
"INTERNAL_ERROR",
|
11
|
+
]
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass(frozen=True)
|
15
|
+
class UploadedPayload:
|
16
|
+
# TODO: Include manifest of payload files for validation
|
17
|
+
stage_path: PurePath
|
18
|
+
entrypoint: List[Union[str, PurePath]]
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass(frozen=True)
|
22
|
+
class ComputeResources:
|
23
|
+
cpu: float # Number of vCPU cores
|
24
|
+
memory: float # Memory in GiB
|
25
|
+
gpu: int = 0 # Number of GPUs
|
26
|
+
gpu_type: Optional[str] = None
|
27
|
+
|
28
|
+
|
29
|
+
@dataclass(frozen=True)
|
30
|
+
class ImageSpec:
|
31
|
+
repo: str
|
32
|
+
image_name: str
|
33
|
+
image_tag: str
|
34
|
+
resource_requests: ComputeResources
|
35
|
+
resource_limits: ComputeResources
|
36
|
+
|
37
|
+
@property
|
38
|
+
def full_name(self) -> str:
|
39
|
+
return f"{self.repo}/{self.image_name}:{self.image_tag}"
|
@@ -0,0 +1,91 @@
|
|
1
|
+
import copy
|
2
|
+
import functools
|
3
|
+
import inspect
|
4
|
+
from typing import Callable, Dict, List, Optional, TypeVar
|
5
|
+
|
6
|
+
from typing_extensions import ParamSpec
|
7
|
+
|
8
|
+
from snowflake import snowpark
|
9
|
+
from snowflake.ml._internal import telemetry
|
10
|
+
from snowflake.ml.jobs import job as jb, manager as jm
|
11
|
+
from snowflake.ml.jobs._utils import payload_utils
|
12
|
+
|
13
|
+
_PROJECT = "MLJob"
|
14
|
+
|
15
|
+
_Args = ParamSpec("_Args")
|
16
|
+
_ReturnValue = TypeVar("_ReturnValue")
|
17
|
+
|
18
|
+
|
19
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
20
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
21
|
+
def remote(
|
22
|
+
compute_pool: str,
|
23
|
+
stage_name: str,
|
24
|
+
pip_requirements: Optional[List[str]] = None,
|
25
|
+
external_access_integrations: Optional[List[str]] = None,
|
26
|
+
query_warehouse: Optional[str] = None,
|
27
|
+
env_vars: Optional[Dict[str, str]] = None,
|
28
|
+
session: Optional[snowpark.Session] = None,
|
29
|
+
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob]]:
|
30
|
+
"""
|
31
|
+
Submit a job to the compute pool.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
compute_pool: The compute pool to use for the job.
|
35
|
+
stage_name: The name of the stage where the job payload will be uploaded.
|
36
|
+
pip_requirements: A list of pip requirements for the job.
|
37
|
+
external_access_integrations: A list of external access integrations.
|
38
|
+
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
39
|
+
env_vars: Environment variables to set in container
|
40
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
Decorator that dispatches invocations of the decorated function as remote jobs.
|
44
|
+
"""
|
45
|
+
|
46
|
+
def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob]:
|
47
|
+
# Copy the function to avoid modifying the original
|
48
|
+
# We need to modify the line number of the function to exclude the
|
49
|
+
# decorator from the copied source code
|
50
|
+
wrapped_func = copy.copy(func)
|
51
|
+
wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
|
52
|
+
|
53
|
+
# Validate function arguments based on signature
|
54
|
+
signature = inspect.signature(func)
|
55
|
+
pos_arg_names = []
|
56
|
+
for name, param in signature.parameters.items():
|
57
|
+
param_type = payload_utils.get_parameter_type(param)
|
58
|
+
if param_type is not None:
|
59
|
+
payload_utils.validate_parameter_type(param_type, name)
|
60
|
+
if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
|
61
|
+
pos_arg_names.append(name)
|
62
|
+
|
63
|
+
@functools.wraps(func)
|
64
|
+
def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob:
|
65
|
+
# Validate positional args
|
66
|
+
for i, arg in enumerate(args):
|
67
|
+
arg_name = pos_arg_names[i] if i < len(pos_arg_names) else f"args[{i}]"
|
68
|
+
payload_utils.validate_parameter_type(type(arg), arg_name)
|
69
|
+
|
70
|
+
# Validate keyword args
|
71
|
+
for k, v in kwargs.items():
|
72
|
+
payload_utils.validate_parameter_type(type(v), k)
|
73
|
+
|
74
|
+
arg_list = [str(v) for v in args] + [x for k, v in kwargs.items() for x in (f"--{k}", str(v))]
|
75
|
+
job = jm._submit_job(
|
76
|
+
source=wrapped_func,
|
77
|
+
args=arg_list,
|
78
|
+
stage_name=stage_name,
|
79
|
+
compute_pool=compute_pool,
|
80
|
+
pip_requirements=pip_requirements,
|
81
|
+
external_access_integrations=external_access_integrations,
|
82
|
+
query_warehouse=query_warehouse,
|
83
|
+
env_vars=env_vars,
|
84
|
+
session=session,
|
85
|
+
)
|
86
|
+
assert isinstance(job, jb.MLJob)
|
87
|
+
return job
|
88
|
+
|
89
|
+
return wrapper
|
90
|
+
|
91
|
+
return decorator
|
snowflake/ml/jobs/job.py
ADDED
@@ -0,0 +1,113 @@
|
|
1
|
+
import time
|
2
|
+
from typing import Any, List, Optional, cast
|
3
|
+
|
4
|
+
from snowflake import snowpark
|
5
|
+
from snowflake.ml._internal import telemetry
|
6
|
+
from snowflake.ml.jobs._utils import constants, types
|
7
|
+
from snowflake.snowpark.context import get_active_session
|
8
|
+
|
9
|
+
_PROJECT = "MLJob"
|
10
|
+
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
|
11
|
+
|
12
|
+
|
13
|
+
class MLJob:
|
14
|
+
def __init__(self, id: str, session: Optional[snowpark.Session] = None) -> None:
|
15
|
+
self._id = id
|
16
|
+
self._session = session or get_active_session()
|
17
|
+
self._status: types.JOB_STATUS = "PENDING"
|
18
|
+
|
19
|
+
@property
|
20
|
+
def id(self) -> str:
|
21
|
+
"""Get the unique job ID"""
|
22
|
+
return self._id
|
23
|
+
|
24
|
+
@property
|
25
|
+
def status(self) -> types.JOB_STATUS:
|
26
|
+
"""Get the job's execution status."""
|
27
|
+
if self._status not in TERMINAL_JOB_STATUSES:
|
28
|
+
# Query backend for job status if not in terminal state
|
29
|
+
self._status = _get_status(self._session, self.id)
|
30
|
+
return self._status
|
31
|
+
|
32
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
33
|
+
def get_logs(self, limit: int = -1) -> str:
|
34
|
+
"""
|
35
|
+
Return the job's execution logs.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
limit: The maximum number of lines to return. Negative values are treated as no limit.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
The job's execution logs.
|
42
|
+
"""
|
43
|
+
logs = _get_logs(self._session, self.id, limit)
|
44
|
+
assert isinstance(logs, str) # mypy
|
45
|
+
return logs
|
46
|
+
|
47
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
48
|
+
def show_logs(self, limit: int = -1) -> None:
|
49
|
+
"""
|
50
|
+
Display the job's execution logs.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
limit: The maximum number of lines to display. Negative values are treated as no limit.
|
54
|
+
"""
|
55
|
+
print(self.get_logs(limit)) # noqa: T201: we need to print here.
|
56
|
+
|
57
|
+
@snowpark._internal.utils.private_preview(version="1.7.4")
|
58
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
59
|
+
def wait(self, timeout: float = -1) -> types.JOB_STATUS:
|
60
|
+
"""
|
61
|
+
Block until completion. Returns completion status.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
timeout: The maximum time to wait in seconds. Negative values are treated as no timeout.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
The job's completion status.
|
68
|
+
|
69
|
+
Raises:
|
70
|
+
TimeoutError: If the job does not complete within the specified timeout.
|
71
|
+
"""
|
72
|
+
delay = constants.JOB_POLL_INITIAL_DELAY_SECONDS # Start with 100ms delay
|
73
|
+
start_time = time.monotonic()
|
74
|
+
while self.status not in TERMINAL_JOB_STATUSES:
|
75
|
+
if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
|
76
|
+
raise TimeoutError(f"Job {self.id} did not complete within {elapsed} seconds")
|
77
|
+
time.sleep(delay)
|
78
|
+
delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
79
|
+
return self.status
|
80
|
+
|
81
|
+
|
82
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
83
|
+
def _get_status(session: snowpark.Session, job_id: str) -> types.JOB_STATUS:
|
84
|
+
"""Retrieve job execution status."""
|
85
|
+
# TODO: snowflake-snowpark-python<1.24.0 shows spurious error messages on
|
86
|
+
# `DESCRIBE` queries with bind variables
|
87
|
+
# Switch to use bind variables instead of client side formatting after
|
88
|
+
# updating to snowflake-snowpark-python>=1.24.0
|
89
|
+
(row,) = session.sql(f"DESCRIBE SERVICE {job_id}").collect()
|
90
|
+
return cast(types.JOB_STATUS, row["status"])
|
91
|
+
|
92
|
+
|
93
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
94
|
+
def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
|
95
|
+
"""
|
96
|
+
Retrieve the job's execution logs.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
job_id: The job ID.
|
100
|
+
limit: The maximum number of lines to return. Negative values are treated as no limit.
|
101
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
The job's execution logs.
|
105
|
+
"""
|
106
|
+
params: List[Any] = [job_id]
|
107
|
+
if limit > 0:
|
108
|
+
params.append(limit)
|
109
|
+
(row,) = session.sql(
|
110
|
+
f"SELECT SYSTEM$GET_SERVICE_LOGS(?, 0, '{constants.DEFAULT_CONTAINER_NAME}'{f', ?' if limit > 0 else ''})",
|
111
|
+
params=params,
|
112
|
+
).collect()
|
113
|
+
return str(row[0])
|