snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -23,7 +23,9 @@ from typing import Dict, List, Optional, Tuple
23
23
 
24
24
  import requests
25
25
 
26
- from snowflake.ml._internal.utils import image_registry_http_client
26
+ from snowflake.ml._internal.container_services.image_registry import (
27
+ http_client as image_registry_http_client,
28
+ )
27
29
 
28
30
  # Common HTTP headers
29
31
  _CONTENT_LENGTH_HEADER = "content-length"
@@ -3,12 +3,14 @@ import logging
3
3
  from typing import Dict, Optional, cast
4
4
  from urllib.parse import urlunparse
5
5
 
6
+ from snowflake.ml._internal.container_services.image_registry import (
7
+ http_client as image_registry_http_client,
8
+ imagelib,
9
+ )
6
10
  from snowflake.ml._internal.exceptions import (
7
11
  error_codes,
8
12
  exceptions as snowml_exceptions,
9
13
  )
10
- from snowflake.ml._internal.utils import image_registry_http_client
11
- from snowflake.ml.model._deploy_client.utils import imagelib
12
14
  from snowflake.snowpark import Session
13
15
  from snowflake.snowpark._internal import utils as snowpark_utils
14
16
 
@@ -33,7 +33,6 @@ class CONDA_OS(Enum):
33
33
 
34
34
  _SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
35
35
  _NODEFAULTS = "nodefaults"
36
- _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION: Optional[bool] = None
37
36
  _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
38
37
  _SNOWFLAKE_CONDA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
39
38
 
@@ -267,18 +266,6 @@ def relax_requirement_version(req: requirements.Requirement) -> requirements.Req
267
266
  return new_req
268
267
 
269
268
 
