snowflake-ml-python 1.7.2__py3-none-any.whl → 1.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (237) hide show
  1. snowflake/cortex/__init__.py +16 -8
  2. snowflake/cortex/_classify_text.py +12 -1
  3. snowflake/cortex/_complete.py +101 -13
  4. snowflake/cortex/_embed_text_1024.py +9 -2
  5. snowflake/cortex/_embed_text_768.py +9 -2
  6. snowflake/cortex/_extract_answer.py +9 -2
  7. snowflake/cortex/_sentiment.py +9 -2
  8. snowflake/cortex/_summarize.py +9 -2
  9. snowflake/cortex/_translate.py +9 -2
  10. snowflake/ml/_internal/env_utils.py +7 -52
  11. snowflake/ml/_internal/platform_capabilities.py +87 -0
  12. snowflake/ml/_internal/utils/identifier.py +4 -2
  13. snowflake/ml/data/__init__.py +3 -0
  14. snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
  15. snowflake/ml/data/data_connector.py +53 -11
  16. snowflake/ml/data/data_ingestor.py +2 -1
  17. snowflake/ml/data/torch_utils.py +18 -5
  18. snowflake/ml/dataset/dataset.py +0 -1
  19. snowflake/ml/feature_store/examples/example_helper.py +2 -1
  20. snowflake/ml/fileset/fileset.py +24 -18
  21. snowflake/ml/jobs/__init__.py +21 -0
  22. snowflake/ml/jobs/_utils/constants.py +51 -0
  23. snowflake/ml/jobs/_utils/payload_utils.py +352 -0
  24. snowflake/ml/jobs/_utils/spec_utils.py +298 -0
  25. snowflake/ml/jobs/_utils/types.py +39 -0
  26. snowflake/ml/jobs/decorators.py +91 -0
  27. snowflake/ml/jobs/job.py +113 -0
  28. snowflake/ml/jobs/manager.py +298 -0
  29. snowflake/ml/model/_client/model/model_version_impl.py +5 -3
  30. snowflake/ml/model/_client/ops/model_ops.py +13 -8
  31. snowflake/ml/model/_client/ops/service_ops.py +1 -11
  32. snowflake/ml/model/_client/sql/model_version.py +11 -0
  33. snowflake/ml/model/_client/sql/service.py +13 -6
  34. snowflake/ml/model/_model_composer/model_composer.py +8 -3
  35. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  37. snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
  38. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
  39. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
  40. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
  41. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  42. snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
  43. snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
  44. snowflake/ml/model/_packager/model_handlers/_utils.py +39 -5
  45. snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
  46. snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +6 -1
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
  49. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
  50. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -10
  51. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
  52. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
  53. snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
  54. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
  55. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  56. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  57. snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
  58. snowflake/ml/model/_signatures/base_handler.py +1 -2
  59. snowflake/ml/model/_signatures/builtins_handler.py +2 -2
  60. snowflake/ml/model/_signatures/numpy_handler.py +6 -7
  61. snowflake/ml/model/_signatures/pandas_handler.py +3 -3
  62. snowflake/ml/model/_signatures/pytorch_handler.py +2 -5
  63. snowflake/ml/model/_signatures/snowpark_handler.py +11 -5
  64. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
  65. snowflake/ml/model/model_signature.py +17 -4
  66. snowflake/ml/model/type_hints.py +1 -0
  67. snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
  68. snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
  69. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
  70. snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
  71. snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
  72. snowflake/ml/modeling/cluster/birch.py +6 -3
  73. snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
  74. snowflake/ml/modeling/cluster/dbscan.py +6 -3
  75. snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
  76. snowflake/ml/modeling/cluster/k_means.py +6 -3
  77. snowflake/ml/modeling/cluster/mean_shift.py +6 -3
  78. snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
  79. snowflake/ml/modeling/cluster/optics.py +6 -3
  80. snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
  81. snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
  82. snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
  83. snowflake/ml/modeling/compose/column_transformer.py +6 -3
  84. snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
  85. snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
  86. snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
  87. snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
  88. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
  89. snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
  90. snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
  91. snowflake/ml/modeling/covariance/oas.py +6 -3
  92. snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
  93. snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
  94. snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
  95. snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
  96. snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
  97. snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
  98. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
  99. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
  100. snowflake/ml/modeling/decomposition/pca.py +6 -3
  101. snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
  102. snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
  103. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
  104. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
  105. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
  106. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
  107. snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
  108. snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
  109. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
  110. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
  111. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
  112. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
  113. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
  114. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
  115. snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
  116. snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
  117. snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
  118. snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
  119. snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
  120. snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
  121. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
  122. snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
  123. snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
  124. snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
  125. snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
  126. snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
  127. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
  128. snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
  129. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
  130. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
  131. snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
  132. snowflake/ml/modeling/impute/knn_imputer.py +6 -3
  133. snowflake/ml/modeling/impute/missing_indicator.py +6 -3
  134. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
  135. snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
  136. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
  137. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
  138. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
  139. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
  140. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
  141. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
  142. snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
  143. snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
  144. snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
  145. snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
  146. snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
  147. snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
  148. snowflake/ml/modeling/linear_model/lars.py +6 -3
  149. snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
  150. snowflake/ml/modeling/linear_model/lasso.py +6 -3
  151. snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
  152. snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
  153. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
  154. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
  155. snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
  156. snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
  157. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
  158. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
  159. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
  160. snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
  161. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
  162. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
  163. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
  164. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
  165. snowflake/ml/modeling/linear_model/perceptron.py +6 -3
  166. snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
  167. snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
  168. snowflake/ml/modeling/linear_model/ridge.py +6 -3
  169. snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
  170. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
  171. snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
  172. snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
  173. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
  174. snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
  175. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
  176. snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
  177. snowflake/ml/modeling/manifold/isomap.py +6 -3
  178. snowflake/ml/modeling/manifold/mds.py +6 -3
  179. snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
  180. snowflake/ml/modeling/manifold/tsne.py +6 -3
  181. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
  182. snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
  183. snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
  184. snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
  185. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
  186. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
  187. snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
  188. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
  189. snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
  190. snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
  191. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
  192. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
  193. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
  194. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
  195. snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
  196. snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
  197. snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
  198. snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
  199. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
  200. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
  201. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
  202. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
  203. snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
  204. snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
  205. snowflake/ml/modeling/pipeline/pipeline.py +16 -178
  206. snowflake/ml/modeling/preprocessing/polynomial_features.py +6 -3
  207. snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
  208. snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
  209. snowflake/ml/modeling/svm/linear_svc.py +6 -3
  210. snowflake/ml/modeling/svm/linear_svr.py +6 -3
  211. snowflake/ml/modeling/svm/nu_svc.py +6 -3
  212. snowflake/ml/modeling/svm/nu_svr.py +6 -3
  213. snowflake/ml/modeling/svm/svc.py +6 -3
  214. snowflake/ml/modeling/svm/svr.py +6 -3
  215. snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
  216. snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
  217. snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
  218. snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
  219. snowflake/ml/modeling/xgboost/xgb_classifier.py +167 -91
  220. snowflake/ml/modeling/xgboost/xgb_regressor.py +166 -88
  221. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +166 -88
  222. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +166 -88
  223. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +4 -4
  224. snowflake/ml/registry/_manager/model_manager.py +70 -33
  225. snowflake/ml/registry/registry.py +41 -22
  226. snowflake/ml/version.py +1 -1
  227. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/METADATA +63 -19
  228. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/RECORD +231 -226
  229. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/WHEEL +1 -1
  230. snowflake/ml/_internal/utils/retryable_http.py +0 -39
  231. snowflake/ml/fileset/parquet_parser.py +0 -170
  232. snowflake/ml/fileset/tf_dataset.py +0 -88
  233. snowflake/ml/fileset/torch_datapipe.py +0 -57
  234. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
  235. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
  236. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/LICENSE.txt +0 -0
  237. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/top_level.txt +0 -0
