snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,11 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
- from snowflake.ml._internal.utils import (
4
- identifier,
5
- query_result_checker,
6
- sql_identifier,
7
- )
8
- from snowflake.snowpark import row, session
3
+ from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
+ from snowflake.ml.model._client.sql import _base
5
+ from snowflake.snowpark import row
9
6
 
10
7
 
11
- class ModelSQLClient:
8
+ class ModelSQLClient(_base._BaseSQLClient):
12
9
  MODEL_NAME_COL_NAME = "name"
13
10
  MODEL_COMMENT_COL_NAME = "comment"
14
11
  MODEL_DEFAULT_VERSION_NAME_COL_NAME = "default_version_name"
@@ -18,35 +15,18 @@ class ModelSQLClient:
18
15
  MODEL_VERSION_METADATA_COL_NAME = "metadata"
19
16
  MODEL_VERSION_MODEL_SPEC_COL_NAME = "model_spec"
20
17
 
21
- def __init__(
22
- self,
23
- session: session.Session,
24
- *,
25
- database_name: sql_identifier.SqlIdentifier,
26
- schema_name: sql_identifier.SqlIdentifier,
27
- ) -> None:
28
- self._session = session
29
- self._database_name = database_name
30
- self._schema_name = schema_name
31
-
32
- def __eq__(self, __value: object) -> bool:
33
- if not isinstance(__value, ModelSQLClient):
34
- return False
35
- return self._database_name == __value._database_name and self._schema_name == __value._schema_name
36
-
37
- def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
38
- return identifier.get_schema_level_object_identifier(
39
- self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
40
- )
41
-
42
18
  def show_models(
43
19
  self,
44
20
  *,
21
+ database_name: Optional[sql_identifier.SqlIdentifier],
22
+ schema_name: Optional[sql_identifier.SqlIdentifier],
45
23
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
46
24
  validate_result: bool = True,
47
25
  statement_params: Optional[Dict[str, Any]] = None,
48
26
  ) -> List[row.Row]:
49
- fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
27
+ actual_database_name = database_name or self._database_name
28
+ actual_schema_name = schema_name or self._schema_name
29
+ fully_qualified_schema_name = ".".join([actual_database_name.identifier(), actual_schema_name.identifier()])
50
30
  like_sql = ""
51
31
  if model_name:
52
32
  like_sql = f" LIKE '{model_name.resolved()}'"
