snowflake-ml-python 1.7.2__py3-none-any.whl → 1.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +16 -8
- snowflake/cortex/_classify_text.py +12 -1
- snowflake/cortex/_complete.py +101 -13
- snowflake/cortex/_embed_text_1024.py +9 -2
- snowflake/cortex/_embed_text_768.py +9 -2
- snowflake/cortex/_extract_answer.py +9 -2
- snowflake/cortex/_sentiment.py +9 -2
- snowflake/cortex/_summarize.py +9 -2
- snowflake/cortex/_translate.py +9 -2
- snowflake/ml/_internal/env_utils.py +7 -52
- snowflake/ml/_internal/platform_capabilities.py +87 -0
- snowflake/ml/_internal/utils/identifier.py +4 -2
- snowflake/ml/data/__init__.py +3 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
- snowflake/ml/data/data_connector.py +53 -11
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/data/torch_utils.py +18 -5
- snowflake/ml/dataset/dataset.py +0 -1
- snowflake/ml/feature_store/examples/example_helper.py +2 -1
- snowflake/ml/fileset/fileset.py +24 -18
- snowflake/ml/jobs/__init__.py +21 -0
- snowflake/ml/jobs/_utils/constants.py +51 -0
- snowflake/ml/jobs/_utils/payload_utils.py +352 -0
- snowflake/ml/jobs/_utils/spec_utils.py +298 -0
- snowflake/ml/jobs/_utils/types.py +39 -0
- snowflake/ml/jobs/decorators.py +91 -0
- snowflake/ml/jobs/job.py +113 -0
- snowflake/ml/jobs/manager.py +298 -0
- snowflake/ml/model/_client/model/model_version_impl.py +5 -3
- snowflake/ml/model/_client/ops/model_ops.py +13 -8
- snowflake/ml/model/_client/ops/service_ops.py +1 -11
- snowflake/ml/model/_client/sql/model_version.py +11 -0
- snowflake/ml/model/_client/sql/service.py +13 -6
- snowflake/ml/model/_model_composer/model_composer.py +8 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
- snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +39 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +6 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -10
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
- snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
- snowflake/ml/model/_signatures/base_handler.py +1 -2
- snowflake/ml/model/_signatures/builtins_handler.py +2 -2
- snowflake/ml/model/_signatures/numpy_handler.py +6 -7
- snowflake/ml/model/_signatures/pandas_handler.py +3 -3
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -5
- snowflake/ml/model/_signatures/snowpark_handler.py +11 -5
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
- snowflake/ml/model/model_signature.py +17 -4
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
- snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
- snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
- snowflake/ml/modeling/cluster/birch.py +6 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
- snowflake/ml/modeling/cluster/dbscan.py +6 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
- snowflake/ml/modeling/cluster/k_means.py +6 -3
- snowflake/ml/modeling/cluster/mean_shift.py +6 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
- snowflake/ml/modeling/cluster/optics.py +6 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
- snowflake/ml/modeling/compose/column_transformer.py +6 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
- snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
- snowflake/ml/modeling/covariance/oas.py +6 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/pca.py +6 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
- snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
- snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
- snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
- snowflake/ml/modeling/impute/knn_imputer.py +6 -3
- snowflake/ml/modeling/impute/missing_indicator.py +6 -3
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/lars.py +6 -3
- snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
- snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/perceptron.py +6 -3
- snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ridge.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
- snowflake/ml/modeling/manifold/isomap.py +6 -3
- snowflake/ml/modeling/manifold/mds.py +6 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
- snowflake/ml/modeling/manifold/tsne.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
- snowflake/ml/modeling/pipeline/pipeline.py +16 -178
- snowflake/ml/modeling/preprocessing/polynomial_features.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
- snowflake/ml/modeling/svm/linear_svc.py +6 -3
- snowflake/ml/modeling/svm/linear_svr.py +6 -3
- snowflake/ml/modeling/svm/nu_svc.py +6 -3
- snowflake/ml/modeling/svm/nu_svr.py +6 -3
- snowflake/ml/modeling/svm/svc.py +6 -3
- snowflake/ml/modeling/svm/svr.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_classifier.py +167 -91
- snowflake/ml/modeling/xgboost/xgb_regressor.py +166 -88
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +166 -88
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +166 -88
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +4 -4
- snowflake/ml/registry/_manager/model_manager.py +70 -33
- snowflake/ml/registry/registry.py +41 -22
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/METADATA +63 -19
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/RECORD +231 -226
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/utils/retryable_http.py +0 -39
- snowflake/ml/fileset/parquet_parser.py +0 -170
- snowflake/ml/fileset/tf_dataset.py +0 -88
- snowflake/ml/fileset/torch_datapipe.py +0 -57
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/top_level.txt +0 -0
snowflake/cortex/__init__.py
CHANGED
@@ -1,24 +1,32 @@
|
|
1
|
-
from snowflake.cortex._classify_text import ClassifyText
|
2
|
-
from snowflake.cortex._complete import Complete, CompleteOptions
|
3
|
-
from snowflake.cortex._embed_text_768 import EmbedText768
|
4
|
-
from snowflake.cortex._embed_text_1024 import EmbedText1024
|
5
|
-
from snowflake.cortex._extract_answer import ExtractAnswer
|
1
|
+
from snowflake.cortex._classify_text import ClassifyText, classify_text
|
2
|
+
from snowflake.cortex._complete import Complete, CompleteOptions, complete
|
3
|
+
from snowflake.cortex._embed_text_768 import EmbedText768, embed_text_768
|
4
|
+
from snowflake.cortex._embed_text_1024 import EmbedText1024, embed_text_1024
|
5
|
+
from snowflake.cortex._extract_answer import ExtractAnswer, extract_answer
|
6
6
|
from snowflake.cortex._finetune import Finetune, FinetuneJob, FinetuneStatus
|
7
|
-
from snowflake.cortex._sentiment import Sentiment
|
8
|
-
from snowflake.cortex._summarize import Summarize
|
9
|
-
from snowflake.cortex._translate import Translate
|
7
|
+
from snowflake.cortex._sentiment import Sentiment, sentiment
|
8
|
+
from snowflake.cortex._summarize import Summarize, summarize
|
9
|
+
from snowflake.cortex._translate import Translate, translate
|
10
10
|
|
11
11
|
__all__ = [
|
12
12
|
"ClassifyText",
|
13
|
+
"classify_text",
|
13
14
|
"Complete",
|
15
|
+
"complete",
|
14
16
|
"CompleteOptions",
|
15
17
|
"EmbedText768",
|
18
|
+
"embed_text_768",
|
16
19
|
"EmbedText1024",
|
20
|
+
"embed_text_1024",
|
17
21
|
"ExtractAnswer",
|
22
|
+
"extract_answer",
|
18
23
|
"Finetune",
|
19
24
|
"FinetuneJob",
|
20
25
|
"FinetuneStatus",
|
21
26
|
"Sentiment",
|
27
|
+
"sentiment",
|
22
28
|
"Summarize",
|
29
|
+
"summarize",
|
23
30
|
"Translate",
|
31
|
+
"translate",
|
24
32
|
]
|
@@ -1,5 +1,7 @@
|
|
1
1
|
from typing import List, Optional, Union, cast
|
2
2
|
|
3
|
+
from typing_extensions import deprecated
|
4
|
+
|
3
5
|
from snowflake import snowpark
|
4
6
|
from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
|
5
7
|
from snowflake.ml._internal import telemetry
|
@@ -8,7 +10,7 @@ from snowflake.ml._internal import telemetry
|
|
8
10
|
@telemetry.send_api_usage_telemetry(
|
9
11
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
10
12
|
)
|
11
|
-
def
|
13
|
+
def classify_text(
|
12
14
|
str_input: Union[str, snowpark.Column],
|
13
15
|
categories: Union[List[str], snowpark.Column],
|
14
16
|
session: Optional[snowpark.Session] = None,
|
@@ -34,3 +36,12 @@ def _classify_text_impl(
|
|
34
36
|
session: Optional[snowpark.Session] = None,
|
35
37
|
) -> Union[str, snowpark.Column]:
|
36
38
|
return cast(Union[str, snowpark.Column], call_sql_function(function, session, str_input, categories))
|
39
|
+
|
40
|
+
|
41
|
+
ClassifyText = deprecated(
|
42
|
+
"ClassifyText() is deprecated and will be removed in a future release. Please use classify_text() instead."
|
43
|
+
)(
|
44
|
+
telemetry.send_api_usage_telemetry(
|
45
|
+
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
46
|
+
)(classify_text)
|
47
|
+
)
|
snowflake/cortex/_complete.py
CHANGED
@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, TypedDict, Uni
|
|
6
6
|
from urllib.parse import urlunparse
|
7
7
|
|
8
8
|
import requests
|
9
|
-
from typing_extensions import NotRequired
|
9
|
+
from typing_extensions import NotRequired, deprecated
|
10
10
|
|
11
11
|
from snowflake import snowpark
|
12
12
|
from snowflake.cortex._sse_client import SSEClient
|
@@ -49,6 +49,10 @@ class CompleteOptions(TypedDict):
|
|
49
49
|
generally used as an alternative to temperature. The difference is that top_p restricts the set of possible tokens
|
50
50
|
that the model outputs, while temperature influences which tokens are chosen at each step. """
|
51
51
|
|
52
|
+
guardrails: NotRequired[bool]
|
53
|
+
""" A boolean value that controls whether Cortex Guard filters unsafe or harmful responses
|
54
|
+
from the language model. """
|
55
|
+
|
52
56
|
|
53
57
|
class ResponseParseException(Exception):
|
54
58
|
"""This exception is raised when the server response cannot be parsed."""
|
@@ -56,6 +60,15 @@ class ResponseParseException(Exception):
|
|
56
60
|
pass
|
57
61
|
|
58
62
|
|
63
|
+
class GuardrailsOptions(TypedDict):
|
64
|
+
enabled: bool
|
65
|
+
"""A boolean value that controls whether Cortex Guard filters unsafe or harmful responses
|
66
|
+
from the language model."""
|
67
|
+
|
68
|
+
response_when_unsafe: str
|
69
|
+
"""The response to return when the language model generates unsafe or harmful content."""
|
70
|
+
|
71
|
+
|
59
72
|
_MAX_RETRY_SECONDS = 30
|
60
73
|
|
61
74
|
|
@@ -117,6 +130,12 @@ def _make_request_body(
|
|
117
130
|
data["temperature"] = options["temperature"]
|
118
131
|
if "top_p" in options:
|
119
132
|
data["top_p"] = options["top_p"]
|
133
|
+
if "guardrails" in options and options["guardrails"]:
|
134
|
+
guardrails_options: GuardrailsOptions = {
|
135
|
+
"enabled": True,
|
136
|
+
"response_when_unsafe": "Response filtered by Cortex Guard",
|
137
|
+
}
|
138
|
+
data["guardrails"] = guardrails_options
|
120
139
|
return data
|
121
140
|
|
122
141
|
|
@@ -127,8 +146,26 @@ def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
|
|
127
146
|
response.status_code = int(raw_resp["status"])
|
128
147
|
response.headers = raw_resp["headers"]
|
129
148
|
|
149
|
+
request_id = None
|
150
|
+
for key, value in raw_resp["headers"].items():
|
151
|
+
# Note: there is some whitespace in the headers making it not possible
|
152
|
+
# to directly index the header reliably.
|
153
|
+
if key.strip().lower() == "x-snowflake-request-id":
|
154
|
+
request_id = value
|
155
|
+
break
|
156
|
+
|
130
157
|
data = raw_resp["content"]
|
131
|
-
|
158
|
+
try:
|
159
|
+
data = json.loads(data)
|
160
|
+
except json.JSONDecodeError:
|
161
|
+
raise ValueError(f"Request failed (request id: {request_id})")
|
162
|
+
|
163
|
+
if response.status_code < 200 or response.status_code >= 300:
|
164
|
+
if "message" not in data:
|
165
|
+
raise ValueError(f"Request failed (request id: {request_id})")
|
166
|
+
message = data["message"]
|
167
|
+
raise ValueError(f"Request failed: {message} (request id: {request_id})")
|
168
|
+
|
132
169
|
# Convert the dictionary to a string format that resembles the SSE event format
|
133
170
|
# For example, if the dict is {'event': 'message', 'data': 'your data'}, it should be formatted like this:
|
134
171
|
sse_format_data = ""
|
@@ -144,6 +181,7 @@ def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
|
|
144
181
|
|
145
182
|
@retry
|
146
183
|
def _call_complete_xp(
|
184
|
+
snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
|
147
185
|
model: str,
|
148
186
|
prompt: Union[str, List[ConversationMessage]],
|
149
187
|
options: Optional[CompleteOptions] = None,
|
@@ -151,9 +189,8 @@ def _call_complete_xp(
|
|
151
189
|
) -> requests.Response:
|
152
190
|
headers = _make_common_request_headers()
|
153
191
|
body = _make_request_body(model, prompt, options)
|
154
|
-
|
155
|
-
|
156
|
-
raw_resp = _snowflake.send_snow_api_request("POST", _REST_COMPLETE_URL, {}, headers, body, {}, deadline)
|
192
|
+
assert snow_api_xp_request_handler is not None
|
193
|
+
raw_resp = snow_api_xp_request_handler("POST", _REST_COMPLETE_URL, {}, headers, body, {}, deadline)
|
157
194
|
return _xp_dict_to_response(raw_resp)
|
158
195
|
|
159
196
|
|
@@ -218,17 +255,26 @@ def _complete_call_sql_function_snowpark(
|
|
218
255
|
|
219
256
|
|
220
257
|
def _complete_non_streaming_immediate(
|
258
|
+
snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
|
221
259
|
model: str,
|
222
260
|
prompt: Union[str, List[ConversationMessage]],
|
223
261
|
options: Optional[CompleteOptions],
|
224
262
|
session: Optional[snowpark.Session] = None,
|
225
263
|
deadline: Optional[float] = None,
|
226
264
|
) -> str:
|
227
|
-
response = _complete_rest(
|
265
|
+
response = _complete_rest(
|
266
|
+
snow_api_xp_request_handler=snow_api_xp_request_handler,
|
267
|
+
model=model,
|
268
|
+
prompt=prompt,
|
269
|
+
options=options,
|
270
|
+
session=session,
|
271
|
+
deadline=deadline,
|
272
|
+
)
|
228
273
|
return "".join(response)
|
229
274
|
|
230
275
|
|
231
276
|
def _complete_non_streaming_impl(
|
277
|
+
snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
|
232
278
|
function: str,
|
233
279
|
model: Union[str, snowpark.Column],
|
234
280
|
prompt: Union[str, List[ConversationMessage], snowpark.Column],
|
@@ -246,19 +292,31 @@ def _complete_non_streaming_impl(
|
|
246
292
|
if isinstance(options, snowpark.Column):
|
247
293
|
raise ValueError("'options' cannot be a snowpark.Column when 'prompt' is a string.")
|
248
294
|
return _complete_non_streaming_immediate(
|
249
|
-
|
295
|
+
snow_api_xp_request_handler=snow_api_xp_request_handler,
|
296
|
+
model=model,
|
297
|
+
prompt=prompt,
|
298
|
+
options=options,
|
299
|
+
session=session,
|
300
|
+
deadline=deadline,
|
250
301
|
)
|
251
302
|
|
252
303
|
|
253
304
|
def _complete_rest(
|
305
|
+
snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
|
254
306
|
model: str,
|
255
307
|
prompt: Union[str, List[ConversationMessage]],
|
256
308
|
options: Optional[CompleteOptions] = None,
|
257
309
|
session: Optional[snowpark.Session] = None,
|
258
310
|
deadline: Optional[float] = None,
|
259
311
|
) -> Iterator[str]:
|
260
|
-
if
|
261
|
-
response = _call_complete_xp(
|
312
|
+
if snow_api_xp_request_handler is not None:
|
313
|
+
response = _call_complete_xp(
|
314
|
+
snow_api_xp_request_handler=snow_api_xp_request_handler,
|
315
|
+
model=model,
|
316
|
+
prompt=prompt,
|
317
|
+
options=options,
|
318
|
+
deadline=deadline,
|
319
|
+
)
|
262
320
|
else:
|
263
321
|
response = _call_complete_rest(model=model, prompt=prompt, options=options, session=session, deadline=deadline)
|
264
322
|
assert response.status_code >= 200 and response.status_code < 300
|
@@ -268,10 +326,11 @@ def _complete_rest(
|
|
268
326
|
def _complete_impl(
|
269
327
|
model: Union[str, snowpark.Column],
|
270
328
|
prompt: Union[str, List[ConversationMessage], snowpark.Column],
|
329
|
+
snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]] = None,
|
330
|
+
function: str = "snowflake.cortex.complete",
|
271
331
|
options: Optional[CompleteOptions] = None,
|
272
332
|
session: Optional[snowpark.Session] = None,
|
273
333
|
stream: bool = False,
|
274
|
-
function: str = "snowflake.cortex.complete",
|
275
334
|
timeout: Optional[float] = None,
|
276
335
|
deadline: Optional[float] = None,
|
277
336
|
) -> Union[str, Iterator[str], snowpark.Column]:
|
@@ -284,14 +343,29 @@ def _complete_impl(
|
|
284
343
|
raise ValueError("in REST mode, 'model' must be a string")
|
285
344
|
if not isinstance(prompt, str) and not isinstance(prompt, List):
|
286
345
|
raise ValueError("in REST mode, 'prompt' must be a string or a list of ConversationMessage")
|
287
|
-
return _complete_rest(
|
288
|
-
|
346
|
+
return _complete_rest(
|
347
|
+
snow_api_xp_request_handler=snow_api_xp_request_handler,
|
348
|
+
model=model,
|
349
|
+
prompt=prompt,
|
350
|
+
options=options,
|
351
|
+
session=session,
|
352
|
+
deadline=deadline,
|
353
|
+
)
|
354
|
+
return _complete_non_streaming_impl(
|
355
|
+
snow_api_xp_request_handler=snow_api_xp_request_handler,
|
356
|
+
function=function,
|
357
|
+
model=model,
|
358
|
+
prompt=prompt,
|
359
|
+
options=options,
|
360
|
+
session=session,
|
361
|
+
deadline=deadline,
|
362
|
+
)
|
289
363
|
|
290
364
|
|
291
365
|
@telemetry.send_api_usage_telemetry(
|
292
366
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
293
367
|
)
|
294
|
-
def
|
368
|
+
def complete(
|
295
369
|
model: Union[str, snowpark.Column],
|
296
370
|
prompt: Union[str, List[ConversationMessage], snowpark.Column],
|
297
371
|
*,
|
@@ -319,10 +393,19 @@ def Complete(
|
|
319
393
|
Returns:
|
320
394
|
A column of string responses.
|
321
395
|
"""
|
396
|
+
|
397
|
+
# Set the XP snow api function, if available.
|
398
|
+
snow_api_xp_request_handler = None
|
399
|
+
if is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
400
|
+
import _snowflake
|
401
|
+
|
402
|
+
snow_api_xp_request_handler = _snowflake.send_snow_api_request
|
403
|
+
|
322
404
|
try:
|
323
405
|
return _complete_impl(
|
324
406
|
model,
|
325
407
|
prompt,
|
408
|
+
snow_api_xp_request_handler=snow_api_xp_request_handler,
|
326
409
|
options=options,
|
327
410
|
session=session,
|
328
411
|
stream=stream,
|
@@ -331,3 +414,8 @@ def Complete(
|
|
331
414
|
)
|
332
415
|
except ValueError as err:
|
333
416
|
raise err
|
417
|
+
|
418
|
+
|
419
|
+
Complete = deprecated("Complete() is deprecated and will be removed in a future release. Use complete() instead")(
|
420
|
+
telemetry.send_api_usage_telemetry(project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT)(complete)
|
421
|
+
)
|
@@ -1,5 +1,7 @@
|
|
1
1
|
from typing import List, Optional, Union, cast
|
2
2
|
|
3
|
+
from typing_extensions import deprecated
|
4
|
+
|
3
5
|
from snowflake import snowpark
|
4
6
|
from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
|
5
7
|
from snowflake.ml._internal import telemetry
|
@@ -8,12 +10,12 @@ from snowflake.ml._internal import telemetry
|
|
8
10
|
@telemetry.send_api_usage_telemetry(
|
9
11
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
10
12
|
)
|
11
|
-
def
|
13
|
+
def embed_text_1024(
|
12
14
|
model: Union[str, snowpark.Column],
|
13
15
|
text: Union[str, snowpark.Column],
|
14
16
|
session: Optional[snowpark.Session] = None,
|
15
17
|
) -> Union[List[float], snowpark.Column]:
|
16
|
-
"""
|
18
|
+
"""Calls into the LLM inference service to embed the text.
|
17
19
|
|
18
20
|
Args:
|
19
21
|
model: A Column of strings representing the model to use for embedding. The value
|
@@ -35,3 +37,8 @@ def _embed_text_1024_impl(
|
|
35
37
|
session: Optional[snowpark.Session] = None,
|
36
38
|
) -> Union[List[float], snowpark.Column]:
|
37
39
|
return cast(Union[List[float], snowpark.Column], call_sql_function(function, session, model, text))
|
40
|
+
|
41
|
+
|
42
|
+
EmbedText1024 = deprecated(
|
43
|
+
"EmbedText1024() is deprecated and will be removed in a future release. Use embed_text_1024() instead"
|
44
|
+
)(telemetry.send_api_usage_telemetry(project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT)(embed_text_1024))
|
@@ -1,5 +1,7 @@
|
|
1
1
|
from typing import List, Optional, Union, cast
|
2
2
|
|
3
|
+
from typing_extensions import deprecated
|
4
|
+
|
3
5
|
from snowflake import snowpark
|
4
6
|
from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
|
5
7
|
from snowflake.ml._internal import telemetry
|
@@ -8,12 +10,12 @@ from snowflake.ml._internal import telemetry
|
|
8
10
|
@telemetry.send_api_usage_telemetry(
|
9
11
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
10
12
|
)
|
11
|
-
def
|
13
|
+
def embed_text_768(
|
12
14
|
model: Union[str, snowpark.Column],
|
13
15
|
text: Union[str, snowpark.Column],
|
14
16
|
session: Optional[snowpark.Session] = None,
|
15
17
|
) -> Union[List[float], snowpark.Column]:
|
16
|
-
"""
|
18
|
+
"""Calls into the LLM inference service to embed the text.
|
17
19
|
|
18
20
|
Args:
|
19
21
|
model: A Column of strings representing the model to use for embedding. The value
|
@@ -35,3 +37,8 @@ def _embed_text_768_impl(
|
|
35
37
|
session: Optional[snowpark.Session] = None,
|
36
38
|
) -> Union[List[float], snowpark.Column]:
|
37
39
|
return cast(Union[List[float], snowpark.Column], call_sql_function(function, session, model, text))
|
40
|
+
|
41
|
+
|
42
|
+
EmbedText768 = deprecated(
|
43
|
+
"EmbedText768() is deprecated and will be removed in a future release. Use embed_text_768() instead"
|
44
|
+
)(telemetry.send_api_usage_telemetry(project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT)(embed_text_768))
|
@@ -1,5 +1,7 @@
|
|
1
1
|
from typing import Optional, Union, cast
|
2
2
|
|
3
|
+
from typing_extensions import deprecated
|
4
|
+
|
3
5
|
from snowflake import snowpark
|
4
6
|
from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
|
5
7
|
from snowflake.ml._internal import telemetry
|
@@ -8,12 +10,12 @@ from snowflake.ml._internal import telemetry
|
|
8
10
|
@telemetry.send_api_usage_telemetry(
|
9
11
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
10
12
|
)
|
11
|
-
def
|
13
|
+
def extract_answer(
|
12
14
|
from_text: Union[str, snowpark.Column],
|
13
15
|
question: Union[str, snowpark.Column],
|
14
16
|
session: Optional[snowpark.Session] = None,
|
15
17
|
) -> Union[str, snowpark.Column]:
|
16
|
-
"""
|
18
|
+
"""Calls into the LLM inference service to extract an answer from within specified text.
|
17
19
|
|
18
20
|
Args:
|
19
21
|
from_text: A Column of strings representing input text.
|
@@ -34,3 +36,8 @@ def _extract_answer_impl(
|
|
34
36
|
session: Optional[snowpark.Session] = None,
|
35
37
|
) -> Union[str, snowpark.Column]:
|
36
38
|
return cast(Union[str, snowpark.Column], call_sql_function(function, session, from_text, question))
|
39
|
+
|
40
|
+
|
41
|
+
ExtractAnswer = deprecated(
|
42
|
+
"ExtractAnswer() is deprecated and will be removed in a future release. Use extract_answer() instead"
|
43
|
+
)(telemetry.send_api_usage_telemetry(project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT)(extract_answer))
|
snowflake/cortex/_sentiment.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
from typing import Optional, Union, cast
|
2
2
|
|
3
|
+
from typing_extensions import deprecated
|
4
|
+
|
3
5
|
from snowflake import snowpark
|
4
6
|
from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
|
5
7
|
from snowflake.ml._internal import telemetry
|
@@ -8,10 +10,10 @@ from snowflake.ml._internal import telemetry
|
|
8
10
|
@telemetry.send_api_usage_telemetry(
|
9
11
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
10
12
|
)
|
11
|
-
def
|
13
|
+
def sentiment(
|
12
14
|
text: Union[str, snowpark.Column], session: Optional[snowpark.Session] = None
|
13
15
|
) -> Union[float, snowpark.Column]:
|
14
|
-
"""
|
16
|
+
"""Calls into the LLM inference service to perform sentiment analysis on the input text.
|
15
17
|
|
16
18
|
Args:
|
17
19
|
text: A Column of text strings to send to the LLM.
|
@@ -31,3 +33,8 @@ def _sentiment_impl(
|
|
31
33
|
if isinstance(output, snowpark.Column):
|
32
34
|
return output
|
33
35
|
return float(cast(str, output))
|
36
|
+
|
37
|
+
|
38
|
+
Sentiment = deprecated("Sentiment() is deprecated and will be removed in a future release. Use sentiment() instead")(
|
39
|
+
telemetry.send_api_usage_telemetry(project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT)(sentiment)
|
40
|
+
)
|
snowflake/cortex/_summarize.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
from typing import Optional, Union, cast
|
2
2
|
|
3
|
+
from typing_extensions import deprecated
|
4
|
+
|
3
5
|
from snowflake import snowpark
|
4
6
|
from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
|
5
7
|
from snowflake.ml._internal import telemetry
|
@@ -8,11 +10,11 @@ from snowflake.ml._internal import telemetry
|
|
8
10
|
@telemetry.send_api_usage_telemetry(
|
9
11
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
10
12
|
)
|
11
|
-
def
|
13
|
+
def summarize(
|
12
14
|
text: Union[str, snowpark.Column],
|
13
15
|
session: Optional[snowpark.Session] = None,
|
14
16
|
) -> Union[str, snowpark.Column]:
|
15
|
-
"""
|
17
|
+
"""Calls into the LLM inference service to summarize the input text.
|
16
18
|
|
17
19
|
Args:
|
18
20
|
text: A Column of strings to summarize.
|
@@ -31,3 +33,8 @@ def _summarize_impl(
|
|
31
33
|
session: Optional[snowpark.Session] = None,
|
32
34
|
) -> Union[str, snowpark.Column]:
|
33
35
|
return cast(Union[str, snowpark.Column], call_sql_function(function, session, text))
|
36
|
+
|
37
|
+
|
38
|
+
Summarize = deprecated("Summarize() is deprecated and will be removed in a future release. Use summarize() instead")(
|
39
|
+
telemetry.send_api_usage_telemetry(project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT)(summarize)
|
40
|
+
)
|
snowflake/cortex/_translate.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
from typing import Optional, Union, cast
|
2
2
|
|
3
|
+
from typing_extensions import deprecated
|
4
|
+
|
3
5
|
from snowflake import snowpark
|
4
6
|
from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
|
5
7
|
from snowflake.ml._internal import telemetry
|
@@ -8,13 +10,13 @@ from snowflake.ml._internal import telemetry
|
|
8
10
|
@telemetry.send_api_usage_telemetry(
|
9
11
|
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
10
12
|
)
|
11
|
-
def
|
13
|
+
def translate(
|
12
14
|
text: Union[str, snowpark.Column],
|
13
15
|
from_language: Union[str, snowpark.Column],
|
14
16
|
to_language: Union[str, snowpark.Column],
|
15
17
|
session: Optional[snowpark.Session] = None,
|
16
18
|
) -> Union[str, snowpark.Column]:
|
17
|
-
"""
|
19
|
+
"""Calls into the LLM inference service to perform translation.
|
18
20
|
|
19
21
|
Args:
|
20
22
|
text: A Column of strings to translate.
|
@@ -37,3 +39,8 @@ def _translate_impl(
|
|
37
39
|
session: Optional[snowpark.Session] = None,
|
38
40
|
) -> Union[str, snowpark.Column]:
|
39
41
|
return cast(Union[str, snowpark.Column], call_sql_function(function, session, text, from_language, to_language))
|
42
|
+
|
43
|
+
|
44
|
+
Translate = deprecated("Translate() is deprecated and will be removed in a future release. Use translate() instead")(
|
45
|
+
telemetry.send_api_usage_telemetry(project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT)(translate)
|
46
|
+
)
|
@@ -15,7 +15,6 @@ import snowflake.connector
|
|
15
15
|
from snowflake.ml._internal import env as snowml_env
|
16
16
|
from snowflake.ml._internal.utils import query_result_checker
|
17
17
|
from snowflake.snowpark import context, exceptions, session
|
18
|
-
from snowflake.snowpark._internal import utils as snowpark_utils
|
19
18
|
|
20
19
|
|
21
20
|
class CONDA_OS(Enum):
|
@@ -344,55 +343,6 @@ def relax_requirement_version(req: requirements.Requirement) -> requirements.Req
|
|
344
343
|
return new_req
|
345
344
|
|
346
345
|
|
347
|
-
def get_matched_package_versions_in_snowflake_conda_channel(
|
348
|
-
req: requirements.Requirement,
|
349
|
-
python_version: str = snowml_env.PYTHON_VERSION,
|
350
|
-
conda_os: CONDA_OS = CONDA_OS.LINUX_64,
|
351
|
-
) -> List[version.Version]:
|
352
|
-
"""Search the snowflake anaconda channel for packages that matches the specifier. Note that this will be the
|
353
|
-
source of truth for checking whether a package indeed exists in Snowflake conda channel.
|
354
|
-
|
355
|
-
Given that a package comes in different architectures, we only check for the Linux x86_64 architecture and assume
|
356
|
-
the package exists in other architectures. If such an assumption does not hold true for a certain package, the
|
357
|
-
caller should specify the architecture to search for.
|
358
|
-
|
359
|
-
Args:
|
360
|
-
req: Requirement specifier.
|
361
|
-
python_version: A string of python version where model is run.
|
362
|
-
conda_os: Specified platform to search availability of the package.
|
363
|
-
|
364
|
-
Returns:
|
365
|
-
List of package versions that meet the requirement specifier.
|
366
|
-
"""
|
367
|
-
# Move the retryable_http import here as when UDF import this file, it won't have the "requests" dependency.
|
368
|
-
from snowflake.ml._internal.utils import retryable_http
|
369
|
-
|
370
|
-
assert not snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call]
|
371
|
-
|
372
|
-
url = f"{SNOWFLAKE_CONDA_CHANNEL_URL}/{conda_os.value}/repodata.json"
|
373
|
-
|
374
|
-
if req.name not in _SNOWFLAKE_CONDA_PACKAGE_CACHE:
|
375
|
-
try:
|
376
|
-
http_client = retryable_http.get_http_client()
|
377
|
-
parsed_python_version = version.Version(python_version)
|
378
|
-
python_version_build_str = f"py{parsed_python_version.major}{parsed_python_version.minor}"
|
379
|
-
repodata = http_client.get(url).json()
|
380
|
-
assert isinstance(repodata, dict)
|
381
|
-
packages_info = repodata["packages"]
|
382
|
-
assert isinstance(packages_info, dict)
|
383
|
-
version_list = [
|
384
|
-
version.parse(package_info["version"])
|
385
|
-
for package_info in packages_info.values()
|
386
|
-
if package_info["name"] == req.name and python_version_build_str in package_info["build"]
|
387
|
-
]
|
388
|
-
_SNOWFLAKE_CONDA_PACKAGE_CACHE[req.name] = version_list
|
389
|
-
except Exception:
|
390
|
-
pass
|
391
|
-
|
392
|
-
matched_versions = list(req.specifier.filter(set(_SNOWFLAKE_CONDA_PACKAGE_CACHE.get(req.name, []))))
|
393
|
-
return matched_versions
|
394
|
-
|
395
|
-
|
396
346
|
def get_matched_package_versions_in_information_schema_with_active_session(
|
397
347
|
reqs: List[requirements.Requirement], python_version: str
|
398
348
|
) -> Dict[str, List[version.Version]]:
|
@@ -404,7 +354,10 @@ def get_matched_package_versions_in_information_schema_with_active_session(
|
|
404
354
|
|
405
355
|
|
406
356
|
def get_matched_package_versions_in_information_schema(
|
407
|
-
session: session.Session,
|
357
|
+
session: session.Session,
|
358
|
+
reqs: List[requirements.Requirement],
|
359
|
+
python_version: str,
|
360
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
408
361
|
) -> Dict[str, List[version.Version]]:
|
409
362
|
"""Look up the information_schema table to check if a package with the specified specifier exists in the Snowflake
|
410
363
|
Conda channel. Note that this is not the source of truth due to the potential delay caused by a package that might
|
@@ -414,6 +367,7 @@ def get_matched_package_versions_in_information_schema(
|
|
414
367
|
session: Snowflake connection session.
|
415
368
|
reqs: List of requirement specifiers.
|
416
369
|
python_version: A string of python version where model is run.
|
370
|
+
statement_params: Optional statement parameters.
|
417
371
|
|
418
372
|
Returns:
|
419
373
|
A Dict, whose key is the package name, and value is a list of versions match the requirements.
|
@@ -451,8 +405,9 @@ def get_matched_package_versions_in_information_schema(
|
|
451
405
|
query_result_checker.SqlResultValidator(
|
452
406
|
session=session,
|
453
407
|
query=sql,
|
408
|
+
statement_params=statement_params,
|
454
409
|
)
|
455
|
-
.has_column("VERSION")
|
410
|
+
.has_column("VERSION", allow_empty=True)
|
456
411
|
.has_dimensions(expected_rows=None, expected_cols=2)
|
457
412
|
.validate()
|
458
413
|
)
|
@@ -0,0 +1,87 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Any, Dict, Optional
|
3
|
+
|
4
|
+
from absl import logging
|
5
|
+
|
6
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
7
|
+
from snowflake.ml._internal.utils import query_result_checker
|
8
|
+
from snowflake.snowpark import (
|
9
|
+
exceptions as snowpark_exceptions,
|
10
|
+
session as snowpark_session,
|
11
|
+
)
|
12
|
+
|
13
|
+
|
14
|
+
class PlatformCapabilities:
|
15
|
+
"""Class that retrieves platform feature values for the currently running server.
|
16
|
+
|
17
|
+
Example usage:
|
18
|
+
```
|
19
|
+
pc = PlatformCapabilities.get_instance(session)
|
20
|
+
if pc.is_nested_function_enabled():
|
21
|
+
# Nested functions are enabled.
|
22
|
+
print("Nested functions are enabled.")
|
23
|
+
else:
|
24
|
+
# Nested functions are disabled.
|
25
|
+
print("Nested functions are disabled or not supported.")
|
26
|
+
```
|
27
|
+
"""
|
28
|
+
|
29
|
+
_instance: Optional["PlatformCapabilities"] = None
|
30
|
+
|
31
|
+
@classmethod
|
32
|
+
def get_instance(cls, session: Optional[snowpark_session.Session] = None) -> "PlatformCapabilities":
|
33
|
+
if not cls._instance:
|
34
|
+
cls._instance = cls(session)
|
35
|
+
return cls._instance
|
36
|
+
|
37
|
+
def is_nested_function_enabled(self) -> bool:
|
38
|
+
return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
|
39
|
+
|
40
|
+
@staticmethod
|
41
|
+
def _get_features(session: snowpark_session.Session) -> Dict[str, Any]:
|
42
|
+
try:
|
43
|
+
result = (
|
44
|
+
query_result_checker.SqlResultValidator(
|
45
|
+
session=session,
|
46
|
+
query="SELECT SYSTEM$ML_PLATFORM_CAPABILITIES() AS FEATURES;",
|
47
|
+
)
|
48
|
+
.has_dimensions(expected_rows=1, expected_cols=1)
|
49
|
+
.has_column("FEATURES")
|
50
|
+
.validate()[0]
|
51
|
+
)
|
52
|
+
if "FEATURES" in result:
|
53
|
+
capabilities_json: str = result["FEATURES"]
|
54
|
+
try:
|
55
|
+
parsed_json = json.loads(capabilities_json)
|
56
|
+
assert isinstance(parsed_json, dict), f"Expected JSON object, got {type(parsed_json)}"
|
57
|
+
return parsed_json
|
58
|
+
except json.JSONDecodeError as e:
|
59
|
+
message = f"""Unable to parse JSON from: "{capabilities_json}"; Error="{e}"."""
|
60
|
+
raise exceptions.SnowflakeMLException(
|
61
|
+
error_code=error_codes.INTERNAL_SNOWML_ERROR, original_exception=RuntimeError(message)
|
62
|
+
)
|
63
|
+
except snowpark_exceptions.SnowparkSQLException as e:
|
64
|
+
logging.debug(f"Failed to retrieve platform capabilities: {e}")
|
65
|
+
# This can happen is server side is older than 9.2. That is fine.
|
66
|
+
return {}
|
67
|
+
|
68
|
+
def __init__(self, session: Optional[snowpark_session.Session] = None) -> None:
|
69
|
+
if not session:
|
70
|
+
session = next(iter(snowpark_session._get_active_sessions()))
|
71
|
+
assert session, "Missing active session object"
|
72
|
+
self.features: Dict[str, Any] = PlatformCapabilities._get_features(session)
|
73
|
+
|
74
|
+
def _get_bool_feature(self, feature_name: str, default_value: bool) -> bool:
|
75
|
+
value = self.features.get(feature_name, default_value)
|
76
|
+
if isinstance(value, bool):
|
77
|
+
return value
|
78
|
+
if isinstance(value, int) and value in [0, 1]:
|
79
|
+
return value == 1
|
80
|
+
if isinstance(value, str):
|
81
|
+
if value.lower() in ["true", "1"]:
|
82
|
+
return True
|
83
|
+
elif value.lower() in ["false", "0"]:
|
84
|
+
return False
|
85
|
+
else:
|
86
|
+
raise ValueError(f"Invalid boolean string: {value} for feature {feature_name}")
|
87
|
+
raise ValueError(f"Invalid boolean feature value: {value} for feature {feature_name}")
|