@@ -1,24 +1,32 @@
1
- from snowflake.cortex._classify_text import ClassifyText
2
- from snowflake.cortex._complete import Complete, CompleteOptions
3
- from snowflake.cortex._embed_text_768 import EmbedText768
4
- from snowflake.cortex._embed_text_1024 import EmbedText1024
5
- from snowflake.cortex._extract_answer import ExtractAnswer
1
+ from snowflake.cortex._classify_text import ClassifyText, classify_text
2
+ from snowflake.cortex._complete import Complete, CompleteOptions, complete
3
+ from snowflake.cortex._embed_text_768 import EmbedText768, embed_text_768
4
+ from snowflake.cortex._embed_text_1024 import EmbedText1024, embed_text_1024
5
+ from snowflake.cortex._extract_answer import ExtractAnswer, extract_answer
6
6
  from snowflake.cortex._finetune import Finetune, FinetuneJob, FinetuneStatus
7
- from snowflake.cortex._sentiment import Sentiment
8
- from snowflake.cortex._summarize import Summarize
9
- from snowflake.cortex._translate import Translate
7
+ from snowflake.cortex._sentiment import Sentiment, sentiment
8
+ from snowflake.cortex._summarize import Summarize, summarize
9
+ from snowflake.cortex._translate import Translate, translate
10
10
 
