snowflake-ml-python 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. snowflake/cortex/_complete.py +1 -1
  2. snowflake/cortex/_extract_answer.py +1 -1
  3. snowflake/cortex/_sentiment.py +1 -1
  4. snowflake/cortex/_summarize.py +1 -1
  5. snowflake/cortex/_translate.py +1 -1
  6. snowflake/ml/_internal/env_utils.py +68 -6
  7. snowflake/ml/_internal/file_utils.py +34 -4
  8. snowflake/ml/_internal/telemetry.py +79 -91
  9. snowflake/ml/_internal/utils/identifier.py +78 -72
  10. snowflake/ml/_internal/utils/retryable_http.py +16 -4
  11. snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
  12. snowflake/ml/dataset/dataset.py +1 -1
  13. snowflake/ml/model/_api.py +21 -14
  14. snowflake/ml/model/_client/model/model_impl.py +176 -0
  15. snowflake/ml/model/_client/model/model_method_info.py +19 -0
  16. snowflake/ml/model/_client/model/model_version_impl.py +291 -0
  17. snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
  18. snowflake/ml/model/_client/ops/model_ops.py +308 -0
  19. snowflake/ml/model/_client/sql/model.py +75 -0
  20. snowflake/ml/model/_client/sql/model_version.py +213 -0
  21. snowflake/ml/model/_client/sql/stage.py +40 -0
  22. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
  23. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
  24. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
  25. snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
  26. snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
  27. snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
  28. snowflake/ml/model/_model_composer/model_composer.py +31 -9
  29. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
  30. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
  31. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  32. snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
  33. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
  34. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
  35. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
  36. snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
  37. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  38. snowflake/ml/model/model_signature.py +108 -53
  39. snowflake/ml/model/type_hints.py +1 -0
  40. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
  41. snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
  42. snowflake/ml/modeling/_internal/model_specifications.py +146 -0
  43. snowflake/ml/modeling/_internal/model_trainer.py +13 -0
  44. snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
  45. snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
  46. snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
  47. snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +108 -135
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +106 -135
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +106 -135
  51. snowflake/ml/modeling/cluster/birch.py +106 -135
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +106 -135
  53. snowflake/ml/modeling/cluster/dbscan.py +106 -135
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +106 -135
  55. snowflake/ml/modeling/cluster/k_means.py +105 -135
  56. snowflake/ml/modeling/cluster/mean_shift.py +106 -135
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +105 -135
  58. snowflake/ml/modeling/cluster/optics.py +106 -135
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +106 -135
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +106 -135
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +106 -135
  62. snowflake/ml/modeling/compose/column_transformer.py +106 -135
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +108 -135
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +106 -135
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +99 -128
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +106 -135
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +106 -135
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +104 -133
  69. snowflake/ml/modeling/covariance/min_cov_det.py +106 -135
  70. snowflake/ml/modeling/covariance/oas.py +99 -128
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +103 -132
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +106 -135
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +106 -135
  74. snowflake/ml/modeling/decomposition/fast_ica.py +106 -135
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +106 -135
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +106 -135
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +106 -135
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +106 -135
  79. snowflake/ml/modeling/decomposition/pca.py +106 -135
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +106 -135
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +106 -135
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +108 -135
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +108 -135
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +108 -135
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +108 -135
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +108 -135
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +108 -135
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +108 -135
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +108 -135
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +108 -135
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +108 -135
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +108 -135
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +108 -135
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +106 -135
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +108 -135
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +108 -135
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +108 -135
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +108 -135
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +108 -135
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +101 -128
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +99 -126
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +99 -126
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +99 -126
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +100 -127
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +99 -126
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +106 -135
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +95 -124
  108. snowflake/ml/modeling/framework/base.py +83 -1
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +108 -135
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +108 -135
  111. snowflake/ml/modeling/impute/iterative_imputer.py +106 -135
  112. snowflake/ml/modeling/impute/knn_imputer.py +106 -135
  113. snowflake/ml/modeling/impute/missing_indicator.py +106 -135
  114. snowflake/ml/modeling/impute/simple_imputer.py +9 -1
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +96 -125
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +106 -135
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +106 -135
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +105 -134
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +103 -132
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +108 -135
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +90 -118
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +90 -118
  123. snowflake/ml/modeling/linear_model/ard_regression.py +108 -135
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +108 -135
  125. snowflake/ml/modeling/linear_model/elastic_net.py +108 -135
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +108 -135
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +108 -135
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +108 -135
  129. snowflake/ml/modeling/linear_model/lars.py +108 -135
  130. snowflake/ml/modeling/linear_model/lars_cv.py +108 -135
  131. snowflake/ml/modeling/linear_model/lasso.py +108 -135
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +108 -135
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +108 -135
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +108 -135
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +108 -135
  136. snowflake/ml/modeling/linear_model/linear_regression.py +108 -135
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +108 -135
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +108 -135
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +108 -135
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +108 -135
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +108 -135
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +108 -135
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +108 -135
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +108 -135
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +107 -135
  146. snowflake/ml/modeling/linear_model/perceptron.py +107 -135
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +108 -135
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +108 -135
  149. snowflake/ml/modeling/linear_model/ridge.py +108 -135
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +108 -135
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +108 -135
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +108 -135
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +108 -135
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +106 -135
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +108 -135
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +108 -135
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +108 -135
  158. snowflake/ml/modeling/manifold/isomap.py +106 -135
  159. snowflake/ml/modeling/manifold/mds.py +106 -135
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +106 -135
  161. snowflake/ml/modeling/manifold/tsne.py +106 -135
  162. snowflake/ml/modeling/metrics/classification.py +196 -55
  163. snowflake/ml/modeling/metrics/correlation.py +4 -2
  164. snowflake/ml/modeling/metrics/covariance.py +7 -4
  165. snowflake/ml/modeling/metrics/ranking.py +32 -16
  166. snowflake/ml/modeling/metrics/regression.py +60 -32
  167. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +106 -135
  168. snowflake/ml/modeling/mixture/gaussian_mixture.py +106 -135
  169. snowflake/ml/modeling/model_selection/grid_search_cv.py +91 -148
  170. snowflake/ml/modeling/model_selection/randomized_search_cv.py +93 -154
  171. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +105 -132
  172. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +108 -135
  173. snowflake/ml/modeling/multiclass/output_code_classifier.py +108 -135
  174. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +108 -135
  175. snowflake/ml/modeling/naive_bayes/categorical_nb.py +108 -135
  176. snowflake/ml/modeling/naive_bayes/complement_nb.py +108 -135
  177. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +98 -125
  178. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +107 -134
  179. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +108 -135
  180. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +108 -135
  181. snowflake/ml/modeling/neighbors/kernel_density.py +106 -135
  182. snowflake/ml/modeling/neighbors/local_outlier_factor.py +106 -135
  183. snowflake/ml/modeling/neighbors/nearest_centroid.py +108 -135
  184. snowflake/ml/modeling/neighbors/nearest_neighbors.py +106 -135
  185. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +108 -135
  186. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +108 -135
  187. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +108 -135
  188. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +106 -135
  189. snowflake/ml/modeling/neural_network/mlp_classifier.py +108 -135
  190. snowflake/ml/modeling/neural_network/mlp_regressor.py +108 -135
  191. snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
  192. snowflake/ml/modeling/preprocessing/binarizer.py +25 -8
  193. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +9 -4
  194. snowflake/ml/modeling/preprocessing/label_encoder.py +31 -11
  195. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +27 -9
  196. snowflake/ml/modeling/preprocessing/min_max_scaler.py +42 -14
  197. snowflake/ml/modeling/preprocessing/normalizer.py +9 -4
  198. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +26 -10
  199. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +37 -13
  200. snowflake/ml/modeling/preprocessing/polynomial_features.py +106 -135
  201. snowflake/ml/modeling/preprocessing/robust_scaler.py +39 -13
  202. snowflake/ml/modeling/preprocessing/standard_scaler.py +36 -12
  203. snowflake/ml/modeling/semi_supervised/label_propagation.py +108 -135
  204. snowflake/ml/modeling/semi_supervised/label_spreading.py +108 -135
  205. snowflake/ml/modeling/svm/linear_svc.py +108 -135
  206. snowflake/ml/modeling/svm/linear_svr.py +108 -135
  207. snowflake/ml/modeling/svm/nu_svc.py +108 -135
  208. snowflake/ml/modeling/svm/nu_svr.py +108 -135
  209. snowflake/ml/modeling/svm/svc.py +108 -135
  210. snowflake/ml/modeling/svm/svr.py +108 -135
  211. snowflake/ml/modeling/tree/decision_tree_classifier.py +108 -135
  212. snowflake/ml/modeling/tree/decision_tree_regressor.py +108 -135
  213. snowflake/ml/modeling/tree/extra_tree_classifier.py +108 -135
  214. snowflake/ml/modeling/tree/extra_tree_regressor.py +108 -135
  215. snowflake/ml/modeling/xgboost/xgb_classifier.py +108 -136
  216. snowflake/ml/modeling/xgboost/xgb_regressor.py +108 -136
  217. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +108 -136
  218. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +108 -136
  219. snowflake/ml/registry/model_registry.py +2 -0
  220. snowflake/ml/registry/registry.py +215 -0
  221. snowflake/ml/version.py +1 -1
  222. {snowflake_ml_python-1.1.0.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +34 -1
  223. snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
  224. snowflake_ml_python-1.1.0.dist-info/RECORD +0 -331
  225. {snowflake_ml_python-1.1.0.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -23,7 +23,7 @@ def Complete(
23
23
  A column of string responses.
24
24
  """
25
25
 
26
- return _complete_impl("snowflake.ml.complete", model, prompt, session=session)
26
+ return _complete_impl("snowflake.cortex.complete", model, prompt, session=session)
27
27
 
28
28
 
29
29
  def _complete_impl(
@@ -25,7 +25,7 @@ def ExtractAnswer(
25
25
  A column of strings containing answers.
26
26
  """
27
27
 
28
- return _extract_answer_impl("snowflake.ml.extract_answer", from_text, question, session=session)
28
+ return _extract_answer_impl("snowflake.cortex.extract_answer", from_text, question, session=session)
29
29
 
30
30
 
31
31
  def _extract_answer_impl(
@@ -22,7 +22,7 @@ def Sentiment(
22
22
  A column of floats. 1 represents positive sentiment, -1 represents negative sentiment.
23
23
  """
24
24
 
25
- return _sentiment_impl("snowflake.ml.sentiment", text, session=session)
25
+ return _sentiment_impl("snowflake.cortex.sentiment", text, session=session)
26
26
 
27
27
 
28
28
  def _sentiment_impl(
@@ -23,7 +23,7 @@ def Summarize(
23
23
  A column of string summaries.
24
24
  """
25
25
 
26
- return _summarize_impl("snowflake.ml.summarize", text, session=session)
26
+ return _summarize_impl("snowflake.cortex.summarize", text, session=session)
27
27
 
28
28
 
29
29
  def _summarize_impl(
@@ -27,7 +27,7 @@ def Translate(
27
27
  A column of string translations.
28
28
  """
29
29
 
30
- return _translate_impl("snowflake.ml.translate", text, from_language, to_language, session=session)
30
+ return _translate_impl("snowflake.cortex.translate", text, from_language, to_language, session=session)
31
31
 
32
32
 
33
33
  def _translate_impl(
@@ -4,6 +4,7 @@ import pathlib
4
4
  import re
5
5
  import textwrap
6
6
  import warnings
7
+ from enum import Enum
7
8
  from importlib import metadata as importlib_metadata
8
9
  from typing import Any, DefaultDict, Dict, List, Optional, Tuple
9
10
 
@@ -18,10 +19,22 @@ from snowflake.ml._internal.exceptions import (
18
19
  )
19
20
  from snowflake.ml._internal.utils import query_result_checker
20
21
  from snowflake.snowpark import session
22
+ from snowflake.snowpark._internal import utils as snowpark_utils
23
+
24
+
25
+ class CONDA_OS(Enum):
26
+ LINUX_64 = "linux-64"
27
+ LINUX_AARCH64 = "linux-aarch64"
28
+ OSX_64 = "osx-64"
29
+ OSX_ARM64 = "osx-arm64"
30
+ WIN_64 = "win-64"
31
+ NO_ARCH = "noarch"
32
+
21
33
 
22
34
  _SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
23
35
  _NODEFAULTS = "nodefaults"
24
36
  _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION: Optional[bool] = None
37
+ _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
25
38
  _SNOWFLAKE_CONDA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
26
39
 
27
40
  DEFAULT_CHANNEL_NAME = ""
@@ -217,6 +230,7 @@ def get_local_installed_version_of_pip_package(pip_req: requirements.Requirement
217
230
  warnings.warn(
218
231
  f"Package requirement {str(pip_req)} specified, while version {local_dist_version} is installed. "
219
232
  "Local version will be ignored to conform to package requirement.",
233
+ stacklevel=2,
220
234
  category=UserWarning,
221
235
  )
222
236
  return pip_req
@@ -265,10 +279,58 @@ def _check_runtime_version_column_existence(session: session.Session) -> bool:
265
279
  return result == 1
266
280
 
267
281
 
268
- def validate_requirements_in_snowflake_conda_channel(
282
+ def get_matched_package_versions_in_snowflake_conda_channel(
283
+ req: requirements.Requirement,
284
+ python_version: str = snowml_env.PYTHON_VERSION,
285
+ conda_os: CONDA_OS = CONDA_OS.LINUX_64,
286
+ ) -> List[version.Version]:
287
+ """Search the snowflake anaconda channel for packages that matches the specifier. Note that this will be the
288
+ source of truth for checking whether a package indeed exists in Snowflake conda channel.
289
+
290
+ Given that a package comes in different architectures, we only check for the Linux x86_64 architecture and assume
291
+ the package exists in other architectures. If such an assumption does not hold true for a certain package, the
292
+ caller should specify the architecture to search for.
293
+
294
+ Args:
295
+ req: Requirement specifier.
296
+ python_version: A string of python version where model is run.
297
+ conda_os: Specified platform to search availability of the package.
298
+
299
+ Returns:
300
+ List of package versions that meet the requirement specifier.
301
+ """
302
+ # Move the retryable_http import here as when UDF import this file, it won't have the "requests" dependency.
303
+ from snowflake.ml._internal.utils import retryable_http
304
+
305
+ assert not snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call]
306
+
307
+ url = f"{_SNOWFLAKE_CONDA_CHANNEL_URL}/{conda_os.value}/repodata.json"
308
+
309
+ if req.name not in _SNOWFLAKE_CONDA_PACKAGE_CACHE:
310
+ http_client = retryable_http.get_http_client()
311
+ parsed_python_version = version.Version(python_version)
312
+ python_version_build_str = f"py{parsed_python_version.major}{parsed_python_version.minor}"
313
+ repodata = http_client.get(url).json()
314
+ assert isinstance(repodata, dict)
315
+ packages_info = repodata["packages"]
316
+ assert isinstance(packages_info, dict)
317
+ version_list = [
318
+ version.parse(package_info["version"])
319
+ for package_info in packages_info.values()
320
+ if package_info["name"] == req.name and python_version_build_str in package_info["build"]
321
+ ]
322
+ _SNOWFLAKE_CONDA_PACKAGE_CACHE[req.name] = version_list
323
+
324
+ matched_versions = list(req.specifier.filter(set(_SNOWFLAKE_CONDA_PACKAGE_CACHE.get(req.name, []))))
325
+ return matched_versions
326
+
327
+
328
+ def validate_requirements_in_information_schema(
269
329
  session: session.Session, reqs: List[requirements.Requirement], python_version: str
270
330
  ) -> Optional[List[str]]:
271
- """Search the snowflake anaconda channel for packages with version meet the specifier.
331
+ """Look up the information_schema table to check if a package with the specified specifier exists in the Snowflake
332
+ Conda channel. Note that this is not the source of truth due to the potential delay caused by a package that might
333
+ exist in the information_schema table but has not yet become available in the Snowflake Conda channel.
272
334
 
273
335
  Args:
274
336
  session: Snowflake connection session.
@@ -285,7 +347,7 @@ def validate_requirements_in_snowflake_conda_channel(
285
347
  ret_list = []
286
348
  reqs_to_request = []
287
349
  for req in reqs:
288
- if req.name not in _SNOWFLAKE_CONDA_PACKAGE_CACHE:
350
+ if req.name not in _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE:
289
351
  reqs_to_request.append(req)
290
352
  if reqs_to_request:
291
353
  pkg_names_str = " OR ".join(
@@ -326,13 +388,13 @@ def validate_requirements_in_snowflake_conda_channel(
326
388
  for row in result:
327
389
  req_name = row["PACKAGE_NAME"]
328
390
  req_ver = version.parse(row["VERSION"])
329
- cached_req_ver_list = _SNOWFLAKE_CONDA_PACKAGE_CACHE.get(req_name, [])
391
+ cached_req_ver_list = _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE.get(req_name, [])
330
392
  cached_req_ver_list.append(req_ver)
331
- _SNOWFLAKE_CONDA_PACKAGE_CACHE[req_name] = cached_req_ver_list
393
+ _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE[req_name] = cached_req_ver_list
332
394
  except snowflake.connector.DataError:
333
395
  return None
334
396
  for req in reqs:
335
- available_versions = list(req.specifier.filter(set(_SNOWFLAKE_CONDA_PACKAGE_CACHE.get(req.name, []))))
397
+ available_versions = list(req.specifier.filter(set(_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE.get(req.name, []))))
336
398
  if not available_versions:
337
399
  return None
338
400
  else:
@@ -28,6 +28,7 @@ import cloudpickle
28
28
 
29
29
  from snowflake import snowpark
30
30
  from snowflake.ml._internal.exceptions import exceptions
31
+ from snowflake.snowpark import exceptions as snowpark_exceptions
31
32
 
32
33
  GENERATED_PY_FILE_EXT = (".pyc", ".pyo", ".pyd", ".pyi")
33
34
 
@@ -286,8 +287,16 @@ def stage_file_exists(
286
287
  return False
287
288
 
288
289
 
290
+ def _retry_on_sql_error(exception: Exception) -> bool:
291
+ return isinstance(exception, snowpark_exceptions.SnowparkSQLException)
292
+
293
+
289
294
  def upload_directory_to_stage(
290
- session: snowpark.Session, local_path: pathlib.Path, stage_path: pathlib.PurePosixPath
295
+ session: snowpark.Session,
296
+ local_path: pathlib.Path,
297
+ stage_path: pathlib.PurePosixPath,
298
+ *,
299
+ statement_params: Optional[Dict[str, Any]] = None,
291
300
  ) -> None:
292
301
  """Upload a local folder recursively to a stage and keep the structure.
293
302
 
@@ -295,7 +304,10 @@ def upload_directory_to_stage(
295
304
  session: Snowpark Session.
296
305
  local_path: Local path to upload.
297
306
  stage_path: Base path in the stage.
307
+ statement_params: Statement Params.
298
308
  """
309
+ import retrying
310
+
299
311
  file_operation = snowpark.FileOperation(session=session)
300
312
 
301
313
  for root, _, filenames in os.walk(local_path):
@@ -305,16 +317,26 @@ def upload_directory_to_stage(
305
317
  stage_dir_path = (
306
318
  stage_path / pathlib.PurePosixPath(local_file_path.relative_to(local_path).as_posix()).parent
307
319
  )
308
- file_operation.put(
320
+ retrying.retry(
321
+ retry_on_exception=_retry_on_sql_error,
322
+ stop_max_attempt_number=5,
323
+ wait_exponential_multiplier=100,
324
+ wait_exponential_max=10000,
325
+ )(file_operation.put)(
309
326
  str(local_file_path),
310
327
  str(stage_dir_path),
311
328
  auto_compress=False,
312
329
  overwrite=False,
330
+ statement_params=statement_params,
313
331
  )
314
332
 
315
333
 
316
334
  def download_directory_from_stage(
317
- session: snowpark.Session, stage_path: pathlib.PurePosixPath, local_path: pathlib.Path
335
+ session: snowpark.Session,
336
+ stage_path: pathlib.PurePosixPath,
337
+ local_path: pathlib.Path,
338
+ *,
339
+ statement_params: Optional[Dict[str, Any]] = None,
318
340
  ) -> None:
319
341
  """Upload a folder in stage recursively to a folder in local and keep the structure.
320
342
 
@@ -322,7 +344,10 @@ def download_directory_from_stage(
322
344
  session: Snowpark Session.
323
345
  stage_path: Stage path to download from.
324
346
  local_path: Local path as the base of destination.
347
+ statement_params: Statement Params.
325
348
  """
349
+ import retrying
350
+
326
351
  file_operation = file_operation = snowpark.FileOperation(session=session)
327
352
  file_list = [
328
353
  pathlib.PurePosixPath(stage_path.parts[0], *pathlib.PurePosixPath(row.name).parts[1:])
@@ -331,4 +356,9 @@ def download_directory_from_stage(
331
356
  for stage_file_path in file_list:
332
357
  local_file_dir = local_path / stage_file_path.relative_to(stage_path).parent
333
358
  local_file_dir.mkdir(parents=True, exist_ok=True)
334
- file_operation.get(str(stage_file_path), str(local_file_dir))
359
+ retrying.retry(
360
+ retry_on_exception=_retry_on_sql_error,
361
+ stop_max_attempt_number=5,
362
+ wait_exponential_multiplier=100,
363
+ wait_exponential_max=10000,
364
+ )(file_operation.get)(str(stage_file_path), str(local_file_dir), statement_params=statement_params)
@@ -42,6 +42,7 @@ class TelemetryField(enum.Enum):
42
42
  NAME = "name"
43
43
  # types of telemetry
44
44
  TYPE_FUNCTION_USAGE = "function_usage"
45
+ TYPE_SNOWML_SPCS_USAGE = "snowml_spcs_usage"
45
46
  # message keys for telemetry
46
47
  KEY_PROJECT = "project"
47
48
  KEY_SUBPROJECT = "subproject"
@@ -207,6 +208,23 @@ def suppress_exceptions(func: Callable[..., Any]) -> Callable[..., Any]:
207
208
  return wrapper
208
209
 
209
210
 
211
+ def send_custom_usage(
212
+ project: str,
213
+ *,
214
+ telemetry_type: str,
215
+ subproject: Optional[str] = None,
216
+ data: Optional[Dict[str, Any]] = None,
217
+ **kwargs: Any,
218
+ ) -> None:
219
+ active_session = next(iter(session._get_active_sessions()))
220
+ assert active_session, "Missing active session object"
221
+
222
+ client = _SourceTelemetryClient(conn=active_session._conn._conn, project=project, subproject=subproject)
223
+ common_metrics = client._create_basic_telemetry_data(telemetry_type=telemetry_type)
224
+ data = {**common_metrics, TelemetryField.KEY_DATA.value: data, **kwargs}
225
+ client._send(msg=data)
226
+
227
+
210
228
  def send_api_usage_telemetry(
211
229
  project: str,
212
230
  subproject: Optional[str] = None,
@@ -228,7 +246,8 @@ def send_api_usage_telemetry(
228
246
  custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
229
247
  ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
230
248
  """
231
- Decorator that sends API usage telemetry.
249
+ Decorator that sends API usage telemetry and adds function usage statement parameters to the dataframe returned by
250
+ the function.
232
251
 
233
252
  Args:
234
253
  project: Project.
@@ -253,6 +272,51 @@ def send_api_usage_telemetry(
253
272
  def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
254
273
  params = _get_func_params(func, func_params_to_log, args, kwargs) if func_params_to_log else None
255
274
 
275
+ api_calls: List[Union[Dict[str, Union[Callable[..., Any], str]], Callable[..., Any], str]] = []
276
+ if api_calls_extractor:
277
+ extracted_api_calls = api_calls_extractor(args[0])
278
+ for api_call in extracted_api_calls:
279
+ if isinstance(api_call, str):
280
+ api_calls.append({TelemetryField.NAME.value: api_call})
281
+ elif callable(api_call):
282
+ api_calls.append({TelemetryField.NAME.value: _get_full_func_name(api_call)})
283
+ else:
284
+ api_calls.append(api_call)
285
+ api_calls.append({TelemetryField.NAME.value: _get_full_func_name(func)})
286
+
287
+ sfqids = None
288
+ if sfqids_extractor:
289
+ sfqids = sfqids_extractor(args[0])
290
+
291
+ statement_params = get_function_usage_statement_params(
292
+ project=project,
293
+ subproject=subproject,
294
+ function_category=TelemetryField.FUNC_CAT_USAGE.value,
295
+ function_name=_get_full_func_name(func),
296
+ function_parameters=params,
297
+ api_calls=api_calls,
298
+ custom_tags=custom_tags,
299
+ )
300
+
301
+ def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: Dict[str, Any]) -> _ReturnValue:
302
+ """
303
+ Update SnowML function usage statement parameters to the object if it is a Snowpark DataFrame.
304
+ Used to track APIs returning a Snowpark DataFrame.
305
+
306
+ Args:
307
+ obj: Object to check and update.
308
+ statement_params: Statement parameters.
309
+
310
+ Returns:
311
+ Updated object.
312
+ """
313
+ if isinstance(obj, dataframe.DataFrame):
314
+ if hasattr(obj, "_statement_params") and obj._statement_params:
315
+ obj._statement_params.update(statement_params)
316
+ else:
317
+ obj._statement_params = statement_params # type: ignore[assignment]
318
+ return obj
319
+
256
320
  # prioritize `conn_attr_name` over the active session
257
321
  if conn_attr_name:
258
322
  # raise AttributeError if conn attribute does not exist in `self`
@@ -266,36 +330,20 @@ def send_api_usage_telemetry(
266
330
  # server no default session
267
331
  except snowpark_exceptions.SnowparkSessionException:
268
332
  try:
269
- return func(*args, **kwargs)
333
+ return update_stmt_params_if_snowpark_df(func(*args, **kwargs), statement_params)
270
334
  except Exception as e:
271
335
  if isinstance(e, snowml_exceptions.SnowflakeMLException):
272
- e = e.original_exception
336
+ raise e.original_exception.with_traceback(e.__traceback__) from None
273
337
  # suppress SnowparkSessionException from telemetry in the stack trace
274
338
  raise e from None
275
339
 
276
340
  conn = active_session._conn._conn
277
341
  if (not active_session.telemetry_enabled) or (conn is None):
278
342
  try:
279
- return func(*args, **kwargs)
343
+ return update_stmt_params_if_snowpark_df(func(*args, **kwargs), statement_params)
280
344
  except snowml_exceptions.SnowflakeMLException as e:
281
345
  raise e.original_exception from e
282
346
 
283
- api_calls: List[Dict[str, Any]] = []
284
- if api_calls_extractor:
285
- extracted_api_calls = api_calls_extractor(args[0])
286
- for api_call in extracted_api_calls:
287
- if isinstance(api_call, str):
288
- api_calls.append({TelemetryField.NAME.value: api_call})
289
- elif callable(api_call):
290
- api_calls.append({TelemetryField.NAME.value: _get_full_func_name(api_call)})
291
- else:
292
- api_calls.append(api_call)
293
- api_calls.append({TelemetryField.NAME.value: _get_full_func_name(func)})
294
-
295
- sfqids = None
296
- if sfqids_extractor:
297
- sfqids = sfqids_extractor(args[0])
298
-
299
347
  # TODO(hayu): [SNOW-750287] Optimize telemetry client to a singleton.
300
348
  telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject)
301
349
  telemetry_args = dict(
@@ -314,22 +362,24 @@ def send_api_usage_telemetry(
314
362
  if hasattr(e, "_snowflake_ml_handled") and e._snowflake_ml_handled:
315
363
  raise e
316
364
  if isinstance(e, snowpark_exceptions.SnowparkClientException):
317
- e = snowml_exceptions.SnowflakeMLException(
365
+ me = snowml_exceptions.SnowflakeMLException(
318
366
  error_code=error_codes.INTERNAL_SNOWPARK_ERROR, original_exception=e
319
367
  )
320
368
  else:
321
- e = snowml_exceptions.SnowflakeMLException(
369
+ me = snowml_exceptions.SnowflakeMLException(
322
370
  error_code=error_codes.UNDEFINED, original_exception=e
323
371
  )
324
- telemetry_args["error"] = repr(e)
325
- telemetry_args["error_code"] = e.error_code
326
- e.original_exception._snowflake_ml_handled = True # type: ignore[attr-defined]
327
- if e.suppress_source_trace:
328
- raise e.original_exception from None
329
372
  else:
330
- raise e.original_exception from e
373
+ me = e
374
+ telemetry_args["error"] = repr(me)
375
+ telemetry_args["error_code"] = me.error_code
376
+ me.original_exception._snowflake_ml_handled = True # type: ignore[attr-defined]
377
+ if me.suppress_source_trace:
378
+ raise me.original_exception from None
379
+ else:
380
+ raise me.original_exception from e
331
381
  else:
332
- return res
382
+ return update_stmt_params_if_snowpark_df(res, statement_params)
333
383
  finally:
334
384
  telemetry.send_function_usage_telemetry(**telemetry_args)
335
385
  global _log_counter
@@ -343,68 +393,6 @@ def send_api_usage_telemetry(
343
393
  return decorator
344
394
 
345
395
 
346
- def add_stmt_params_to_df(
347
- project: str,
348
- subproject: Optional[str] = None,
349
- *,
350
- function_category: str = TelemetryField.FUNC_CAT_USAGE.value,
351
- func_params_to_log: Optional[Iterable[str]] = None,
352
- api_calls: Optional[
353
- List[
354
- Union[
355
- Dict[str, Union[Callable[..., Any], str]],
356
- Union[Callable[..., Any], str],
357
- ]
358
- ]
359
- ] = None,
360
- custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
361
- ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
362
- """
363
- Decorator that adds function usage statement parameters to the dataframe returned by the function.
364
-
365
- Args:
366
- project: Project.
367
- subproject: Subproject.
368
- function_category: Function category.
369
- func_params_to_log: Function parameters to log.
370
- api_calls: API calls in the function.
371
- custom_tags: Custom tags.
372
-
373
- Returns:
374
- Decorator that adds function usage statement parameters to the dataframe returned by the decorated function.
375
- """
376
-
377
- def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, _ReturnValue]:
378
- @functools.wraps(func)
379
- def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
380
- params = _get_func_params(func, func_params_to_log, args, kwargs) if func_params_to_log else None
381
- statement_params = get_function_usage_statement_params(
382
- project=project,
383
- subproject=subproject,
384
- function_category=function_category,
385
- function_name=_get_full_func_name(func),
386
- function_parameters=params,
387
- api_calls=api_calls,
388
- custom_tags=custom_tags,
389
- )
390
-
391
- try:
392
- res = func(*args, **kwargs)
393
- if isinstance(res, dataframe.DataFrame):
394
- if hasattr(res, "_statement_params") and res._statement_params:
395
- res._statement_params.update(statement_params)
396
- else:
397
- res._statement_params = statement_params # type: ignore[assignment]
398
- except Exception:
399
- raise
400
- else:
401
- return res
402
-
403
- return cast(Callable[_Args, _ReturnValue], wrap)
404
-
405
- return decorator
406
-
407
-
408
396
  def _get_full_func_name(func: Callable[..., Any]) -> str:
409
397
  """
410
398
  Get the full function name with module and qualname.