snowflake-ml-python 1.7.2__py3-none-any.whl → 1.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +16 -8
- snowflake/cortex/_classify_text.py +12 -1
- snowflake/cortex/_complete.py +101 -13
- snowflake/cortex/_embed_text_1024.py +9 -2
- snowflake/cortex/_embed_text_768.py +9 -2
- snowflake/cortex/_extract_answer.py +9 -2
- snowflake/cortex/_sentiment.py +9 -2
- snowflake/cortex/_summarize.py +9 -2
- snowflake/cortex/_translate.py +9 -2
- snowflake/ml/_internal/env_utils.py +7 -52
- snowflake/ml/_internal/platform_capabilities.py +87 -0
- snowflake/ml/_internal/utils/identifier.py +4 -2
- snowflake/ml/data/__init__.py +3 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
- snowflake/ml/data/data_connector.py +53 -11
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/data/torch_utils.py +18 -5
- snowflake/ml/dataset/dataset.py +0 -1
- snowflake/ml/feature_store/examples/example_helper.py +2 -1
- snowflake/ml/fileset/fileset.py +24 -18
- snowflake/ml/jobs/__init__.py +21 -0
- snowflake/ml/jobs/_utils/constants.py +51 -0
- snowflake/ml/jobs/_utils/payload_utils.py +352 -0
- snowflake/ml/jobs/_utils/spec_utils.py +298 -0
- snowflake/ml/jobs/_utils/types.py +39 -0
- snowflake/ml/jobs/decorators.py +91 -0
- snowflake/ml/jobs/job.py +113 -0
- snowflake/ml/jobs/manager.py +298 -0
- snowflake/ml/model/_client/model/model_version_impl.py +5 -3
- snowflake/ml/model/_client/ops/model_ops.py +13 -8
- snowflake/ml/model/_client/ops/service_ops.py +1 -11
- snowflake/ml/model/_client/sql/model_version.py +11 -0
- snowflake/ml/model/_client/sql/service.py +13 -6
- snowflake/ml/model/_model_composer/model_composer.py +8 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
- snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +39 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +6 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -10
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
- snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
- snowflake/ml/model/_signatures/base_handler.py +1 -2
- snowflake/ml/model/_signatures/builtins_handler.py +2 -2
- snowflake/ml/model/_signatures/numpy_handler.py +6 -7
- snowflake/ml/model/_signatures/pandas_handler.py +3 -3
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -5
- snowflake/ml/model/_signatures/snowpark_handler.py +11 -5
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
- snowflake/ml/model/model_signature.py +17 -4
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
- snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
- snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
- snowflake/ml/modeling/cluster/birch.py +6 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
- snowflake/ml/modeling/cluster/dbscan.py +6 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
- snowflake/ml/modeling/cluster/k_means.py +6 -3
- snowflake/ml/modeling/cluster/mean_shift.py +6 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
- snowflake/ml/modeling/cluster/optics.py +6 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
- snowflake/ml/modeling/compose/column_transformer.py +6 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
- snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
- snowflake/ml/modeling/covariance/oas.py +6 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/pca.py +6 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
- snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
- snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
- snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
- snowflake/ml/modeling/impute/knn_imputer.py +6 -3
- snowflake/ml/modeling/impute/missing_indicator.py +6 -3
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/lars.py +6 -3
- snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
- snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/perceptron.py +6 -3
- snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ridge.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
- snowflake/ml/modeling/manifold/isomap.py +6 -3
- snowflake/ml/modeling/manifold/mds.py +6 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
- snowflake/ml/modeling/manifold/tsne.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
- snowflake/ml/modeling/pipeline/pipeline.py +16 -178
- snowflake/ml/modeling/preprocessing/polynomial_features.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
- snowflake/ml/modeling/svm/linear_svc.py +6 -3
- snowflake/ml/modeling/svm/linear_svr.py +6 -3
- snowflake/ml/modeling/svm/nu_svc.py +6 -3
- snowflake/ml/modeling/svm/nu_svr.py +6 -3
- snowflake/ml/modeling/svm/svc.py +6 -3
- snowflake/ml/modeling/svm/svr.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_classifier.py +167 -91
- snowflake/ml/modeling/xgboost/xgb_regressor.py +166 -88
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +166 -88
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +166 -88
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +4 -4
- snowflake/ml/registry/_manager/model_manager.py +70 -33
- snowflake/ml/registry/registry.py +41 -22
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/METADATA +63 -19
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/RECORD +231 -226
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/utils/retryable_http.py +0 -39
- snowflake/ml/fileset/parquet_parser.py +0 -170
- snowflake/ml/fileset/tf_dataset.py +0 -88
- snowflake/ml/fileset/torch_datapipe.py +0 -57
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,352 @@
|
|
1
|
+
import inspect
|
2
|
+
import io
|
3
|
+
import sys
|
4
|
+
import textwrap
|
5
|
+
from pathlib import Path, PurePath
|
6
|
+
from typing import (
|
7
|
+
Any,
|
8
|
+
Callable,
|
9
|
+
List,
|
10
|
+
Optional,
|
11
|
+
Type,
|
12
|
+
Union,
|
13
|
+
cast,
|
14
|
+
get_args,
|
15
|
+
get_origin,
|
16
|
+
)
|
17
|
+
|
18
|
+
import cloudpickle as cp
|
19
|
+
|
20
|
+
from snowflake import snowpark
|
21
|
+
from snowflake.ml.jobs._utils import constants, types
|
22
|
+
from snowflake.snowpark._internal import code_generation
|
23
|
+
|
24
|
+
_SUPPORTED_ARG_TYPES = {str, int, float}
|
25
|
+
_STARTUP_SCRIPT_PATH = PurePath("startup.sh")
|
26
|
+
_STARTUP_SCRIPT_CODE = textwrap.dedent(
|
27
|
+
f"""
|
28
|
+
#!/bin/bash
|
29
|
+
|
30
|
+
##### Perform common set up steps #####
|
31
|
+
set -e # exit if a command fails
|
32
|
+
|
33
|
+
echo "Creating log directories..."
|
34
|
+
mkdir -p /var/log/managedservices/user/mlrs
|
35
|
+
mkdir -p /var/log/managedservices/system/mlrs
|
36
|
+
mkdir -p /var/log/managedservices/system/ray
|
37
|
+
|
38
|
+
echo "*/1 * * * * root /etc/ray_copy_cron.sh" >> /etc/cron.d/ray_copy_cron
|
39
|
+
echo "" >> /etc/cron.d/ray_copy_cron
|
40
|
+
chmod 744 /etc/cron.d/ray_copy_cron
|
41
|
+
|
42
|
+
service cron start
|
43
|
+
|
44
|
+
mkdir -p /tmp/prometheus-multi-dir
|
45
|
+
|
46
|
+
# Change directory to user payload directory
|
47
|
+
if [ -n "${constants.PAYLOAD_DIR_ENV_VAR}" ]; then
|
48
|
+
cd ${constants.PAYLOAD_DIR_ENV_VAR}
|
49
|
+
fi
|
50
|
+
|
51
|
+
##### Set up Python environment #####
|
52
|
+
export PYTHONPATH=/opt/env/site-packages/
|
53
|
+
MLRS_REQUIREMENTS_FILE=${{MLRS_REQUIREMENTS_FILE:-"requirements.txt"}}
|
54
|
+
if [ -f "${{MLRS_REQUIREMENTS_FILE}}" ]; then
|
55
|
+
# TODO: Prevent collisions with MLRS packages using virtualenvs
|
56
|
+
echo "Installing packages from $MLRS_REQUIREMENTS_FILE"
|
57
|
+
pip install -r $MLRS_REQUIREMENTS_FILE
|
58
|
+
fi
|
59
|
+
|
60
|
+
MLRS_CONDA_ENV_FILE=${{MLRS_CONDA_ENV_FILE:-"environment.yml"}}
|
61
|
+
if [ -f "${{MLRS_CONDA_ENV_FILE}}" ]; then
|
62
|
+
# TODO: Handle conda environment
|
63
|
+
echo "Custom conda environments not currently supported"
|
64
|
+
exit 1
|
65
|
+
fi
|
66
|
+
##### End Python environment setup #####
|
67
|
+
|
68
|
+
##### Ray configuration #####
|
69
|
+
shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
|
70
|
+
|
71
|
+
# Configure IP address and logging directory
|
72
|
+
eth0Ip=$(ifconfig eth0 | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
|
73
|
+
log_dir="/tmp/ray"
|
74
|
+
|
75
|
+
# Check if eth0Ip is empty and set default if necessary
|
76
|
+
if [ -z "$eth0Ip" ]; then
|
77
|
+
# This should never happen, but just in case ethOIp is not set, we should default to localhost
|
78
|
+
eth0Ip="127.0.0.1"
|
79
|
+
fi
|
80
|
+
|
81
|
+
# Common parameters for both head and worker nodes
|
82
|
+
common_params=(
|
83
|
+
"--node-ip-address=$eth0Ip"
|
84
|
+
"--object-manager-port=${{RAY_OBJECT_MANAGER_PORT:-12011}}"
|
85
|
+
"--node-manager-port=${{RAY_NODE_MANAGER_PORT:-12012}}"
|
86
|
+
"--runtime-env-agent-port=${{RAY_RUNTIME_ENV_AGENT_PORT:-12013}}"
|
87
|
+
"--dashboard-agent-grpc-port=${{RAY_DASHBOARD_AGENT_GRPC_PORT:-12014}}"
|
88
|
+
"--dashboard-agent-listen-port=${{RAY_DASHBOARD_AGENT_LISTEN_PORT:-12015}}"
|
89
|
+
"--min-worker-port=${{RAY_MIN_WORKER_PORT:-12031}}"
|
90
|
+
"--max-worker-port=${{RAY_MAX_WORKER_PORT:-13000}}"
|
91
|
+
"--metrics-export-port=11502"
|
92
|
+
"--temp-dir=$log_dir"
|
93
|
+
"--disable-usage-stats"
|
94
|
+
)
|
95
|
+
|
96
|
+
# Additional head-specific parameters
|
97
|
+
head_params=(
|
98
|
+
"--head"
|
99
|
+
"--port=${{RAY_HEAD_GCS_PORT:-12001}}" # Port of Ray (GCS server)
|
100
|
+
"--ray-client-server-port=${{RAY_HEAD_CLIENT_SERVER_PORT:-10001}}" # Listening port for Ray Client Server
|
101
|
+
"--dashboard-host=${{NODE_IP_ADDRESS}}" # Host to bind the dashboard server
|
102
|
+
"--dashboard-grpc-port=${{RAY_HEAD_DASHBOARD_GRPC_PORT:-12002}}" # Dashboard head to listen for grpc on
|
103
|
+
"--dashboard-port=${{DASHBOARD_PORT}}" # Port to bind the dashboard server for local debugging
|
104
|
+
"--resources={{\\"node_tag:head\\":1}}" # Resource tag for selecting head as coordinator
|
105
|
+
)
|
106
|
+
|
107
|
+
# Start Ray on the head node
|
108
|
+
ray start "${{common_params[@]}}" "${{head_params[@]}}" &
|
109
|
+
##### End Ray configuration #####
|
110
|
+
|
111
|
+
# TODO: Monitor MLRS and handle process crashes
|
112
|
+
python -m web.ml_runtime_grpc_server &
|
113
|
+
|
114
|
+
# TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
|
115
|
+
|
116
|
+
# Run user's Python entrypoint
|
117
|
+
echo Running command: python "$@"
|
118
|
+
python "$@"
|
119
|
+
"""
|
120
|
+
).strip()
|
121
|
+
|
122
|
+
|
123
|
+
class JobPayload:
|
124
|
+
def __init__(
|
125
|
+
self,
|
126
|
+
source: Union[str, Path, Callable[..., Any]],
|
127
|
+
entrypoint: Optional[Union[str, Path]] = None,
|
128
|
+
*,
|
129
|
+
pip_requirements: Optional[List[str]] = None,
|
130
|
+
) -> None:
|
131
|
+
self.source = Path(source) if isinstance(source, str) else source
|
132
|
+
self.entrypoint = Path(entrypoint) if isinstance(entrypoint, str) else entrypoint
|
133
|
+
self.pip_requirements = pip_requirements
|
134
|
+
|
135
|
+
def validate(self) -> None:
|
136
|
+
if callable(self.source):
|
137
|
+
# Any entrypoint value is OK for callable payloads (including None aka default)
|
138
|
+
# since we will generate the file from the serialized callable
|
139
|
+
pass
|
140
|
+
elif isinstance(self.source, Path):
|
141
|
+
# Validate self.source and self.entrypoint for files
|
142
|
+
if not self.source.exists():
|
143
|
+
raise FileNotFoundError(f"{self.source} does not exist")
|
144
|
+
if self.entrypoint is None:
|
145
|
+
if self.source.is_file():
|
146
|
+
self.entrypoint = self.source
|
147
|
+
else:
|
148
|
+
raise ValueError("entrypoint must be provided when source is a directory")
|
149
|
+
if not self.entrypoint.is_file():
|
150
|
+
# Check if self.entrypoint is a valid relative path
|
151
|
+
self.entrypoint = self.source.joinpath(self.entrypoint)
|
152
|
+
if not self.entrypoint.is_file():
|
153
|
+
raise FileNotFoundError(f"File {self.entrypoint} does not exist")
|
154
|
+
if not self.entrypoint.is_relative_to(self.source):
|
155
|
+
raise ValueError(f"{self.entrypoint} must be a subpath of {self.source}")
|
156
|
+
if self.entrypoint.suffix != ".py":
|
157
|
+
raise NotImplementedError("Only Python entrypoints are supported currently")
|
158
|
+
else:
|
159
|
+
raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
|
160
|
+
|
161
|
+
def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
|
162
|
+
# Validate payload
|
163
|
+
self.validate()
|
164
|
+
|
165
|
+
# Prepare local variables
|
166
|
+
stage_path = PurePath(stage_path) if isinstance(stage_path, str) else stage_path
|
167
|
+
source = self.source
|
168
|
+
entrypoint = self.entrypoint or Path(constants.DEFAULT_ENTRYPOINT_PATH)
|
169
|
+
|
170
|
+
# Create stage if necessary
|
171
|
+
stage_name = stage_path.parts[0]
|
172
|
+
session.sql(
|
173
|
+
f"create stage if not exists {stage_name.lstrip('@')}"
|
174
|
+
" encryption = ( type = 'SNOWFLAKE_SSE' )"
|
175
|
+
" comment = 'Created by snowflake.ml.jobs Python API'"
|
176
|
+
).collect()
|
177
|
+
|
178
|
+
# Upload payload to stage
|
179
|
+
if not isinstance(source, Path):
|
180
|
+
source_code = generate_python_code(source, source_code_display=True)
|
181
|
+
_ = session.file.put_stream(
|
182
|
+
io.BytesIO(source_code.encode()),
|
183
|
+
stage_location=stage_path.joinpath(entrypoint).as_posix(),
|
184
|
+
auto_compress=False,
|
185
|
+
overwrite=True,
|
186
|
+
)
|
187
|
+
source = entrypoint.parent
|
188
|
+
elif source.is_dir():
|
189
|
+
# Manually traverse the directory and upload each file, since Snowflake PUT
|
190
|
+
# can't handle directories. Reduce the number of PUT operations by using
|
191
|
+
# wildcard patterns to batch upload files with the same extension.
|
192
|
+
for path in {
|
193
|
+
p.parent.joinpath(f"*{p.suffix}") if p.suffix else p for p in source.resolve().rglob("*") if p.is_file()
|
194
|
+
}:
|
195
|
+
session.file.put(
|
196
|
+
str(path),
|
197
|
+
stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
|
198
|
+
overwrite=True,
|
199
|
+
auto_compress=False,
|
200
|
+
)
|
201
|
+
else:
|
202
|
+
session.file.put(
|
203
|
+
str(source.resolve()),
|
204
|
+
stage_path.as_posix(),
|
205
|
+
overwrite=True,
|
206
|
+
auto_compress=False,
|
207
|
+
)
|
208
|
+
source = source.parent
|
209
|
+
|
210
|
+
# Upload requirements
|
211
|
+
# TODO: Check if payload includes both a requirements.txt file and pip_requirements
|
212
|
+
if self.pip_requirements:
|
213
|
+
# Upload requirements.txt to stage
|
214
|
+
session.file.put_stream(
|
215
|
+
io.BytesIO("\n".join(self.pip_requirements).encode()),
|
216
|
+
stage_location=stage_path.joinpath("requirements.txt").as_posix(),
|
217
|
+
auto_compress=False,
|
218
|
+
overwrite=True,
|
219
|
+
)
|
220
|
+
|
221
|
+
# Upload startup script
|
222
|
+
# TODO: Make sure payload does not include file with same name
|
223
|
+
session.file.put_stream(
|
224
|
+
io.BytesIO(_STARTUP_SCRIPT_CODE.encode()),
|
225
|
+
stage_location=stage_path.joinpath(_STARTUP_SCRIPT_PATH).as_posix(),
|
226
|
+
auto_compress=False,
|
227
|
+
overwrite=False, # FIXME
|
228
|
+
)
|
229
|
+
|
230
|
+
return types.UploadedPayload(
|
231
|
+
stage_path=stage_path,
|
232
|
+
entrypoint=[
|
233
|
+
"bash",
|
234
|
+
_STARTUP_SCRIPT_PATH,
|
235
|
+
entrypoint.relative_to(source),
|
236
|
+
],
|
237
|
+
)
|
238
|
+
|
239
|
+
|
240
|
+
def get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
|
241
|
+
# Unwrap Optional type annotations
|
242
|
+
param_type = param.annotation
|
243
|
+
if get_origin(param_type) is Union and len(get_args(param_type)) == 2 and type(None) in get_args(param_type):
|
244
|
+
param_type = next(t for t in get_args(param_type) if t is not type(None))
|
245
|
+
|
246
|
+
# Return None for empty type annotations
|
247
|
+
if param_type == inspect.Parameter.empty:
|
248
|
+
return None
|
249
|
+
return cast(Type[object], param_type)
|
250
|
+
|
251
|
+
|
252
|
+
def validate_parameter_type(param_type: Type[object], param_name: str) -> None:
|
253
|
+
# Validate param_type is a supported type
|
254
|
+
if param_type not in _SUPPORTED_ARG_TYPES:
|
255
|
+
raise ValueError(
|
256
|
+
f"Unsupported argument type {param_type} for '{param_name}'."
|
257
|
+
f" Supported types: {', '.join(t.__name__ for t in _SUPPORTED_ARG_TYPES)}"
|
258
|
+
)
|
259
|
+
|
260
|
+
|
261
|
+
def generate_python_code(func: Callable[..., Any], source_code_display: bool = False) -> str:
|
262
|
+
signature = inspect.signature(func)
|
263
|
+
if any(
|
264
|
+
p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
265
|
+
for p in signature.parameters.values()
|
266
|
+
):
|
267
|
+
raise NotImplementedError("Function must not have unpacking arguments (* or **)")
|
268
|
+
|
269
|
+
# Mirrored from Snowpark generate_python_code() function
|
270
|
+
# https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
|
271
|
+
try:
|
272
|
+
source_code_comment = (
|
273
|
+
code_generation.generate_source_code(func) if source_code_display else "" # type: ignore[arg-type]
|
274
|
+
)
|
275
|
+
except Exception as exc:
|
276
|
+
error_msg = f"Source code comment could not be generated for {func} due to error {exc}."
|
277
|
+
source_code_comment = code_generation.comment_source_code(error_msg)
|
278
|
+
|
279
|
+
func_name = "func"
|
280
|
+
func_code = f"""
|
281
|
+
{source_code_comment}
|
282
|
+
|
283
|
+
import pickle
|
284
|
+
{func_name} = pickle.loads(bytes.fromhex('{cp.dumps(func).hex()}'))
|
285
|
+
"""
|
286
|
+
|
287
|
+
# Generate argparse logic for argument handling (type coercion, default values, etc)
|
288
|
+
argparse_code = ["import argparse", "", "parser = argparse.ArgumentParser()"]
|
289
|
+
argparse_postproc = []
|
290
|
+
for name, param in signature.parameters.items():
|
291
|
+
opts = {}
|
292
|
+
|
293
|
+
param_type = get_parameter_type(param)
|
294
|
+
if param_type is not None:
|
295
|
+
validate_parameter_type(param_type, name)
|
296
|
+
opts["type"] = param_type.__name__
|
297
|
+
|
298
|
+
if param.default != inspect.Parameter.empty:
|
299
|
+
opts["default"] = f"'{param.default}'" if isinstance(param.default, str) else param.default
|
300
|
+
|
301
|
+
if param.kind == inspect.Parameter.KEYWORD_ONLY:
|
302
|
+
# Keyword argument
|
303
|
+
argparse_code.append(
|
304
|
+
f"parser.add_argument('--{name}', required={'default' not in opts},"
|
305
|
+
f" {', '.join(f'{k}={v}' for k, v in opts.items())})"
|
306
|
+
)
|
307
|
+
else:
|
308
|
+
# Positional argument. Use `argparse.add_mutually_exclusive_group()`
|
309
|
+
# to allow passing positional args by name as well
|
310
|
+
group_name = f"{name}_group"
|
311
|
+
argparse_code.append(
|
312
|
+
f"{group_name} = parser.add_mutually_exclusive_group(required={'default' not in opts})"
|
313
|
+
)
|
314
|
+
argparse_code.append(
|
315
|
+
f"{group_name}.add_argument('pos-{name}', metavar='{name}', nargs='?',"
|
316
|
+
f" {', '.join(f'{k}={v}' for k, v in opts.items() if k != 'default')})"
|
317
|
+
)
|
318
|
+
argparse_code.append(
|
319
|
+
f"{group_name}.add_argument('--{name}', {', '.join(f'{k}={v}' for k, v in opts.items())})"
|
320
|
+
)
|
321
|
+
argparse_code.append("") # Add newline for readability
|
322
|
+
argparse_postproc.append(
|
323
|
+
f"args.{name} = {name} if ({name} := args.__dict__.pop('pos-{name}')) is not None else args.{name}"
|
324
|
+
)
|
325
|
+
argparse_code.append("args = parser.parse_args()")
|
326
|
+
param_code = "\n".join(argparse_code + argparse_postproc)
|
327
|
+
|
328
|
+
return f"""
|
329
|
+
### Version guard to check compatibility across Python versions ###
|
330
|
+
import sys
|
331
|
+
import warnings
|
332
|
+
|
333
|
+
if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
|
334
|
+
warnings.warn(
|
335
|
+
"Python version mismatch: job was created using"
|
336
|
+
" python{sys.version_info.major}.{sys.version_info.minor}"
|
337
|
+
f" but runtime environment uses python{{sys.version_info.major}}.{{sys.version_info.minor}}."
|
338
|
+
" Compatibility across Python versions is not guaranteed and may result in unexpected behavior."
|
339
|
+
" This will be fixed in a future release; for now, please use Python version"
|
340
|
+
f" {{sys.version_info.major}}.{{sys.version_info.minor}}.",
|
341
|
+
RuntimeWarning,
|
342
|
+
stacklevel=0,
|
343
|
+
)
|
344
|
+
### End version guard ###
|
345
|
+
|
346
|
+
{func_code.strip()}
|
347
|
+
|
348
|
+
if __name__ == '__main__':
|
349
|
+
{textwrap.indent(param_code, ' ')}
|
350
|
+
|
351
|
+
{func_name}(**vars(args))
|
352
|
+
"""
|
@@ -0,0 +1,298 @@
|
|
1
|
+
import logging
|
2
|
+
from math import ceil
|
3
|
+
from pathlib import PurePath
|
4
|
+
from typing import Any, Dict, List, Optional, Union
|
5
|
+
|
6
|
+
from snowflake import snowpark
|
7
|
+
from snowflake.ml._internal.utils import snowflake_env
|
8
|
+
from snowflake.ml.jobs._utils import constants, types
|
9
|
+
|
10
|
+
|
11
|
+
def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
|
12
|
+
"""Extract resource information for the specified compute pool"""
|
13
|
+
# Get the instance family
|
14
|
+
rows = session.sql(f"show compute pools like '{compute_pool}'").collect()
|
15
|
+
if not rows:
|
16
|
+
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
17
|
+
instance_family: str = rows[0]["instance_family"]
|
18
|
+
|
19
|
+
# Get the cloud we're using (AWS, Azure, etc)
|
20
|
+
region = snowflake_env.get_regions(session)[snowflake_env.get_current_region_id(session)]
|
21
|
+
cloud = region["cloud"]
|
22
|
+
|
23
|
+
return (
|
24
|
+
constants.COMMON_INSTANCE_FAMILIES.get(instance_family)
|
25
|
+
or constants.CLOUD_INSTANCE_FAMILIES[cloud][instance_family]
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
|
30
|
+
# Retrieve compute pool node resources
|
31
|
+
resources = _get_node_resources(session, compute_pool=compute_pool)
|
32
|
+
|
33
|
+
# Use MLRuntime image
|
34
|
+
image_repo = constants.DEFAULT_IMAGE_REPO
|
35
|
+
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
36
|
+
image_tag = constants.DEFAULT_IMAGE_TAG
|
37
|
+
|
38
|
+
# Try to pull latest image tag from server side if possible
|
39
|
+
query_result = session.sql("SHOW PARAMETERS LIKE 'constants.RUNTIME_BASE_IMAGE_TAG' IN ACCOUNT").collect()
|
40
|
+
if query_result:
|
41
|
+
image_tag = query_result[0]["value"]
|
42
|
+
|
43
|
+
# TODO: Should each instance consume the entire pod?
|
44
|
+
return types.ImageSpec(
|
45
|
+
repo=image_repo,
|
46
|
+
image_name=image_name,
|
47
|
+
image_tag=image_tag,
|
48
|
+
resource_requests=resources,
|
49
|
+
resource_limits=resources,
|
50
|
+
)
|
51
|
+
|
52
|
+
|
53
|
+
def generate_spec_overrides(
|
54
|
+
environment_vars: Optional[Dict[str, str]] = None,
|
55
|
+
custom_overrides: Optional[Dict[str, Any]] = None,
|
56
|
+
) -> Dict[str, Any]:
|
57
|
+
"""
|
58
|
+
Generate a dictionary of service specification overrides.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
environment_vars: Environment variables to set in primary container
|
62
|
+
custom_overrides: Custom service specification overrides
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
Resulting service specifiation patch dict. Empty if no overrides were supplied.
|
66
|
+
"""
|
67
|
+
# Generate container level overrides
|
68
|
+
container_spec: Dict[str, Any] = {
|
69
|
+
"name": constants.DEFAULT_CONTAINER_NAME,
|
70
|
+
}
|
71
|
+
if environment_vars:
|
72
|
+
# TODO: Validate environment variables
|
73
|
+
container_spec["env"] = environment_vars
|
74
|
+
|
75
|
+
# Build container override spec only if any overrides were supplied
|
76
|
+
spec = {}
|
77
|
+
if len(container_spec) > 1:
|
78
|
+
spec = {
|
79
|
+
"spec": {
|
80
|
+
"containers": [container_spec],
|
81
|
+
}
|
82
|
+
}
|
83
|
+
|
84
|
+
# Apply custom overrides
|
85
|
+
if custom_overrides:
|
86
|
+
spec = merge_patch(spec, custom_overrides, display_name="custom_overrides")
|
87
|
+
|
88
|
+
return spec
|
89
|
+
|
90
|
+
|
91
|
+
def generate_service_spec(
|
92
|
+
session: snowpark.Session,
|
93
|
+
compute_pool: str,
|
94
|
+
payload: types.UploadedPayload,
|
95
|
+
args: Optional[List[str]] = None,
|
96
|
+
) -> Dict[str, Any]:
|
97
|
+
"""
|
98
|
+
Generate a service specification for a job.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
session: Snowflake session
|
102
|
+
compute_pool: Compute pool for job execution
|
103
|
+
payload: Uploaded job payload
|
104
|
+
args: Arguments to pass to entrypoint script
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
Job service specification
|
108
|
+
"""
|
109
|
+
# Set resource requests/limits, including nvidia.com/gpu quantity if applicable
|
110
|
+
image_spec = _get_image_spec(session, compute_pool)
|
111
|
+
resource_requests: Dict[str, Union[str, int]] = {
|
112
|
+
"cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
|
113
|
+
"memory": f"{image_spec.resource_limits.memory}Gi",
|
114
|
+
}
|
115
|
+
resource_limits: Dict[str, Union[str, int]] = {
|
116
|
+
"cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
|
117
|
+
"memory": f"{image_spec.resource_limits.memory}Gi",
|
118
|
+
}
|
119
|
+
if image_spec.resource_limits.gpu > 0:
|
120
|
+
resource_requests["nvidia.com/gpu"] = image_spec.resource_requests.gpu
|
121
|
+
resource_limits["nvidia.com/gpu"] = image_spec.resource_limits.gpu
|
122
|
+
|
123
|
+
# Add local volumes for ephemeral logs and artifacts
|
124
|
+
volumes: List[Dict[str, str]] = []
|
125
|
+
volume_mounts: List[Dict[str, str]] = []
|
126
|
+
for volume_name, mount_path in [
|
127
|
+
("system-logs", "/var/log/managedservices/system/mlrs"),
|
128
|
+
("user-logs", "/var/log/managedservices/user/mlrs"),
|
129
|
+
]:
|
130
|
+
volume_mounts.append(
|
131
|
+
{
|
132
|
+
"name": volume_name,
|
133
|
+
"mountPath": mount_path,
|
134
|
+
}
|
135
|
+
)
|
136
|
+
volumes.append(
|
137
|
+
{
|
138
|
+
"name": volume_name,
|
139
|
+
"source": "local",
|
140
|
+
}
|
141
|
+
)
|
142
|
+
|
143
|
+
# Mount 30% of memory limit as a memory-backed volume
|
144
|
+
memory_volume_name = "dshm"
|
145
|
+
memory_volume_size = min(
|
146
|
+
ceil(image_spec.resource_limits.memory * constants.MEMORY_VOLUME_SIZE),
|
147
|
+
image_spec.resource_requests.memory,
|
148
|
+
)
|
149
|
+
volume_mounts.append(
|
150
|
+
{
|
151
|
+
"name": memory_volume_name,
|
152
|
+
"mountPath": "/dev/shm",
|
153
|
+
}
|
154
|
+
)
|
155
|
+
volumes.append(
|
156
|
+
{
|
157
|
+
"name": memory_volume_name,
|
158
|
+
"source": "memory",
|
159
|
+
"size": f"{memory_volume_size}Gi",
|
160
|
+
}
|
161
|
+
)
|
162
|
+
|
163
|
+
# Mount payload as volume
|
164
|
+
stage_mount = PurePath("/opt/app")
|
165
|
+
stage_volume_name = "stage-volume"
|
166
|
+
volume_mounts.append(
|
167
|
+
{
|
168
|
+
"name": stage_volume_name,
|
169
|
+
"mountPath": stage_mount.as_posix(),
|
170
|
+
}
|
171
|
+
)
|
172
|
+
volumes.append(
|
173
|
+
{
|
174
|
+
"name": stage_volume_name,
|
175
|
+
"source": payload.stage_path.as_posix(),
|
176
|
+
}
|
177
|
+
)
|
178
|
+
|
179
|
+
# TODO: Add hooks for endpoints for integration with TensorBoard etc
|
180
|
+
|
181
|
+
# Assemble into service specification dict
|
182
|
+
spec = {
|
183
|
+
"spec": {
|
184
|
+
"containers": [
|
185
|
+
{
|
186
|
+
"name": constants.DEFAULT_CONTAINER_NAME,
|
187
|
+
"image": image_spec.full_name,
|
188
|
+
"command": ["/usr/local/bin/_entrypoint.sh"],
|
189
|
+
"args": [
|
190
|
+
stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v for v in payload.entrypoint
|
191
|
+
]
|
192
|
+
+ (args or []),
|
193
|
+
"env": {
|
194
|
+
constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix(),
|
195
|
+
},
|
196
|
+
"volumeMounts": volume_mounts,
|
197
|
+
"resources": {
|
198
|
+
"requests": resource_requests,
|
199
|
+
"limits": resource_limits,
|
200
|
+
},
|
201
|
+
},
|
202
|
+
],
|
203
|
+
"volumes": volumes,
|
204
|
+
}
|
205
|
+
}
|
206
|
+
|
207
|
+
return spec
|
208
|
+
|
209
|
+
|
210
|
+
def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
|
211
|
+
"""
|
212
|
+
Implements a modified RFC7386 JSON Merge Patch
|
213
|
+
https://datatracker.ietf.org/doc/html/rfc7386
|
214
|
+
|
215
|
+
Behavior differs from the RFC in the following ways:
|
216
|
+
1. Empty nested dictionaries resulting from the patch are treated as None and are pruned
|
217
|
+
2. Attempts to merge lists of dicts using a merge key (default "name").
|
218
|
+
See _merge_lists_of_dicts for details on list merge behavior.
|
219
|
+
|
220
|
+
Args:
|
221
|
+
base: The base object to patch.
|
222
|
+
patch: The patch object.
|
223
|
+
display_name: The name of the patch object for logging purposes.
|
224
|
+
|
225
|
+
Returns:
|
226
|
+
The patched object.
|
227
|
+
"""
|
228
|
+
if not type(base) is type(patch):
|
229
|
+
if base is not None:
|
230
|
+
logging.warning(f"Type mismatch while merging {display_name} (base={type(base)}, patch={type(patch)})")
|
231
|
+
return patch
|
232
|
+
elif isinstance(patch, list) and all(isinstance(v, dict) for v in base + patch):
|
233
|
+
# TODO: Should we prune empty lists?
|
234
|
+
return _merge_lists_of_dicts(base, patch, display_name=display_name)
|
235
|
+
elif not isinstance(patch, dict) or len(patch) == 0:
|
236
|
+
return patch
|
237
|
+
|
238
|
+
result = dict(base) # Shallow copy
|
239
|
+
for key, value in patch.items():
|
240
|
+
if value is None:
|
241
|
+
result.pop(key, None)
|
242
|
+
else:
|
243
|
+
merge_result = merge_patch(result.get(key, None), value, display_name=f"{display_name}.{key}")
|
244
|
+
if isinstance(merge_result, dict) and len(merge_result) == 0:
|
245
|
+
result.pop(key, None)
|
246
|
+
else:
|
247
|
+
result[key] = merge_result
|
248
|
+
|
249
|
+
return result
|
250
|
+
|
251
|
+
|
252
|
+
def _merge_lists_of_dicts(
|
253
|
+
base: List[Dict[str, Any]], patch: List[Dict[str, Any]], merge_key: str = "name", display_name: str = ""
|
254
|
+
) -> List[Dict[str, Any]]:
|
255
|
+
"""
|
256
|
+
Attempts to merge lists of dicts by matching on a merge key (default "name").
|
257
|
+
- If the merge key is missing, the behavior falls back to overwriting the list.
|
258
|
+
- If the merge key is present, the behavior is to match the list elements based on the
|
259
|
+
merge key and preserving any unmatched elements from the base list.
|
260
|
+
- Matched entries may be dropped in the following way(s):
|
261
|
+
1. The matching patch entry has a None key entry, e.g. { "name": "foo", None: None }.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
base: The base list of dicts.
|
265
|
+
patch: The patch list of dicts.
|
266
|
+
merge_key: The key to use for merging.
|
267
|
+
display_name: The name of the patch object for logging purposes.
|
268
|
+
|
269
|
+
Returns:
|
270
|
+
The merged list of dicts if merging successful, else returns the patch list.
|
271
|
+
"""
|
272
|
+
if any(merge_key not in d for d in base + patch):
|
273
|
+
logging.warning(f"Missing merge key {merge_key} in {display_name}. Falling back to overwrite behavior.")
|
274
|
+
return patch
|
275
|
+
|
276
|
+
# Build mapping of merge key values to list elements for the base list
|
277
|
+
result = {d[merge_key]: d for d in base}
|
278
|
+
if len(result) != len(base):
|
279
|
+
logging.warning(f"Duplicate merge key {merge_key} in {display_name}. Falling back to overwrite behavior.")
|
280
|
+
return patch
|
281
|
+
|
282
|
+
# Apply patches
|
283
|
+
for d in patch:
|
284
|
+
key = d[merge_key]
|
285
|
+
|
286
|
+
# Removal case 1: `None` key in patch entry
|
287
|
+
if None in d:
|
288
|
+
result.pop(key, None)
|
289
|
+
continue
|
290
|
+
|
291
|
+
# Apply patch
|
292
|
+
if key in result:
|
293
|
+
d = merge_patch(result[key], d, display_name=f"{display_name}[{merge_key}={d[merge_key]}]")
|
294
|
+
# TODO: Should we drop the item if the patch result is empty save for the merge key?
|
295
|
+
# Can check `d.keys() <= {merge_key}`
|
296
|
+
result[key] = d
|
297
|
+
|
298
|
+
return list(result.values())
|
@@ -0,0 +1,39 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from pathlib import PurePath
|
3
|
+
from typing import List, Literal, Optional, Union
|
4
|
+
|
5
|
+
JOB_STATUS = Literal[
|
6
|
+
"PENDING",
|
7
|
+
"RUNNING",
|
8
|
+
"FAILED",
|
9
|
+
"DONE",
|
10
|
+
"INTERNAL_ERROR",
|
11
|
+
]
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass(frozen=True)
|
15
|
+
class UploadedPayload:
|
16
|
+
# TODO: Include manifest of payload files for validation
|
17
|
+
stage_path: PurePath
|
18
|
+
entrypoint: List[Union[str, PurePath]]
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass(frozen=True)
|
22
|
+
class ComputeResources:
|
23
|
+
cpu: float # Number of vCPU cores
|
24
|
+
memory: float # Memory in GiB
|
25
|
+
gpu: int = 0 # Number of GPUs
|
26
|
+
gpu_type: Optional[str] = None
|
27
|
+
|
28
|
+
|
29
|
+
@dataclass(frozen=True)
|
30
|
+
class ImageSpec:
|
31
|
+
repo: str
|
32
|
+
image_name: str
|
33
|
+
image_tag: str
|
34
|
+
resource_requests: ComputeResources
|
35
|
+
resource_limits: ComputeResources
|
36
|
+
|
37
|
+
@property
|
38
|
+
def full_name(self) -> str:
|
39
|
+
return f"{self.repo}/{self.image_name}:{self.image_tag}"
|