11
11
  __all__ = [
12
12
  "ClassifyText",
13
+ "classify_text",
13
14
  "Complete",
15
+ "complete",
14
16
  "CompleteOptions",
15
17
  "EmbedText768",
18
+ "embed_text_768",
16
19
  "EmbedText1024",
20
+ "embed_text_1024",
17
21
  "ExtractAnswer",
22
+ "extract_answer",
18
23
  "Finetune",
19
24
  "FinetuneJob",
20
25
  "FinetuneStatus",
21
26
  "Sentiment",
27
+ "sentiment",
22
28
  "Summarize",
29
+ "summarize",
23
30
  "Translate",
31
+ "translate",
24
32
  ]
@@ -1,5 +1,7 @@
1
1
  from typing import List, Optional, Union, cast
2
2
 
3
+ from typing_extensions import deprecated
4
+
3
5
  from snowflake import snowpark
4
6
  from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
5
7
  from snowflake.ml._internal import telemetry
@@ -8,7 +10,7 @@ from snowflake.ml._internal import telemetry
8
10
  @telemetry.send_api_usage_telemetry(
9
11
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
10
12
  )
11
- def ClassifyText(
13
+ def classify_text(
12
14
  str_input: Union[str, snowpark.Column],
13
15
  categories: Union[List[str], snowpark.Column],
14
16
  session: Optional[snowpark.Session] = None,
@@ -34,3 +36,12 @@ def _classify_text_impl(
34
36
  session: Optional[snowpark.Session] = None,
35
37
  ) -> Union[str, snowpark.Column]:
36
38
  return cast(Union[str, snowpark.Column], call_sql_function(function, session, str_input, categories))
39
+
40
+
41
+ ClassifyText = deprecated(
42
+ "ClassifyText() is deprecated and will be removed in a future release. Please use classify_text() instead."
43
+ )(
44
+ telemetry.send_api_usage_telemetry(
45
+ project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
46
+ )(classify_text)
47
+ )
@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, TypedDict, Uni
6
6
  from urllib.parse import urlunparse
7
7
 
8
8
  import requests
9
- from typing_extensions import NotRequired
9
+ from typing_extensions import NotRequired, deprecated
10
10
 
11
11
  from snowflake import snowpark
12
12
  from snowflake.cortex._sse_client import SSEClient
@@ -49,6 +49,10 @@ class CompleteOptions(TypedDict):
49
49
  generally used as an alternative to temperature. The difference is that top_p restricts the set of possible tokens
50
50
  that the model outputs, while temperature influences which tokens are chosen at each step. """
51
51
 
52
+ guardrails: NotRequired[bool]
53
+ """ A boolean value that controls whether Cortex Guard filters unsafe or harmful responses
54
+ from the language model. """
55
+
52
56
 
53
57
  class ResponseParseException(Exception):
54
58
  """This exception is raised when the server response cannot be parsed."""
@@ -56,6 +60,15 @@ class ResponseParseException(Exception):
56
60
  pass
57
61
 
58
62
 
63
+ class GuardrailsOptions(TypedDict):
64
+ enabled: bool
65
+ """A boolean value that controls whether Cortex Guard filters unsafe or harmful responses
66
+ from the language model."""
67
+
68
+ response_when_unsafe: str
69
+ """The response to return when the language model generates unsafe or harmful content."""
70
+
71
+
59
72
  _MAX_RETRY_SECONDS = 30
60
73
 
61
74
 
@@ -117,6 +130,12 @@ def _make_request_body(
117
130
  data["temperature"] = options["temperature"]
118
131
  if "top_p" in options:
119
132
  data["top_p"] = options["top_p"]
133
+ if "guardrails" in options and options["guardrails"]:
134
+ guardrails_options: GuardrailsOptions = {
135
+ "enabled": True,
136
+ "response_when_unsafe": "Response filtered by Cortex Guard",
137
+ }
138
+ data["guardrails"] = guardrails_options
120
139
  return data
121
140
 
122
141
 
@@ -127,8 +146,26 @@ def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
127
146
  response.status_code = int(raw_resp["status"])
128
147
  response.headers = raw_resp["headers"]
129
148
 
149
+ request_id = None
150
+ for key, value in raw_resp["headers"].items():
151
+ # Note: there is some whitespace in the headers making it not possible
152
+ # to directly index the header reliably.
153
+ if key.strip().lower() == "x-snowflake-request-id":
154
+ request_id = value
155
+ break
156
+
130
157
  data = raw_resp["content"]
131
- data = json.loads(data)
158
+ try:
159
+ data = json.loads(data)
160
+ except json.JSONDecodeError:
161
+ raise ValueError(f"Request failed (request id: {request_id})")
162
+
163
+ if response.status_code < 200 or response.status_code >= 300:
164
+ if "message" not in data:
165
+ raise ValueError(f"Request failed (request id: {request_id})")
166
+ message = data["message"]
167
+ raise ValueError(f"Request failed: {message} (request id: {request_id})")
168
+
132
169
  # Convert the dictionary to a string format that resembles the SSE event format
133
170
  # For example, if the dict is {'event': 'message', 'data': 'your data'}, it should be formatted like this:
134
171
  sse_format_data = ""
@@ -144,6 +181,7 @@ def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
144
181
 
145
182
  @retry
146
183
  def _call_complete_xp(
184
+ snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
147
185
  model: str,
148
186
  prompt: Union[str, List[ConversationMessage]],
149
187
  options: Optional[CompleteOptions] = None,
@@ -151,9 +189,8 @@ def _call_complete_xp(
151
189
  ) -> requests.Response:
152
190
  headers = _make_common_request_headers()
153
191
  body = _make_request_body(model, prompt, options)
154
- import _snowflake
155
-
156
- raw_resp = _snowflake.send_snow_api_request("POST", _REST_COMPLETE_URL, {}, headers, body, {}, deadline)
192
+ assert snow_api_xp_request_handler is not None
193
+ raw_resp = snow_api_xp_request_handler("POST", _REST_COMPLETE_URL, {}, headers, body, {}, deadline)
157
194
  return _xp_dict_to_response(raw_resp)
158
195
 
159
196
 
@@ -218,17 +255,26 @@ def _complete_call_sql_function_snowpark(
218
255
 
219
256
 
220
257
  def _complete_non_streaming_immediate(
258
+ snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
221
259
  model: str,
222
260
  prompt: Union[str, List[ConversationMessage]],
223
261
  options: Optional[CompleteOptions],
224
262
  session: Optional[snowpark.Session] = None,
225
263
  deadline: Optional[float] = None,
226
264
  ) -> str:
227
- response = _complete_rest(model=model, prompt=prompt, options=options, session=session, deadline=deadline)
265
+ response = _complete_rest(
266
+ snow_api_xp_request_handler=snow_api_xp_request_handler,
267
+ model=model,
268
+ prompt=prompt,
269
+ options=options,
270
+ session=session,
271
+ deadline=deadline,
272
+ )
228
273
  return "".join(response)
229
274
 
230
275
 
231
276
  def _complete_non_streaming_impl(
277
+ snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
232
278
  function: str,
233
279
  model: Union[str, snowpark.Column],
234
280
  prompt: Union[str, List[ConversationMessage], snowpark.Column],
@@ -246,19 +292,31 @@ def _complete_non_streaming_impl(
246
292
  if isinstance(options, snowpark.Column):
247
293
  raise ValueError("'options' cannot be a snowpark.Column when 'prompt' is a string.")
248
294
  return _complete_non_streaming_immediate(
249
- model=model, prompt=prompt, options=options, session=session, deadline=deadline
295
+ snow_api_xp_request_handler=snow_api_xp_request_handler,
296
+ model=model,
297
+ prompt=prompt,
298
+ options=options,
299
+ session=session,
300
+ deadline=deadline,
250
301
  )
251
302
 
252
303
 
253
304
  def _complete_rest(
305
+ snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
254
306
  model: str,
255
307
  prompt: Union[str, List[ConversationMessage]],
256
308
  options: Optional[CompleteOptions] = None,
257
309
  session: Optional[snowpark.Session] = None,
258
310
  deadline: Optional[float] = None,
259
311
  ) -> Iterator[str]:
260
- if is_in_stored_procedure(): # type: ignore[no-untyped-call]
261
- response = _call_complete_xp(model=model, prompt=prompt, options=options, deadline=deadline)
312
+ if snow_api_xp_request_handler is not None:
313
+ response = _call_complete_xp(
314
+ snow_api_xp_request_handler=snow_api_xp_request_handler,
315
+ model=model,
316
+ prompt=prompt,
317
+ options=options,
318
+ deadline=deadline,
319
+ )
262
320
  else:
263
321
  response = _call_complete_rest(model=model, prompt=prompt, options=options, session=session, deadline=deadline)
264
322
  assert response.status_code >= 200 and response.status_code < 300
@@ -268,10 +326,11 @@ def _complete_rest(
268
326
  def _complete_impl(
269
327
  model: Union[str, snowpark.Column],
270
328
  prompt: Union[str, List[ConversationMessage], snowpark.Column],
329
+ snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]] = None,
330
+ function: str = "snowflake.cortex.complete",
271
331
  options: Optional[CompleteOptions] = None,
272
332
  session: Optional[snowpark.Session] = None,
273
333
  stream: bool = False,
274
- function: str = "snowflake.cortex.complete",
275
334
  timeout: Optional[float] = None,
276
335
  deadline: Optional[float] = None,
277
336
  ) -> Union[str, Iterator[str], snowpark.Column]:
@@ -284,14 +343,29 @@ def _complete_impl(
284
343
  raise ValueError("in REST mode, 'model' must be a string")
285
344
  if not isinstance(prompt, str) and not isinstance(prompt, List):
286
345
  raise ValueError("in REST mode, 'prompt' must be a string or a list of ConversationMessage")
287
- return _complete_rest(model=model, prompt=prompt, options=options, session=session, deadline=deadline)
288
- return _complete_non_streaming_impl(function, model, prompt, options, session, deadline)
346
+ return _complete_rest(
347
+ snow_api_xp_request_handler=snow_api_xp_request_handler,
348
+ model=model,
349
+ prompt=prompt,
350
+ options=options,
351
+ session=session,
352
+ deadline=deadline,
353
+ )
354
+ return _complete_non_streaming_impl(
355
+ snow_api_xp_request_handler=snow_api_xp_request_handler,
356
+ function=function,
357
+ model=model,
358
+ prompt=prompt,
359
+ options=options,
360
+ session=session,
361
+ deadline=deadline,
362
+ )
289
363
 
290
364
 
291
365
  @telemetry.send_api_usage_telemetry(
292
366
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
293
367
  )
294
- def Complete(
368
+ def complete(
295
369
  model: Union[str, snowpark.Column],
296
370
  prompt: Union[str, List[ConversationMessage], snowpark.Column],
297
371
  *,
@@ -319,10 +393,19 @@ def Complete(
319
393
  Returns:
320
394
  A column of string responses.
321
395
  """
396
+
397
+ # Set the XP snow api function, if available.
398
+ snow_api_xp_request_handler = None
399
+ if is_in_stored_procedure(): # type: ignore[no-untyped-call]
400
+ import _snowflake
401
+
402
+ snow_api_xp_request_handler = _snowflake.send_snow_api_request
403
+
322
404
  try:
323
405
  return _complete_impl(
324
406
  model,
325
407
  prompt,
408
+ snow_api_xp_request_handler=snow_api_xp_request_handler,
326
409
  options=options,
327
410
  session=session,
328
411
  stream=stream,
@@ -331,3 +414,8 @@ def Complete(
331
414
  )
332
415
  except ValueError as err:
333
416
  raise err
417
+
418
+
419
+ Complete = deprecated("Complete() is deprecated and will be removed in a future release. Use complete() instead")(
420
+ telemetry.send_api_usage_telemetry(project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT)(complete)
421
+ )
@@ -1,5 +1,7 @@
1
1
  from typing import List, Optional, Union, cast
2
2
 
3
+ from typing_extensions import deprecated
4
+
3
5
  from snowflake import snowpark
4
6
  from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
5
7
  from snowflake.ml._internal import telemetry
@@ -8,12 +10,12 @@ from snowflake.ml._internal import telemetry
8
10
  @telemetry.send_api_usage_telemetry(
9
11
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
10
12
  )
11
- def EmbedText1024(
13
+ def embed_text_1024(
12
14
  model: Union[str, snowpark.Column],
13
15
  text: Union[str, snowpark.Column],
14
16
  session: Optional[snowpark.Session] = None,
15
17
  ) -> Union[List[float], snowpark.Column]:
16
- """TextEmbed calls into the LLM inference service to embed the text.
18
+ """Calls into the LLM inference service to embed the text.
17
19
 
18
20
  Args:
19
21
  model: A Column of strings representing the model to use for embedding. The value
@@ -35,3 +37,8 @@ def _embed_text_1024_impl(
35
37
  session: Optional[snowpark.Session] = None,
36
38
  ) -> Union[List[float], snowpark.Column]:
37
39
  return cast(Union[List[float], snowpark.Column], call_sql_function(function, session, model, text))
40
+
41
+
42
+ EmbedText1024 = deprecated(
43
+ "EmbedText1024() is deprecated and will be removed in a future release. Use embed_text_1024() instead"
44
+ )(telemetry.send_api_usage_telemetry(project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT)(embed_text_1024))
@@ -1,5 +1,7 @@
1
1
  from typing import List, Optional, Union, cast
2
2
 
3
+ from typing_extensions import deprecated
4
+
3
5
  from snowflake import snowpark
4
6
  from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
5
7
  from snowflake.ml._internal import telemetry
@@ -8,12 +10,12 @@ from snowflake.ml._internal import telemetry
8
10
  @telemetry.send_api_usage_telemetry(
9
11
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
10
12
  )
11
- def EmbedText768(
13
+ def embed_text_768(
12
14
  model: Union[str, snowpark.Column],
13
15
  text: Union[str, snowpark.Column],
14
16
  session: Optional[snowpark.Session] = None,
15
17
  ) -> Union[List[float], snowpark.Column]:
16
- """TextEmbed calls into the LLM inference service to embed the text.
18
+ """Calls into the LLM inference service to embed the text.
17
19
 
18
20
  Args:
19
21
  model: A Column of strings representing the model to use for embedding. The value
@@ -35,3 +37,8 @@ def _embed_text_768_impl(
35
37
  session: Optional[snowpark.Session] = None,
36
38
  ) -> Union[List[float], snowpark.Column]:
37
39
  return cast(Union[List[float], snowpark.Column], call_sql_function(function, session, model, text))
40
+
41
+
42
+ EmbedText768 = deprecated(
43
+ "EmbedText768() is deprecated and will be removed in a future release. Use embed_text_768() instead"
44
+ )(telemetry.send_api_usage_telemetry(project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT)(embed_text_768))
@@ -1,5 +1,7 @@
1
1
  from typing import Optional, Union, cast
2
2
 
3
+ from typing_extensions import deprecated
4
+
3
5
  from snowflake import snowpark
4
6
  from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
5
7
  from snowflake.ml._internal import telemetry
@@ -8,12 +10,12 @@ from snowflake.ml._internal import telemetry
8
10
  @telemetry.send_api_usage_telemetry(
9
11
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
10
12
  )
11
- def ExtractAnswer(
13
+ def extract_answer(
12
14
  from_text: Union[str, snowpark.Column],
13
15
  question: Union[str, snowpark.Column],
14
16
  session: Optional[snowpark.Session] = None,
15
17
  ) -> Union[str, snowpark.Column]:
16
- """ExtractAnswer calls into the LLM inference service to extract an answer from within specified text.
18
+ """Calls into the LLM inference service to extract an answer from within specified text.
17
19
 
18
20
  Args:
19
21
  from_text: A Column of strings representing input text.
@@ -34,3 +36,8 @@ def _extract_answer_impl(
34
36
  session: Optional[snowpark.Session] = None,
35
37
  ) -> Union[str, snowpark.Column]:
36
38
  return cast(Union[str, snowpark.Column], call_sql_function(function, session, from_text, question))
39
+
40
+
41
+ ExtractAnswer = deprecated(
42
+ "ExtractAnswer() is deprecated and will be removed in a future release. Use extract_answer() instead"
43
+ )(telemetry.send_api_usage_telemetry(project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT)(extract_answer))
@@ -1,5 +1,7 @@
1
1
  from typing import Optional, Union, cast
2
2
 
3
+ from typing_extensions import deprecated
4
+
3
5
  from snowflake import snowpark
4
6
  from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
5
7
  from snowflake.ml._internal import telemetry
@@ -8,10 +10,10 @@ from snowflake.ml._internal import telemetry
8
10
  @telemetry.send_api_usage_telemetry(
9
11
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
10
12
  )
11
- def Sentiment(
13
+ def sentiment(
12
14
  text: Union[str, snowpark.Column], session: Optional[snowpark.Session] = None
13
15
  ) -> Union[float, snowpark.Column]:
14
- """Sentiment calls into the LLM inference service to perform sentiment analysis on the input text.
16
+ """Calls into the LLM inference service to perform sentiment analysis on the input text.
15
17
 
16
18
  Args:
17
19
  text: A Column of text strings to send to the LLM.
@@ -31,3 +33,8 @@ def _sentiment_impl(
31
33
  if isinstance(output, snowpark.Column):
32
34
  return output
33
35
  return float(cast(str, output))
36
+
37
+
38
+ Sentiment = deprecated("Sentiment() is deprecated and will be removed in a future release. Use sentiment() instead")(
39
+ telemetry.send_api_usage_telemetry(project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT)(sentiment)
40
+ )
@@ -1,5 +1,7 @@
1
1
  from typing import Optional, Union, cast
2
2
 
3
+ from typing_extensions import deprecated
4
+
3
5
  from snowflake import snowpark
4
6
  from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
5
7
  from snowflake.ml._internal import telemetry
@@ -8,11 +10,11 @@ from snowflake.ml._internal import telemetry
8
10
  @telemetry.send_api_usage_telemetry(
9
11
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
10
12
  )
11
- def Summarize(
13
+ def summarize(
12
14
  text: Union[str, snowpark.Column],
13
15
  session: Optional[snowpark.Session] = None,
14
16
  ) -> Union[str, snowpark.Column]:
15
- """Summarize calls into the LLM inference service to summarize the input text.
17
+ """Calls into the LLM inference service to summarize the input text.
16
18
 
17
19
  Args:
18
20
  text: A Column of strings to summarize.
@@ -31,3 +33,8 @@ def _summarize_impl(
31
33
  session: Optional[snowpark.Session] = None,
32
34
  ) -> Union[str, snowpark.Column]:
33
35
  return cast(Union[str, snowpark.Column], call_sql_function(function, session, text))
36
+
37
+
38
+ Summarize = deprecated("Summarize() is deprecated and will be removed in a future release. Use summarize() instead")(
39
+ telemetry.send_api_usage_telemetry(project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT)(summarize)
40
+ )
@@ -1,5 +1,7 @@
1
1
  from typing import Optional, Union, cast
2
2
 
3
+ from typing_extensions import deprecated
4
+
3
5
  from snowflake import snowpark
4
6
  from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
5
7
  from snowflake.ml._internal import telemetry
@@ -8,13 +10,13 @@ from snowflake.ml._internal import telemetry
8
10
  @telemetry.send_api_usage_telemetry(
9
11
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
10
12
  )
11
- def Translate(
13
+ def translate(
12
14
  text: Union[str, snowpark.Column],
13
15
  from_language: Union[str, snowpark.Column],
14
16
  to_language: Union[str, snowpark.Column],
15
17
  session: Optional[snowpark.Session] = None,
16
18
  ) -> Union[str, snowpark.Column]:
17
- """Translate calls into the LLM inference service to perform translation.
19
+ """Calls into the LLM inference service to perform translation.
18
20
 
19
21
  Args:
20
22
  text: A Column of strings to translate.
@@ -37,3 +39,8 @@ def _translate_impl(
37
39
  session: Optional[snowpark.Session] = None,
38
40
  ) -> Union[str, snowpark.Column]:
39
41
  return cast(Union[str, snowpark.Column], call_sql_function(function, session, text, from_language, to_language))
42
+
43
+
44
+ Translate = deprecated("Translate() is deprecated and will be removed in a future release. Use translate() instead")(
45
+ telemetry.send_api_usage_telemetry(project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT)(translate)
46
+ )
@@ -15,7 +15,6 @@ import snowflake.connector
15
15
  from snowflake.ml._internal import env as snowml_env
16
16
  from snowflake.ml._internal.utils import query_result_checker
17
17
  from snowflake.snowpark import context, exceptions, session
18
- from snowflake.snowpark._internal import utils as snowpark_utils
19
18
 
20
19
 
21
20
  class CONDA_OS(Enum):
@@ -344,55 +343,6 @@ def relax_requirement_version(req: requirements.Requirement) -> requirements.Req
344
343
  return new_req
345
344
 
346
345
 
347
- def get_matched_package_versions_in_snowflake_conda_channel(
348
- req: requirements.Requirement,
349
- python_version: str = snowml_env.PYTHON_VERSION,
350
- conda_os: CONDA_OS = CONDA_OS.LINUX_64,
351
- ) -> List[version.Version]:
352
- """Search the snowflake anaconda channel for packages that matches the specifier. Note that this will be the
353
- source of truth for checking whether a package indeed exists in Snowflake conda channel.
354
-
355
- Given that a package comes in different architectures, we only check for the Linux x86_64 architecture and assume
356
- the package exists in other architectures. If such an assumption does not hold true for a certain package, the
357
- caller should specify the architecture to search for.
358
-
359
- Args:
360
- req: Requirement specifier.
361
- python_version: A string of python version where model is run.
362
- conda_os: Specified platform to search availability of the package.
363
-
364
- Returns:
365
- List of package versions that meet the requirement specifier.
366
- """
367
- # Move the retryable_http import here as when UDF import this file, it won't have the "requests" dependency.
368
- from snowflake.ml._internal.utils import retryable_http
369
-
370
- assert not snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call]
371
-
372
- url = f"{SNOWFLAKE_CONDA_CHANNEL_URL}/{conda_os.value}/repodata.json"
373
-
374
- if req.name not in _SNOWFLAKE_CONDA_PACKAGE_CACHE:
375
- try:
376
- http_client = retryable_http.get_http_client()
377
- parsed_python_version = version.Version(python_version)
378
- python_version_build_str = f"py{parsed_python_version.major}{parsed_python_version.minor}"
379
- repodata = http_client.get(url).json()
380
- assert isinstance(repodata, dict)
381
- packages_info = repodata["packages"]
382
- assert isinstance(packages_info, dict)
383
- version_list = [
384
- version.parse(package_info["version"])
385
- for package_info in packages_info.values()
386
- if package_info["name"] == req.name and python_version_build_str in package_info["build"]
387
- ]
388
- _SNOWFLAKE_CONDA_PACKAGE_CACHE[req.name] = version_list
389
- except Exception:
390
- pass
391
-
392
- matched_versions = list(req.specifier.filter(set(_SNOWFLAKE_CONDA_PACKAGE_CACHE.get(req.name, []))))
393
- return matched_versions
394
-
395
-
396
346
  def get_matched_package_versions_in_information_schema_with_active_session(
397
347
  reqs: List[requirements.Requirement], python_version: str
398
348
  ) -> Dict[str, List[version.Version]]:
@@ -404,7 +354,10 @@ def get_matched_package_versions_in_information_schema_with_active_session(
404
354
 
405
355
 
406
356
  def get_matched_package_versions_in_information_schema(
407
- session: session.Session, reqs: List[requirements.Requirement], python_version: str
357
+ session: session.Session,
358
+ reqs: List[requirements.Requirement],
359
+ python_version: str,
360
+ statement_params: Optional[Dict[str, Any]] = None,
408
361
  ) -> Dict[str, List[version.Version]]:
409
362
  """Look up the information_schema table to check if a package with the specified specifier exists in the Snowflake
410
363
  Conda channel. Note that this is not the source of truth due to the potential delay caused by a package that might
@@ -414,6 +367,7 @@ def get_matched_package_versions_in_information_schema(
414
367
  session: Snowflake connection session.
415
368
  reqs: List of requirement specifiers.
416
369
  python_version: A string of python version where model is run.
370
+ statement_params: Optional statement parameters.
417
371
 
418
372
  Returns:
419
373
  A Dict, whose key is the package name, and value is a list of versions match the requirements.
@@ -451,8 +405,9 @@ def get_matched_package_versions_in_information_schema(
451
405
  query_result_checker.SqlResultValidator(
452
406
  session=session,
453
407
  query=sql,
408
+ statement_params=statement_params,
454
409
  )
455
- .has_column("VERSION")
410
+ .has_column("VERSION", allow_empty=True)
456
411
  .has_dimensions(expected_rows=None, expected_cols=2)
457
412
  .validate()
458
413
  )
@@ -0,0 +1,87 @@
1
+ import json
2
+ from typing import Any, Dict, Optional
3
+
4
+ from absl import logging
5
+
6
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
7
+ from snowflake.ml._internal.utils import query_result_checker
8
+ from snowflake.snowpark import (
9
+ exceptions as snowpark_exceptions,
10
+ session as snowpark_session,
11
+ )
12
+
13
+
14
+ class PlatformCapabilities:
15
+ """Class that retrieves platform feature values for the currently running server.
16
+
17
+ Example usage:
18
+ ```
19
+ pc = PlatformCapabilities.get_instance(session)
20
+ if pc.is_nested_function_enabled():
21
+ # Nested functions are enabled.
22
+ print("Nested functions are enabled.")
23
+ else:
24
+ # Nested functions are disabled.
25
+ print("Nested functions are disabled or not supported.")
26
+ ```
27
+ """
28
+
29
+ _instance: Optional["PlatformCapabilities"] = None
30
+
31
+ @classmethod
32
+ def get_instance(cls, session: Optional[snowpark_session.Session] = None) -> "PlatformCapabilities":
33
+ if not cls._instance:
34
+ cls._instance = cls(session)
35
+ return cls._instance
36
+
37
+ def is_nested_function_enabled(self) -> bool:
38
+ return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
39
+
40
+ @staticmethod
41
+ def _get_features(session: snowpark_session.Session) -> Dict[str, Any]:
42
+ try:
43
+ result = (
44
+ query_result_checker.SqlResultValidator(
45
+ session=session,
46
+ query="SELECT SYSTEM$ML_PLATFORM_CAPABILITIES() AS FEATURES;",
47
+ )
48
+ .has_dimensions(expected_rows=1, expected_cols=1)
49
+ .has_column("FEATURES")
50
+ .validate()[0]
51
+ )
52
+ if "FEATURES" in result:
53
+ capabilities_json: str = result["FEATURES"]
54
+ try:
55
+ parsed_json = json.loads(capabilities_json)
56
+ assert isinstance(parsed_json, dict), f"Expected JSON object, got {type(parsed_json)}"
57
+ return parsed_json
58
+ except json.JSONDecodeError as e:
59
+ message = f"""Unable to parse JSON from: "{capabilities_json}"; Error="{e}"."""
60
+ raise exceptions.SnowflakeMLException(
61
+ error_code=error_codes.INTERNAL_SNOWML_ERROR, original_exception=RuntimeError(message)
62
+ )
63
+ except snowpark_exceptions.SnowparkSQLException as e:
64
+ logging.debug(f"Failed to retrieve platform capabilities: {e}")
65
+ # This can happen is server side is older than 9.2. That is fine.
66
+ return {}
67
+
68
+ def __init__(self, session: Optional[snowpark_session.Session] = None) -> None:
69
+ if not session:
70
+ session = next(iter(snowpark_session._get_active_sessions()))
71
+ assert session, "Missing active session object"
72
+ self.features: Dict[str, Any] = PlatformCapabilities._get_features(session)
73
+
74
+ def _get_bool_feature(self, feature_name: str, default_value: bool) -> bool:
75
+ value = self.features.get(feature_name, default_value)
76
+ if isinstance(value, bool):
77
+ return value
78
+ if isinstance(value, int) and value in [0, 1]:
79
+ return value == 1
80
+ if isinstance(value, str):
81
+ if value.lower() in ["true", "1"]:
82
+ return True
83
+ elif value.lower() in ["false", "0"]:
84
+ return False
85
+ else:
86
+ raise ValueError(f"Invalid boolean string: {value} for feature {feature_name}")
87
+ raise ValueError(f"Invalid boolean feature value: {value} for feature {feature_name}")