snowflake-ml-python 1.7.3__py3-none-any.whl → 1.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +19 -0
- snowflake/ml/_internal/platform_capabilities.py +87 -0
- snowflake/ml/dataset/dataset.py +0 -1
- snowflake/ml/fileset/fileset.py +6 -0
- snowflake/ml/jobs/__init__.py +21 -0
- snowflake/ml/jobs/_utils/constants.py +51 -0
- snowflake/ml/jobs/_utils/payload_utils.py +352 -0
- snowflake/ml/jobs/_utils/spec_utils.py +298 -0
- snowflake/ml/jobs/_utils/types.py +39 -0
- snowflake/ml/jobs/decorators.py +91 -0
- snowflake/ml/jobs/job.py +113 -0
- snowflake/ml/jobs/manager.py +298 -0
- snowflake/ml/model/_client/ops/model_ops.py +11 -2
- snowflake/ml/model/_client/ops/service_ops.py +1 -11
- snowflake/ml/model/_client/sql/service.py +13 -6
- snowflake/ml/model/_packager/model_handlers/_utils.py +12 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_signatures/base_handler.py +1 -2
- snowflake/ml/model/_signatures/builtins_handler.py +2 -2
- snowflake/ml/model/_signatures/numpy_handler.py +6 -7
- snowflake/ml/model/_signatures/pandas_handler.py +2 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -5
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
- snowflake/ml/model/model_signature.py +17 -4
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
- snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
- snowflake/ml/modeling/cluster/birch.py +6 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
- snowflake/ml/modeling/cluster/dbscan.py +6 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
- snowflake/ml/modeling/cluster/k_means.py +6 -3
- snowflake/ml/modeling/cluster/mean_shift.py +6 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
- snowflake/ml/modeling/cluster/optics.py +6 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
- snowflake/ml/modeling/compose/column_transformer.py +6 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
- snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
- snowflake/ml/modeling/covariance/oas.py +6 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/pca.py +6 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
- snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
- snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
- snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
- snowflake/ml/modeling/impute/knn_imputer.py +6 -3
- snowflake/ml/modeling/impute/missing_indicator.py +6 -3
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/lars.py +6 -3
- snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
- snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/perceptron.py +6 -3
- snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ridge.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
- snowflake/ml/modeling/manifold/isomap.py +6 -3
- snowflake/ml/modeling/manifold/mds.py +6 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
- snowflake/ml/modeling/manifold/tsne.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
- snowflake/ml/modeling/pipeline/pipeline.py +10 -2
- snowflake/ml/modeling/preprocessing/polynomial_features.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
- snowflake/ml/modeling/svm/linear_svc.py +6 -3
- snowflake/ml/modeling/svm/linear_svr.py +6 -3
- snowflake/ml/modeling/svm/nu_svc.py +6 -3
- snowflake/ml/modeling/svm/nu_svr.py +6 -3
- snowflake/ml/modeling/svm/svc.py +6 -3
- snowflake/ml/modeling/svm/svr.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_classifier.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +6 -3
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +6 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/METADATA +29 -14
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/RECORD +187 -178
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/top_level.txt +0 -0
snowflake/cortex/_complete.py
CHANGED
@@ -49,6 +49,10 @@ class CompleteOptions(TypedDict):
|
|
49
49
|
generally used as an alternative to temperature. The difference is that top_p restricts the set of possible tokens
|
50
50
|
that the model outputs, while temperature influences which tokens are chosen at each step. """
|
51
51
|
|
52
|
+
guardrails: NotRequired[bool]
|
53
|
+
""" A boolean value that controls whether Cortex Guard filters unsafe or harmful responses
|
54
|
+
from the language model. """
|
55
|
+
|
52
56
|
|
53
57
|
class ResponseParseException(Exception):
|
54
58
|
"""This exception is raised when the server response cannot be parsed."""
|
@@ -56,6 +60,15 @@ class ResponseParseException(Exception):
|
|
56
60
|
pass
|
57
61
|
|
58
62
|
|
63
|
+
class GuardrailsOptions(TypedDict):
|
64
|
+
enabled: bool
|
65
|
+
"""A boolean value that controls whether Cortex Guard filters unsafe or harmful responses
|
66
|
+
from the language model."""
|
67
|
+
|
68
|
+
response_when_unsafe: str
|
69
|
+
"""The response to return when the language model generates unsafe or harmful content."""
|
70
|
+
|
71
|
+
|
59
72
|
_MAX_RETRY_SECONDS = 30
|
60
73
|
|
61
74
|
|
@@ -117,6 +130,12 @@ def _make_request_body(
|
|
117
130
|
data["temperature"] = options["temperature"]
|
118
131
|
if "top_p" in options:
|
119
132
|
data["top_p"] = options["top_p"]
|
133
|
+
if "guardrails" in options and options["guardrails"]:
|
134
|
+
guardrails_options: GuardrailsOptions = {
|
135
|
+
"enabled": True,
|
136
|
+
"response_when_unsafe": "Response filtered by Cortex Guard",
|
137
|
+
}
|
138
|
+
data["guardrails"] = guardrails_options
|
120
139
|
return data
|
121
140
|
|
122
141
|
|
@@ -0,0 +1,87 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Any, Dict, Optional
|
3
|
+
|
4
|
+
from absl import logging
|
5
|
+
|
6
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
7
|
+
from snowflake.ml._internal.utils import query_result_checker
|
8
|
+
from snowflake.snowpark import (
|
9
|
+
exceptions as snowpark_exceptions,
|
10
|
+
session as snowpark_session,
|
11
|
+
)
|
12
|
+
|
13
|
+
|
14
|
+
class PlatformCapabilities:
|
15
|
+
"""Class that retrieves platform feature values for the currently running server.
|
16
|
+
|
17
|
+
Example usage:
|
18
|
+
```
|
19
|
+
pc = PlatformCapabilities.get_instance(session)
|
20
|
+
if pc.is_nested_function_enabled():
|
21
|
+
# Nested functions are enabled.
|
22
|
+
print("Nested functions are enabled.")
|
23
|
+
else:
|
24
|
+
# Nested functions are disabled.
|
25
|
+
print("Nested functions are disabled or not supported.")
|
26
|
+
```
|
27
|
+
"""
|
28
|
+
|
29
|
+
_instance: Optional["PlatformCapabilities"] = None
|
30
|
+
|
31
|
+
@classmethod
|
32
|
+
def get_instance(cls, session: Optional[snowpark_session.Session] = None) -> "PlatformCapabilities":
|
33
|
+
if not cls._instance:
|
34
|
+
cls._instance = cls(session)
|
35
|
+
return cls._instance
|
36
|
+
|
37
|
+
def is_nested_function_enabled(self) -> bool:
|
38
|
+
return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
|
39
|
+
|
40
|
+
@staticmethod
|
41
|
+
def _get_features(session: snowpark_session.Session) -> Dict[str, Any]:
|
42
|
+
try:
|
43
|
+
result = (
|
44
|
+
query_result_checker.SqlResultValidator(
|
45
|
+
session=session,
|
46
|
+
query="SELECT SYSTEM$ML_PLATFORM_CAPABILITIES() AS FEATURES;",
|
47
|
+
)
|
48
|
+
.has_dimensions(expected_rows=1, expected_cols=1)
|
49
|
+
.has_column("FEATURES")
|
50
|
+
.validate()[0]
|
51
|
+
)
|
52
|
+
if "FEATURES" in result:
|
53
|
+
capabilities_json: str = result["FEATURES"]
|
54
|
+
try:
|
55
|
+
parsed_json = json.loads(capabilities_json)
|
56
|
+
assert isinstance(parsed_json, dict), f"Expected JSON object, got {type(parsed_json)}"
|
57
|
+
return parsed_json
|
58
|
+
except json.JSONDecodeError as e:
|
59
|
+
message = f"""Unable to parse JSON from: "{capabilities_json}"; Error="{e}"."""
|
60
|
+
raise exceptions.SnowflakeMLException(
|
61
|
+
error_code=error_codes.INTERNAL_SNOWML_ERROR, original_exception=RuntimeError(message)
|
62
|
+
)
|
63
|
+
except snowpark_exceptions.SnowparkSQLException as e:
|
64
|
+
logging.debug(f"Failed to retrieve platform capabilities: {e}")
|
65
|
+
# This can happen is server side is older than 9.2. That is fine.
|
66
|
+
return {}
|
67
|
+
|
68
|
+
def __init__(self, session: Optional[snowpark_session.Session] = None) -> None:
|
69
|
+
if not session:
|
70
|
+
session = next(iter(snowpark_session._get_active_sessions()))
|
71
|
+
assert session, "Missing active session object"
|
72
|
+
self.features: Dict[str, Any] = PlatformCapabilities._get_features(session)
|
73
|
+
|
74
|
+
def _get_bool_feature(self, feature_name: str, default_value: bool) -> bool:
|
75
|
+
value = self.features.get(feature_name, default_value)
|
76
|
+
if isinstance(value, bool):
|
77
|
+
return value
|
78
|
+
if isinstance(value, int) and value in [0, 1]:
|
79
|
+
return value == 1
|
80
|
+
if isinstance(value, str):
|
81
|
+
if value.lower() in ["true", "1"]:
|
82
|
+
return True
|
83
|
+
elif value.lower() in ["false", "0"]:
|
84
|
+
return False
|
85
|
+
else:
|
86
|
+
raise ValueError(f"Invalid boolean string: {value} for feature {feature_name}")
|
87
|
+
raise ValueError(f"Invalid boolean feature value: {value} for feature {feature_name}")
|
snowflake/ml/dataset/dataset.py
CHANGED
@@ -419,7 +419,6 @@ class Dataset(lineage_node.LineageNode):
|
|
419
419
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
420
420
|
def delete(self) -> None:
|
421
421
|
"""Delete Dataset and all contained versions"""
|
422
|
-
# TODO: Check and warn if any versions exist
|
423
422
|
self._session.sql(f"DROP DATASET {self.fully_qualified_name}").collect(
|
424
423
|
statement_params=_TELEMETRY_STATEMENT_PARAMS
|
425
424
|
)
|
snowflake/ml/fileset/fileset.py
CHANGED
@@ -2,6 +2,8 @@ import functools
|
|
2
2
|
import inspect
|
3
3
|
from typing import Any, Callable, List, Optional
|
4
4
|
|
5
|
+
from typing_extensions import deprecated
|
6
|
+
|
5
7
|
from snowflake import snowpark
|
6
8
|
from snowflake.connector import connection
|
7
9
|
from snowflake.ml._internal import telemetry
|
@@ -42,6 +44,10 @@ def _raise_if_deleted(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
42
44
|
return raise_if_deleted_helper
|
43
45
|
|
44
46
|
|
47
|
+
@deprecated(
|
48
|
+
"FileSet is deprecated and will be removed in a future release."
|
49
|
+
" Use snowflake.ml.dataset.Dataset and snowflake.ml.data.DataConnector instead"
|
50
|
+
)
|
45
51
|
class FileSet:
|
46
52
|
"""A FileSet represents an immutable snapshot of the result of a query in the form of files."""
|
47
53
|
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from snowflake.ml.jobs._utils.types import JOB_STATUS
|
2
|
+
from snowflake.ml.jobs.decorators import remote
|
3
|
+
from snowflake.ml.jobs.job import MLJob
|
4
|
+
from snowflake.ml.jobs.manager import (
|
5
|
+
delete_job,
|
6
|
+
get_job,
|
7
|
+
list_jobs,
|
8
|
+
submit_directory,
|
9
|
+
submit_file,
|
10
|
+
)
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
"remote",
|
14
|
+
"submit_file",
|
15
|
+
"submit_directory",
|
16
|
+
"list_jobs",
|
17
|
+
"get_job",
|
18
|
+
"delete_job",
|
19
|
+
"MLJob",
|
20
|
+
"JOB_STATUS",
|
21
|
+
]
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from snowflake.ml._internal.utils.snowflake_env import SnowflakeCloudType
|
2
|
+
from snowflake.ml.jobs._utils.types import ComputeResources
|
3
|
+
|
4
|
+
# SPCS specification constants
|
5
|
+
DEFAULT_CONTAINER_NAME = "main"
|
6
|
+
PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
7
|
+
|
8
|
+
# Default container image information
|
9
|
+
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
10
|
+
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
11
|
+
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
12
|
+
DEFAULT_IMAGE_TAG = "0.8.0"
|
13
|
+
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
14
|
+
|
15
|
+
# Percent of container memory to allocate for /dev/shm volume
|
16
|
+
MEMORY_VOLUME_SIZE = 0.3
|
17
|
+
|
18
|
+
# Job status polling constants
|
19
|
+
JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
|
20
|
+
JOB_POLL_MAX_DELAY_SECONDS = 1
|
21
|
+
|
22
|
+
# Compute pool resource information
|
23
|
+
# TODO: Query Snowflake for resource information instead of relying on this hardcoded
|
24
|
+
# table from https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool
|
25
|
+
COMMON_INSTANCE_FAMILIES = {
|
26
|
+
"CPU_X64_XS": ComputeResources(cpu=1, memory=6),
|
27
|
+
"CPU_X64_S": ComputeResources(cpu=3, memory=13),
|
28
|
+
"CPU_X64_M": ComputeResources(cpu=6, memory=28),
|
29
|
+
"CPU_X64_L": ComputeResources(cpu=28, memory=116),
|
30
|
+
"HIGHMEM_X64_S": ComputeResources(cpu=6, memory=58),
|
31
|
+
}
|
32
|
+
AWS_INSTANCE_FAMILIES = {
|
33
|
+
"HIGHMEM_X64_M": ComputeResources(cpu=28, memory=240),
|
34
|
+
"HIGHMEM_X64_L": ComputeResources(cpu=124, memory=984),
|
35
|
+
"GPU_NV_S": ComputeResources(cpu=6, memory=27, gpu=1, gpu_type="A10G"),
|
36
|
+
"GPU_NV_M": ComputeResources(cpu=44, memory=178, gpu=4, gpu_type="A10G"),
|
37
|
+
"GPU_NV_L": ComputeResources(cpu=92, memory=1112, gpu=8, gpu_type="A100"),
|
38
|
+
}
|
39
|
+
AZURE_INSTANCE_FAMILIES = {
|
40
|
+
"HIGHMEM_X64_M": ComputeResources(cpu=28, memory=244),
|
41
|
+
"HIGHMEM_X64_L": ComputeResources(cpu=92, memory=654),
|
42
|
+
"GPU_NV_XS": ComputeResources(cpu=3, memory=26, gpu=1, gpu_type="T4"),
|
43
|
+
"GPU_NV_SM": ComputeResources(cpu=32, memory=424, gpu=1, gpu_type="A10"),
|
44
|
+
"GPU_NV_2M": ComputeResources(cpu=68, memory=858, gpu=2, gpu_type="A10"),
|
45
|
+
"GPU_NV_3M": ComputeResources(cpu=44, memory=424, gpu=2, gpu_type="A100"),
|
46
|
+
"GPU_NV_SL": ComputeResources(cpu=92, memory=858, gpu=4, gpu_type="A100"),
|
47
|
+
}
|
48
|
+
CLOUD_INSTANCE_FAMILIES = {
|
49
|
+
SnowflakeCloudType.AWS: AWS_INSTANCE_FAMILIES,
|
50
|
+
SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES,
|
51
|
+
}
|
@@ -0,0 +1,352 @@
|
|
1
|
+
import inspect
|
2
|
+
import io
|
3
|
+
import sys
|
4
|
+
import textwrap
|
5
|
+
from pathlib import Path, PurePath
|
6
|
+
from typing import (
|
7
|
+
Any,
|
8
|
+
Callable,
|
9
|
+
List,
|
10
|
+
Optional,
|
11
|
+
Type,
|
12
|
+
Union,
|
13
|
+
cast,
|
14
|
+
get_args,
|
15
|
+
get_origin,
|
16
|
+
)
|
17
|
+
|
18
|
+
import cloudpickle as cp
|
19
|
+
|
20
|
+
from snowflake import snowpark
|
21
|
+
from snowflake.ml.jobs._utils import constants, types
|
22
|
+
from snowflake.snowpark._internal import code_generation
|
23
|
+
|
24
|
+
_SUPPORTED_ARG_TYPES = {str, int, float}
|
25
|
+
_STARTUP_SCRIPT_PATH = PurePath("startup.sh")
|
26
|
+
_STARTUP_SCRIPT_CODE = textwrap.dedent(
|
27
|
+
f"""
|
28
|
+
#!/bin/bash
|
29
|
+
|
30
|
+
##### Perform common set up steps #####
|
31
|
+
set -e # exit if a command fails
|
32
|
+
|
33
|
+
echo "Creating log directories..."
|
34
|
+
mkdir -p /var/log/managedservices/user/mlrs
|
35
|
+
mkdir -p /var/log/managedservices/system/mlrs
|
36
|
+
mkdir -p /var/log/managedservices/system/ray
|
37
|
+
|
38
|
+
echo "*/1 * * * * root /etc/ray_copy_cron.sh" >> /etc/cron.d/ray_copy_cron
|
39
|
+
echo "" >> /etc/cron.d/ray_copy_cron
|
40
|
+
chmod 744 /etc/cron.d/ray_copy_cron
|
41
|
+
|
42
|
+
service cron start
|
43
|
+
|
44
|
+
mkdir -p /tmp/prometheus-multi-dir
|
45
|
+
|
46
|
+
# Change directory to user payload directory
|
47
|
+
if [ -n "${constants.PAYLOAD_DIR_ENV_VAR}" ]; then
|
48
|
+
cd ${constants.PAYLOAD_DIR_ENV_VAR}
|
49
|
+
fi
|
50
|
+
|
51
|
+
##### Set up Python environment #####
|
52
|
+
export PYTHONPATH=/opt/env/site-packages/
|
53
|
+
MLRS_REQUIREMENTS_FILE=${{MLRS_REQUIREMENTS_FILE:-"requirements.txt"}}
|
54
|
+
if [ -f "${{MLRS_REQUIREMENTS_FILE}}" ]; then
|
55
|
+
# TODO: Prevent collisions with MLRS packages using virtualenvs
|
56
|
+
echo "Installing packages from $MLRS_REQUIREMENTS_FILE"
|
57
|
+
pip install -r $MLRS_REQUIREMENTS_FILE
|
58
|
+
fi
|
59
|
+
|
60
|
+
MLRS_CONDA_ENV_FILE=${{MLRS_CONDA_ENV_FILE:-"environment.yml"}}
|
61
|
+
if [ -f "${{MLRS_CONDA_ENV_FILE}}" ]; then
|
62
|
+
# TODO: Handle conda environment
|
63
|
+
echo "Custom conda environments not currently supported"
|
64
|
+
exit 1
|
65
|
+
fi
|
66
|
+
##### End Python environment setup #####
|
67
|
+
|
68
|
+
##### Ray configuration #####
|
69
|
+
shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
|
70
|
+
|
71
|
+
# Configure IP address and logging directory
|
72
|
+
eth0Ip=$(ifconfig eth0 | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
|
73
|
+
log_dir="/tmp/ray"
|
74
|
+
|
75
|
+
# Check if eth0Ip is empty and set default if necessary
|
76
|
+
if [ -z "$eth0Ip" ]; then
|
77
|
+
# This should never happen, but just in case ethOIp is not set, we should default to localhost
|
78
|
+
eth0Ip="127.0.0.1"
|
79
|
+
fi
|
80
|
+
|
81
|
+
# Common parameters for both head and worker nodes
|
82
|
+
common_params=(
|
83
|
+
"--node-ip-address=$eth0Ip"
|
84
|
+
"--object-manager-port=${{RAY_OBJECT_MANAGER_PORT:-12011}}"
|
85
|
+
"--node-manager-port=${{RAY_NODE_MANAGER_PORT:-12012}}"
|
86
|
+
"--runtime-env-agent-port=${{RAY_RUNTIME_ENV_AGENT_PORT:-12013}}"
|
87
|
+
"--dashboard-agent-grpc-port=${{RAY_DASHBOARD_AGENT_GRPC_PORT:-12014}}"
|
88
|
+
"--dashboard-agent-listen-port=${{RAY_DASHBOARD_AGENT_LISTEN_PORT:-12015}}"
|
89
|
+
"--min-worker-port=${{RAY_MIN_WORKER_PORT:-12031}}"
|
90
|
+
"--max-worker-port=${{RAY_MAX_WORKER_PORT:-13000}}"
|
91
|
+
"--metrics-export-port=11502"
|
92
|
+
"--temp-dir=$log_dir"
|
93
|
+
"--disable-usage-stats"
|
94
|
+
)
|
95
|
+
|
96
|
+
# Additional head-specific parameters
|
97
|
+
head_params=(
|
98
|
+
"--head"
|
99
|
+
"--port=${{RAY_HEAD_GCS_PORT:-12001}}" # Port of Ray (GCS server)
|
100
|
+
"--ray-client-server-port=${{RAY_HEAD_CLIENT_SERVER_PORT:-10001}}" # Listening port for Ray Client Server
|
101
|
+
"--dashboard-host=${{NODE_IP_ADDRESS}}" # Host to bind the dashboard server
|
102
|
+
"--dashboard-grpc-port=${{RAY_HEAD_DASHBOARD_GRPC_PORT:-12002}}" # Dashboard head to listen for grpc on
|
103
|
+
"--dashboard-port=${{DASHBOARD_PORT}}" # Port to bind the dashboard server for local debugging
|
104
|
+
"--resources={{\\"node_tag:head\\":1}}" # Resource tag for selecting head as coordinator
|
105
|
+
)
|
106
|
+
|
107
|
+
# Start Ray on the head node
|
108
|
+
ray start "${{common_params[@]}}" "${{head_params[@]}}" &
|
109
|
+
##### End Ray configuration #####
|
110
|
+
|
111
|
+
# TODO: Monitor MLRS and handle process crashes
|
112
|
+
python -m web.ml_runtime_grpc_server &
|
113
|
+
|
114
|
+
# TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
|
115
|
+
|
116
|
+
# Run user's Python entrypoint
|
117
|
+
echo Running command: python "$@"
|
118
|
+
python "$@"
|
119
|
+
"""
|
120
|
+
).strip()
|
121
|
+
|
122
|
+
|
123
|
+
class JobPayload:
|
124
|
+
def __init__(
|
125
|
+
self,
|
126
|
+
source: Union[str, Path, Callable[..., Any]],
|
127
|
+
entrypoint: Optional[Union[str, Path]] = None,
|
128
|
+
*,
|
129
|
+
pip_requirements: Optional[List[str]] = None,
|
130
|
+
) -> None:
|
131
|
+
self.source = Path(source) if isinstance(source, str) else source
|
132
|
+
self.entrypoint = Path(entrypoint) if isinstance(entrypoint, str) else entrypoint
|
133
|
+
self.pip_requirements = pip_requirements
|
134
|
+
|
135
|
+
def validate(self) -> None:
|
136
|
+
if callable(self.source):
|
137
|
+
# Any entrypoint value is OK for callable payloads (including None aka default)
|
138
|
+
# since we will generate the file from the serialized callable
|
139
|
+
pass
|
140
|
+
elif isinstance(self.source, Path):
|
141
|
+
# Validate self.source and self.entrypoint for files
|
142
|
+
if not self.source.exists():
|
143
|
+
raise FileNotFoundError(f"{self.source} does not exist")
|
144
|
+
if self.entrypoint is None:
|
145
|
+
if self.source.is_file():
|
146
|
+
self.entrypoint = self.source
|
147
|
+
else:
|
148
|
+
raise ValueError("entrypoint must be provided when source is a directory")
|
149
|
+
if not self.entrypoint.is_file():
|
150
|
+
# Check if self.entrypoint is a valid relative path
|
151
|
+
self.entrypoint = self.source.joinpath(self.entrypoint)
|
152
|
+
if not self.entrypoint.is_file():
|
153
|
+
raise FileNotFoundError(f"File {self.entrypoint} does not exist")
|
154
|
+
if not self.entrypoint.is_relative_to(self.source):
|
155
|
+
raise ValueError(f"{self.entrypoint} must be a subpath of {self.source}")
|
156
|
+
if self.entrypoint.suffix != ".py":
|
157
|
+
raise NotImplementedError("Only Python entrypoints are supported currently")
|
158
|
+
else:
|
159
|
+
raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
|
160
|
+
|
161
|
+
def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
|
162
|
+
# Validate payload
|
163
|
+
self.validate()
|
164
|
+
|
165
|
+
# Prepare local variables
|
166
|
+
stage_path = PurePath(stage_path) if isinstance(stage_path, str) else stage_path
|
167
|
+
source = self.source
|
168
|
+
entrypoint = self.entrypoint or Path(constants.DEFAULT_ENTRYPOINT_PATH)
|
169
|
+
|
170
|
+
# Create stage if necessary
|
171
|
+
stage_name = stage_path.parts[0]
|
172
|
+
session.sql(
|
173
|
+
f"create stage if not exists {stage_name.lstrip('@')}"
|
174
|
+
" encryption = ( type = 'SNOWFLAKE_SSE' )"
|
175
|
+
" comment = 'Created by snowflake.ml.jobs Python API'"
|
176
|
+
).collect()
|
177
|
+
|
178
|
+
# Upload payload to stage
|
179
|
+
if not isinstance(source, Path):
|
180
|
+
source_code = generate_python_code(source, source_code_display=True)
|
181
|
+
_ = session.file.put_stream(
|
182
|
+
io.BytesIO(source_code.encode()),
|
183
|
+
stage_location=stage_path.joinpath(entrypoint).as_posix(),
|
184
|
+
auto_compress=False,
|
185
|
+
overwrite=True,
|
186
|
+
)
|
187
|
+
source = entrypoint.parent
|
188
|
+
elif source.is_dir():
|
189
|
+
# Manually traverse the directory and upload each file, since Snowflake PUT
|
190
|
+
# can't handle directories. Reduce the number of PUT operations by using
|
191
|
+
# wildcard patterns to batch upload files with the same extension.
|
192
|
+
for path in {
|
193
|
+
p.parent.joinpath(f"*{p.suffix}") if p.suffix else p for p in source.resolve().rglob("*") if p.is_file()
|
194
|
+
}:
|
195
|
+
session.file.put(
|
196
|
+
str(path),
|
197
|
+
stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
|
198
|
+
overwrite=True,
|
199
|
+
auto_compress=False,
|
200
|
+
)
|
201
|
+
else:
|
202
|
+
session.file.put(
|
203
|
+
str(source.resolve()),
|
204
|
+
stage_path.as_posix(),
|
205
|
+
overwrite=True,
|
206
|
+
auto_compress=False,
|
207
|
+
)
|
208
|
+
source = source.parent
|
209
|
+
|
210
|
+
# Upload requirements
|
211
|
+
# TODO: Check if payload includes both a requirements.txt file and pip_requirements
|
212
|
+
if self.pip_requirements:
|
213
|
+
# Upload requirements.txt to stage
|
214
|
+
session.file.put_stream(
|
215
|
+
io.BytesIO("\n".join(self.pip_requirements).encode()),
|
216
|
+
stage_location=stage_path.joinpath("requirements.txt").as_posix(),
|
217
|
+
auto_compress=False,
|
218
|
+
overwrite=True,
|
219
|
+
)
|
220
|
+
|
221
|
+
# Upload startup script
|
222
|
+
# TODO: Make sure payload does not include file with same name
|
223
|
+
session.file.put_stream(
|
224
|
+
io.BytesIO(_STARTUP_SCRIPT_CODE.encode()),
|
225
|
+
stage_location=stage_path.joinpath(_STARTUP_SCRIPT_PATH).as_posix(),
|
226
|
+
auto_compress=False,
|
227
|
+
overwrite=False, # FIXME
|
228
|
+
)
|
229
|
+
|
230
|
+
return types.UploadedPayload(
|
231
|
+
stage_path=stage_path,
|
232
|
+
entrypoint=[
|
233
|
+
"bash",
|
234
|
+
_STARTUP_SCRIPT_PATH,
|
235
|
+
entrypoint.relative_to(source),
|
236
|
+
],
|
237
|
+
)
|
238
|
+
|
239
|
+
|
240
|
+
def get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
|
241
|
+
# Unwrap Optional type annotations
|
242
|
+
param_type = param.annotation
|
243
|
+
if get_origin(param_type) is Union and len(get_args(param_type)) == 2 and type(None) in get_args(param_type):
|
244
|
+
param_type = next(t for t in get_args(param_type) if t is not type(None))
|
245
|
+
|
246
|
+
# Return None for empty type annotations
|
247
|
+
if param_type == inspect.Parameter.empty:
|
248
|
+
return None
|
249
|
+
return cast(Type[object], param_type)
|
250
|
+
|
251
|
+
|
252
|
+
def validate_parameter_type(param_type: Type[object], param_name: str) -> None:
|
253
|
+
# Validate param_type is a supported type
|
254
|
+
if param_type not in _SUPPORTED_ARG_TYPES:
|
255
|
+
raise ValueError(
|
256
|
+
f"Unsupported argument type {param_type} for '{param_name}'."
|
257
|
+
f" Supported types: {', '.join(t.__name__ for t in _SUPPORTED_ARG_TYPES)}"
|
258
|
+
)
|
259
|
+
|
260
|
+
|
261
|
+
def generate_python_code(func: Callable[..., Any], source_code_display: bool = False) -> str:
|
262
|
+
signature = inspect.signature(func)
|
263
|
+
if any(
|
264
|
+
p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
265
|
+
for p in signature.parameters.values()
|
266
|
+
):
|
267
|
+
raise NotImplementedError("Function must not have unpacking arguments (* or **)")
|
268
|
+
|
269
|
+
# Mirrored from Snowpark generate_python_code() function
|
270
|
+
# https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
|
271
|
+
try:
|
272
|
+
source_code_comment = (
|
273
|
+
code_generation.generate_source_code(func) if source_code_display else "" # type: ignore[arg-type]
|
274
|
+
)
|
275
|
+
except Exception as exc:
|
276
|
+
error_msg = f"Source code comment could not be generated for {func} due to error {exc}."
|
277
|
+
source_code_comment = code_generation.comment_source_code(error_msg)
|
278
|
+
|
279
|
+
func_name = "func"
|
280
|
+
func_code = f"""
|
281
|
+
{source_code_comment}
|
282
|
+
|
283
|
+
import pickle
|
284
|
+
{func_name} = pickle.loads(bytes.fromhex('{cp.dumps(func).hex()}'))
|
285
|
+
"""
|
286
|
+
|
287
|
+
# Generate argparse logic for argument handling (type coercion, default values, etc)
|
288
|
+
argparse_code = ["import argparse", "", "parser = argparse.ArgumentParser()"]
|
289
|
+
argparse_postproc = []
|
290
|
+
for name, param in signature.parameters.items():
|
291
|
+
opts = {}
|
292
|
+
|
293
|
+
param_type = get_parameter_type(param)
|
294
|
+
if param_type is not None:
|
295
|
+
validate_parameter_type(param_type, name)
|
296
|
+
opts["type"] = param_type.__name__
|
297
|
+
|
298
|
+
if param.default != inspect.Parameter.empty:
|
299
|
+
opts["default"] = f"'{param.default}'" if isinstance(param.default, str) else param.default
|
300
|
+
|
301
|
+
if param.kind == inspect.Parameter.KEYWORD_ONLY:
|
302
|
+
# Keyword argument
|
303
|
+
argparse_code.append(
|
304
|
+
f"parser.add_argument('--{name}', required={'default' not in opts},"
|
305
|
+
f" {', '.join(f'{k}={v}' for k, v in opts.items())})"
|
306
|
+
)
|
307
|
+
else:
|
308
|
+
# Positional argument. Use `argparse.add_mutually_exclusive_group()`
|
309
|
+
# to allow passing positional args by name as well
|
310
|
+
group_name = f"{name}_group"
|
311
|
+
argparse_code.append(
|
312
|
+
f"{group_name} = parser.add_mutually_exclusive_group(required={'default' not in opts})"
|
313
|
+
)
|
314
|
+
argparse_code.append(
|
315
|
+
f"{group_name}.add_argument('pos-{name}', metavar='{name}', nargs='?',"
|
316
|
+
f" {', '.join(f'{k}={v}' for k, v in opts.items() if k != 'default')})"
|
317
|
+
)
|
318
|
+
argparse_code.append(
|
319
|
+
f"{group_name}.add_argument('--{name}', {', '.join(f'{k}={v}' for k, v in opts.items())})"
|
320
|
+
)
|
321
|
+
argparse_code.append("") # Add newline for readability
|
322
|
+
argparse_postproc.append(
|
323
|
+
f"args.{name} = {name} if ({name} := args.__dict__.pop('pos-{name}')) is not None else args.{name}"
|
324
|
+
)
|
325
|
+
argparse_code.append("args = parser.parse_args()")
|
326
|
+
param_code = "\n".join(argparse_code + argparse_postproc)
|
327
|
+
|
328
|
+
return f"""
|
329
|
+
### Version guard to check compatibility across Python versions ###
|
330
|
+
import sys
|
331
|
+
import warnings
|
332
|
+
|
333
|
+
if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
|
334
|
+
warnings.warn(
|
335
|
+
"Python version mismatch: job was created using"
|
336
|
+
" python{sys.version_info.major}.{sys.version_info.minor}"
|
337
|
+
f" but runtime environment uses python{{sys.version_info.major}}.{{sys.version_info.minor}}."
|
338
|
+
" Compatibility across Python versions is not guaranteed and may result in unexpected behavior."
|
339
|
+
" This will be fixed in a future release; for now, please use Python version"
|
340
|
+
f" {{sys.version_info.major}}.{{sys.version_info.minor}}.",
|
341
|
+
RuntimeWarning,
|
342
|
+
stacklevel=0,
|
343
|
+
)
|
344
|
+
### End version guard ###
|
345
|
+
|
346
|
+
{func_code.strip()}
|
347
|
+
|
348
|
+
if __name__ == '__main__':
|
349
|
+
{textwrap.indent(param_code, ' ')}
|
350
|
+
|
351
|
+
{func_name}(**vars(args))
|
352
|
+
"""
|