snowflake-ml-python 1.5.0__py3-none-any.whl → 1.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. snowflake/cortex/_sentiment.py +7 -4
  2. snowflake/ml/_internal/env_utils.py +6 -0
  3. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  4. snowflake/ml/_internal/telemetry.py +1 -0
  5. snowflake/ml/_internal/utils/identifier.py +1 -1
  6. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  7. snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
  8. snowflake/ml/dataset/__init__.py +2 -1
  9. snowflake/ml/dataset/dataset.py +4 -3
  10. snowflake/ml/dataset/dataset_reader.py +5 -8
  11. snowflake/ml/feature_store/__init__.py +6 -0
  12. snowflake/ml/feature_store/access_manager.py +283 -0
  13. snowflake/ml/feature_store/feature_store.py +160 -100
  14. snowflake/ml/feature_store/feature_view.py +30 -19
  15. snowflake/ml/fileset/embedded_stage_fs.py +15 -12
  16. snowflake/ml/fileset/snowfs.py +2 -30
  17. snowflake/ml/fileset/stage_fs.py +25 -7
  18. snowflake/ml/model/_client/model/model_impl.py +46 -39
  19. snowflake/ml/model/_client/model/model_version_impl.py +24 -2
  20. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  21. snowflake/ml/model/_client/ops/model_ops.py +174 -16
  22. snowflake/ml/model/_client/sql/_base.py +34 -0
  23. snowflake/ml/model/_client/sql/model.py +32 -39
  24. snowflake/ml/model/_client/sql/model_version.py +111 -42
  25. snowflake/ml/model/_client/sql/stage.py +6 -32
  26. snowflake/ml/model/_client/sql/tag.py +32 -56
  27. snowflake/ml/model/_model_composer/model_composer.py +8 -4
  28. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  29. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
  30. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
  31. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +90 -142
  32. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +159 -0
  33. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +81 -3
  34. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +8 -1
  35. snowflake/ml/modeling/cluster/affinity_propagation.py +8 -1
  36. snowflake/ml/modeling/cluster/agglomerative_clustering.py +8 -1
  37. snowflake/ml/modeling/cluster/birch.py +8 -1
  38. snowflake/ml/modeling/cluster/bisecting_k_means.py +8 -1
  39. snowflake/ml/modeling/cluster/dbscan.py +8 -1
  40. snowflake/ml/modeling/cluster/feature_agglomeration.py +8 -1
  41. snowflake/ml/modeling/cluster/k_means.py +8 -1
  42. snowflake/ml/modeling/cluster/mean_shift.py +8 -1
  43. snowflake/ml/modeling/cluster/mini_batch_k_means.py +8 -1
  44. snowflake/ml/modeling/cluster/optics.py +8 -1
  45. snowflake/ml/modeling/cluster/spectral_biclustering.py +8 -1
  46. snowflake/ml/modeling/cluster/spectral_clustering.py +8 -1
  47. snowflake/ml/modeling/cluster/spectral_coclustering.py +8 -1
  48. snowflake/ml/modeling/compose/column_transformer.py +8 -1
  49. snowflake/ml/modeling/compose/transformed_target_regressor.py +8 -1
  50. snowflake/ml/modeling/covariance/elliptic_envelope.py +8 -1
  51. snowflake/ml/modeling/covariance/empirical_covariance.py +8 -1
  52. snowflake/ml/modeling/covariance/graphical_lasso.py +8 -1
  53. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +8 -1
  54. snowflake/ml/modeling/covariance/ledoit_wolf.py +8 -1
  55. snowflake/ml/modeling/covariance/min_cov_det.py +8 -1
  56. snowflake/ml/modeling/covariance/oas.py +8 -1
  57. snowflake/ml/modeling/covariance/shrunk_covariance.py +8 -1
  58. snowflake/ml/modeling/decomposition/dictionary_learning.py +8 -1
  59. snowflake/ml/modeling/decomposition/factor_analysis.py +8 -1
  60. snowflake/ml/modeling/decomposition/fast_ica.py +8 -1
  61. snowflake/ml/modeling/decomposition/incremental_pca.py +8 -1
  62. snowflake/ml/modeling/decomposition/kernel_pca.py +8 -1
  63. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +8 -1
  64. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +8 -1
  65. snowflake/ml/modeling/decomposition/pca.py +8 -1
  66. snowflake/ml/modeling/decomposition/sparse_pca.py +8 -1
  67. snowflake/ml/modeling/decomposition/truncated_svd.py +8 -1
  68. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +8 -1
  69. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +8 -1
  70. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +8 -1
  71. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +8 -1
  72. snowflake/ml/modeling/ensemble/bagging_classifier.py +8 -1
  73. snowflake/ml/modeling/ensemble/bagging_regressor.py +8 -1
  74. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +8 -1
  75. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +8 -1
  76. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +8 -1
  77. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +8 -1
  78. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +8 -1
  79. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +8 -1
  80. snowflake/ml/modeling/ensemble/isolation_forest.py +8 -1
  81. snowflake/ml/modeling/ensemble/random_forest_classifier.py +8 -1
  82. snowflake/ml/modeling/ensemble/random_forest_regressor.py +8 -1
  83. snowflake/ml/modeling/ensemble/stacking_regressor.py +8 -1
  84. snowflake/ml/modeling/ensemble/voting_classifier.py +8 -1
  85. snowflake/ml/modeling/ensemble/voting_regressor.py +8 -1
  86. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +8 -1
  87. snowflake/ml/modeling/feature_selection/select_fdr.py +8 -1
  88. snowflake/ml/modeling/feature_selection/select_fpr.py +8 -1
  89. snowflake/ml/modeling/feature_selection/select_fwe.py +8 -1
  90. snowflake/ml/modeling/feature_selection/select_k_best.py +8 -1
  91. snowflake/ml/modeling/feature_selection/select_percentile.py +8 -1
  92. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +8 -1
  93. snowflake/ml/modeling/feature_selection/variance_threshold.py +8 -1
  94. snowflake/ml/modeling/framework/base.py +4 -3
  95. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +8 -1
  96. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +8 -1
  97. snowflake/ml/modeling/impute/iterative_imputer.py +8 -1
  98. snowflake/ml/modeling/impute/knn_imputer.py +8 -1
  99. snowflake/ml/modeling/impute/missing_indicator.py +8 -1
  100. snowflake/ml/modeling/impute/simple_imputer.py +21 -2
  101. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +8 -1
  102. snowflake/ml/modeling/kernel_approximation/nystroem.py +8 -1
  103. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +8 -1
  104. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +8 -1
  105. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +8 -1
  106. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +8 -1
  107. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +8 -1
  108. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +8 -1
  109. snowflake/ml/modeling/linear_model/ard_regression.py +8 -1
  110. snowflake/ml/modeling/linear_model/bayesian_ridge.py +8 -1
  111. snowflake/ml/modeling/linear_model/elastic_net.py +8 -1
  112. snowflake/ml/modeling/linear_model/elastic_net_cv.py +8 -1
  113. snowflake/ml/modeling/linear_model/gamma_regressor.py +8 -1
  114. snowflake/ml/modeling/linear_model/huber_regressor.py +8 -1
  115. snowflake/ml/modeling/linear_model/lars.py +8 -1
  116. snowflake/ml/modeling/linear_model/lars_cv.py +8 -1
  117. snowflake/ml/modeling/linear_model/lasso.py +8 -1
  118. snowflake/ml/modeling/linear_model/lasso_cv.py +8 -1
  119. snowflake/ml/modeling/linear_model/lasso_lars.py +8 -1
  120. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +8 -1
  121. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +8 -1
  122. snowflake/ml/modeling/linear_model/linear_regression.py +8 -1
  123. snowflake/ml/modeling/linear_model/logistic_regression.py +8 -1
  124. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +8 -1
  125. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +8 -1
  126. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +8 -1
  127. snowflake/ml/modeling/linear_model/multi_task_lasso.py +8 -1
  128. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +8 -1
  129. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +8 -1
  130. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +8 -1
  131. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +8 -1
  132. snowflake/ml/modeling/linear_model/perceptron.py +8 -1
  133. snowflake/ml/modeling/linear_model/poisson_regressor.py +8 -1
  134. snowflake/ml/modeling/linear_model/ransac_regressor.py +8 -1
  135. snowflake/ml/modeling/linear_model/ridge.py +8 -1
  136. snowflake/ml/modeling/linear_model/ridge_classifier.py +8 -1
  137. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +8 -1
  138. snowflake/ml/modeling/linear_model/ridge_cv.py +8 -1
  139. snowflake/ml/modeling/linear_model/sgd_classifier.py +8 -1
  140. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +8 -1
  141. snowflake/ml/modeling/linear_model/sgd_regressor.py +8 -1
  142. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +8 -1
  143. snowflake/ml/modeling/linear_model/tweedie_regressor.py +8 -1
  144. snowflake/ml/modeling/manifold/isomap.py +8 -1
  145. snowflake/ml/modeling/manifold/mds.py +8 -1
  146. snowflake/ml/modeling/manifold/spectral_embedding.py +8 -1
  147. snowflake/ml/modeling/manifold/tsne.py +8 -1
  148. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +8 -1
  149. snowflake/ml/modeling/mixture/gaussian_mixture.py +8 -1
  150. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +8 -1
  151. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +8 -1
  152. snowflake/ml/modeling/multiclass/output_code_classifier.py +8 -1
  153. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +8 -1
  154. snowflake/ml/modeling/naive_bayes/categorical_nb.py +8 -1
  155. snowflake/ml/modeling/naive_bayes/complement_nb.py +8 -1
  156. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +8 -1
  157. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +8 -1
  158. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +8 -1
  159. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +8 -1
  160. snowflake/ml/modeling/neighbors/kernel_density.py +8 -1
  161. snowflake/ml/modeling/neighbors/local_outlier_factor.py +8 -1
  162. snowflake/ml/modeling/neighbors/nearest_centroid.py +8 -1
  163. snowflake/ml/modeling/neighbors/nearest_neighbors.py +8 -1
  164. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +8 -1
  165. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +8 -1
  166. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +8 -1
  167. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +8 -1
  168. snowflake/ml/modeling/neural_network/mlp_classifier.py +8 -1
  169. snowflake/ml/modeling/neural_network/mlp_regressor.py +8 -1
  170. snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
  171. snowflake/ml/modeling/pipeline/pipeline.py +27 -7
  172. snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -1
  173. snowflake/ml/modeling/semi_supervised/label_propagation.py +8 -1
  174. snowflake/ml/modeling/semi_supervised/label_spreading.py +8 -1
  175. snowflake/ml/modeling/svm/linear_svc.py +8 -1
  176. snowflake/ml/modeling/svm/linear_svr.py +8 -1
  177. snowflake/ml/modeling/svm/nu_svc.py +8 -1
  178. snowflake/ml/modeling/svm/nu_svr.py +8 -1
  179. snowflake/ml/modeling/svm/svc.py +8 -1
  180. snowflake/ml/modeling/svm/svr.py +8 -1
  181. snowflake/ml/modeling/tree/decision_tree_classifier.py +8 -1
  182. snowflake/ml/modeling/tree/decision_tree_regressor.py +8 -1
  183. snowflake/ml/modeling/tree/extra_tree_classifier.py +8 -1
  184. snowflake/ml/modeling/tree/extra_tree_regressor.py +8 -1
  185. snowflake/ml/modeling/xgboost/xgb_classifier.py +8 -1
  186. snowflake/ml/modeling/xgboost/xgb_regressor.py +8 -1
  187. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +8 -1
  188. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +8 -1
  189. snowflake/ml/registry/_manager/model_manager.py +95 -8
  190. snowflake/ml/registry/registry.py +10 -1
  191. snowflake/ml/version.py +1 -1
  192. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/METADATA +66 -10
  193. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/RECORD +196 -192
  194. snowflake/ml/_internal/lineage/dataset_dataframe.py +0 -44
  195. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/LICENSE.txt +0 -0
  196. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/WHEEL +0 -0
  197. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,9 @@