270
- def _check_runtime_version_column_existence(session: session.Session) -> bool:
271
- sql = textwrap.dedent(
272
- """
273
- SHOW COLUMNS
274
- LIKE 'runtime_version'
275
- IN TABLE information_schema.packages;
276
- """
277
- )
278
- result = session.sql(sql).count()
279
- return result == 1
280
-
281
-
282
269
  def get_matched_package_versions_in_snowflake_conda_channel(
283
270
  req: requirements.Requirement,
284
271
  python_version: str = snowml_env.PYTHON_VERSION,
@@ -325,9 +312,9 @@ def get_matched_package_versions_in_snowflake_conda_channel(
325
312
  return matched_versions
326
313
 
327
314
 
328
- def validate_requirements_in_information_schema(
315
+ def get_matched_package_versions_in_information_schema(
329
316
  session: session.Session, reqs: List[requirements.Requirement], python_version: str
330
- ) -> Optional[List[str]]:
317
+ ) -> Dict[str, List[version.Version]]:
331
318
  """Look up the information_schema table to check if a package with the specified specifier exists in the Snowflake
332
319
  Conda channel. Note that this is not the source of truth due to the potential delay caused by a package that might
333
320
  exist in the information_schema table but has not yet become available in the Snowflake Conda channel.
@@ -338,42 +325,35 @@ def validate_requirements_in_information_schema(
338
325
  python_version: A string of python version where model is run.
339
326
 
340
327
  Returns:
341
- A list of pinned latest version that available in Snowflake anaconda channel and meet the version specifier.
328
+ A Dict, whose key is the package name, and value is a list of versions match the requirements.
342
329
  """
343
- global _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION
344
-
345
- if _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION is None:
346
- _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION = _check_runtime_version_column_existence(session)
347
- ret_list = []
348
- reqs_to_request = []
330
+ ret_dict: Dict[str, List[version.Version]] = {}
331
+ reqs_to_request: List[requirements.Requirement] = []
349
332
  for req in reqs:
350
- if req.name not in _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE:
333
+ if req.name in _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE:
334
+ available_versions = list(
335
+ sorted(req.specifier.filter(set(_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE.get(req.name, []))))
336
+ )
337
+ ret_dict[req.name] = available_versions
338
+ else:
351
339
  reqs_to_request.append(req)
340
+
352
341
  if reqs_to_request:
353
342
  pkg_names_str = " OR ".join(
354
343
  f"package_name = '{req_name}'" for req_name in sorted(req.name for req in reqs_to_request)
355
344
  )
356
- if _INFO_SCHEMA_PACKAGES_HAS_RUNTIME_VERSION:
357
- parsed_python_version = version.Version(python_version)
358
- sql = textwrap.dedent(
359
- f"""
360
- SELECT PACKAGE_NAME, VERSION
361
- FROM information_schema.packages
362
- WHERE ({pkg_names_str})
363
- AND language = 'python'
364
- AND (runtime_version = '{parsed_python_version.major}.{parsed_python_version.minor}'
365
- OR runtime_version is null);
366
- """
367
- )
368
- else:
369
- sql = textwrap.dedent(
370
- f"""
371
- SELECT PACKAGE_NAME, VERSION
372
- FROM information_schema.packages
373
- WHERE ({pkg_names_str})
374
- AND language = 'python';
375
- """
376
- )
345
+
346
+ parsed_python_version = version.Version(python_version)
347
+ sql = textwrap.dedent(
348
+ f"""
349
+ SELECT PACKAGE_NAME, VERSION
350
+ FROM information_schema.packages
351
+ WHERE ({pkg_names_str})
352
+ AND language = 'python'
353
+ AND (runtime_version = '{parsed_python_version.major}.{parsed_python_version.minor}'
354
+ OR runtime_version is null);
355
+ """
356
+ )
377
357
 
378
358
  try:
379
359
  result = (
@@ -392,14 +372,13 @@ def validate_requirements_in_information_schema(
392
372
  cached_req_ver_list.append(req_ver)
393
373
  _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE[req_name] = cached_req_ver_list
394
374
  except snowflake.connector.DataError:
395
- return None
396
- for req in reqs:
397
- available_versions = list(req.specifier.filter(set(_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE.get(req.name, []))))
398
- if not available_versions:
399
- return None
400
- else:
401
- ret_list.append(str(req))
402
- return sorted(ret_list)
375
+ return ret_dict
376
+ for req in reqs_to_request:
377
+ available_versions = list(
378
+ sorted(req.specifier.filter(set(_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE.get(req.name, []))))
379
+ )
380
+ ret_dict[req.name] = available_versions
381
+ return ret_dict
403
382
 
404
383
 
405
384
  def save_conda_env_file(
@@ -362,3 +362,20 @@ def download_directory_from_stage(
362
362
  wait_exponential_multiplier=100,
363
363
  wait_exponential_max=10000,
364
364
  )(file_operation.get)(str(stage_file_path), str(local_file_dir), statement_params=statement_params)
365
+
366
+
367
+ def open_file(path: str, *args: Any, **kwargs: Any) -> Any:
368
+ """This function is a wrapper on top of the Python built-in "open" function, with a few added default values
369
+ to ensure successful execution across different platforms.
370
+
371
+ Args:
372
+ path: file path
373
+ *args: arguments.
374
+ **kwargs: key arguments.
375
+
376
+ Returns:
377
+ Open file and return a stream.
378
+ """
379
+ kwargs.setdefault("newline", "\n")
380
+ kwargs.setdefault("encoding", "utf-8")
381
+ return open(path, *args, **kwargs)
@@ -584,3 +584,22 @@ class _SourceTelemetryClient:
584
584
  """Send the telemetry data batch immediately."""
585
585
  if self._telemetry:
586
586
  self._telemetry.send_batch()
587
+
588
+
589
+ def get_sproc_statement_params_kwargs(sproc: Callable[..., Any], statement_params: Dict[str, Any]) -> Dict[str, Any]:
590
+ """
591
+ Get statement_params keyword argument for sproc call.
592
+
593
+ Args:
594
+ sproc: sproc function
595
+ statement_params: dictionary to be passed as statement params, if possible
596
+
597
+ Returns:
598
+ Keyword arguments dict
599
+ """
600
+ sproc_argspec = inspect.getfullargspec(sproc)
601
+ kwargs = {}
602
+ if "statement_params" in sproc_argspec.args:
603
+ kwargs["statement_params"] = statement_params
604
+
605
+ return kwargs
@@ -60,9 +60,13 @@ def result_dimension_matcher(
60
60
  return True
61
61
 
62
62
 
63
- def column_name_matcher(expected_col_name: str, result: list[snowpark.Row], sql: str | None = None) -> bool:
63
+ def column_name_matcher(
64
+ expected_col_name: str, allow_empty: bool, result: list[snowpark.Row], sql: str | None = None
65
+ ) -> bool:
64
66
  """Returns true if `expected_col_name` is found. Raise exception otherwise."""
65
67
  if not result:
68
+ if allow_empty:
69
+ return True
66
70
  raise connector.DataError(f"Query Result is empty.{_query_log(sql)}")
67
71
  if expected_col_name not in result[0]:
68
72
  raise connector.DataError(
@@ -159,16 +163,17 @@ class ResultValidator:
159
163
  self._success_matchers.append(partial(result_dimension_matcher, expected_rows, expected_cols))
160
164
  return self
161
165
 
162
- def has_column(self, expected_col_name: str) -> ResultValidator:
166
+ def has_column(self, expected_col_name: str, allow_empty: bool = False) -> ResultValidator:
163
167
  """Validate that the a column with the name `expected_column_name` exists in the result.
164
168
 
165
169
  Args:
166
170
  expected_col_name: Name of the column that is expected to be present in the result (case sensitive).
171
+ allow_empty: If the check will fail if the result is empty.
167
172
 
168
173
  Returns:
169
174
  ResultValidator object (self)
170
175
  """
171
- self._success_matchers.append(partial(column_name_matcher, expected_col_name))
176
+ self._success_matchers.append(partial(column_name_matcher, expected_col_name, allow_empty))
172
177
  return self
173
178
 
174
179
  def has_named_value_match(self, row_idx: int, col_name: str, expected_value: Any) -> ResultValidator:
@@ -224,8 +229,6 @@ class ResultValidator:
224
229
  Returns:
225
230
  Query result.
226
231
  """
227
- if len(self._success_matchers) == 0:
228
- self._success_matchers = _DEFAULT_MATCHERS
229
232
  result = self._get_result()
230
233
  for matcher in self._success_matchers:
231
234
  assert matcher(result, self._query)
@@ -0,0 +1,95 @@
1
+ import enum
2
+ from typing import Any, Dict, Optional, TypedDict, cast
3
+
4
+ from packaging import version
5
+ from typing_extensions import Required
6
+
7
+ from snowflake.ml._internal.utils import query_result_checker
8
+ from snowflake.snowpark import session
9
+
10
+
11
+ def get_current_snowflake_version(
12
+ sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None
13
+ ) -> version.Version:
14
+ """Get Snowflake Version as a version.Version object follow PEP way of versioning, that is to say:
15
+ "7.44.2 b202312132139364eb71238" to <Version('7.44.2+b202312132139364eb71238')>
16
+
17
+ Args:
18
+ sess: Snowpark Session.
19
+ statement_params: Statement params. Defaults to None.
20
+
21
+ Returns:
22
+ The version of Snowflake Version.
23
+ """
24
+ res = (
25
+ query_result_checker.SqlResultValidator(
26
+ sess, "SELECT CURRENT_VERSION() AS CURRENT_VERSION", statement_params=statement_params
27
+ )
28
+ .has_dimensions(expected_rows=1, expected_cols=1)
29
+ .validate()[0]
30
+ )
31
+
32
+ version_str = res.CURRENT_VERSION
33
+ assert isinstance(version_str, str)
34
+
35
+ version_str = "+".join(version_str.split())
36
+ return version.parse(version_str)
37
+
38
+
39
+ class SnowflakeCloudType(enum.Enum):
40
+ AWS = "aws"
41
+ AZURE = "azure"
42
+ GCP = "gcp"
43
+
44
+ @classmethod
45
+ def from_value(cls, value: str) -> "SnowflakeCloudType":
46
+ assert value
47
+ for k in cls:
48
+ if k.value == value.lower():
49
+ return k
50
+ else:
51
+ raise ValueError(f"'{cls.__name__}' enum not found for '{value}'")
52
+
53
+
54
+ class SnowflakeRegion(TypedDict):
55
+ region_group: Required[str]
56
+ snowflake_region: Required[str]
57
+ cloud: Required[SnowflakeCloudType]
58
+ region: Required[str]
59
+ display_name: Required[str]
60
+
61
+
62
+ def get_regions(
63
+ sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None
64
+ ) -> Dict[str, SnowflakeRegion]:
65
+ res = (
66
+ query_result_checker.SqlResultValidator(sess, "SHOW REGIONS", statement_params=statement_params)
67
+ .has_column("region_group")
68
+ .has_column("snowflake_region")
69
+ .has_column("cloud")
70
+ .has_column("region")
71
+ .has_column("display_name")
72
+ .validate()
73
+ )
74
+ return {
75
+ f"{r.region_group}.{r.snowflake_region}": SnowflakeRegion(
76
+ region_group=r.region_group,
77
+ snowflake_region=r.snowflake_region,
78
+ cloud=SnowflakeCloudType.from_value(r.cloud),
79
+ region=r.region,
80
+ display_name=r.display_name,
81
+ )
82
+ for r in res
83
+ }
84
+
85
+
86
+ def get_current_region_id(sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None) -> str:
87
+ res = (
88
+ query_result_checker.SqlResultValidator(
89
+ sess, "SELECT CURRENT_REGION() AS CURRENT_REGION", statement_params=statement_params
90
+ )
91
+ .has_dimensions(expected_rows=1, expected_cols=1)
92
+ .validate()[0]
93
+ )
94
+
95
+ return cast(str, res.CURRENT_REGION)
@@ -1,4 +1,6 @@
1
1
  import collections
2
+ import logging
3
+ import time
2
4
  from typing import Any, Deque, Dict, Iterator, List
3
5
 
4
6
  import fsspec
@@ -83,7 +85,7 @@ class ParquetParser:
83
85
  np.random.shuffle(files)
84
86
  pa_dataset: ds.Dataset = ds.dataset(files, format="parquet", filesystem=self._fs)
85
87
 
86
- for rb in pa_dataset.to_batches(batch_size=self._dataset_batch_size):
88
+ for rb in _retryable_batches(pa_dataset, batch_size=self._dataset_batch_size):
87
89
  if self._shuffle:
88
90
  rb = rb.take(np.random.permutation(rb.num_rows))
89
91
  self._rb_buffer.append(rb)
@@ -138,3 +140,31 @@ def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]:
138
140
  array = column.to_numpy(zero_copy_only=False)
139
141
  batch_dict[column_schema.name] = array
140
142
  return batch_dict
143
+
144
+
145
+ def _retryable_batches(
146
+ dataset: ds.Dataset, batch_size: int, max_retries: int = 3, delay: int = 0
147
+ ) -> Iterator[pa.RecordBatch]:
148
+ """Make the Dataset to_batches retryable."""
149
+ retries = 0
150
+ current_batch_index = 0
151
+
152
+ while True:
153
+ try:
154
+ for batch_index, batch in enumerate(dataset.to_batches(batch_size=batch_size)):
155
+ if batch_index < current_batch_index:
156
+ # Skip batches that have already been processed
157
+ continue
158
+
159
+ yield batch
160
+ current_batch_index = batch_index + 1
161
+ # Exit the loop once all batches are processed
162
+ break
163
+
164
+ except Exception as e:
165
+ if retries < max_retries:
166
+ retries += 1
167
+ logging.info(f"Error encountered: {e}. Retrying {retries}/{max_retries}...")
168
+ time.sleep(delay)
169
+ else:
170
+ raise e
@@ -0,0 +1,6 @@
1
+ from snowflake.ml.model._client.model.model_impl import Model
2
+ from snowflake.ml.model._client.model.model_version_impl import ModelVersion
3
+ from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
4
+ from snowflake.ml.model.models.llm import LLM, LLMOptions
5
+
6
+ __all__ = ["Model", "ModelVersion", "HuggingFacePipelineModel", "LLM", "LLMOptions"]
@@ -1,7 +1,9 @@
1
- from typing import List, Union
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ import pandas as pd
2
4
 
3
5
  from snowflake.ml._internal import telemetry
4
- from snowflake.ml._internal.utils import sql_identifier
6
+ from snowflake.ml._internal.utils import identifier, sql_identifier
5
7
  from snowflake.ml.model._client.model import model_version_impl
6
8
  from snowflake.ml.model._client.ops import model_ops
7
9
 
@@ -37,10 +39,12 @@ class Model:
37
39
 
38
40
  @property
39
41
  def name(self) -> str:
42
+ """Return the name of the model that can be used to refer to it in SQL."""
40
43
  return self._model_name.identifier()
41
44
 
42
45
  @property
43
46
  def fully_qualified_name(self) -> str:
47
+ """Return the fully qualified name of the model that can be used to refer to it in SQL."""
44
48
  return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name)
45
49
 
46
50
  @property
@@ -49,6 +53,24 @@ class Model:
49
53
  subproject=_TELEMETRY_SUBPROJECT,
50
54
  )
51
55
  def description(self) -> str:
56
+ """The description for the model. This is an alias of `comment`."""
57
+ return self.comment
58
+
59
+ @description.setter
60
+ @telemetry.send_api_usage_telemetry(
61
+ project=_TELEMETRY_PROJECT,
62
+ subproject=_TELEMETRY_SUBPROJECT,
63
+ )
64
+ def description(self, description: str) -> None:
65
+ self.comment = description
66
+
67
+ @property
68
+ @telemetry.send_api_usage_telemetry(
69
+ project=_TELEMETRY_PROJECT,
70
+ subproject=_TELEMETRY_SUBPROJECT,
71
+ )
72
+ def comment(self) -> str:
73
+ """The comment to the model."""
52
74
  statement_params = telemetry.get_statement_params(
53
75
  project=_TELEMETRY_PROJECT,
54
76
  subproject=_TELEMETRY_SUBPROJECT,
@@ -58,18 +80,18 @@ class Model:
58
80
  statement_params=statement_params,
59
81
  )
60
82
 
61
- @description.setter
83
+ @comment.setter
62
84
  @telemetry.send_api_usage_telemetry(
63
85
  project=_TELEMETRY_PROJECT,
64
86
  subproject=_TELEMETRY_SUBPROJECT,
65
87
  )
66
- def description(self, description: str) -> None:
88
+ def comment(self, comment: str) -> None:
67
89
  statement_params = telemetry.get_statement_params(
68
90
  project=_TELEMETRY_PROJECT,
69
91
  subproject=_TELEMETRY_SUBPROJECT,
70
92
  )
71
93
  return self._model_ops.set_comment(
72
- comment=description,
94
+ comment=comment,
73
95
  model_name=self._model_name,
74
96
  statement_params=statement_params,
75
97
  )
@@ -80,12 +102,13 @@ class Model:
80
102
  subproject=_TELEMETRY_SUBPROJECT,
81
103
  )
82
104
  def default(self) -> model_version_impl.ModelVersion:
105
+ """The default version of the model."""
83
106
  statement_params = telemetry.get_statement_params(
84
107
  project=_TELEMETRY_PROJECT,
85
108
  subproject=_TELEMETRY_SUBPROJECT,
86
109
  class_name=self.__class__.__name__,
87
110
  )
88
- default_version_name = self._model_ops._model_version_client.get_default_version(
111
+ default_version_name = self._model_ops.get_default_version(
89
112
  model_name=self._model_name, statement_params=statement_params
90
113
  )
91
114
  return self.version(default_version_name)
@@ -105,7 +128,7 @@ class Model:
105
128
  version_name = sql_identifier.SqlIdentifier(version)
106
129
  else:
107
130
  version_name = version._version_name
108
- self._model_ops._model_version_client.set_default_version(
131
+ self._model_ops.set_default_version(
109
132
  model_name=self._model_name, version_name=version_name, statement_params=statement_params
110
133
  )
111
134
 
@@ -114,13 +137,14 @@ class Model:
114
137
  subproject=_TELEMETRY_SUBPROJECT,
115
138
  )
116
139
  def version(self, version_name: str) -> model_version_impl.ModelVersion:
117
- """Get a model version object given a version name in the model.
140
+ """
141
+ Get a model version object given a version name in the model.
118
142
 
119
143
  Args:
120
- version_name: The name of version
144
+ version_name: The name of the version.
121
145
 
122
146
  Raises:
123
- ValueError: Raised when the version requested does not exist.
147
+ ValueError: When the requested version does not exist.
124
148
 
125
149
  Returns:
126
150
  The model version object.
@@ -149,11 +173,11 @@ class Model:
149
173
  project=_TELEMETRY_PROJECT,
150
174
  subproject=_TELEMETRY_SUBPROJECT,
151
175
  )
152
- def list_versions(self) -> List[model_version_impl.ModelVersion]:
153
- """List all versions in the model.
176
+ def versions(self) -> List[model_version_impl.ModelVersion]:
177
+ """Get all versions in the model.
154
178
 
155
179
  Returns:
156
- A List of ModelVersion object representing all versions in the model.
180
+ A list of ModelVersion objects representing all versions in the model.
157
181
  """
158
182
  statement_params = telemetry.get_statement_params(
159
183
  project=_TELEMETRY_PROJECT,
@@ -172,5 +196,140 @@ class Model:
172
196
  for version_name in version_names
173
197
  ]
174
198
 
199
+ @telemetry.send_api_usage_telemetry(
200
+ project=_TELEMETRY_PROJECT,
201
+ subproject=_TELEMETRY_SUBPROJECT,
202
+ )
203
+ def show_versions(self) -> pd.DataFrame:
204
+ """Show information about all versions in the model.
205
+
206
+ Returns:
207
+ A Pandas DataFrame showing information about all versions in the model.
208
+ """
209
+ statement_params = telemetry.get_statement_params(
210
+ project=_TELEMETRY_PROJECT,
211
+ subproject=_TELEMETRY_SUBPROJECT,
212
+ )
213
+ rows = self._model_ops.show_models_or_versions(
214
+ model_name=self._model_name,
215
+ statement_params=statement_params,
216
+ )
217
+ return pd.DataFrame([row.as_dict() for row in rows])
218
+
175
219
  def delete_version(self, version_name: str) -> None:
176
220
  raise NotImplementedError("Deleting version has not been supported yet.")
221
+
222
+ @telemetry.send_api_usage_telemetry(
223
+ project=_TELEMETRY_PROJECT,
224
+ subproject=_TELEMETRY_SUBPROJECT,
225
+ )
226
+ def show_tags(self) -> Dict[str, str]:
227
+ """Get a dictionary showing the tag and its value attached to the model.
228
+
229
+ Returns:
230
+ The model version object.
231
+ """
232
+ statement_params = telemetry.get_statement_params(
233
+ project=_TELEMETRY_PROJECT,
234
+ subproject=_TELEMETRY_SUBPROJECT,
235
+ )
236
+ return self._model_ops.show_tags(model_name=self._model_name, statement_params=statement_params)
237
+
238
+ def _parse_tag_name(
239
+ self,
240
+ tag_name: str,
241
+ ) -> Tuple[sql_identifier.SqlIdentifier, sql_identifier.SqlIdentifier, sql_identifier.SqlIdentifier]:
242
+ _tag_db, _tag_schema, _tag_name, _ = identifier.parse_schema_level_object_identifier(tag_name)
243
+ if _tag_db is None:
244
+ tag_db_id = self._model_ops._model_client._database_name
245
+ else:
246
+ tag_db_id = sql_identifier.SqlIdentifier(_tag_db)
247
+
248
+ if _tag_schema is None:
249
+ tag_schema_id = self._model_ops._model_client._schema_name
250
+ else:
251
+ tag_schema_id = sql_identifier.SqlIdentifier(_tag_schema)
252
+
253
+ if _tag_name is None:
254
+ raise ValueError(f"Unable parse the tag name `{tag_name}` you input.")
255
+
256
+ tag_name_id = sql_identifier.SqlIdentifier(_tag_name)
257
+
258
+ return tag_db_id, tag_schema_id, tag_name_id
259
+
260
+ @telemetry.send_api_usage_telemetry(
261
+ project=_TELEMETRY_PROJECT,
262
+ subproject=_TELEMETRY_SUBPROJECT,
263
+ )
264
+ def get_tag(self, tag_name: str) -> Optional[str]:
265
+ """Get the value of a tag attached to the model.
266
+
267
+ Args:
268
+ tag_name: The name of the tag, can be fully qualified. If not fully qualified, the database or schema of
269
+ the model will be used.
270
+
271
+ Returns:
272
+ The tag value as a string if the tag is attached, otherwise None.
273
+ """
274
+ statement_params = telemetry.get_statement_params(
275
+ project=_TELEMETRY_PROJECT,
276
+ subproject=_TELEMETRY_SUBPROJECT,
277
+ )
278
+ tag_db_id, tag_schema_id, tag_name_id = self._parse_tag_name(tag_name)
279
+ return self._model_ops.get_tag_value(
280
+ model_name=self._model_name,
281
+ tag_database_name=tag_db_id,
282
+ tag_schema_name=tag_schema_id,
283
+ tag_name=tag_name_id,
284
+ statement_params=statement_params,
285
+ )
286
+
287
+ @telemetry.send_api_usage_telemetry(
288
+ project=_TELEMETRY_PROJECT,
289
+ subproject=_TELEMETRY_SUBPROJECT,
290
+ )
291
+ def set_tag(self, tag_name: str, tag_value: str) -> None:
292
+ """Set the value of a tag, attaching it to the model if not.
293
+
294
+ Args:
295
+ tag_name: The name of the tag, can be fully qualified. If not fully qualified, the database or schema of
296
+ the model will be used.
297
+ tag_value: The value of the tag
298
+ """
299
+ statement_params = telemetry.get_statement_params(
300
+ project=_TELEMETRY_PROJECT,
301
+ subproject=_TELEMETRY_SUBPROJECT,
302
+ )
303
+ tag_db_id, tag_schema_id, tag_name_id = self._parse_tag_name(tag_name)
304
+ self._model_ops.set_tag(
305
+ model_name=self._model_name,
306
+ tag_database_name=tag_db_id,
307
+ tag_schema_name=tag_schema_id,
308
+ tag_name=tag_name_id,
309
+ tag_value=tag_value,
310
+ statement_params=statement_params,
311
+ )
312
+
313
+ @telemetry.send_api_usage_telemetry(
314
+ project=_TELEMETRY_PROJECT,
315
+ subproject=_TELEMETRY_SUBPROJECT,
316
+ )
317
+ def unset_tag(self, tag_name: str) -> None:
318
+ """Unset a tag attached to a model.
319
+
320
+ Args:
321
+ tag_name: The name of the tag, can be fully qualified. If not fully qualified, the database or schema of
322
+ the model will be used.
323
+ """
324
+ statement_params = telemetry.get_statement_params(
325
+ project=_TELEMETRY_PROJECT,
326
+ subproject=_TELEMETRY_SUBPROJECT,
327
+ )
328
+ tag_db_id, tag_schema_id, tag_name_id = self._parse_tag_name(tag_name)
329
+ self._model_ops.unset_tag(
330
+ model_name=self._model_name,
331
+ tag_database_name=tag_db_id,
332
+ tag_schema_name=tag_schema_id,
333
+ tag_name=tag_name_id,
334
+ statement_params=statement_params,
335
+ )