snowflake-ml-python 1.7.3__py3-none-any.whl → 1.7.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +19 -0
- snowflake/ml/_internal/env_utils.py +64 -21
- snowflake/ml/_internal/platform_capabilities.py +87 -0
- snowflake/ml/_internal/relax_version_strategy.py +16 -0
- snowflake/ml/_internal/telemetry.py +21 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
- snowflake/ml/dataset/dataset.py +0 -1
- snowflake/ml/feature_store/feature_store.py +18 -0
- snowflake/ml/feature_store/feature_view.py +46 -1
- snowflake/ml/fileset/fileset.py +6 -0
- snowflake/ml/jobs/__init__.py +21 -0
- snowflake/ml/jobs/_utils/constants.py +57 -0
- snowflake/ml/jobs/_utils/payload_utils.py +438 -0
- snowflake/ml/jobs/_utils/spec_utils.py +296 -0
- snowflake/ml/jobs/_utils/types.py +39 -0
- snowflake/ml/jobs/decorators.py +71 -0
- snowflake/ml/jobs/job.py +113 -0
- snowflake/ml/jobs/manager.py +298 -0
- snowflake/ml/model/_client/ops/model_ops.py +11 -2
- snowflake/ml/model/_client/ops/service_ops.py +1 -11
- snowflake/ml/model/_client/sql/service.py +13 -6
- snowflake/ml/model/_packager/model_env/model_env.py +45 -28
- snowflake/ml/model/_packager/model_handlers/_utils.py +19 -6
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +17 -0
- snowflake/ml/model/_packager/model_handlers/keras.py +230 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +28 -3
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +74 -21
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +27 -49
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +3 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -1
- snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
- snowflake/ml/model/_signatures/base_handler.py +1 -2
- snowflake/ml/model/_signatures/builtins_handler.py +2 -2
- snowflake/ml/model/_signatures/core.py +2 -2
- snowflake/ml/model/_signatures/numpy_handler.py +11 -12
- snowflake/ml/model/_signatures/pandas_handler.py +11 -9
- snowflake/ml/model/_signatures/pytorch_handler.py +3 -6
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
- snowflake/ml/model/model_signature.py +25 -4
- snowflake/ml/model/type_hints.py +15 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
- snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
- snowflake/ml/modeling/cluster/birch.py +6 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
- snowflake/ml/modeling/cluster/dbscan.py +6 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
- snowflake/ml/modeling/cluster/k_means.py +6 -3
- snowflake/ml/modeling/cluster/mean_shift.py +6 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
- snowflake/ml/modeling/cluster/optics.py +6 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
- snowflake/ml/modeling/compose/column_transformer.py +6 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
- snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
- snowflake/ml/modeling/covariance/oas.py +6 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/pca.py +6 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
- snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
- snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
- snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
- snowflake/ml/modeling/impute/knn_imputer.py +6 -3
- snowflake/ml/modeling/impute/missing_indicator.py +6 -3
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/lars.py +6 -3
- snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
- snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/perceptron.py +6 -3
- snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ridge.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
- snowflake/ml/modeling/manifold/isomap.py +6 -3
- snowflake/ml/modeling/manifold/mds.py +6 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
- snowflake/ml/modeling/manifold/tsne.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
- snowflake/ml/modeling/pipeline/pipeline.py +28 -3
- snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -5
- snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
- snowflake/ml/modeling/svm/linear_svc.py +6 -3
- snowflake/ml/modeling/svm/linear_svr.py +6 -3
- snowflake/ml/modeling/svm/nu_svc.py +6 -3
- snowflake/ml/modeling/svm/nu_svr.py +6 -3
- snowflake/ml/modeling/svm/svc.py +6 -3
- snowflake/ml/modeling/svm/svr.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_classifier.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +6 -3
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +6 -3
- snowflake/ml/registry/registry.py +34 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/METADATA +81 -33
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/RECORD +208 -196
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,57 @@
|
|
1
|
+
from snowflake.ml._internal.utils.snowflake_env import SnowflakeCloudType
|
2
|
+
from snowflake.ml.jobs._utils.types import ComputeResources
|
3
|
+
|
4
|
+
# SPCS specification constants
|
5
|
+
DEFAULT_CONTAINER_NAME = "main"
|
6
|
+
PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
7
|
+
MEMORY_VOLUME_NAME = "dshm"
|
8
|
+
STAGE_VOLUME_NAME = "stage-volume"
|
9
|
+
STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
10
|
+
|
11
|
+
# Default container image information
|
12
|
+
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
13
|
+
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
14
|
+
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
15
|
+
DEFAULT_IMAGE_TAG = "0.9.2"
|
16
|
+
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
17
|
+
|
18
|
+
# Percent of container memory to allocate for /dev/shm volume
|
19
|
+
MEMORY_VOLUME_SIZE = 0.3
|
20
|
+
|
21
|
+
# Job status polling constants
|
22
|
+
JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
|
23
|
+
JOB_POLL_MAX_DELAY_SECONDS = 1
|
24
|
+
|
25
|
+
# Magic attributes
|
26
|
+
IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
|
27
|
+
|
28
|
+
# Compute pool resource information
|
29
|
+
# TODO: Query Snowflake for resource information instead of relying on this hardcoded
|
30
|
+
# table from https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool
|
31
|
+
COMMON_INSTANCE_FAMILIES = {
|
32
|
+
"CPU_X64_XS": ComputeResources(cpu=1, memory=6),
|
33
|
+
"CPU_X64_S": ComputeResources(cpu=3, memory=13),
|
34
|
+
"CPU_X64_M": ComputeResources(cpu=6, memory=28),
|
35
|
+
"CPU_X64_L": ComputeResources(cpu=28, memory=116),
|
36
|
+
"HIGHMEM_X64_S": ComputeResources(cpu=6, memory=58),
|
37
|
+
}
|
38
|
+
AWS_INSTANCE_FAMILIES = {
|
39
|
+
"HIGHMEM_X64_M": ComputeResources(cpu=28, memory=240),
|
40
|
+
"HIGHMEM_X64_L": ComputeResources(cpu=124, memory=984),
|
41
|
+
"GPU_NV_S": ComputeResources(cpu=6, memory=27, gpu=1, gpu_type="A10G"),
|
42
|
+
"GPU_NV_M": ComputeResources(cpu=44, memory=178, gpu=4, gpu_type="A10G"),
|
43
|
+
"GPU_NV_L": ComputeResources(cpu=92, memory=1112, gpu=8, gpu_type="A100"),
|
44
|
+
}
|
45
|
+
AZURE_INSTANCE_FAMILIES = {
|
46
|
+
"HIGHMEM_X64_M": ComputeResources(cpu=28, memory=244),
|
47
|
+
"HIGHMEM_X64_L": ComputeResources(cpu=92, memory=654),
|
48
|
+
"GPU_NV_XS": ComputeResources(cpu=3, memory=26, gpu=1, gpu_type="T4"),
|
49
|
+
"GPU_NV_SM": ComputeResources(cpu=32, memory=424, gpu=1, gpu_type="A10"),
|
50
|
+
"GPU_NV_2M": ComputeResources(cpu=68, memory=858, gpu=2, gpu_type="A10"),
|
51
|
+
"GPU_NV_3M": ComputeResources(cpu=44, memory=424, gpu=2, gpu_type="A100"),
|
52
|
+
"GPU_NV_SL": ComputeResources(cpu=92, memory=858, gpu=4, gpu_type="A100"),
|
53
|
+
}
|
54
|
+
CLOUD_INSTANCE_FAMILIES = {
|
55
|
+
SnowflakeCloudType.AWS: AWS_INSTANCE_FAMILIES,
|
56
|
+
SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES,
|
57
|
+
}
|
@@ -0,0 +1,438 @@
|
|
1
|
+
import functools
|
2
|
+
import inspect
|
3
|
+
import io
|
4
|
+
import itertools
|
5
|
+
import pickle
|
6
|
+
import sys
|
7
|
+
import textwrap
|
8
|
+
from pathlib import Path, PurePath
|
9
|
+
from typing import (
|
10
|
+
Any,
|
11
|
+
Callable,
|
12
|
+
List,
|
13
|
+
Optional,
|
14
|
+
Type,
|
15
|
+
Union,
|
16
|
+
cast,
|
17
|
+
get_args,
|
18
|
+
get_origin,
|
19
|
+
)
|
20
|
+
|
21
|
+
import cloudpickle as cp
|
22
|
+
|
23
|
+
from snowflake import snowpark
|
24
|
+
from snowflake.ml.jobs._utils import constants, types
|
25
|
+
from snowflake.snowpark import exceptions as sp_exceptions
|
26
|
+
from snowflake.snowpark._internal import code_generation
|
27
|
+
|
28
|
+
_SUPPORTED_ARG_TYPES = {str, int, float}
|
29
|
+
_SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
|
30
|
+
_STARTUP_SCRIPT_PATH = PurePath("startup.sh")
|
31
|
+
_STARTUP_SCRIPT_CODE = textwrap.dedent(
|
32
|
+
f"""
|
33
|
+
#!/bin/bash
|
34
|
+
|
35
|
+
##### Perform common set up steps #####
|
36
|
+
set -e # exit if a command fails
|
37
|
+
|
38
|
+
echo "Creating log directories..."
|
39
|
+
mkdir -p /var/log/managedservices/user/mlrs
|
40
|
+
mkdir -p /var/log/managedservices/system/mlrs
|
41
|
+
mkdir -p /var/log/managedservices/system/ray
|
42
|
+
|
43
|
+
echo "*/1 * * * * root /etc/ray_copy_cron.sh" >> /etc/cron.d/ray_copy_cron
|
44
|
+
echo "" >> /etc/cron.d/ray_copy_cron
|
45
|
+
chmod 744 /etc/cron.d/ray_copy_cron
|
46
|
+
|
47
|
+
service cron start
|
48
|
+
|
49
|
+
mkdir -p /tmp/prometheus-multi-dir
|
50
|
+
|
51
|
+
# Change directory to user payload directory
|
52
|
+
if [ -n "${constants.PAYLOAD_DIR_ENV_VAR}" ]; then
|
53
|
+
cd ${constants.PAYLOAD_DIR_ENV_VAR}
|
54
|
+
fi
|
55
|
+
|
56
|
+
##### Set up Python environment #####
|
57
|
+
export PYTHONPATH=/opt/env/site-packages/
|
58
|
+
MLRS_REQUIREMENTS_FILE=${{MLRS_REQUIREMENTS_FILE:-"requirements.txt"}}
|
59
|
+
if [ -f "${{MLRS_REQUIREMENTS_FILE}}" ]; then
|
60
|
+
# TODO: Prevent collisions with MLRS packages using virtualenvs
|
61
|
+
echo "Installing packages from $MLRS_REQUIREMENTS_FILE"
|
62
|
+
pip install -r $MLRS_REQUIREMENTS_FILE
|
63
|
+
fi
|
64
|
+
|
65
|
+
MLRS_CONDA_ENV_FILE=${{MLRS_CONDA_ENV_FILE:-"environment.yml"}}
|
66
|
+
if [ -f "${{MLRS_CONDA_ENV_FILE}}" ]; then
|
67
|
+
# TODO: Handle conda environment
|
68
|
+
echo "Custom conda environments not currently supported"
|
69
|
+
exit 1
|
70
|
+
fi
|
71
|
+
##### End Python environment setup #####
|
72
|
+
|
73
|
+
##### Ray configuration #####
|
74
|
+
shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
|
75
|
+
|
76
|
+
# Configure IP address and logging directory
|
77
|
+
eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
|
78
|
+
log_dir="/tmp/ray"
|
79
|
+
|
80
|
+
# Check if eth0Ip is a valid IP address and fall back to default if necessary
|
81
|
+
if [[ ! $eth0Ip =~ ^[0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+$ ]]; then
|
82
|
+
eth0Ip="127.0.0.1"
|
83
|
+
fi
|
84
|
+
|
85
|
+
# Common parameters for both head and worker nodes
|
86
|
+
common_params=(
|
87
|
+
"--node-ip-address=$eth0Ip"
|
88
|
+
"--object-manager-port=${{RAY_OBJECT_MANAGER_PORT:-12011}}"
|
89
|
+
"--node-manager-port=${{RAY_NODE_MANAGER_PORT:-12012}}"
|
90
|
+
"--runtime-env-agent-port=${{RAY_RUNTIME_ENV_AGENT_PORT:-12013}}"
|
91
|
+
"--dashboard-agent-grpc-port=${{RAY_DASHBOARD_AGENT_GRPC_PORT:-12014}}"
|
92
|
+
"--dashboard-agent-listen-port=${{RAY_DASHBOARD_AGENT_LISTEN_PORT:-12015}}"
|
93
|
+
"--min-worker-port=${{RAY_MIN_WORKER_PORT:-12031}}"
|
94
|
+
"--max-worker-port=${{RAY_MAX_WORKER_PORT:-13000}}"
|
95
|
+
"--metrics-export-port=11502"
|
96
|
+
"--temp-dir=$log_dir"
|
97
|
+
"--disable-usage-stats"
|
98
|
+
)
|
99
|
+
|
100
|
+
# Additional head-specific parameters
|
101
|
+
head_params=(
|
102
|
+
"--head"
|
103
|
+
"--port=${{RAY_HEAD_GCS_PORT:-12001}}" # Port of Ray (GCS server)
|
104
|
+
"--ray-client-server-port=${{RAY_HEAD_CLIENT_SERVER_PORT:-10001}}" # Listening port for Ray Client Server
|
105
|
+
"--dashboard-host=${{NODE_IP_ADDRESS}}" # Host to bind the dashboard server
|
106
|
+
"--dashboard-grpc-port=${{RAY_HEAD_DASHBOARD_GRPC_PORT:-12002}}" # Dashboard head to listen for grpc on
|
107
|
+
"--dashboard-port=${{DASHBOARD_PORT}}" # Port to bind the dashboard server for local debugging
|
108
|
+
"--resources={{\\"node_tag:head\\":1}}" # Resource tag for selecting head as coordinator
|
109
|
+
)
|
110
|
+
|
111
|
+
# Start Ray on the head node
|
112
|
+
ray start "${{common_params[@]}}" "${{head_params[@]}}" &
|
113
|
+
##### End Ray configuration #####
|
114
|
+
|
115
|
+
# TODO: Monitor MLRS and handle process crashes
|
116
|
+
python -m web.ml_runtime_grpc_server &
|
117
|
+
|
118
|
+
# TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
|
119
|
+
|
120
|
+
# Run user's Python entrypoint
|
121
|
+
echo Running command: python "$@"
|
122
|
+
python "$@"
|
123
|
+
"""
|
124
|
+
).strip()
|
125
|
+
|
126
|
+
|
127
|
+
def _resolve_entrypoint(parent: Path, entrypoint: Optional[Path]) -> Path:
|
128
|
+
parent = parent.absolute()
|
129
|
+
if entrypoint is None:
|
130
|
+
if parent.is_file():
|
131
|
+
# Infer entrypoint from source
|
132
|
+
entrypoint = parent
|
133
|
+
else:
|
134
|
+
raise ValueError("entrypoint must be provided when source is a directory")
|
135
|
+
elif entrypoint.is_absolute():
|
136
|
+
# Absolute path - validate it's a subpath of source dir
|
137
|
+
if not entrypoint.is_relative_to(parent):
|
138
|
+
raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint})")
|
139
|
+
else:
|
140
|
+
# Relative path
|
141
|
+
if (abs_entrypoint := entrypoint.absolute()).is_relative_to(parent) and abs_entrypoint.is_file():
|
142
|
+
# Relative to working dir iff path is relative to source dir and exists
|
143
|
+
entrypoint = abs_entrypoint
|
144
|
+
else:
|
145
|
+
# Relative to source dir
|
146
|
+
entrypoint = parent.joinpath(entrypoint)
|
147
|
+
if not entrypoint.is_file():
|
148
|
+
raise FileNotFoundError(
|
149
|
+
"Entrypoint not found. Ensure the entrypoint is a valid file and is under"
|
150
|
+
f" the source directory (source={parent}, entrypoint={entrypoint})"
|
151
|
+
)
|
152
|
+
return entrypoint
|
153
|
+
|
154
|
+
|
155
|
+
class JobPayload:
|
156
|
+
def __init__(
|
157
|
+
self,
|
158
|
+
source: Union[str, Path, Callable[..., Any]],
|
159
|
+
entrypoint: Optional[Union[str, Path]] = None,
|
160
|
+
*,
|
161
|
+
pip_requirements: Optional[List[str]] = None,
|
162
|
+
) -> None:
|
163
|
+
self.source = Path(source) if isinstance(source, str) else source
|
164
|
+
self.entrypoint = Path(entrypoint) if isinstance(entrypoint, str) else entrypoint
|
165
|
+
self.pip_requirements = pip_requirements
|
166
|
+
|
167
|
+
def validate(self) -> None:
|
168
|
+
if callable(self.source):
|
169
|
+
# Any entrypoint value is OK for callable payloads (including None aka default)
|
170
|
+
# since we will generate the file from the serialized callable
|
171
|
+
pass
|
172
|
+
elif isinstance(self.source, Path):
|
173
|
+
# Validate source
|
174
|
+
source = self.source
|
175
|
+
if not source.exists():
|
176
|
+
raise FileNotFoundError(f"{source} does not exist")
|
177
|
+
source = source.absolute()
|
178
|
+
|
179
|
+
# Validate entrypoint
|
180
|
+
entrypoint = _resolve_entrypoint(source, self.entrypoint)
|
181
|
+
if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
|
182
|
+
raise ValueError(
|
183
|
+
"Unsupported entrypoint type:"
|
184
|
+
f" supported={','.join(_SUPPORTED_ENTRYPOINT_EXTENSIONS)} got={entrypoint.suffix}"
|
185
|
+
)
|
186
|
+
|
187
|
+
# Update fields with normalized values
|
188
|
+
self.source = source
|
189
|
+
self.entrypoint = entrypoint
|
190
|
+
else:
|
191
|
+
raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
|
192
|
+
|
193
|
+
def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
|
194
|
+
# Validate payload
|
195
|
+
self.validate()
|
196
|
+
|
197
|
+
# Prepare local variables
|
198
|
+
stage_path = PurePath(stage_path) if isinstance(stage_path, str) else stage_path
|
199
|
+
source = self.source
|
200
|
+
entrypoint = self.entrypoint or Path(constants.DEFAULT_ENTRYPOINT_PATH)
|
201
|
+
|
202
|
+
# Create stage if necessary
|
203
|
+
stage_name = stage_path.parts[0].lstrip("@")
|
204
|
+
# Explicitly check if stage exists first since we may not have CREATE STAGE privilege
|
205
|
+
try:
|
206
|
+
session.sql(f"describe stage {stage_name}").collect()
|
207
|
+
except sp_exceptions.SnowparkSQLException:
|
208
|
+
session.sql(
|
209
|
+
f"create stage if not exists {stage_name}"
|
210
|
+
" encryption = ( type = 'SNOWFLAKE_SSE' )"
|
211
|
+
" comment = 'Created by snowflake.ml.jobs Python API'"
|
212
|
+
).collect()
|
213
|
+
|
214
|
+
# Upload payload to stage
|
215
|
+
if not isinstance(source, Path):
|
216
|
+
source_code = generate_python_code(source, source_code_display=True)
|
217
|
+
_ = session.file.put_stream(
|
218
|
+
io.BytesIO(source_code.encode()),
|
219
|
+
stage_location=stage_path.joinpath(entrypoint).as_posix(),
|
220
|
+
auto_compress=False,
|
221
|
+
overwrite=True,
|
222
|
+
)
|
223
|
+
source = entrypoint.parent
|
224
|
+
elif source.is_dir():
|
225
|
+
# Manually traverse the directory and upload each file, since Snowflake PUT
|
226
|
+
# can't handle directories. Reduce the number of PUT operations by using
|
227
|
+
# wildcard patterns to batch upload files with the same extension.
|
228
|
+
for path in {
|
229
|
+
p.parent.joinpath(f"*{p.suffix}") if p.suffix else p for p in source.resolve().rglob("*") if p.is_file()
|
230
|
+
}:
|
231
|
+
session.file.put(
|
232
|
+
str(path),
|
233
|
+
stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
|
234
|
+
overwrite=True,
|
235
|
+
auto_compress=False,
|
236
|
+
)
|
237
|
+
else:
|
238
|
+
session.file.put(
|
239
|
+
str(source.resolve()),
|
240
|
+
stage_path.as_posix(),
|
241
|
+
overwrite=True,
|
242
|
+
auto_compress=False,
|
243
|
+
)
|
244
|
+
source = source.parent
|
245
|
+
|
246
|
+
# Upload requirements
|
247
|
+
# TODO: Check if payload includes both a requirements.txt file and pip_requirements
|
248
|
+
if self.pip_requirements:
|
249
|
+
# Upload requirements.txt to stage
|
250
|
+
session.file.put_stream(
|
251
|
+
io.BytesIO("\n".join(self.pip_requirements).encode()),
|
252
|
+
stage_location=stage_path.joinpath("requirements.txt").as_posix(),
|
253
|
+
auto_compress=False,
|
254
|
+
overwrite=True,
|
255
|
+
)
|
256
|
+
|
257
|
+
# Upload startup script
|
258
|
+
# TODO: Make sure payload does not include file with same name
|
259
|
+
session.file.put_stream(
|
260
|
+
io.BytesIO(_STARTUP_SCRIPT_CODE.encode()),
|
261
|
+
stage_location=stage_path.joinpath(_STARTUP_SCRIPT_PATH).as_posix(),
|
262
|
+
auto_compress=False,
|
263
|
+
overwrite=False, # FIXME
|
264
|
+
)
|
265
|
+
|
266
|
+
return types.UploadedPayload(
|
267
|
+
stage_path=stage_path,
|
268
|
+
entrypoint=[
|
269
|
+
"bash",
|
270
|
+
_STARTUP_SCRIPT_PATH,
|
271
|
+
entrypoint.relative_to(source),
|
272
|
+
],
|
273
|
+
)
|
274
|
+
|
275
|
+
|
276
|
+
def _get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
|
277
|
+
# Unwrap Optional type annotations
|
278
|
+
param_type = param.annotation
|
279
|
+
if get_origin(param_type) is Union and len(get_args(param_type)) == 2 and type(None) in get_args(param_type):
|
280
|
+
param_type = next(t for t in get_args(param_type) if t is not type(None))
|
281
|
+
|
282
|
+
# Return None for empty type annotations
|
283
|
+
if param_type == inspect.Parameter.empty:
|
284
|
+
return None
|
285
|
+
return cast(Type[object], param_type)
|
286
|
+
|
287
|
+
|
288
|
+
def _validate_parameter_type(param_type: Type[object], param_name: str) -> None:
|
289
|
+
# Validate param_type is a supported type
|
290
|
+
if param_type not in _SUPPORTED_ARG_TYPES:
|
291
|
+
raise ValueError(
|
292
|
+
f"Unsupported argument type {param_type} for '{param_name}'."
|
293
|
+
f" Supported types: {', '.join(t.__name__ for t in _SUPPORTED_ARG_TYPES)}"
|
294
|
+
)
|
295
|
+
|
296
|
+
|
297
|
+
def _generate_source_code_comment(func: Callable[..., Any]) -> str:
|
298
|
+
"""Generate a comment string containing the source code of a function for readability."""
|
299
|
+
try:
|
300
|
+
if isinstance(func, functools.partial):
|
301
|
+
# Unwrap functools.partial and generate source code comment from the original function
|
302
|
+
comment = code_generation.generate_source_code(func.func) # type: ignore[arg-type]
|
303
|
+
args = itertools.chain((repr(a) for a in func.args), (f"{k}={v!r}" for k, v in func.keywords.items()))
|
304
|
+
|
305
|
+
# Update invocation comment to show arguments passed via functools.partial
|
306
|
+
comment = comment.replace(
|
307
|
+
f"= {func.func.__name__}",
|
308
|
+
"= functools.partial({}({}))".format(
|
309
|
+
func.func.__name__,
|
310
|
+
", ".join(args),
|
311
|
+
),
|
312
|
+
)
|
313
|
+
return comment
|
314
|
+
else:
|
315
|
+
return code_generation.generate_source_code(func) # type: ignore[arg-type]
|
316
|
+
except Exception as exc:
|
317
|
+
error_msg = f"Source code comment could not be generated for {func} due to error {exc}."
|
318
|
+
return code_generation.comment_source_code(error_msg)
|
319
|
+
|
320
|
+
|
321
|
+
def _serialize_callable(func: Callable[..., Any]) -> bytes:
|
322
|
+
try:
|
323
|
+
func_bytes: bytes = cp.dumps(func)
|
324
|
+
return func_bytes
|
325
|
+
except pickle.PicklingError as e:
|
326
|
+
if isinstance(func, functools.partial):
|
327
|
+
# Try to find which part of the partial isn't serializable for better debuggability
|
328
|
+
objects = [
|
329
|
+
("function", func.func),
|
330
|
+
*((f"positional arg {i}", a) for i, a in enumerate(func.args)),
|
331
|
+
*((f"keyword arg '{k}'", v) for k, v in func.keywords.items()),
|
332
|
+
]
|
333
|
+
for name, obj in objects:
|
334
|
+
try:
|
335
|
+
cp.dumps(obj)
|
336
|
+
except pickle.PicklingError:
|
337
|
+
raise ValueError(f"Unable to serialize {name}: {obj}") from e
|
338
|
+
raise ValueError(f"Unable to serialize function: {func}") from e
|
339
|
+
|
340
|
+
|
341
|
+
def _generate_param_handler_code(signature: inspect.Signature, output_name: str = "kwargs") -> str:
|
342
|
+
# Generate argparse logic for argument handling (type coercion, default values, etc)
|
343
|
+
argparse_code = ["import argparse", "", "parser = argparse.ArgumentParser()"]
|
344
|
+
argparse_postproc = []
|
345
|
+
for name, param in signature.parameters.items():
|
346
|
+
opts = {}
|
347
|
+
|
348
|
+
param_type = _get_parameter_type(param)
|
349
|
+
if param_type is not None:
|
350
|
+
_validate_parameter_type(param_type, name)
|
351
|
+
opts["type"] = param_type.__name__
|
352
|
+
|
353
|
+
if param.default != inspect.Parameter.empty:
|
354
|
+
opts["default"] = f"'{param.default}'" if isinstance(param.default, str) else param.default
|
355
|
+
|
356
|
+
if param.kind == inspect.Parameter.KEYWORD_ONLY:
|
357
|
+
# Keyword argument
|
358
|
+
argparse_code.append(
|
359
|
+
f"parser.add_argument('--{name}', required={'default' not in opts},"
|
360
|
+
f" {', '.join(f'{k}={v}' for k, v in opts.items())})"
|
361
|
+
)
|
362
|
+
else:
|
363
|
+
# Positional argument. Use `argparse.add_mutually_exclusive_group()`
|
364
|
+
# to allow passing positional args by name as well
|
365
|
+
group_name = f"{name}_group"
|
366
|
+
argparse_code.append(
|
367
|
+
f"{group_name} = parser.add_mutually_exclusive_group(required={'default' not in opts})"
|
368
|
+
)
|
369
|
+
argparse_code.append(
|
370
|
+
f"{group_name}.add_argument('pos-{name}', metavar='{name}', nargs='?',"
|
371
|
+
f" {', '.join(f'{k}={v}' for k, v in opts.items() if k != 'default')})"
|
372
|
+
)
|
373
|
+
argparse_code.append(
|
374
|
+
f"{group_name}.add_argument('--{name}', {', '.join(f'{k}={v}' for k, v in opts.items())})"
|
375
|
+
)
|
376
|
+
argparse_code.append("") # Add newline for readability
|
377
|
+
argparse_postproc.append(
|
378
|
+
f"args.{name} = {name} if ({name} := args.__dict__.pop('pos-{name}')) is not None else args.{name}"
|
379
|
+
)
|
380
|
+
argparse_code.append("args = parser.parse_args()")
|
381
|
+
param_code = "\n".join(argparse_code + argparse_postproc)
|
382
|
+
param_code += f"\n{output_name} = vars(args)"
|
383
|
+
|
384
|
+
return param_code
|
385
|
+
|
386
|
+
|
387
|
+
def generate_python_code(func: Callable[..., Any], source_code_display: bool = False) -> str:
|
388
|
+
"""Generate an entrypoint script from a Python function."""
|
389
|
+
signature = inspect.signature(func)
|
390
|
+
if any(
|
391
|
+
p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
392
|
+
for p in signature.parameters.values()
|
393
|
+
):
|
394
|
+
raise NotImplementedError("Function must not have unpacking arguments (* or **)")
|
395
|
+
|
396
|
+
# Mirrored from Snowpark generate_python_code() function
|
397
|
+
# https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
|
398
|
+
source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
|
399
|
+
|
400
|
+
func_name = "func"
|
401
|
+
func_code = f"""
|
402
|
+
{source_code_comment}
|
403
|
+
|
404
|
+
import pickle
|
405
|
+
{func_name} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
|
406
|
+
"""
|
407
|
+
|
408
|
+
arg_dict_name = "kwargs"
|
409
|
+
if getattr(func, constants.IS_MLJOB_REMOTE_ATTR, None):
|
410
|
+
param_code = f"{arg_dict_name} = {{}}"
|
411
|
+
else:
|
412
|
+
param_code = _generate_param_handler_code(signature, arg_dict_name)
|
413
|
+
|
414
|
+
return f"""
|
415
|
+
### Version guard to check compatibility across Python versions ###
|
416
|
+
import sys
|
417
|
+
import warnings
|
418
|
+
|
419
|
+
if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
|
420
|
+
warnings.warn(
|
421
|
+
"Python version mismatch: job was created using"
|
422
|
+
" python{sys.version_info.major}.{sys.version_info.minor}"
|
423
|
+
f" but runtime environment uses python{{sys.version_info.major}}.{{sys.version_info.minor}}."
|
424
|
+
" Compatibility across Python versions is not guaranteed and may result in unexpected behavior."
|
425
|
+
" This will be fixed in a future release; for now, please use Python version"
|
426
|
+
f" {{sys.version_info.major}}.{{sys.version_info.minor}}.",
|
427
|
+
RuntimeWarning,
|
428
|
+
stacklevel=0,
|
429
|
+
)
|
430
|
+
### End version guard ###
|
431
|
+
|
432
|
+
{func_code.strip()}
|
433
|
+
|
434
|
+
if __name__ == '__main__':
|
435
|
+
{textwrap.indent(param_code, ' ')}
|
436
|
+
|
437
|
+
{func_name}(**{arg_dict_name})
|
438
|
+
"""
|