1
1
  import collections
2
2
  import logging
3
3
  import re
4
- from typing import Any, Dict, Optional
4
+ from typing import Any, Optional
5
5
 
6
6
  import fsspec
7
- import packaging.version as pkg_version
8
7
 
9
8
  from snowflake import snowpark
10
9
  from snowflake.connector import connection
@@ -12,7 +11,7 @@ from snowflake.ml._internal.exceptions import (
12
11
  error_codes,
13
12
  exceptions as snowml_exceptions,
14
13
  )
15
- from snowflake.ml._internal.utils import identifier, snowflake_env
14
+ from snowflake.ml._internal.utils import identifier
16
15
  from snowflake.ml.fileset import embedded_stage_fs, sfcfs
17
16
 
18
17
  PROTOCOL_NAME = "snow"
@@ -28,10 +27,6 @@ _SNOWURL_PATTERN = re.compile(
28
27
  r"(?P<path>versions/(?:(?P<version>[^/]+)(?:/(?P<relpath>.*))?)?)"
29
28
  )
30
29
 
31
- # FIXME(dhung): Temporary fix for bug in GS version 8.17
32
- _BUG_VERSION_MIN = pkg_version.Version("8.17") # Inclusive minimum version with bugged behavior
33
- _BUG_VERSION_MAX = pkg_version.Version("8.18") # Exclusive maximum version with bugged behavior
34
-
35
30
 