@@ -69,6 +49,8 @@ class ModelSQLClient:
69
49
  def show_versions(
70
50
  self,
71
51
  *,
52
+ database_name: Optional[sql_identifier.SqlIdentifier],
53
+ schema_name: Optional[sql_identifier.SqlIdentifier],
72
54
  model_name: sql_identifier.SqlIdentifier,
73
55
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
74
56
  validate_result: bool = True,
@@ -82,7 +64,10 @@ class ModelSQLClient:
82
64
  res = (
83
65
  query_result_checker.SqlResultValidator(
84
66
  self._session,
85
- f"SHOW VERSIONS{like_sql} IN MODEL {self.fully_qualified_model_name(model_name)}",
67
+ (
68
+ f"SHOW VERSIONS{like_sql} IN "
69
+ f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
70
+ ),
86
71
  statement_params=statement_params,
87
72
  )
88
73
  .has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
@@ -99,43 +84,53 @@ class ModelSQLClient:
99
84
  def set_comment(
100
85
  self,
101
86
  *,
102
- comment: str,
87
+ database_name: Optional[sql_identifier.SqlIdentifier],
88
+ schema_name: Optional[sql_identifier.SqlIdentifier],
103
89
  model_name: sql_identifier.SqlIdentifier,
90
+ comment: str,
104
91
  statement_params: Optional[Dict[str, Any]] = None,
105
92
  ) -> None:
106
93
  query_result_checker.SqlResultValidator(
107
94
  self._session,
108
- f"COMMENT ON MODEL {self.fully_qualified_model_name(model_name)} IS $${comment}$$",
95
+ (
96
+ f"COMMENT ON MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
97
+ f" IS $${comment}$$"
98
+ ),
109
99
  statement_params=statement_params,
110
100
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
111
101
 
112
102
  def drop_model(
113
103
  self,
114
104
  *,
105
+ database_name: Optional[sql_identifier.SqlIdentifier],
106
+ schema_name: Optional[sql_identifier.SqlIdentifier],
115
107
  model_name: sql_identifier.SqlIdentifier,
116
108
  statement_params: Optional[Dict[str, Any]] = None,
117
109
  ) -> None:
118
110
  query_result_checker.SqlResultValidator(
119
111
  self._session,
120
- f"DROP MODEL {self.fully_qualified_model_name(model_name)}",
112
+ f"DROP MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}",
121
113
  statement_params=statement_params,
122
114
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
123
115
 
124
- def config_model_details(
116
+ def rename(
125
117
  self,
126
118
  *,
127
- enable: bool,
119
+ database_name: Optional[sql_identifier.SqlIdentifier],
120
+ schema_name: Optional[sql_identifier.SqlIdentifier],
121
+ model_name: sql_identifier.SqlIdentifier,
122
+ new_model_db: Optional[sql_identifier.SqlIdentifier],
123
+ new_model_schema: Optional[sql_identifier.SqlIdentifier],
124
+ new_model_name: sql_identifier.SqlIdentifier,
128
125
  statement_params: Optional[Dict[str, Any]] = None,
129
126
  ) -> None:
130
- if enable:
131
- query_result_checker.SqlResultValidator(
132
- self._session,
133
- "ALTER SESSION SET SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL=true",
134
- statement_params=statement_params,
135
- ).has_dimensions(expected_rows=1, expected_cols=1).validate()
136
- else:
137
- query_result_checker.SqlResultValidator(
138
- self._session,
139
- "ALTER SESSION UNSET SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL",
140
- statement_params=statement_params,
141
- ).has_dimensions(expected_rows=1, expected_cols=1).validate()
127
+ # Use registry's database and schema if a non fully qualified new model name is provided.
128
+ new_fully_qualified_name = self.fully_qualified_object_name(new_model_db, new_model_schema, new_model_name)
129
+ query_result_checker.SqlResultValidator(
130
+ self._session,
131
+ (
132
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
133
+ f" RENAME TO {new_fully_qualified_name}"
134
+ ),
135
+ statement_params=statement_params,
136
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -9,7 +9,8 @@ from snowflake.ml._internal.utils import (
9
9
  query_result_checker,
10
10
  sql_identifier,
11
11
  )
12
- from snowflake.snowpark import dataframe, functions as F, row, session, types as spt
12
+ from snowflake.ml.model._client.sql import _base
13
+ from snowflake.snowpark import dataframe, functions as F, row, types as spt
13
14
  from snowflake.snowpark._internal import utils as snowpark_utils
14
15
 
15
16
 
@@ -20,34 +21,15 @@ def _normalize_url_for_sql(url: str) -> str:
20
21
  return f"'{url}'"
21
22
 
22
23
 
23
- class ModelVersionSQLClient:
24
+ class ModelVersionSQLClient(_base._BaseSQLClient):
24
25
  FUNCTION_NAME_COL_NAME = "name"
25
26
  FUNCTION_RETURN_TYPE_COL_NAME = "return_type"
26
27
 
27
- def __init__(
28
- self,
29
- session: session.Session,
30
- *,
31
- database_name: sql_identifier.SqlIdentifier,
32
- schema_name: sql_identifier.SqlIdentifier,
33
- ) -> None:
34
- self._session = session
35
- self._database_name = database_name
36
- self._schema_name = schema_name
37
-
38
- def __eq__(self, __value: object) -> bool:
39
- if not isinstance(__value, ModelVersionSQLClient):
40
- return False
41
- return self._database_name == __value._database_name and self._schema_name == __value._schema_name
42
-
43
- def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
44
- return identifier.get_schema_level_object_identifier(
45
- self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
46
- )
47
-
48
28
  def create_from_stage(
49
29
  self,
50
30
  *,
31
+ database_name: Optional[sql_identifier.SqlIdentifier],
32
+ schema_name: Optional[sql_identifier.SqlIdentifier],
51
33
  model_name: sql_identifier.SqlIdentifier,
52
34
  version_name: sql_identifier.SqlIdentifier,
53
35
  stage_path: str,
@@ -56,8 +38,8 @@ class ModelVersionSQLClient:
56
38
  query_result_checker.SqlResultValidator(
57
39
  self._session,
58
40
  (
59
- f"CREATE MODEL {self.fully_qualified_model_name(model_name)} WITH VERSION {version_name.identifier()}"
60
- f" FROM {stage_path}"
41
+ f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
42
+ f" WITH VERSION {version_name.identifier()} FROM {stage_path}"
61
43
  ),
62
44
  statement_params=statement_params,
63
45
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -66,6 +48,8 @@ class ModelVersionSQLClient:
66
48
  def add_version_from_stage(
67
49
  self,
68
50
  *,
51
+ database_name: Optional[sql_identifier.SqlIdentifier],
52
+ schema_name: Optional[sql_identifier.SqlIdentifier],
69
53
  model_name: sql_identifier.SqlIdentifier,
70
54
  version_name: sql_identifier.SqlIdentifier,
71
55
  stage_path: str,
@@ -74,8 +58,8 @@ class ModelVersionSQLClient:
74
58
  query_result_checker.SqlResultValidator(
75
59
  self._session,
76
60
  (
77
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} ADD VERSION {version_name.identifier()}"
78
- f" FROM {stage_path}"
61
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
62
+ f" ADD VERSION {version_name.identifier()} FROM {stage_path}"
79
63
  ),
80
64
  statement_params=statement_params,
81
65
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -83,6 +67,8 @@ class ModelVersionSQLClient:
83
67
  def set_default_version(
84
68
  self,
85
69
  *,
70
+ database_name: Optional[sql_identifier.SqlIdentifier],
71
+ schema_name: Optional[sql_identifier.SqlIdentifier],
86
72
  model_name: sql_identifier.SqlIdentifier,
87
73
  version_name: sql_identifier.SqlIdentifier,
88
74
  statement_params: Optional[Dict[str, Any]] = None,
@@ -90,15 +76,54 @@ class ModelVersionSQLClient:
90
76
  query_result_checker.SqlResultValidator(
91
77
  self._session,
92
78
  (
93
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
79
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
94
80
  f"SET DEFAULT_VERSION = {version_name.identifier()}"
95
81
  ),
96
82
  statement_params=statement_params,
97
83
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
98
84
 
85
+ def list_file(
86
+ self,
87
+ *,
88
+ database_name: Optional[sql_identifier.SqlIdentifier],
89
+ schema_name: Optional[sql_identifier.SqlIdentifier],
90
+ model_name: sql_identifier.SqlIdentifier,
91
+ version_name: sql_identifier.SqlIdentifier,
92
+ file_path: pathlib.PurePosixPath,
93
+ is_dir: bool = False,
94
+ statement_params: Optional[Dict[str, Any]] = None,
95
+ ) -> List[row.Row]:
96
+ # Workaround for snowURL bug.
97
+ trailing_slash = "/" if is_dir else ""
98
+
99
+ stage_location = (
100
+ pathlib.PurePosixPath(
101
+ self.fully_qualified_object_name(database_name, schema_name, model_name),
102
+ "versions",
103
+ version_name.resolved(),
104
+ file_path,
105
+ ).as_posix()
106
+ + trailing_slash
107
+ )
108
+ stage_location_url = ParseResult(
109
+ scheme="snow", netloc="model", path=stage_location, params="", query="", fragment=""
110
+ ).geturl()
111
+
112
+ return (
113
+ query_result_checker.SqlResultValidator(
114
+ self._session,
115
+ f"List {_normalize_url_for_sql(stage_location_url)}",
116
+ statement_params=statement_params,
117
+ )
118
+ .has_column("name", allow_empty=True)
119
+ .validate()
120
+ )
121
+
99
122
  def get_file(
100
123
  self,
101
124
  *,
125
+ database_name: Optional[sql_identifier.SqlIdentifier],
126
+ schema_name: Optional[sql_identifier.SqlIdentifier],
102
127
  model_name: sql_identifier.SqlIdentifier,
103
128
  version_name: sql_identifier.SqlIdentifier,
104
129
  file_path: pathlib.PurePosixPath,
@@ -106,7 +131,10 @@ class ModelVersionSQLClient:
106
131
  statement_params: Optional[Dict[str, Any]] = None,
107
132
  ) -> pathlib.Path:
108
133
  stage_location = pathlib.PurePosixPath(
109
- self.fully_qualified_model_name(model_name), "versions", version_name.resolved(), file_path
134
+ self.fully_qualified_object_name(database_name, schema_name, model_name),
135
+ "versions",
136
+ version_name.resolved(),
137
+ file_path,
110
138
  ).as_posix()
111
139
  stage_location_url = ParseResult(
112
140
  scheme="snow", netloc="model", path=stage_location, params="", query="", fragment=""
@@ -130,6 +158,8 @@ class ModelVersionSQLClient:
130
158
  def show_functions(
131
159
  self,
132
160
  *,
161
+ database_name: Optional[sql_identifier.SqlIdentifier],
162
+ schema_name: Optional[sql_identifier.SqlIdentifier],
133
163
  model_name: sql_identifier.SqlIdentifier,
134
164
  version_name: sql_identifier.SqlIdentifier,
135
165
  statement_params: Optional[Dict[str, Any]] = None,
@@ -137,7 +167,7 @@ class ModelVersionSQLClient:
137
167
  res = query_result_checker.SqlResultValidator(
138
168
  self._session,
139
169
  (
140
- f"SHOW FUNCTIONS IN MODEL {self.fully_qualified_model_name(model_name)}"
170
+ f"SHOW FUNCTIONS IN MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
141
171
  f" VERSION {version_name.identifier()}"
142
172
  ),
143
173
  statement_params=statement_params,
@@ -148,23 +178,27 @@ class ModelVersionSQLClient:
148
178
  def set_comment(
149
179
  self,
150
180
  *,
151
- comment: str,
181
+ database_name: Optional[sql_identifier.SqlIdentifier],
182
+ schema_name: Optional[sql_identifier.SqlIdentifier],
152
183
  model_name: sql_identifier.SqlIdentifier,
153
184
  version_name: sql_identifier.SqlIdentifier,
185
+ comment: str,
154
186
  statement_params: Optional[Dict[str, Any]] = None,
155
187
  ) -> None:
156
188
  query_result_checker.SqlResultValidator(
157
189
  self._session,
158
190
  (
159
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
191
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)} "
160
192
  f"MODIFY VERSION {version_name.identifier()} SET COMMENT=$${comment}$$"
161
193
  ),
162
194
  statement_params=statement_params,
163
195
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
164
196
 
165
- def invoke_method(
197
+ def invoke_function_method(
166
198
  self,
167
199
  *,
200
+ database_name: Optional[sql_identifier.SqlIdentifier],
201
+ schema_name: Optional[sql_identifier.SqlIdentifier],
168
202
  model_name: sql_identifier.SqlIdentifier,
169
203
  version_name: sql_identifier.SqlIdentifier,
170
204
  method_name: sql_identifier.SqlIdentifier,
@@ -178,10 +212,12 @@ class ModelVersionSQLClient:
178
212
  INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
179
213
  with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
180
214
  else:
215
+ actual_database_name = database_name or self._database_name
216
+ actual_schema_name = schema_name or self._schema_name
181
217
  tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
182
218
  INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
183
- self._database_name.identifier(),
184
- self._schema_name.identifier(),
219
+ actual_database_name.identifier(),
220
+ actual_schema_name.identifier(),
185
221
  tmp_table_name,
186
222
  )
187
223
  input_df.write.save_as_table( # type: ignore[call-overload]
@@ -196,7 +232,8 @@ class ModelVersionSQLClient:
196
232
  module_version_alias = "MODEL_VERSION_ALIAS"
197
233
  with_statements.append(
198
234
  f"{module_version_alias} AS "
199
- f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}"
235
+ f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
236
+ f" VERSION {version_name.identifier()}"
200
237
  )
201
238
 
202
239
  args_sql_list = []
@@ -232,10 +269,93 @@ class ModelVersionSQLClient:
232
269
 
233
270
  return output_df
234
271
 
272
+ def invoke_table_function_method(
273
+ self,
274
+ *,
275
+ database_name: Optional[sql_identifier.SqlIdentifier],
276
+ schema_name: Optional[sql_identifier.SqlIdentifier],
277
+ model_name: sql_identifier.SqlIdentifier,
278
+ version_name: sql_identifier.SqlIdentifier,
279
+ method_name: sql_identifier.SqlIdentifier,
280
+ input_df: dataframe.DataFrame,
281
+ input_args: List[sql_identifier.SqlIdentifier],
282
+ returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
283
+ partition_column: Optional[sql_identifier.SqlIdentifier],
284
+ statement_params: Optional[Dict[str, Any]] = None,
285
+ ) -> dataframe.DataFrame:
286
+ with_statements = []
287
+ if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
288
+ INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
289
+ with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
290
+ else:
291
+ actual_database_name = database_name or self._database_name
292
+ actual_schema_name = schema_name or self._schema_name
293
+ tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
294
+ INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
295
+ actual_database_name.identifier(),
296
+ actual_schema_name.identifier(),
297
+ tmp_table_name,
298
+ )
299
+ input_df.write.save_as_table( # type: ignore[call-overload]
300
+ table_name=INTERMEDIATE_TABLE_NAME,
301
+ mode="errorifexists",
302
+ table_type="temporary",
303
+ statement_params=statement_params,
304
+ )
305
+
306
+ module_version_alias = "MODEL_VERSION_ALIAS"
307
+ with_statements.append(
308
+ f"{module_version_alias} AS "
309
+ f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
310
+ f" VERSION {version_name.identifier()}"
311
+ )
312
+
313
+ partition_by = partition_column.identifier() if partition_column is not None else "1"
314
+
315
+ args_sql_list = []
316
+ for input_arg_value in input_args:
317
+ args_sql_list.append(input_arg_value)
318
+
319
+ args_sql = ", ".join(args_sql_list)
320
+
321
+ sql = textwrap.dedent(
322
+ f"""WITH {','.join(with_statements)}
323
+ SELECT *,
324
+ FROM {INTERMEDIATE_TABLE_NAME},
325
+ TABLE({module_version_alias}!{method_name.identifier()}({args_sql})
326
+ OVER (PARTITION BY {partition_by}))"""
327
+ )
328
+
329
+ output_df = self._session.sql(sql)
330
+
331
+ # Prepare the output
332
+ output_cols = []
333
+ output_names = []
334
+
335
+ for output_name, output_type, output_col_name in returns:
336
+ output_cols.append(F.col(output_name).astype(output_type))
337
+ output_names.append(output_col_name)
338
+
339
+ if partition_column is not None:
340
+ output_cols.append(F.col(partition_column.identifier()))
341
+ output_names.append(partition_column)
342
+
343
+ output_df = output_df.with_columns(
344
+ col_names=output_names,
345
+ values=output_cols,
346
+ )
347
+
348
+ if statement_params:
349
+ output_df._statement_params = statement_params # type: ignore[assignment]
350
+
351
+ return output_df
352
+
235
353
  def set_metadata(
236
354
  self,
237
355
  metadata_dict: Dict[str, Any],
238
356
  *,
357
+ database_name: Optional[sql_identifier.SqlIdentifier],
358
+ schema_name: Optional[sql_identifier.SqlIdentifier],
239
359
  model_name: sql_identifier.SqlIdentifier,
240
360
  version_name: sql_identifier.SqlIdentifier,
241
361
  statement_params: Optional[Dict[str, Any]] = None,
@@ -244,8 +364,8 @@ class ModelVersionSQLClient:
244
364
  query_result_checker.SqlResultValidator(
245
365
  self._session,
246
366
  (
247
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} MODIFY VERSION {version_name.identifier()}"
248
- f" SET METADATA=$${json_metadata}$$"
367
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
368
+ f" MODIFY VERSION {version_name.identifier()} SET METADATA=$${json_metadata}$$"
249
369
  ),
250
370
  statement_params=statement_params,
251
371
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -253,12 +373,17 @@ class ModelVersionSQLClient:
253
373
  def drop_version(
254
374
  self,
255
375
  *,
376
+ database_name: Optional[sql_identifier.SqlIdentifier],
377
+ schema_name: Optional[sql_identifier.SqlIdentifier],
256
378
  model_name: sql_identifier.SqlIdentifier,
257
379
  version_name: sql_identifier.SqlIdentifier,
258
380
  statement_params: Optional[Dict[str, Any]] = None,
259
381
  ) -> None:
260
382
  query_result_checker.SqlResultValidator(
261
383
  self._session,
262
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} DROP VERSION {version_name.identifier()}",
384
+ (
385
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
386
+ f" DROP VERSION {version_name.identifier()}"
387
+ ),
263
388
  statement_params=statement_params,
264
389
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -1,46 +1,20 @@
1
1
  from typing import Any, Dict, Optional
2
2
 
3
- from snowflake.ml._internal.utils import (
4
- identifier,
5
- query_result_checker,
6
- sql_identifier,
7
- )
8
- from snowflake.snowpark import session
3
+ from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
+ from snowflake.ml.model._client.sql import _base
9
5
 
10
6
 
11
- class StageSQLClient:
12
- def __init__(
13
- self,
14
- session: session.Session,
15
- *,
16
- database_name: sql_identifier.SqlIdentifier,
17
- schema_name: sql_identifier.SqlIdentifier,
18
- ) -> None:
19
- self._session = session
20
- self._database_name = database_name
21
- self._schema_name = schema_name
22
-
23
- def __eq__(self, __value: object) -> bool:
24
- if not isinstance(__value, StageSQLClient):
25
- return False
26
- return self._database_name == __value._database_name and self._schema_name == __value._schema_name
27
-
28
- def fully_qualified_stage_name(
29
- self,
30
- stage_name: sql_identifier.SqlIdentifier,
31
- ) -> str:
32
- return identifier.get_schema_level_object_identifier(
33
- self._database_name.identifier(), self._schema_name.identifier(), stage_name.identifier()
34
- )
35
-
7
+ class StageSQLClient(_base._BaseSQLClient):
36
8
  def create_tmp_stage(
37
9
  self,
38
10
  *,
11
+ database_name: Optional[sql_identifier.SqlIdentifier],
12
+ schema_name: Optional[sql_identifier.SqlIdentifier],
39
13
  stage_name: sql_identifier.SqlIdentifier,
40
14
  statement_params: Optional[Dict[str, Any]] = None,
41
15
  ) -> None:
42
16
  query_result_checker.SqlResultValidator(
43
17
  self._session,
44
- f"CREATE TEMPORARY STAGE {self.fully_qualified_stage_name(stage_name)}",
18
+ f"CREATE TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
45
19
  statement_params=statement_params,
46
20
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()