snowflake-ml-python 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. snowflake/cortex/_complete.py +1 -1
  2. snowflake/cortex/_extract_answer.py +1 -1
  3. snowflake/cortex/_sentiment.py +1 -1
  4. snowflake/cortex/_summarize.py +1 -1
  5. snowflake/cortex/_translate.py +1 -1
  6. snowflake/ml/_internal/env_utils.py +68 -6
  7. snowflake/ml/_internal/file_utils.py +34 -4
  8. snowflake/ml/_internal/telemetry.py +79 -91
  9. snowflake/ml/_internal/utils/retryable_http.py +16 -4
  10. snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
  11. snowflake/ml/dataset/dataset.py +1 -1
  12. snowflake/ml/model/_api.py +21 -14
  13. snowflake/ml/model/_client/model/model_impl.py +176 -0
  14. snowflake/ml/model/_client/model/model_method_info.py +19 -0
  15. snowflake/ml/model/_client/model/model_version_impl.py +291 -0
  16. snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
  17. snowflake/ml/model/_client/ops/model_ops.py +308 -0
  18. snowflake/ml/model/_client/sql/model.py +75 -0
  19. snowflake/ml/model/_client/sql/model_version.py +213 -0
  20. snowflake/ml/model/_client/sql/stage.py +40 -0
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
  22. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
  23. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
  24. snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
  25. snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
  26. snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
  27. snowflake/ml/model/_model_composer/model_composer.py +31 -9
  28. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
  29. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
  30. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  31. snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
  32. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
  33. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
  34. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
  36. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  37. snowflake/ml/model/model_signature.py +108 -53
  38. snowflake/ml/model/type_hints.py +1 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
  40. snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
  41. snowflake/ml/modeling/_internal/model_specifications.py +146 -0
  42. snowflake/ml/modeling/_internal/model_trainer.py +13 -0
  43. snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
  44. snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
  45. snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
  46. snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
  47. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +96 -124
  48. snowflake/ml/modeling/cluster/affinity_propagation.py +94 -124
  49. snowflake/ml/modeling/cluster/agglomerative_clustering.py +94 -124
  50. snowflake/ml/modeling/cluster/birch.py +94 -124
  51. snowflake/ml/modeling/cluster/bisecting_k_means.py +94 -124
  52. snowflake/ml/modeling/cluster/dbscan.py +94 -124
  53. snowflake/ml/modeling/cluster/feature_agglomeration.py +94 -124
  54. snowflake/ml/modeling/cluster/k_means.py +93 -124
  55. snowflake/ml/modeling/cluster/mean_shift.py +94 -124
  56. snowflake/ml/modeling/cluster/mini_batch_k_means.py +93 -124
  57. snowflake/ml/modeling/cluster/optics.py +94 -124
  58. snowflake/ml/modeling/cluster/spectral_biclustering.py +94 -124
  59. snowflake/ml/modeling/cluster/spectral_clustering.py +94 -124
  60. snowflake/ml/modeling/cluster/spectral_coclustering.py +94 -124
  61. snowflake/ml/modeling/compose/column_transformer.py +94 -124
  62. snowflake/ml/modeling/compose/transformed_target_regressor.py +96 -124
  63. snowflake/ml/modeling/covariance/elliptic_envelope.py +94 -124
  64. snowflake/ml/modeling/covariance/empirical_covariance.py +80 -110
  65. snowflake/ml/modeling/covariance/graphical_lasso.py +94 -124
  66. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +94 -124
  67. snowflake/ml/modeling/covariance/ledoit_wolf.py +85 -115
  68. snowflake/ml/modeling/covariance/min_cov_det.py +94 -124
  69. snowflake/ml/modeling/covariance/oas.py +80 -110
  70. snowflake/ml/modeling/covariance/shrunk_covariance.py +84 -114
  71. snowflake/ml/modeling/decomposition/dictionary_learning.py +94 -124
  72. snowflake/ml/modeling/decomposition/factor_analysis.py +94 -124
  73. snowflake/ml/modeling/decomposition/fast_ica.py +94 -124
  74. snowflake/ml/modeling/decomposition/incremental_pca.py +94 -124
  75. snowflake/ml/modeling/decomposition/kernel_pca.py +94 -124
  76. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +94 -124
  77. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +94 -124
  78. snowflake/ml/modeling/decomposition/pca.py +94 -124
  79. snowflake/ml/modeling/decomposition/sparse_pca.py +94 -124
  80. snowflake/ml/modeling/decomposition/truncated_svd.py +94 -124
  81. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +96 -124
  82. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +91 -119
  83. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +96 -124
  84. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +96 -124
  85. snowflake/ml/modeling/ensemble/bagging_classifier.py +96 -124
  86. snowflake/ml/modeling/ensemble/bagging_regressor.py +96 -124
  87. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +96 -124
  88. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +96 -124
  89. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +96 -124
  90. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +96 -124
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +96 -124
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +96 -124
  93. snowflake/ml/modeling/ensemble/isolation_forest.py +94 -124
  94. snowflake/ml/modeling/ensemble/random_forest_classifier.py +96 -124
  95. snowflake/ml/modeling/ensemble/random_forest_regressor.py +96 -124
  96. snowflake/ml/modeling/ensemble/stacking_regressor.py +96 -124
  97. snowflake/ml/modeling/ensemble/voting_classifier.py +96 -124
  98. snowflake/ml/modeling/ensemble/voting_regressor.py +91 -119
  99. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +82 -110
  100. snowflake/ml/modeling/feature_selection/select_fdr.py +80 -108
  101. snowflake/ml/modeling/feature_selection/select_fpr.py +80 -108
  102. snowflake/ml/modeling/feature_selection/select_fwe.py +80 -108
  103. snowflake/ml/modeling/feature_selection/select_k_best.py +81 -109
  104. snowflake/ml/modeling/feature_selection/select_percentile.py +80 -108
  105. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +94 -124
  106. snowflake/ml/modeling/feature_selection/variance_threshold.py +76 -106
  107. snowflake/ml/modeling/framework/base.py +2 -2
  108. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +96 -124
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +96 -124
  110. snowflake/ml/modeling/impute/iterative_imputer.py +94 -124
  111. snowflake/ml/modeling/impute/knn_imputer.py +94 -124
  112. snowflake/ml/modeling/impute/missing_indicator.py +94 -124
  113. snowflake/ml/modeling/impute/simple_imputer.py +1 -1
  114. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +77 -107
  115. snowflake/ml/modeling/kernel_approximation/nystroem.py +94 -124
  116. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +94 -124
  117. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +86 -116
  118. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +84 -114
  119. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +96 -124
  120. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +71 -100
  121. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +71 -100
  122. snowflake/ml/modeling/linear_model/ard_regression.py +96 -124
  123. snowflake/ml/modeling/linear_model/bayesian_ridge.py +96 -124
  124. snowflake/ml/modeling/linear_model/elastic_net.py +96 -124
  125. snowflake/ml/modeling/linear_model/elastic_net_cv.py +96 -124
  126. snowflake/ml/modeling/linear_model/gamma_regressor.py +96 -124
  127. snowflake/ml/modeling/linear_model/huber_regressor.py +96 -124
  128. snowflake/ml/modeling/linear_model/lars.py +96 -124
  129. snowflake/ml/modeling/linear_model/lars_cv.py +96 -124
  130. snowflake/ml/modeling/linear_model/lasso.py +96 -124
  131. snowflake/ml/modeling/linear_model/lasso_cv.py +96 -124
  132. snowflake/ml/modeling/linear_model/lasso_lars.py +96 -124
  133. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +96 -124
  134. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +96 -124
  135. snowflake/ml/modeling/linear_model/linear_regression.py +91 -119
  136. snowflake/ml/modeling/linear_model/logistic_regression.py +96 -124
  137. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +96 -124
  138. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +96 -124
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +96 -124
  140. snowflake/ml/modeling/linear_model/multi_task_lasso.py +96 -124
  141. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +96 -124
  142. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +96 -124
  143. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +96 -124
  144. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +95 -124
  145. snowflake/ml/modeling/linear_model/perceptron.py +95 -124
  146. snowflake/ml/modeling/linear_model/poisson_regressor.py +96 -124
  147. snowflake/ml/modeling/linear_model/ransac_regressor.py +96 -124
  148. snowflake/ml/modeling/linear_model/ridge.py +96 -124
  149. snowflake/ml/modeling/linear_model/ridge_classifier.py +96 -124
  150. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +96 -124
  151. snowflake/ml/modeling/linear_model/ridge_cv.py +96 -124
  152. snowflake/ml/modeling/linear_model/sgd_classifier.py +96 -124
  153. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +94 -124
  154. snowflake/ml/modeling/linear_model/sgd_regressor.py +96 -124
  155. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +96 -124
  156. snowflake/ml/modeling/linear_model/tweedie_regressor.py +96 -124
  157. snowflake/ml/modeling/manifold/isomap.py +94 -124
  158. snowflake/ml/modeling/manifold/mds.py +94 -124
  159. snowflake/ml/modeling/manifold/spectral_embedding.py +94 -124
  160. snowflake/ml/modeling/manifold/tsne.py +94 -124
  161. snowflake/ml/modeling/metrics/classification.py +187 -52
  162. snowflake/ml/modeling/metrics/correlation.py +4 -2
  163. snowflake/ml/modeling/metrics/covariance.py +7 -4
  164. snowflake/ml/modeling/metrics/ranking.py +32 -16
  165. snowflake/ml/modeling/metrics/regression.py +60 -32
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +94 -124
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +94 -124
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +88 -138
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +90 -144
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +86 -114
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +93 -121
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +94 -122
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +92 -120
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +96 -124
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +92 -120
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -107
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +88 -116
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +96 -124
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +96 -124
  180. snowflake/ml/modeling/neighbors/kernel_density.py +94 -124
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +94 -124
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +89 -117
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +94 -124
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +96 -124
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +96 -124
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +96 -124
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +94 -124
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +96 -124
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +96 -124
  190. snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
  191. snowflake/ml/modeling/preprocessing/binarizer.py +14 -9
  192. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +0 -4
  193. snowflake/ml/modeling/preprocessing/label_encoder.py +21 -13
  194. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +20 -14
  195. snowflake/ml/modeling/preprocessing/min_max_scaler.py +35 -19
  196. snowflake/ml/modeling/preprocessing/normalizer.py +6 -9
  197. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +20 -13
  198. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +25 -13
  199. snowflake/ml/modeling/preprocessing/polynomial_features.py +94 -124
  200. snowflake/ml/modeling/preprocessing/robust_scaler.py +28 -14
  201. snowflake/ml/modeling/preprocessing/standard_scaler.py +25 -13
  202. snowflake/ml/modeling/semi_supervised/label_propagation.py +96 -124
  203. snowflake/ml/modeling/semi_supervised/label_spreading.py +96 -124
  204. snowflake/ml/modeling/svm/linear_svc.py +96 -124
  205. snowflake/ml/modeling/svm/linear_svr.py +96 -124
  206. snowflake/ml/modeling/svm/nu_svc.py +96 -124
  207. snowflake/ml/modeling/svm/nu_svr.py +96 -124
  208. snowflake/ml/modeling/svm/svc.py +96 -124
  209. snowflake/ml/modeling/svm/svr.py +96 -124
  210. snowflake/ml/modeling/tree/decision_tree_classifier.py +96 -124
  211. snowflake/ml/modeling/tree/decision_tree_regressor.py +96 -124
  212. snowflake/ml/modeling/tree/extra_tree_classifier.py +96 -124
  213. snowflake/ml/modeling/tree/extra_tree_regressor.py +96 -124
  214. snowflake/ml/modeling/xgboost/xgb_classifier.py +96 -125
  215. snowflake/ml/modeling/xgboost/xgb_regressor.py +96 -125
  216. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +96 -125
  217. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +96 -125
  218. snowflake/ml/registry/model_registry.py +2 -0
  219. snowflake/ml/registry/registry.py +215 -0
  220. snowflake/ml/version.py +1 -1
  221. {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +21 -3
  222. snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
  223. snowflake_ml_python-1.1.1.dist-info/RECORD +0 -331
  224. {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -23,7 +23,7 @@ def Complete(
23
23
  A column of string responses.
24
24
  """
25
25
 
26
- return _complete_impl("snowflake.ml.complete", model, prompt, session=session)
26
+ return _complete_impl("snowflake.cortex.complete", model, prompt, session=session)
27
27
 
28
28
 
29
29
  def _complete_impl(
@@ -25,7 +25,7 @@ def ExtractAnswer(
25
25
  A column of strings containing answers.
26
26
  """
27
27
 
28
- return _extract_answer_impl("snowflake.ml.extract_answer", from_text, question, session=session)
28
+ return _extract_answer_impl("snowflake.cortex.extract_answer", from_text, question, session=session)
29
29
 
30
30
 
31
31
  def _extract_answer_impl(
@@ -22,7 +22,7 @@ def Sentiment(
22
22
  A column of floats. 1 represents positive sentiment, -1 represents negative sentiment.
23
23
  """
24
24
 
25
- return _sentiment_impl("snowflake.ml.sentiment", text, session=session)
25
+ return _sentiment_impl("snowflake.cortex.sentiment", text, session=session)
26
26
 
27
27
 
28
28
  def _sentiment_impl(
@@ -23,7 +23,7 @@ def Summarize(
23
23
  A column of string summaries.
24
24
  """
25
25
 
26
- return _summarize_impl("snowflake.ml.summarize", text, session=session)
26
+ return _summarize_impl("snowflake.cortex.summarize", text, session=session)
27
27
 
28
28
 
29
29
  def _summarize_impl(
@@ -27,7 +27,7 @@ def Translate(
27
27
  A column of string translations.
28
28
  """
29
29
 
30
- return _translate_impl("snowflake.ml.translate", text, from_language, to_language, session=session)
30
+ return _translate_impl("snowflake.cortex.translate", text, from_language, to_language, session=session)
31
31
 
32
32
 
33
33
  def _translate_impl(
@@ -4,6 +4,7 @@ import pathlib
4
4
  import re
5
5
  import textwrap
6
6
  import warnings
7
+ from enum import Enum
7
8
  from importlib import metadata as importlib_metadata
8
9
  from typing import Any, DefaultDict, Dict, List, Optional, Tuple
9
10
 
@@ -18,10 +19,22 @@ from snowflake.ml._internal.exceptions import (
18
19
  )
19
20
  from snowflake.ml._internal.utils import query_result_checker
20
21
  from snowflake.snowpark import session
22
+ from snowflake.snowpark._internal import utils as snowpark_utils
23
+
24
+
25
+ class CONDA_OS(Enum):
26
+ LINUX_64 = "linux-64"
27
+ LINUX_AARCH64 = "linux-aarch64"
28
+ OSX_64 = "osx-64"
29
+ OSX_ARM64 = "osx-arm64"
30
+ WIN_64 = "win-64"
31
+ NO_ARCH = "noarch"
32
+
21
33
 
22
34
  _SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
23
35
  _NODEFAULTS = "nodefaults"
24
36
  _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION: Optional[bool] = None
37
+ _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
25
38
  _SNOWFLAKE_CONDA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
26
39
 
27
40
  DEFAULT_CHANNEL_NAME = ""
@@ -217,6 +230,7 @@ def get_local_installed_version_of_pip_package(pip_req: requirements.Requirement
217
230
  warnings.warn(
218
231
  f"Package requirement {str(pip_req)} specified, while version {local_dist_version} is installed. "
219
232
  "Local version will be ignored to conform to package requirement.",
233
+ stacklevel=2,
220
234
  category=UserWarning,
221
235
  )
222
236
  return pip_req
@@ -265,10 +279,58 @@ def _check_runtime_version_column_existence(session: session.Session) -> bool:
265
279
  return result == 1
266
280
 
267
281
 
268
- def validate_requirements_in_snowflake_conda_channel(
282
+ def get_matched_package_versions_in_snowflake_conda_channel(
283
+ req: requirements.Requirement,
284
+ python_version: str = snowml_env.PYTHON_VERSION,
285
+ conda_os: CONDA_OS = CONDA_OS.LINUX_64,
286
+ ) -> List[version.Version]:
287
+ """Search the snowflake anaconda channel for packages that matches the specifier. Note that this will be the
288
+ source of truth for checking whether a package indeed exists in Snowflake conda channel.
289
+
290
+ Given that a package comes in different architectures, we only check for the Linux x86_64 architecture and assume
291
+ the package exists in other architectures. If such an assumption does not hold true for a certain package, the
292
+ caller should specify the architecture to search for.
293
+
294
+ Args:
295
+ req: Requirement specifier.
296
+ python_version: A string of python version where model is run.
297
+ conda_os: Specified platform to search availability of the package.
298
+
299
+ Returns:
300
+ List of package versions that meet the requirement specifier.
301
+ """
302
+ # Move the retryable_http import here as when UDF import this file, it won't have the "requests" dependency.
303
+ from snowflake.ml._internal.utils import retryable_http
304
+
305
+ assert not snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call]
306
+
307
+ url = f"{_SNOWFLAKE_CONDA_CHANNEL_URL}/{conda_os.value}/repodata.json"
308
+
309
+ if req.name not in _SNOWFLAKE_CONDA_PACKAGE_CACHE:
310
+ http_client = retryable_http.get_http_client()
311
+ parsed_python_version = version.Version(python_version)
312
+ python_version_build_str = f"py{parsed_python_version.major}{parsed_python_version.minor}"
313
+ repodata = http_client.get(url).json()
314
+ assert isinstance(repodata, dict)
315
+ packages_info = repodata["packages"]
316
+ assert isinstance(packages_info, dict)
317
+ version_list = [
318
+ version.parse(package_info["version"])
319
+ for package_info in packages_info.values()
320
+ if package_info["name"] == req.name and python_version_build_str in package_info["build"]
321
+ ]
322
+ _SNOWFLAKE_CONDA_PACKAGE_CACHE[req.name] = version_list
323
+
324
+ matched_versions = list(req.specifier.filter(set(_SNOWFLAKE_CONDA_PACKAGE_CACHE.get(req.name, []))))
325
+ return matched_versions
326
+
327
+
328
+ def validate_requirements_in_information_schema(
269
329
  session: session.Session, reqs: List[requirements.Requirement], python_version: str
270
330
  ) -> Optional[List[str]]:
271
- """Search the snowflake anaconda channel for packages with version meet the specifier.
331
+ """Look up the information_schema table to check if a package with the specified specifier exists in the Snowflake
332
+ Conda channel. Note that this is not the source of truth due to the potential delay caused by a package that might
333
+ exist in the information_schema table but has not yet become available in the Snowflake Conda channel.
272
334
 
273
335
  Args:
274
336
  session: Snowflake connection session.
@@ -285,7 +347,7 @@ def validate_requirements_in_snowflake_conda_channel(
285
347
  ret_list = []
286
348
  reqs_to_request = []
287
349
  for req in reqs:
288
- if req.name not in _SNOWFLAKE_CONDA_PACKAGE_CACHE:
350
+ if req.name not in _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE:
289
351
  reqs_to_request.append(req)
290
352
  if reqs_to_request:
291
353
  pkg_names_str = " OR ".join(
@@ -326,13 +388,13 @@ def validate_requirements_in_snowflake_conda_channel(
326
388
  for row in result:
327
389
  req_name = row["PACKAGE_NAME"]
328
390
  req_ver = version.parse(row["VERSION"])
329
- cached_req_ver_list = _SNOWFLAKE_CONDA_PACKAGE_CACHE.get(req_name, [])
391
+ cached_req_ver_list = _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE.get(req_name, [])
330
392
  cached_req_ver_list.append(req_ver)
331
- _SNOWFLAKE_CONDA_PACKAGE_CACHE[req_name] = cached_req_ver_list
393
+ _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE[req_name] = cached_req_ver_list
332
394
  except snowflake.connector.DataError:
333
395
  return None
334
396
  for req in reqs:
335
- available_versions = list(req.specifier.filter(set(_SNOWFLAKE_CONDA_PACKAGE_CACHE.get(req.name, []))))
397
+ available_versions = list(req.specifier.filter(set(_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE.get(req.name, []))))
336
398
  if not available_versions:
337
399
  return None
338
400
  else:
@@ -28,6 +28,7 @@ import cloudpickle
28
28
 
29
29
  from snowflake import snowpark
30
30
  from snowflake.ml._internal.exceptions import exceptions
31
+ from snowflake.snowpark import exceptions as snowpark_exceptions
31
32
 
32
33
  GENERATED_PY_FILE_EXT = (".pyc", ".pyo", ".pyd", ".pyi")
33
34
 
@@ -286,8 +287,16 @@ def stage_file_exists(
286
287
  return False
287
288
 
288
289
 
290
+ def _retry_on_sql_error(exception: Exception) -> bool:
291
+ return isinstance(exception, snowpark_exceptions.SnowparkSQLException)
292
+
293
+
289
294
  def upload_directory_to_stage(
290
- session: snowpark.Session, local_path: pathlib.Path, stage_path: pathlib.PurePosixPath
295
+ session: snowpark.Session,
296
+ local_path: pathlib.Path,
297
+ stage_path: pathlib.PurePosixPath,
298
+ *,
299
+ statement_params: Optional[Dict[str, Any]] = None,
291
300
  ) -> None:
292
301
  """Upload a local folder recursively to a stage and keep the structure.
293
302
 
@@ -295,7 +304,10 @@ def upload_directory_to_stage(
295
304
  session: Snowpark Session.
296
305
  local_path: Local path to upload.
297
306
  stage_path: Base path in the stage.
307
+ statement_params: Statement Params.
298
308
  """
309
+ import retrying
310
+
299
311
  file_operation = snowpark.FileOperation(session=session)
300
312
 
301
313
  for root, _, filenames in os.walk(local_path):
@@ -305,16 +317,26 @@ def upload_directory_to_stage(
305
317
  stage_dir_path = (
306
318
  stage_path / pathlib.PurePosixPath(local_file_path.relative_to(local_path).as_posix()).parent
307
319
  )
308
- file_operation.put(
320
+ retrying.retry(
321
+ retry_on_exception=_retry_on_sql_error,
322
+ stop_max_attempt_number=5,
323
+ wait_exponential_multiplier=100,
324
+ wait_exponential_max=10000,
325
+ )(file_operation.put)(
309
326
  str(local_file_path),
310
327
  str(stage_dir_path),
311
328
  auto_compress=False,
312
329
  overwrite=False,
330
+ statement_params=statement_params,
313
331
  )
314
332
 
315
333
 
316
334
  def download_directory_from_stage(
317
- session: snowpark.Session, stage_path: pathlib.PurePosixPath, local_path: pathlib.Path
335
+ session: snowpark.Session,
336
+ stage_path: pathlib.PurePosixPath,
337
+ local_path: pathlib.Path,
338
+ *,
339
+ statement_params: Optional[Dict[str, Any]] = None,
318
340
  ) -> None:
319
341
  """Upload a folder in stage recursively to a folder in local and keep the structure.
320
342
 
@@ -322,7 +344,10 @@ def download_directory_from_stage(
322
344
  session: Snowpark Session.
323
345
  stage_path: Stage path to download from.
324
346
  local_path: Local path as the base of destination.
347
+ statement_params: Statement Params.
325
348
  """
349
+ import retrying
350
+
326
351
  file_operation = file_operation = snowpark.FileOperation(session=session)
327
352
  file_list = [
328
353
  pathlib.PurePosixPath(stage_path.parts[0], *pathlib.PurePosixPath(row.name).parts[1:])
@@ -331,4 +356,9 @@ def download_directory_from_stage(
331
356
  for stage_file_path in file_list:
332
357
  local_file_dir = local_path / stage_file_path.relative_to(stage_path).parent
333
358
  local_file_dir.mkdir(parents=True, exist_ok=True)
334
- file_operation.get(str(stage_file_path), str(local_file_dir))
359
+ retrying.retry(
360
+ retry_on_exception=_retry_on_sql_error,
361
+ stop_max_attempt_number=5,
362
+ wait_exponential_multiplier=100,
363
+ wait_exponential_max=10000,
364
+ )(file_operation.get)(str(stage_file_path), str(local_file_dir), statement_params=statement_params)
@@ -42,6 +42,7 @@ class TelemetryField(enum.Enum):
42
42
  NAME = "name"
43
43
  # types of telemetry
44
44
  TYPE_FUNCTION_USAGE = "function_usage"
45
+ TYPE_SNOWML_SPCS_USAGE = "snowml_spcs_usage"
45
46
  # message keys for telemetry
46
47
  KEY_PROJECT = "project"
47
48
  KEY_SUBPROJECT = "subproject"
@@ -207,6 +208,23 @@ def suppress_exceptions(func: Callable[..., Any]) -> Callable[..., Any]:
207
208
  return wrapper
208
209
 
209
210
 
211
+ def send_custom_usage(
212
+ project: str,
213
+ *,
214
+ telemetry_type: str,
215
+ subproject: Optional[str] = None,
216
+ data: Optional[Dict[str, Any]] = None,
217
+ **kwargs: Any,
218
+ ) -> None:
219
+ active_session = next(iter(session._get_active_sessions()))
220
+ assert active_session, "Missing active session object"
221
+
222
+ client = _SourceTelemetryClient(conn=active_session._conn._conn, project=project, subproject=subproject)
223
+ common_metrics = client._create_basic_telemetry_data(telemetry_type=telemetry_type)
224
+ data = {**common_metrics, TelemetryField.KEY_DATA.value: data, **kwargs}
225
+ client._send(msg=data)
226
+
227
+
210
228
  def send_api_usage_telemetry(
211
229
  project: str,
212
230
  subproject: Optional[str] = None,
@@ -228,7 +246,8 @@ def send_api_usage_telemetry(
228
246
  custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
229
247
  ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
230
248
  """
231
- Decorator that sends API usage telemetry.
249
+ Decorator that sends API usage telemetry and adds function usage statement parameters to the dataframe returned by
250
+ the function.
232
251
 
233
252
  Args:
234
253
  project: Project.
@@ -253,6 +272,51 @@ def send_api_usage_telemetry(
253
272
  def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
254
273
  params = _get_func_params(func, func_params_to_log, args, kwargs) if func_params_to_log else None
255
274
 
275
+ api_calls: List[Union[Dict[str, Union[Callable[..., Any], str]], Callable[..., Any], str]] = []
276
+ if api_calls_extractor:
277
+ extracted_api_calls = api_calls_extractor(args[0])
278
+ for api_call in extracted_api_calls:
279
+ if isinstance(api_call, str):
280
+ api_calls.append({TelemetryField.NAME.value: api_call})
281
+ elif callable(api_call):
282
+ api_calls.append({TelemetryField.NAME.value: _get_full_func_name(api_call)})
283
+ else:
284
+ api_calls.append(api_call)
285
+ api_calls.append({TelemetryField.NAME.value: _get_full_func_name(func)})
286
+
287
+ sfqids = None
288
+ if sfqids_extractor:
289
+ sfqids = sfqids_extractor(args[0])
290
+
291
+ statement_params = get_function_usage_statement_params(
292
+ project=project,
293
+ subproject=subproject,
294
+ function_category=TelemetryField.FUNC_CAT_USAGE.value,
295
+ function_name=_get_full_func_name(func),
296
+ function_parameters=params,
297
+ api_calls=api_calls,
298
+ custom_tags=custom_tags,
299
+ )
300
+
301
+ def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: Dict[str, Any]) -> _ReturnValue:
302
+ """
303
+ Update SnowML function usage statement parameters to the object if it is a Snowpark DataFrame.
304
+ Used to track APIs returning a Snowpark DataFrame.
305
+
306
+ Args:
307
+ obj: Object to check and update.
308
+ statement_params: Statement parameters.
309
+
310
+ Returns:
311
+ Updated object.
312
+ """
313
+ if isinstance(obj, dataframe.DataFrame):
314
+ if hasattr(obj, "_statement_params") and obj._statement_params:
315
+ obj._statement_params.update(statement_params)
316
+ else:
317
+ obj._statement_params = statement_params # type: ignore[assignment]
318
+ return obj
319
+
256
320
  # prioritize `conn_attr_name` over the active session
257
321
  if conn_attr_name:
258
322
  # raise AttributeError if conn attribute does not exist in `self`
@@ -266,36 +330,20 @@ def send_api_usage_telemetry(
266
330
  # server no default session
267
331
  except snowpark_exceptions.SnowparkSessionException:
268
332
  try:
269
- return func(*args, **kwargs)
333
+ return update_stmt_params_if_snowpark_df(func(*args, **kwargs), statement_params)
270
334
  except Exception as e:
271
335
  if isinstance(e, snowml_exceptions.SnowflakeMLException):
272
- e = e.original_exception
336
+ raise e.original_exception.with_traceback(e.__traceback__) from None
273
337
  # suppress SnowparkSessionException from telemetry in the stack trace
274
338
  raise e from None
275
339
 
276
340
  conn = active_session._conn._conn
277
341
  if (not active_session.telemetry_enabled) or (conn is None):
278
342
  try:
279
- return func(*args, **kwargs)
343
+ return update_stmt_params_if_snowpark_df(func(*args, **kwargs), statement_params)
280
344
  except snowml_exceptions.SnowflakeMLException as e:
281
345
  raise e.original_exception from e
282
346
 
283
- api_calls: List[Dict[str, Any]] = []
284
- if api_calls_extractor:
285
- extracted_api_calls = api_calls_extractor(args[0])
286
- for api_call in extracted_api_calls:
287
- if isinstance(api_call, str):
288
- api_calls.append({TelemetryField.NAME.value: api_call})
289
- elif callable(api_call):
290
- api_calls.append({TelemetryField.NAME.value: _get_full_func_name(api_call)})
291
- else:
292
- api_calls.append(api_call)
293
- api_calls.append({TelemetryField.NAME.value: _get_full_func_name(func)})
294
-
295
- sfqids = None
296
- if sfqids_extractor:
297
- sfqids = sfqids_extractor(args[0])
298
-
299
347
  # TODO(hayu): [SNOW-750287] Optimize telemetry client to a singleton.
300
348
  telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject)
301
349
  telemetry_args = dict(
@@ -314,22 +362,24 @@ def send_api_usage_telemetry(
314
362
  if hasattr(e, "_snowflake_ml_handled") and e._snowflake_ml_handled:
315
363
  raise e
316
364
  if isinstance(e, snowpark_exceptions.SnowparkClientException):
317
- e = snowml_exceptions.SnowflakeMLException(
365
+ me = snowml_exceptions.SnowflakeMLException(
318
366
  error_code=error_codes.INTERNAL_SNOWPARK_ERROR, original_exception=e
319
367
  )
320
368
  else:
321
- e = snowml_exceptions.SnowflakeMLException(
369
+ me = snowml_exceptions.SnowflakeMLException(
322
370
  error_code=error_codes.UNDEFINED, original_exception=e
323
371
  )
324
- telemetry_args["error"] = repr(e)
325
- telemetry_args["error_code"] = e.error_code
326
- e.original_exception._snowflake_ml_handled = True # type: ignore[attr-defined]
327
- if e.suppress_source_trace:
328
- raise e.original_exception from None
329
372
  else:
330
- raise e.original_exception from e
373
+ me = e
374
+ telemetry_args["error"] = repr(me)
375
+ telemetry_args["error_code"] = me.error_code
376
+ me.original_exception._snowflake_ml_handled = True # type: ignore[attr-defined]
377
+ if me.suppress_source_trace:
378
+ raise me.original_exception from None
379
+ else:
380
+ raise me.original_exception from e
331
381
  else:
332
- return res
382
+ return update_stmt_params_if_snowpark_df(res, statement_params)
333
383
  finally:
334
384
  telemetry.send_function_usage_telemetry(**telemetry_args)
335
385
  global _log_counter
@@ -343,68 +393,6 @@ def send_api_usage_telemetry(
343
393
  return decorator
344
394
 
345
395
 
346
- def add_stmt_params_to_df(
347
- project: str,
348
- subproject: Optional[str] = None,
349
- *,
350
- function_category: str = TelemetryField.FUNC_CAT_USAGE.value,
351
- func_params_to_log: Optional[Iterable[str]] = None,
352
- api_calls: Optional[
353
- List[
354
- Union[
355
- Dict[str, Union[Callable[..., Any], str]],
356
- Union[Callable[..., Any], str],
357
- ]
358
- ]
359
- ] = None,
360
- custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
361
- ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
362
- """
363
- Decorator that adds function usage statement parameters to the dataframe returned by the function.
364
-
365
- Args:
366
- project: Project.
367
- subproject: Subproject.
368
- function_category: Function category.
369
- func_params_to_log: Function parameters to log.
370
- api_calls: API calls in the function.
371
- custom_tags: Custom tags.
372
-
373
- Returns:
374
- Decorator that adds function usage statement parameters to the dataframe returned by the decorated function.
375
- """
376
-
377
- def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, _ReturnValue]:
378
- @functools.wraps(func)
379
- def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
380
- params = _get_func_params(func, func_params_to_log, args, kwargs) if func_params_to_log else None
381
- statement_params = get_function_usage_statement_params(
382
- project=project,
383
- subproject=subproject,
384
- function_category=function_category,
385
- function_name=_get_full_func_name(func),
386
- function_parameters=params,
387
- api_calls=api_calls,
388
- custom_tags=custom_tags,
389
- )
390
-
391
- try:
392
- res = func(*args, **kwargs)
393
- if isinstance(res, dataframe.DataFrame):
394
- if hasattr(res, "_statement_params") and res._statement_params:
395
- res._statement_params.update(statement_params)
396
- else:
397
- res._statement_params = statement_params # type: ignore[assignment]
398
- except Exception:
399
- raise
400
- else:
401
- return res
402
-
403
- return cast(Callable[_Args, _ReturnValue], wrap)
404
-
405
- return decorator
406
-
407
-
408
396
  def _get_full_func_name(func: Callable[..., Any]) -> str:
409
397
  """
410
398
  Get the full function name with module and qualname.
@@ -5,11 +5,23 @@ from requests import adapters
5
5
  from urllib3.util import retry
6
6
 
7
7
 
8
- def get_http_client() -> requests.Session:
9
- # Set up a retry policy for requests
8
+ def get_http_client(total_retries: int = 5, backoff_factor: float = 0.1) -> requests.Session:
9
+ """Construct retryable http client.
10
+
11
+ Args:
12
+ total_retries: Total number of retries to allow.
13
+ backoff_factor: A backoff factor to apply between attempts after the second try. Time to sleep is calculated by
14
+ {backoff factor} * (2 ** ({number of previous retries})). For example, with default retries of 5 and backoff
15
+ factor set to 0.1, each subsequent retry will sleep [0.2s, 0.4s, 0.8s, 1.6s, 3.2s] respectively.
16
+
17
+ Returns:
18
+ requests.Session object.
19
+
20
+ """
21
+
10
22
  retry_strategy = retry.Retry(
11
- total=3, # total number of retries
12
- backoff_factor=0.1, # 100ms initial delay
23
+ total=total_retries,
24
+ backoff_factor=backoff_factor,
13
25
  status_forcelist=[
14
26
  http.HTTPStatus.TOO_MANY_REQUESTS,
15
27
  http.HTTPStatus.INTERNAL_SERVER_ERROR,
@@ -0,0 +1,122 @@
1
+ import logging
2
+ from datetime import datetime
3
+ from typing import Any, Dict, Optional
4
+
5
+ from snowflake import snowpark
6
+ from snowflake.ml._internal import telemetry
7
+ from snowflake.ml._internal.utils import query_result_checker
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ _DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f %z"
12
+ _COMPUTE_POOL = "compute_pool"
13
+ _CREATED_ON = "created_on"
14
+ _INSTANCE_FAMILY = "instance_family"
15
+ _NAME = "name"
16
+ _TELEMETRY_PROJECT = "MLOps"
17
+ _TELEMETRY_SUBPROJECT = "SpcsDeployment"
18
+ _SERVICE_START = "SPCS_SERVICE_START"
19
+ _SERVICE_END = "SPCS_SERVICE_END"
20
+
21
+
22
+ def _desc_compute_pool(session: snowpark.Session, compute_pool_name: str) -> Dict[str, Any]:
23
+ sql = f"DESC COMPUTE POOL {compute_pool_name}"
24
+ result = (
25
+ query_result_checker.SqlResultValidator(
26
+ session=session,
27
+ query=sql,
28
+ )
29
+ .has_column(_INSTANCE_FAMILY)
30
+ .has_column(_NAME)
31
+ .has_dimensions(expected_rows=1)
32
+ .validate()
33
+ )
34
+ return result[0].as_dict()
35
+
36
+
37
+ def _desc_service(session: snowpark.Session, fully_qualified_name: str) -> Dict[str, Any]:
38
+ sql = f"DESC SERVICE {fully_qualified_name}"
39
+ result = (
40
+ query_result_checker.SqlResultValidator(
41
+ session=session,
42
+ query=sql,
43
+ )
44
+ .has_column(_COMPUTE_POOL)
45
+ .has_dimensions(expected_rows=1)
46
+ .validate()
47
+ )
48
+ return result[0].as_dict()
49
+
50
+
51
+ def _get_current_time() -> datetime:
52
+ """
53
+ This method exists to make it easier to mock datetime in test.
54
+
55
+ Returns:
56
+ current datetime
57
+ """
58
+ return datetime.now()
59
+
60
+
61
+ def _send_service_telemetry(
62
+ fully_qualified_name: Optional[str] = None,
63
+ compute_pool_name: Optional[str] = None,
64
+ service_details: Optional[Dict[str, Any]] = None,
65
+ compute_pool_details: Optional[Dict[str, Any]] = None,
66
+ duration_in_seconds: Optional[int] = None,
67
+ kwargs: Optional[Dict[str, Any]] = None,
68
+ ) -> None:
69
+ try:
70
+ telemetry.send_custom_usage(
71
+ project=_TELEMETRY_PROJECT,
72
+ subproject=_TELEMETRY_SUBPROJECT,
73
+ telemetry_type=telemetry.TelemetryField.TYPE_SNOWML_SPCS_USAGE.value,
74
+ data={
75
+ "service_name": fully_qualified_name,
76
+ "compute_pool_name": compute_pool_name,
77
+ "service_details": service_details,
78
+ "compute_pool_details": compute_pool_details,
79
+ "duration_in_seconds": duration_in_seconds,
80
+ },
81
+ kwargs=kwargs,
82
+ )
83
+ except Exception as e:
84
+ logger.error(f"Failed to send service telemetry: {e}")
85
+
86
+
87
+ def record_service_start(session: snowpark.Session, fully_qualified_name: str) -> None:
88
+ service_details = _desc_service(session, fully_qualified_name)
89
+ compute_pool_name = service_details[_COMPUTE_POOL]
90
+ compute_pool_details = _desc_compute_pool(session, compute_pool_name)
91
+
92
+ _send_service_telemetry(
93
+ fully_qualified_name=fully_qualified_name,
94
+ compute_pool_name=compute_pool_name,
95
+ service_details=service_details,
96
+ compute_pool_details=compute_pool_details,
97
+ kwargs={telemetry.TelemetryField.KEY_CUSTOM_TAGS.value: _SERVICE_START},
98
+ )
99
+
100
+ logger.info(f"Service {fully_qualified_name} created with compute pool {compute_pool_name}.")
101
+
102
+
103
+ def record_service_end(session: snowpark.Session, fully_qualified_name: str) -> None:
104
+ service_details = _desc_service(session, fully_qualified_name)
105
+ compute_pool_details = _desc_compute_pool(session, service_details[_COMPUTE_POOL])
106
+ compute_pool_name = service_details[_COMPUTE_POOL]
107
+
108
+ created_on_datetime: datetime = service_details[_CREATED_ON]
109
+ current_time: datetime = _get_current_time()
110
+ current_time = current_time.replace(tzinfo=created_on_datetime.tzinfo)
111
+ duration_in_seconds = int((current_time - created_on_datetime).total_seconds())
112
+
113
+ _send_service_telemetry(
114
+ fully_qualified_name=fully_qualified_name,
115
+ compute_pool_name=compute_pool_name,
116
+ service_details=service_details,
117
+ compute_pool_details=compute_pool_details,
118
+ duration_in_seconds=duration_in_seconds,
119
+ kwargs={telemetry.TelemetryField.KEY_CUSTOM_TAGS.value: _SERVICE_END},
120
+ )
121
+
122
+ logger.info(f"Service {fully_qualified_name} deleted from compute pool {compute_pool_name}")
@@ -140,7 +140,7 @@ Got {len(self.df.queries['queries'])}: {self.df.queries['queries']}
140
140
 
141
141
  @classmethod
142
142
  def from_json(cls, json_str: str, session: Session) -> "Dataset":
143
- json_dict = json.loads(json_str)
143
+ json_dict = json.loads(json_str, strict=False)
144
144
  json_dict["df"] = session.sql(json_dict.pop("df_query"))
145
145
 
146
146
  fs_meta_json = json_dict["feature_store_metadata"]