36
31
  class SnowFileSystem(sfcfs.SFFileSystem):
37
32
  """A filesystem that allows user to access Snowflake embedded stage files with valid Snowflake locations.
@@ -54,21 +49,6 @@ class SnowFileSystem(sfcfs.SFFileSystem):
54
49
  ) -> None:
55
50
  super().__init__(sf_connection=sf_connection, snowpark_session=snowpark_session, **kwargs)
56
51
 
57
- # FIXME(dhung): Temporary fix for bug in GS version 8.17
58
- if SnowFileSystem._IS_BUGGED_VERSION is None:
59
- try:
60
- sf_version = snowflake_env.get_current_snowflake_version(self._session)
61
- SnowFileSystem._IS_BUGGED_VERSION = _BUG_VERSION_MIN <= sf_version < _BUG_VERSION_MAX
62
- except Exception:
63
- SnowFileSystem._IS_BUGGED_VERSION = False
64
-
65
- def info(self, path: str, **kwargs: Any) -> Dict[str, Any]:
66
- # FIXME(dhung): Temporary fix for bug in GS version 8.17
67
- res: Dict[str, Any] = super().info(path, **kwargs)
68
- if res.get("type") == "directory" and not res["name"].endswith("/"):
69
- res["name"] += "/"
70
- return res
71
-
72
52
  def _get_stage_fs(
73
53
  self, sf_file_path: _SFFileEntityPath # type: ignore[override]
74
54
  ) -> embedded_stage_fs.SFEmbeddedStageFileSystem:
@@ -100,11 +80,6 @@ class SnowFileSystem(sfcfs.SFFileSystem):
100
80
  if stage_name.startswith(protocol):
101
81
  stage_name = stage_name[len(protocol) :]
102
82
  abs_path = stage_name + "/" + path
103
- # FIXME(dhung): Temporary fix for bug in GS version 8.17
104
- if self._IS_BUGGED_VERSION:
105
- match = _SNOWURL_PATTERN.fullmatch(abs_path)
106
- assert match is not None
107
- abs_path = abs_path.replace(match.group("relpath"), match.group("relpath").lstrip("/"))
108
83
  return abs_path
109
84
 
110
85
  @classmethod
@@ -143,9 +118,6 @@ class SnowFileSystem(sfcfs.SFFileSystem):
143
118
  version = snowurl_match.group("version")
144
119
  relative_path = snowurl_match.group("relpath") or ""
145
120
  logging.debug(f"Parsed snow URL: {snowurl_match.groups()}")
146
- # FIXME(dhung): Temporary fix for bug in GS version 8.17
147
- if cls._IS_BUGGED_VERSION:
148
- filepath = filepath.replace(f"{version}/", f"{version}//")
149
121
  return _SFFileEntityPath(
150
122
  domain=domain, name=name, version=version, relative_path=relative_path, filepath=filepath
151
123
  )
@@ -2,13 +2,13 @@ import inspect
2
2
  import logging
3
3
  import time
4
4
  from dataclasses import dataclass
5
- from typing import Any, Dict, List, Optional, Tuple, Union
5
+ from typing import Any, Dict, List, Optional, Tuple, Union, cast
6
6
 
7
7
  import fsspec
8
8
  from fsspec.implementations import http as httpfs
9
9
 
10
10
  from snowflake import snowpark
11
- from snowflake.connector import connection, errorcode
11
+ from snowflake.connector import connection, errorcode, errors as snowpark_errors
12
12
  from snowflake.ml._internal import telemetry
13
13
  from snowflake.ml._internal.exceptions import (
14
14
  error_codes,
@@ -18,6 +18,7 @@ from snowflake.ml._internal.exceptions import (
18
18
  )
19
19
  from snowflake.snowpark import exceptions as snowpark_exceptions
20
20
  from snowflake.snowpark._internal import utils as snowpark_utils
21
+ from snowflake.snowpark._internal.analyzer import snowflake_plan
21
22
 
22
23
  # The default length of how long a presigned url stays active in seconds.
23
24
  # Presigned url here is used to fetch file objects from Snowflake when SFStageFileSystem.open() is called.
@@ -167,7 +168,8 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
167
168
  try:
168
169
  loc = self.stage_name
169
170
  path = path.lstrip("/")
170
- objects = self._session.sql(f"LIST '{loc}/{path}'").collect()
171
+ async_job: snowpark.AsyncJob = self._session.sql(f"LIST '{loc}/{path}'").collect(block=False)
172
+ objects: List[snowpark.Row] = _resolve_async_job(async_job)
171
173
  except snowpark_exceptions.SnowparkClientException as e:
172
174
  if e.message.startswith(fileset_errors.ERRNO_DOMAIN_NOT_EXIST):
173
175
  raise snowml_exceptions.SnowflakeMLException(
@@ -289,9 +291,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
289
291
  original_exception=e,
290
292
  )
291
293
 
292
- def _parse_list_result(
293
- self, list_result: List[Tuple[str, int, str, str]], search_path: str
294
- ) -> List[Dict[str, Any]]:
294
+ def _parse_list_result(self, list_result: List[snowpark.Row], search_path: str) -> List[Dict[str, Any]]:
295
295
  """Convert the result from LIST query to the expected format of fsspec ls() method.
296
296
 
297
297
  Note that Snowflake LIST query has different behavior with ls(). LIST query will return all the stage files
@@ -312,7 +312,8 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
312
312
  """
313
313
  files: Dict[str, Dict[str, Any]] = {}
314
314
  search_path = search_path.strip("/")
315
- for name, size, md5, last_modified in list_result:
315
+ for row in list_result:
316
+ name, size, md5, last_modified = row["name"], row["size"], row["md5"], row["last_modified"]
316
317
  obj_path = self._stage_path_to_relative_path(name)
317
318
  if obj_path == search_path:
318
319
  # If there is a exact match, then the matched object will always be a file object.
@@ -408,3 +409,20 @@ def _match_error_code(ex: snowpark_exceptions.SnowparkSQLException, error_code:
408
409
  # Snowpark writes error code to message instead of populating e.error_code
409
410
  error_code_str = str(error_code)
410
411
  return ex.error_code == error_code_str or error_code_str in ex.message
412
+
413
+
414
+ @snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
415
+ def _resolve_async_job(async_job: snowpark.AsyncJob) -> List[snowpark.Row]:
416
+ # Make sure Snowpark exceptions are properly caught and converted by wrap_exception wrapper
417
+ try:
418
+ query_result = cast(List[snowpark.Row], async_job.result("row"))
419
+ return query_result
420
+ except snowpark_errors.DatabaseError as e:
421
+ # HACK: Snowpark surfaces a generic exception if query doesn't complete immediately
422
+ # assume it's due to FileNotFound
423
+ if type(e) is snowpark_errors.DatabaseError and "results are unavailable" in str(e):
424
+ raise snowml_exceptions.SnowflakeMLException(
425
+ error_code=error_codes.SNOWML_NOT_FOUND,
426
+ original_exception=fileset_errors.StageNotFoundError("Query failed."),
427
+ ) from e
428
+ raise
@@ -1,9 +1,9 @@
1
- from typing import Dict, List, Optional, Tuple, Union
1
+ from typing import Dict, List, Optional, Union
2
2
 
3
3
  import pandas as pd
4
4
 
5
5
  from snowflake.ml._internal import telemetry
6
- from snowflake.ml._internal.utils import identifier, sql_identifier
6
+ from snowflake.ml._internal.utils import sql_identifier
7
7
  from snowflake.ml.model._client.model import model_version_impl
8
8
  from snowflake.ml.model._client.ops import model_ops
9
9
 
@@ -45,7 +45,7 @@ class Model:
45
45
  @property
46
46
  def fully_qualified_name(self) -> str:
47
47
  """Return the fully qualified name of the model that can be used to refer to it in SQL."""
48
- return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name)
48
+ return self._model_ops._model_version_client.fully_qualified_object_name(None, None, self._model_name)
49
49
 
50
50
  @property
51
51
  @telemetry.send_api_usage_telemetry(
@@ -76,6 +76,8 @@ class Model:
76
76
  subproject=_TELEMETRY_SUBPROJECT,
77
77
  )
78
78
  return self._model_ops.get_comment(
79
+ database_name=None,
80
+ schema_name=None,
79
81
  model_name=self._model_name,
80
82
  statement_params=statement_params,
81
83
  )
@@ -92,6 +94,8 @@ class Model:
92
94
  )
93
95
  return self._model_ops.set_comment(
94
96
  comment=comment,
97
+ database_name=None,
98
+ schema_name=None,
95
99
  model_name=self._model_name,
96
100
  statement_params=statement_params,
97
101
  )
@@ -109,7 +113,7 @@ class Model:
109
113
  class_name=self.__class__.__name__,
110
114
  )
111
115
  default_version_name = self._model_ops.get_default_version(
112
- model_name=self._model_name, statement_params=statement_params
116
+ database_name=None, schema_name=None, model_name=self._model_name, statement_params=statement_params
113
117
  )
114
118
  return self.version(default_version_name)
115
119
 
@@ -129,7 +133,11 @@ class Model:
129
133
  else:
130
134
  version_name = version._version_name
131
135
  self._model_ops.set_default_version(
132
- model_name=self._model_name, version_name=version_name, statement_params=statement_params
136
+ database_name=None,
137
+ schema_name=None,
138
+ model_name=self._model_name,
139
+ version_name=version_name,
140
+ statement_params=statement_params,
133
141
  )
134
142
 
135
143
  @telemetry.send_api_usage_telemetry(
@@ -155,6 +163,8 @@ class Model:
155
163
  )
156
164
  version_id = sql_identifier.SqlIdentifier(version_name)
157
165
  if self._model_ops.validate_existence(
166
+ database_name=None,
167
+ schema_name=None,
158
168
  model_name=self._model_name,
159
169
  version_name=version_id,
160
170
  statement_params=statement_params,
@@ -184,6 +194,8 @@ class Model:
184
194
  subproject=_TELEMETRY_SUBPROJECT,
185
195
  )
186
196
  version_names = self._model_ops.list_models_or_versions(
197
+ database_name=None,
198
+ schema_name=None,
187
199
  model_name=self._model_name,
188
200
  statement_params=statement_params,
189
201
  )
@@ -211,6 +223,8 @@ class Model:
211
223
  subproject=_TELEMETRY_SUBPROJECT,
212
224
  )
213
225
  rows = self._model_ops.show_models_or_versions(
226
+ database_name=None,
227
+ schema_name=None,
214
228
  model_name=self._model_name,
215
229
  statement_params=statement_params,
216
230
  )
@@ -231,6 +245,8 @@ class Model:
231
245
  subproject=_TELEMETRY_SUBPROJECT,
232
246
  )
233
247
  self._model_ops.delete_model_or_version(
248
+ database_name=None,
249
+ schema_name=None,
234
250
  model_name=self._model_name,
235
251
  version_name=sql_identifier.SqlIdentifier(version_name),
236
252
  statement_params=statement_params,
@@ -250,29 +266,9 @@ class Model:
250
266
  project=_TELEMETRY_PROJECT,
251
267
  subproject=_TELEMETRY_SUBPROJECT,
252
268
  )
253
- return self._model_ops.show_tags(model_name=self._model_name, statement_params=statement_params)
254
-
255
- def _parse_tag_name(
256
- self,
257
- tag_name: str,
258
- ) -> Tuple[sql_identifier.SqlIdentifier, sql_identifier.SqlIdentifier, sql_identifier.SqlIdentifier]:
259
- _tag_db, _tag_schema, _tag_name, _ = identifier.parse_schema_level_object_identifier(tag_name)
260
- if _tag_db is None:
261
- tag_db_id = self._model_ops._model_client._database_name
262
- else:
263
- tag_db_id = sql_identifier.SqlIdentifier(_tag_db)
264
-
265
- if _tag_schema is None:
266
- tag_schema_id = self._model_ops._model_client._schema_name
267
- else:
268
- tag_schema_id = sql_identifier.SqlIdentifier(_tag_schema)
269
-
270
- if _tag_name is None:
271
- raise ValueError(f"Unable parse the tag name `{tag_name}` you input.")
272
-
273
- tag_name_id = sql_identifier.SqlIdentifier(_tag_name)
274
-
275
- return tag_db_id, tag_schema_id, tag_name_id
269
+ return self._model_ops.show_tags(
270
+ database_name=None, schema_name=None, model_name=self._model_name, statement_params=statement_params
271
+ )
276
272
 
277
273
  @telemetry.send_api_usage_telemetry(
278
274
  project=_TELEMETRY_PROJECT,
@@ -292,8 +288,10 @@ class Model:
292
288
  project=_TELEMETRY_PROJECT,
293
289
  subproject=_TELEMETRY_SUBPROJECT,
294
290
  )
295
- tag_db_id, tag_schema_id, tag_name_id = self._parse_tag_name(tag_name)
291
+ tag_db_id, tag_schema_id, tag_name_id = sql_identifier.parse_fully_qualified_name(tag_name)
296
292
  return self._model_ops.get_tag_value(
293
+ database_name=None,
294
+ schema_name=None,
297
295
  model_name=self._model_name,
298
296
  tag_database_name=tag_db_id,
299
297
  tag_schema_name=tag_schema_id,
@@ -317,8 +315,10 @@ class Model:
317
315
  project=_TELEMETRY_PROJECT,
318
316
  subproject=_TELEMETRY_SUBPROJECT,
319
317
  )
320
- tag_db_id, tag_schema_id, tag_name_id = self._parse_tag_name(tag_name)
318
+ tag_db_id, tag_schema_id, tag_name_id = sql_identifier.parse_fully_qualified_name(tag_name)
321
319
  self._model_ops.set_tag(
320
+ database_name=None,
321
+ schema_name=None,
322
322
  model_name=self._model_name,
323
323
  tag_database_name=tag_db_id,
324
324
  tag_schema_name=tag_schema_id,
@@ -342,8 +342,10 @@ class Model:
342
342
  project=_TELEMETRY_PROJECT,
343
343
  subproject=_TELEMETRY_SUBPROJECT,
344
344
  )
345
- tag_db_id, tag_schema_id, tag_name_id = self._parse_tag_name(tag_name)
345
+ tag_db_id, tag_schema_id, tag_name_id = sql_identifier.parse_fully_qualified_name(tag_name)
346
346
  self._model_ops.unset_tag(
347
+ database_name=None,
348
+ schema_name=None,
347
349
  model_name=self._model_name,
348
350
  tag_database_name=tag_db_id,
349
351
  tag_schema_name=tag_schema_id,
@@ -365,15 +367,20 @@ class Model:
365
367
  project=_TELEMETRY_PROJECT,
366
368
  subproject=_TELEMETRY_SUBPROJECT,
367
369
  )
368
- db, schema, model, _ = identifier.parse_schema_level_object_identifier(model_name)
369
- new_model_db = sql_identifier.SqlIdentifier(db) if db else None
370
- new_model_schema = sql_identifier.SqlIdentifier(schema) if schema else None
371
- new_model_id = sql_identifier.SqlIdentifier(model)
370
+ new_db, new_schema, new_model = sql_identifier.parse_fully_qualified_name(model_name)
371
+
372
372
  self._model_ops.rename(
373
+ database_name=None,
374
+ schema_name=None,
373
375
  model_name=self._model_name,
374
- new_model_db=new_model_db,
375
- new_model_schema=new_model_schema,
376
- new_model_name=new_model_id,
376
+ new_model_db=new_db,
377
+ new_model_schema=new_schema,
378
+ new_model_name=new_model,
377
379
  statement_params=statement_params,
378
380
  )
379
- self._model_name = new_model_id
381
+ self._model_ops = model_ops.ModelOperator(
382
+ self._model_ops._session,
383
+ database_name=new_db or self._model_ops._model_client._database_name,
384
+ schema_name=new_schema or self._model_ops._model_client._schema_name,
385
+ )
386
+ self._model_name = new_model
@@ -72,7 +72,7 @@ class ModelVersion:
72
72
  @property
73
73
  def fully_qualified_model_name(self) -> str:
74
74
  """Return the fully qualified name of the model to which the model version belongs."""
75
- return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name)
75
+ return self._model_ops._model_version_client.fully_qualified_object_name(None, None, self._model_name)
76
76
 
77
77
  @property
78
78
  @telemetry.send_api_usage_telemetry(
@@ -103,6 +103,8 @@ class ModelVersion:
103
103
  subproject=_TELEMETRY_SUBPROJECT,
104
104
  )
105
105
  return self._model_ops.get_comment(
106
+ database_name=None,
107
+ schema_name=None,
106
108
  model_name=self._model_name,
107
109
  version_name=self._version_name,
108
110
  statement_params=statement_params,
@@ -120,6 +122,8 @@ class ModelVersion:
120
122
  )
121
123
  return self._model_ops.set_comment(
122
124
  comment=comment,
125
+ database_name=None,
126
+ schema_name=None,
123
127
  model_name=self._model_name,
124
128
  version_name=self._version_name,
125
129
  statement_params=statement_params,
@@ -140,7 +144,11 @@ class ModelVersion:
140
144
  subproject=_TELEMETRY_SUBPROJECT,
141
145
  )
142
146
  return self._model_ops._metadata_ops.load(
143
- model_name=self._model_name, version_name=self._version_name, statement_params=statement_params
147
+ database_name=None,
148
+ schema_name=None,
149
+ model_name=self._model_name,
150
+ version_name=self._version_name,
151
+ statement_params=statement_params,
144
152
  )["metrics"]
145
153
 
146
154
  @telemetry.send_api_usage_telemetry(
@@ -183,6 +191,8 @@ class ModelVersion:
183
191
  metrics[metric_name] = value
184
192
  self._model_ops._metadata_ops.save(
185
193
  metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
194
+ database_name=None,
195
+ schema_name=None,
186
196
  model_name=self._model_name,
187
197
  version_name=self._version_name,
188
198
  statement_params=statement_params,
@@ -211,6 +221,8 @@ class ModelVersion:
211
221
  del metrics[metric_name]
212
222
  self._model_ops._metadata_ops.save(
213
223
  metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
224
+ database_name=None,
225
+ schema_name=None,
214
226
  model_name=self._model_name,
215
227
  version_name=self._version_name,
216
228
  statement_params=statement_params,
@@ -222,6 +234,8 @@ class ModelVersion:
222
234
  subproject=_TELEMETRY_SUBPROJECT,
223
235
  )
224
236
  return self._model_ops.get_functions(
237
+ database_name=None,
238
+ schema_name=None,
225
239
  model_name=self._model_name,
226
240
  version_name=self._version_name,
227
241
  statement_params=statement_params,
@@ -309,6 +323,8 @@ class ModelVersion:
309
323
  method_function_type=target_function_info["target_method_function_type"],
310
324
  signature=target_function_info["signature"],
311
325
  X=X,
326
+ database_name=None,
327
+ schema_name=None,
312
328
  model_name=self._model_name,
313
329
  version_name=self._version_name,
314
330
  strict_input_validation=strict_input_validation,
@@ -341,6 +357,8 @@ class ModelVersion:
341
357
  subproject=_TELEMETRY_SUBPROJECT,
342
358
  )
343
359
  self._model_ops.download_files(
360
+ database_name=None,
361
+ schema_name=None,
344
362
  model_name=self._model_name,
345
363
  version_name=self._version_name,
346
364
  target_path=target_local_path,
@@ -380,6 +398,8 @@ class ModelVersion:
380
398
  with tempfile.TemporaryDirectory() as tmp_workspace_for_validation:
381
399
  ws_path_for_validation = pathlib.Path(tmp_workspace_for_validation)
382
400
  self._model_ops.download_files(
401
+ database_name=None,
402
+ schema_name=None,
383
403
  model_name=self._model_name,
384
404
  version_name=self._version_name,
385
405
  target_path=ws_path_for_validation,
@@ -417,6 +437,8 @@ class ModelVersion:
417
437
  # We need the folder to be existed.
418
438
  workspace = pathlib.Path(tempfile.mkdtemp())
419
439
  self._model_ops.download_files(
440
+ database_name=None,
441
+ schema_name=None,
420
442
  model_name=self._model_name,
421
443
  version_name=self._version_name,
422
444
  target_path=workspace,
@@ -61,12 +61,18 @@ class MetadataOperator:
61
61
  def _get_current_metadata_dict(
62
62
  self,
63
63
  *,
64
+ database_name: Optional[sql_identifier.SqlIdentifier],
65
+ schema_name: Optional[sql_identifier.SqlIdentifier],
64
66
  model_name: sql_identifier.SqlIdentifier,
65
67
  version_name: sql_identifier.SqlIdentifier,
66
68
  statement_params: Optional[Dict[str, Any]] = None,
67
69
  ) -> Dict[str, Any]:
68
70
  version_info_list = self._model_client.show_versions(
69
- model_name=model_name, version_name=version_name, statement_params=statement_params
71
+ database_name=database_name,
72
+ schema_name=schema_name,
73
+ model_name=model_name,
74
+ version_name=version_name,
75
+ statement_params=statement_params,
70
76
  )
71
77
  metadata_str = version_info_list[0][self._model_client.MODEL_VERSION_METADATA_COL_NAME]
72
78
  if not metadata_str:
@@ -79,12 +85,18 @@ class MetadataOperator:
79
85
  def load(
80
86
  self,
81
87
  *,
88
+ database_name: Optional[sql_identifier.SqlIdentifier],
89
+ schema_name: Optional[sql_identifier.SqlIdentifier],
82
90
  model_name: sql_identifier.SqlIdentifier,
83
91
  version_name: sql_identifier.SqlIdentifier,
84
92
  statement_params: Optional[Dict[str, Any]] = None,
85
93
  ) -> ModelVersionMetadataSchema:
86
94
  metadata_dict = self._get_current_metadata_dict(
87
- model_name=model_name, version_name=version_name, statement_params=statement_params
95
+ database_name=database_name,
96
+ schema_name=schema_name,
97
+ model_name=model_name,
98
+ version_name=version_name,
99
+ statement_params=statement_params,
88
100
  )
89
101
  return MetadataOperator._parse(metadata_dict)
90
102
 
@@ -92,14 +104,25 @@ class MetadataOperator:
92
104
  self,
93
105
  metadata: ModelVersionMetadataSchema,
94
106
  *,
107
+ database_name: Optional[sql_identifier.SqlIdentifier],
108
+ schema_name: Optional[sql_identifier.SqlIdentifier],
95
109
  model_name: sql_identifier.SqlIdentifier,
96
110
  version_name: sql_identifier.SqlIdentifier,
97
111
  statement_params: Optional[Dict[str, Any]] = None,
98
112
  ) -> None:
99
113
  metadata_dict = self._get_current_metadata_dict(
100
- model_name=model_name, version_name=version_name, statement_params=statement_params
114
+ database_name=database_name,
115
+ schema_name=schema_name,
116
+ model_name=model_name,
117
+ version_name=version_name,
118
+ statement_params=statement_params,
101
119
  )
102
120
  metadata_dict.update({**metadata, "snowpark_ml_schema_version": MODEL_VERSION_METADATA_SCHEMA_VERSION})
103
121
  self._model_version_client.set_metadata(
104
- metadata_dict, model_name=model_name, version_name=version_name, statement_params=statement_params
122
+ metadata_dict,
123
+ database_name=database_name,
124
+ schema_name=schema_name,
125
+ model_name=model_name,
126
+ version_name=version_name,
127
+ statement_params=statement_params,
105
128
  )