snowflake-ml-python 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. snowflake/cortex/_complete.py +1 -1
  2. snowflake/cortex/_extract_answer.py +1 -1
  3. snowflake/cortex/_sentiment.py +1 -1
  4. snowflake/cortex/_summarize.py +1 -1
  5. snowflake/cortex/_translate.py +1 -1
  6. snowflake/ml/_internal/env_utils.py +68 -6
  7. snowflake/ml/_internal/file_utils.py +34 -4
  8. snowflake/ml/_internal/telemetry.py +79 -91
  9. snowflake/ml/_internal/utils/retryable_http.py +16 -4
  10. snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
  11. snowflake/ml/dataset/dataset.py +1 -1
  12. snowflake/ml/model/_api.py +21 -14
  13. snowflake/ml/model/_client/model/model_impl.py +176 -0
  14. snowflake/ml/model/_client/model/model_method_info.py +19 -0
  15. snowflake/ml/model/_client/model/model_version_impl.py +291 -0
  16. snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
  17. snowflake/ml/model/_client/ops/model_ops.py +308 -0
  18. snowflake/ml/model/_client/sql/model.py +75 -0
  19. snowflake/ml/model/_client/sql/model_version.py +213 -0
  20. snowflake/ml/model/_client/sql/stage.py +40 -0
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
  22. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
  23. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
  24. snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
  25. snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
  26. snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
  27. snowflake/ml/model/_model_composer/model_composer.py +31 -9
  28. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
  29. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
  30. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  31. snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
  32. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
  33. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
  34. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
  36. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  37. snowflake/ml/model/model_signature.py +108 -53
  38. snowflake/ml/model/type_hints.py +1 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
  40. snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
  41. snowflake/ml/modeling/_internal/model_specifications.py +146 -0
  42. snowflake/ml/modeling/_internal/model_trainer.py +13 -0
  43. snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
  44. snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
  45. snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
  46. snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
  47. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +96 -124
  48. snowflake/ml/modeling/cluster/affinity_propagation.py +94 -124
  49. snowflake/ml/modeling/cluster/agglomerative_clustering.py +94 -124
  50. snowflake/ml/modeling/cluster/birch.py +94 -124
  51. snowflake/ml/modeling/cluster/bisecting_k_means.py +94 -124
  52. snowflake/ml/modeling/cluster/dbscan.py +94 -124
  53. snowflake/ml/modeling/cluster/feature_agglomeration.py +94 -124
  54. snowflake/ml/modeling/cluster/k_means.py +93 -124
  55. snowflake/ml/modeling/cluster/mean_shift.py +94 -124
  56. snowflake/ml/modeling/cluster/mini_batch_k_means.py +93 -124
  57. snowflake/ml/modeling/cluster/optics.py +94 -124
  58. snowflake/ml/modeling/cluster/spectral_biclustering.py +94 -124
  59. snowflake/ml/modeling/cluster/spectral_clustering.py +94 -124
  60. snowflake/ml/modeling/cluster/spectral_coclustering.py +94 -124
  61. snowflake/ml/modeling/compose/column_transformer.py +94 -124
  62. snowflake/ml/modeling/compose/transformed_target_regressor.py +96 -124
  63. snowflake/ml/modeling/covariance/elliptic_envelope.py +94 -124
  64. snowflake/ml/modeling/covariance/empirical_covariance.py +80 -110
  65. snowflake/ml/modeling/covariance/graphical_lasso.py +94 -124
  66. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +94 -124
  67. snowflake/ml/modeling/covariance/ledoit_wolf.py +85 -115
  68. snowflake/ml/modeling/covariance/min_cov_det.py +94 -124
  69. snowflake/ml/modeling/covariance/oas.py +80 -110
  70. snowflake/ml/modeling/covariance/shrunk_covariance.py +84 -114
  71. snowflake/ml/modeling/decomposition/dictionary_learning.py +94 -124
  72. snowflake/ml/modeling/decomposition/factor_analysis.py +94 -124
  73. snowflake/ml/modeling/decomposition/fast_ica.py +94 -124
  74. snowflake/ml/modeling/decomposition/incremental_pca.py +94 -124
  75. snowflake/ml/modeling/decomposition/kernel_pca.py +94 -124
  76. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +94 -124
  77. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +94 -124
  78. snowflake/ml/modeling/decomposition/pca.py +94 -124
  79. snowflake/ml/modeling/decomposition/sparse_pca.py +94 -124
  80. snowflake/ml/modeling/decomposition/truncated_svd.py +94 -124
  81. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +96 -124
  82. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +91 -119
  83. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +96 -124
  84. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +96 -124
  85. snowflake/ml/modeling/ensemble/bagging_classifier.py +96 -124
  86. snowflake/ml/modeling/ensemble/bagging_regressor.py +96 -124
  87. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +96 -124
  88. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +96 -124
  89. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +96 -124
  90. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +96 -124
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +96 -124
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +96 -124
  93. snowflake/ml/modeling/ensemble/isolation_forest.py +94 -124
  94. snowflake/ml/modeling/ensemble/random_forest_classifier.py +96 -124
  95. snowflake/ml/modeling/ensemble/random_forest_regressor.py +96 -124
  96. snowflake/ml/modeling/ensemble/stacking_regressor.py +96 -124
  97. snowflake/ml/modeling/ensemble/voting_classifier.py +96 -124
  98. snowflake/ml/modeling/ensemble/voting_regressor.py +91 -119
  99. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +82 -110
  100. snowflake/ml/modeling/feature_selection/select_fdr.py +80 -108
  101. snowflake/ml/modeling/feature_selection/select_fpr.py +80 -108
  102. snowflake/ml/modeling/feature_selection/select_fwe.py +80 -108
  103. snowflake/ml/modeling/feature_selection/select_k_best.py +81 -109
  104. snowflake/ml/modeling/feature_selection/select_percentile.py +80 -108
  105. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +94 -124
  106. snowflake/ml/modeling/feature_selection/variance_threshold.py +76 -106
  107. snowflake/ml/modeling/framework/base.py +2 -2
  108. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +96 -124
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +96 -124
  110. snowflake/ml/modeling/impute/iterative_imputer.py +94 -124
  111. snowflake/ml/modeling/impute/knn_imputer.py +94 -124
  112. snowflake/ml/modeling/impute/missing_indicator.py +94 -124
  113. snowflake/ml/modeling/impute/simple_imputer.py +1 -1
  114. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +77 -107
  115. snowflake/ml/modeling/kernel_approximation/nystroem.py +94 -124
  116. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +94 -124
  117. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +86 -116
  118. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +84 -114
  119. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +96 -124
  120. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +71 -100
  121. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +71 -100
  122. snowflake/ml/modeling/linear_model/ard_regression.py +96 -124
  123. snowflake/ml/modeling/linear_model/bayesian_ridge.py +96 -124
  124. snowflake/ml/modeling/linear_model/elastic_net.py +96 -124
  125. snowflake/ml/modeling/linear_model/elastic_net_cv.py +96 -124
  126. snowflake/ml/modeling/linear_model/gamma_regressor.py +96 -124
  127. snowflake/ml/modeling/linear_model/huber_regressor.py +96 -124
  128. snowflake/ml/modeling/linear_model/lars.py +96 -124
  129. snowflake/ml/modeling/linear_model/lars_cv.py +96 -124
  130. snowflake/ml/modeling/linear_model/lasso.py +96 -124
  131. snowflake/ml/modeling/linear_model/lasso_cv.py +96 -124
  132. snowflake/ml/modeling/linear_model/lasso_lars.py +96 -124
  133. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +96 -124
  134. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +96 -124
  135. snowflake/ml/modeling/linear_model/linear_regression.py +91 -119
  136. snowflake/ml/modeling/linear_model/logistic_regression.py +96 -124
  137. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +96 -124
  138. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +96 -124
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +96 -124
  140. snowflake/ml/modeling/linear_model/multi_task_lasso.py +96 -124
  141. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +96 -124
  142. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +96 -124
  143. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +96 -124
  144. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +95 -124
  145. snowflake/ml/modeling/linear_model/perceptron.py +95 -124
  146. snowflake/ml/modeling/linear_model/poisson_regressor.py +96 -124
  147. snowflake/ml/modeling/linear_model/ransac_regressor.py +96 -124
  148. snowflake/ml/modeling/linear_model/ridge.py +96 -124
  149. snowflake/ml/modeling/linear_model/ridge_classifier.py +96 -124
  150. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +96 -124
  151. snowflake/ml/modeling/linear_model/ridge_cv.py +96 -124
  152. snowflake/ml/modeling/linear_model/sgd_classifier.py +96 -124
  153. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +94 -124
  154. snowflake/ml/modeling/linear_model/sgd_regressor.py +96 -124
  155. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +96 -124
  156. snowflake/ml/modeling/linear_model/tweedie_regressor.py +96 -124
  157. snowflake/ml/modeling/manifold/isomap.py +94 -124
  158. snowflake/ml/modeling/manifold/mds.py +94 -124
  159. snowflake/ml/modeling/manifold/spectral_embedding.py +94 -124
  160. snowflake/ml/modeling/manifold/tsne.py +94 -124
  161. snowflake/ml/modeling/metrics/classification.py +187 -52
  162. snowflake/ml/modeling/metrics/correlation.py +4 -2
  163. snowflake/ml/modeling/metrics/covariance.py +7 -4
  164. snowflake/ml/modeling/metrics/ranking.py +32 -16
  165. snowflake/ml/modeling/metrics/regression.py +60 -32
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +94 -124
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +94 -124
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +88 -138
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +90 -144
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +86 -114
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +93 -121
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +94 -122
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +92 -120
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +96 -124
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +92 -120
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -107
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +88 -116
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +96 -124
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +96 -124
  180. snowflake/ml/modeling/neighbors/kernel_density.py +94 -124
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +94 -124
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +89 -117
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +94 -124
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +96 -124
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +96 -124
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +96 -124
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +94 -124
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +96 -124
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +96 -124
  190. snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
  191. snowflake/ml/modeling/preprocessing/binarizer.py +14 -9
  192. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +0 -4
  193. snowflake/ml/modeling/preprocessing/label_encoder.py +21 -13
  194. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +20 -14
  195. snowflake/ml/modeling/preprocessing/min_max_scaler.py +35 -19
  196. snowflake/ml/modeling/preprocessing/normalizer.py +6 -9
  197. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +20 -13
  198. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +25 -13
  199. snowflake/ml/modeling/preprocessing/polynomial_features.py +94 -124
  200. snowflake/ml/modeling/preprocessing/robust_scaler.py +28 -14
  201. snowflake/ml/modeling/preprocessing/standard_scaler.py +25 -13
  202. snowflake/ml/modeling/semi_supervised/label_propagation.py +96 -124
  203. snowflake/ml/modeling/semi_supervised/label_spreading.py +96 -124
  204. snowflake/ml/modeling/svm/linear_svc.py +96 -124
  205. snowflake/ml/modeling/svm/linear_svr.py +96 -124
  206. snowflake/ml/modeling/svm/nu_svc.py +96 -124
  207. snowflake/ml/modeling/svm/nu_svr.py +96 -124
  208. snowflake/ml/modeling/svm/svc.py +96 -124
  209. snowflake/ml/modeling/svm/svr.py +96 -124
  210. snowflake/ml/modeling/tree/decision_tree_classifier.py +96 -124
  211. snowflake/ml/modeling/tree/decision_tree_regressor.py +96 -124
  212. snowflake/ml/modeling/tree/extra_tree_classifier.py +96 -124
  213. snowflake/ml/modeling/tree/extra_tree_regressor.py +96 -124
  214. snowflake/ml/modeling/xgboost/xgb_classifier.py +96 -125
  215. snowflake/ml/modeling/xgboost/xgb_regressor.py +96 -125
  216. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +96 -125
  217. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +96 -125
  218. snowflake/ml/registry/model_registry.py +2 -0
  219. snowflake/ml/registry/registry.py +215 -0
  220. snowflake/ml/version.py +1 -1
  221. {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +21 -3
  222. snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
  223. snowflake_ml_python-1.1.1.dist-info/RECORD +0 -331
  224. {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,308 @@
1
+ import pathlib
2
+ import tempfile
3
+ from typing import Any, Dict, List, Optional, Union, cast
4
+
5
+ import yaml
6
+
7
+ from snowflake.ml._internal.utils import sql_identifier
8
+ from snowflake.ml.model import model_signature, type_hints
9
+ from snowflake.ml.model._client.ops import metadata_ops
10
+ from snowflake.ml.model._client.sql import (
11
+ model as model_sql,
12
+ model_version as model_version_sql,
13
+ stage as stage_sql,
14
+ )
15
+ from snowflake.ml.model._model_composer import model_composer
16
+ from snowflake.ml.model._model_composer.model_manifest import (
17
+ model_manifest,
18
+ model_manifest_schema,
19
+ )
20
+ from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema
21
+ from snowflake.ml.model._signatures import snowpark_handler
22
+ from snowflake.snowpark import dataframe, session
23
+ from snowflake.snowpark._internal import utils as snowpark_utils
24
+
25
+
26
+ class ModelOperator:
27
+ def __init__(
28
+ self,
29
+ session: session.Session,
30
+ *,
31
+ database_name: sql_identifier.SqlIdentifier,
32
+ schema_name: sql_identifier.SqlIdentifier,
33
+ ) -> None:
34
+ # Ideally, we should only keep session object inside the client, however, some components other than client
35
+ # are requiring session object like ModelComposer and SnowparkDataFrameHandler. We currently cannot refractor
36
+ # them all but we should try to avoid use the _session object here unless no other choice.
37
+ self._session = session
38
+ self._stage_client = stage_sql.StageSQLClient(
39
+ session,
40
+ database_name=database_name,
41
+ schema_name=schema_name,
42
+ )
43
+ self._model_client = model_sql.ModelSQLClient(
44
+ session,
45
+ database_name=database_name,
46
+ schema_name=schema_name,
47
+ )
48
+ self._model_version_client = model_version_sql.ModelVersionSQLClient(
49
+ session,
50
+ database_name=database_name,
51
+ schema_name=schema_name,
52
+ )
53
+ self._metadata_ops = metadata_ops.MetadataOperator(
54
+ session,
55
+ database_name=database_name,
56
+ schema_name=schema_name,
57
+ )
58
+
59
+ def __eq__(self, __value: object) -> bool:
60
+ if not isinstance(__value, ModelOperator):
61
+ return False
62
+ return (
63
+ self._stage_client == __value._stage_client
64
+ and self._model_client == __value._model_client
65
+ and self._model_version_client == __value._model_version_client
66
+ )
67
+
68
+ def prepare_model_stage_path(self, *, statement_params: Optional[Dict[str, Any]] = None) -> str:
69
+ stage_name = sql_identifier.SqlIdentifier(
70
+ snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
71
+ )
72
+ self._stage_client.create_tmp_stage(stage_name=stage_name, statement_params=statement_params)
73
+ return f"@{self._stage_client.fully_qualified_stage_name(stage_name)}/model"
74
+
75
+ def create_from_stage(
76
+ self,
77
+ composed_model: model_composer.ModelComposer,
78
+ *,
79
+ model_name: sql_identifier.SqlIdentifier,
80
+ version_name: sql_identifier.SqlIdentifier,
81
+ statement_params: Optional[Dict[str, Any]] = None,
82
+ ) -> None:
83
+ stage_path = str(composed_model.stage_path)
84
+ if self.validate_existence(
85
+ model_name=model_name,
86
+ statement_params=statement_params,
87
+ ):
88
+ if self.validate_existence(
89
+ model_name=model_name,
90
+ version_name=version_name,
91
+ statement_params=statement_params,
92
+ ):
93
+ raise ValueError(
94
+ f"Model {self._model_version_client.fully_qualified_model_name(model_name)} "
95
+ f"version {version_name} already existed."
96
+ )
97
+ else:
98
+ self._model_version_client.add_version_from_stage(
99
+ stage_path=stage_path,
100
+ model_name=model_name,
101
+ version_name=version_name,
102
+ statement_params=statement_params,
103
+ )
104
+ else:
105
+ self._model_version_client.create_from_stage(
106
+ stage_path=stage_path,
107
+ model_name=model_name,
108
+ version_name=version_name,
109
+ statement_params=statement_params,
110
+ )
111
+
112
+ def list_models_or_versions(
113
+ self,
114
+ *,
115
+ model_name: Optional[sql_identifier.SqlIdentifier] = None,
116
+ statement_params: Optional[Dict[str, Any]] = None,
117
+ ) -> List[sql_identifier.SqlIdentifier]:
118
+ if model_name:
119
+ res = self._model_client.show_versions(
120
+ model_name=model_name,
121
+ statement_params=statement_params,
122
+ )
123
+ else:
124
+ res = self._model_client.show_models(
125
+ statement_params=statement_params,
126
+ )
127
+ return [sql_identifier.SqlIdentifier(row.name, case_sensitive=True) for row in res]
128
+
129
+ def validate_existence(
130
+ self,
131
+ *,
132
+ model_name: sql_identifier.SqlIdentifier,
133
+ version_name: Optional[sql_identifier.SqlIdentifier] = None,
134
+ statement_params: Optional[Dict[str, Any]] = None,
135
+ ) -> bool:
136
+ if version_name:
137
+ res = self._model_client.show_versions(
138
+ model_name=model_name,
139
+ version_name=version_name,
140
+ statement_params=statement_params,
141
+ )
142
+ else:
143
+ res = self._model_client.show_models(
144
+ model_name=model_name,
145
+ statement_params=statement_params,
146
+ )
147
+ return len(res) == 1
148
+
149
+ def get_comment(
150
+ self,
151
+ *,
152
+ model_name: sql_identifier.SqlIdentifier,
153
+ version_name: Optional[sql_identifier.SqlIdentifier] = None,
154
+ statement_params: Optional[Dict[str, Any]] = None,
155
+ ) -> str:
156
+ if version_name:
157
+ res = self._model_client.show_versions(
158
+ model_name=model_name,
159
+ version_name=version_name,
160
+ statement_params=statement_params,
161
+ )
162
+ else:
163
+ res = self._model_client.show_models(
164
+ model_name=model_name,
165
+ statement_params=statement_params,
166
+ )
167
+ assert len(res) == 1
168
+ return cast(str, res[0].comment)
169
+
170
+ def set_comment(
171
+ self,
172
+ *,
173
+ comment: str,
174
+ model_name: sql_identifier.SqlIdentifier,
175
+ version_name: Optional[sql_identifier.SqlIdentifier] = None,
176
+ statement_params: Optional[Dict[str, Any]] = None,
177
+ ) -> None:
178
+ if version_name:
179
+ self._model_version_client.set_comment(
180
+ comment=comment,
181
+ model_name=model_name,
182
+ version_name=version_name,
183
+ statement_params=statement_params,
184
+ )
185
+ else:
186
+ self._model_client.set_comment(
187
+ comment=comment,
188
+ model_name=model_name,
189
+ statement_params=statement_params,
190
+ )
191
+
192
+ def get_model_version_manifest(
193
+ self,
194
+ *,
195
+ model_name: sql_identifier.SqlIdentifier,
196
+ version_name: sql_identifier.SqlIdentifier,
197
+ statement_params: Optional[Dict[str, Any]] = None,
198
+ ) -> model_manifest_schema.ModelManifestDict:
199
+ with tempfile.TemporaryDirectory() as tmpdir:
200
+ self._model_version_client.get_file(
201
+ model_name=model_name,
202
+ version_name=version_name,
203
+ file_path=pathlib.PurePosixPath(model_manifest.ModelManifest.MANIFEST_FILE_REL_PATH),
204
+ target_path=pathlib.Path(tmpdir),
205
+ statement_params=statement_params,
206
+ )
207
+ mm = model_manifest.ModelManifest(pathlib.Path(tmpdir))
208
+ return mm.load()
209
+
210
+ def get_model_version_native_packing_meta(
211
+ self,
212
+ *,
213
+ model_name: sql_identifier.SqlIdentifier,
214
+ version_name: sql_identifier.SqlIdentifier,
215
+ statement_params: Optional[Dict[str, Any]] = None,
216
+ ) -> model_meta_schema.ModelMetadataDict:
217
+ with tempfile.TemporaryDirectory() as tmpdir:
218
+ model_meta_file_path = self._model_version_client.get_file(
219
+ model_name=model_name,
220
+ version_name=version_name,
221
+ file_path=pathlib.PurePosixPath(
222
+ model_composer.ModelComposer.MODEL_DIR_REL_PATH, model_meta.MODEL_METADATA_FILE
223
+ ),
224
+ target_path=pathlib.Path(tmpdir),
225
+ statement_params=statement_params,
226
+ )
227
+ with open(model_meta_file_path, encoding="utf-8") as f:
228
+ raw_model_meta = yaml.safe_load(f)
229
+ return model_meta.ModelMetadata._validate_model_metadata(raw_model_meta)
230
+
231
+ def invoke_method(
232
+ self,
233
+ *,
234
+ method_name: sql_identifier.SqlIdentifier,
235
+ signature: model_signature.ModelSignature,
236
+ X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
237
+ model_name: sql_identifier.SqlIdentifier,
238
+ version_name: sql_identifier.SqlIdentifier,
239
+ statement_params: Optional[Dict[str, str]] = None,
240
+ ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
241
+ identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
242
+
243
+ # Validate and prepare input
244
+ if not isinstance(X, dataframe.DataFrame):
245
+ keep_order = True
246
+ output_with_input_features = False
247
+ df = model_signature._convert_and_validate_local_data(X, signature.inputs)
248
+ s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(self._session, df, keep_order=keep_order)
249
+ else:
250
+ keep_order = False
251
+ output_with_input_features = True
252
+ identifier_rule = model_signature._validate_snowpark_data(X, signature.inputs)
253
+ s_df = X
254
+
255
+ original_cols = s_df.columns
256
+
257
+ # Compose input and output names
258
+ input_args = []
259
+ for input_feature in signature.inputs:
260
+ col_name = identifier_rule.get_sql_identifier_from_feature(input_feature.name)
261
+
262
+ input_args.append(col_name)
263
+
264
+ returns = []
265
+ for output_feature in signature.outputs:
266
+ output_name = identifier_rule.get_sql_identifier_from_feature(output_feature.name)
267
+ returns.append((output_feature.name, output_feature.as_snowpark_type(), output_name))
268
+ # Avoid removing output cols when output_with_input_features is False
269
+ if output_name in original_cols:
270
+ original_cols.remove(output_name)
271
+
272
+ df_res = self._model_version_client.invoke_method(
273
+ method_name=method_name,
274
+ input_df=s_df,
275
+ input_args=input_args,
276
+ returns=returns,
277
+ model_name=model_name,
278
+ version_name=version_name,
279
+ statement_params=statement_params,
280
+ )
281
+
282
+ if keep_order:
283
+ df_res = df_res.sort(
284
+ "_ID",
285
+ ascending=True,
286
+ )
287
+
288
+ if not output_with_input_features:
289
+ df_res = df_res.drop(*original_cols)
290
+
291
+ # Get final result
292
+ if not isinstance(X, dataframe.DataFrame):
293
+ return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(df_res, features=signature.outputs)
294
+ else:
295
+ return df_res
296
+
297
+ def delete_model_or_version(
298
+ self,
299
+ *,
300
+ model_name: sql_identifier.SqlIdentifier,
301
+ version_name: Optional[sql_identifier.SqlIdentifier] = None,
302
+ statement_params: Optional[Dict[str, Any]] = None,
303
+ ) -> None:
304
+ # TODO: Delete version is not supported yet.
305
+ self._model_client.drop_model(
306
+ model_name=model_name,
307
+ statement_params=statement_params,
308
+ )
@@ -0,0 +1,75 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ from snowflake.ml._internal.utils import identifier, sql_identifier
4
+ from snowflake.snowpark import row, session
5
+
6
+
7
+ class ModelSQLClient:
8
+ def __init__(
9
+ self,
10
+ session: session.Session,
11
+ *,
12
+ database_name: sql_identifier.SqlIdentifier,
13
+ schema_name: sql_identifier.SqlIdentifier,
14
+ ) -> None:
15
+ self._session = session
16
+ self._database_name = database_name
17
+ self._schema_name = schema_name
18
+
19
+ def __eq__(self, __value: object) -> bool:
20
+ if not isinstance(__value, ModelSQLClient):
21
+ return False
22
+ return self._database_name == __value._database_name and self._schema_name == __value._schema_name
23
+
24
+ def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
25
+ return identifier.get_schema_level_object_identifier(
26
+ self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
27
+ )
28
+
29
+ def show_models(
30
+ self,
31
+ *,
32
+ model_name: Optional[sql_identifier.SqlIdentifier] = None,
33
+ statement_params: Optional[Dict[str, Any]] = None,
34
+ ) -> List[row.Row]:
35
+ fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
36
+ like_sql = ""
37
+ if model_name:
38
+ like_sql = f" LIKE '{model_name.resolved()}'"
39
+ res = self._session.sql(f"SHOW MODELS{like_sql} IN SCHEMA {fully_qualified_schema_name}")
40
+
41
+ return res.collect(statement_params=statement_params)
42
+
43
+ def show_versions(
44
+ self,
45
+ *,
46
+ model_name: sql_identifier.SqlIdentifier,
47
+ version_name: Optional[sql_identifier.SqlIdentifier] = None,
48
+ statement_params: Optional[Dict[str, Any]] = None,
49
+ ) -> List[row.Row]:
50
+ like_sql = ""
51
+ if version_name:
52
+ like_sql = f" LIKE '{version_name.resolved()}'"
53
+ res = self._session.sql(f"SHOW VERSIONS{like_sql} IN MODEL {self.fully_qualified_model_name(model_name)}")
54
+
55
+ return res.collect(statement_params=statement_params)
56
+
57
+ def set_comment(
58
+ self,
59
+ *,
60
+ comment: str,
61
+ model_name: sql_identifier.SqlIdentifier,
62
+ statement_params: Optional[Dict[str, Any]] = None,
63
+ ) -> None:
64
+ comment_sql = f"COMMENT ON MODEL {self.fully_qualified_model_name(model_name)} IS $${comment}$$"
65
+ self._session.sql(comment_sql).collect(statement_params=statement_params)
66
+
67
+ def drop_model(
68
+ self,
69
+ *,
70
+ model_name: sql_identifier.SqlIdentifier,
71
+ statement_params: Optional[Dict[str, Any]] = None,
72
+ ) -> None:
73
+ self._session.sql(f"DROP MODEL {self.fully_qualified_model_name(model_name)}").collect(
74
+ statement_params=statement_params
75
+ )
@@ -0,0 +1,213 @@
1
+ import json
2
+ import pathlib
3
+ import textwrap
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+ from urllib.parse import ParseResult
6
+
7
+ from snowflake.ml._internal.utils import identifier, sql_identifier
8
+ from snowflake.snowpark import dataframe, functions as F, session, types as spt
9
+ from snowflake.snowpark._internal import utils as snowpark_utils
10
+
11
+
12
+ def _normalize_url_for_sql(url: str) -> str:
13
+ if url.startswith("'") and url.endswith("'"):
14
+ url = url[1:-1]
15
+ url = url.replace("'", "\\'")
16
+ return f"'{url}'"
17
+
18
+
19
+ class ModelVersionSQLClient:
20
+ def __init__(
21
+ self,
22
+ session: session.Session,
23
+ *,
24
+ database_name: sql_identifier.SqlIdentifier,
25
+ schema_name: sql_identifier.SqlIdentifier,
26
+ ) -> None:
27
+ self._session = session
28
+ self._database_name = database_name
29
+ self._schema_name = schema_name
30
+
31
+ def __eq__(self, __value: object) -> bool:
32
+ if not isinstance(__value, ModelVersionSQLClient):
33
+ return False
34
+ return self._database_name == __value._database_name and self._schema_name == __value._schema_name
35
+
36
+ def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
37
+ return identifier.get_schema_level_object_identifier(
38
+ self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
39
+ )
40
+
41
+ def create_from_stage(
42
+ self,
43
+ *,
44
+ model_name: sql_identifier.SqlIdentifier,
45
+ version_name: sql_identifier.SqlIdentifier,
46
+ stage_path: str,
47
+ statement_params: Optional[Dict[str, Any]] = None,
48
+ ) -> None:
49
+ self._version_name = version_name
50
+ self._session.sql(
51
+ f"CREATE MODEL {self.fully_qualified_model_name(model_name)} WITH VERSION {version_name.identifier()}"
52
+ f" FROM {stage_path}"
53
+ ).collect(statement_params=statement_params)
54
+
55
+ # TODO(SNOW-987381): Merge with above when we have `create or alter module m [with] version v1 ...`
56
+ def add_version_from_stage(
57
+ self,
58
+ *,
59
+ model_name: sql_identifier.SqlIdentifier,
60
+ version_name: sql_identifier.SqlIdentifier,
61
+ stage_path: str,
62
+ statement_params: Optional[Dict[str, Any]] = None,
63
+ ) -> None:
64
+ self._version_name = version_name
65
+ self._session.sql(
66
+ f"ALTER MODEL {self.fully_qualified_model_name(model_name)} ADD VERSION {version_name.identifier()}"
67
+ f" FROM {stage_path}"
68
+ ).collect(statement_params=statement_params)
69
+
70
+ def set_default_version(
71
+ self,
72
+ *,
73
+ model_name: sql_identifier.SqlIdentifier,
74
+ version_name: sql_identifier.SqlIdentifier,
75
+ statement_params: Optional[Dict[str, Any]] = None,
76
+ ) -> None:
77
+ self._session.sql(
78
+ f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
79
+ f"SET DEFAULT_VERSION = {version_name.identifier()}"
80
+ ).collect(statement_params=statement_params)
81
+
82
+ def get_default_version(
83
+ self,
84
+ *,
85
+ model_name: sql_identifier.SqlIdentifier,
86
+ statement_params: Optional[Dict[str, Any]] = None,
87
+ ) -> str:
88
+ # TODO: Replace SHOW with DESC when available.
89
+ default_version: str = (
90
+ self._session.sql(f"SHOW VERSIONS IN MODEL {self.fully_qualified_model_name(model_name)}")
91
+ .filter('"is_default_version" = TRUE')[['"name"']]
92
+ .collect(statement_params=statement_params)[0][0]
93
+ )
94
+ return default_version
95
+
96
+ def get_file(
97
+ self,
98
+ *,
99
+ model_name: sql_identifier.SqlIdentifier,
100
+ version_name: sql_identifier.SqlIdentifier,
101
+ file_path: pathlib.PurePosixPath,
102
+ target_path: pathlib.Path,
103
+ statement_params: Optional[Dict[str, Any]] = None,
104
+ ) -> pathlib.Path:
105
+ stage_location = pathlib.PurePosixPath(
106
+ self.fully_qualified_model_name(model_name), "versions", version_name.resolved(), file_path
107
+ ).as_posix()
108
+ stage_location_url = ParseResult(
109
+ scheme="snow", netloc="model", path=stage_location, params="", query="", fragment=""
110
+ ).geturl()
111
+ local_location = target_path.absolute().as_posix()
112
+ local_location_url = ParseResult(
113
+ scheme="file", netloc="", path=local_location, params="", query="", fragment=""
114
+ ).geturl()
115
+
116
+ self._session.sql(
117
+ f"GET {_normalize_url_for_sql(stage_location_url)} {_normalize_url_for_sql(local_location_url)}"
118
+ ).collect(statement_params=statement_params)
119
+ return target_path / file_path.name
120
+
121
+ def set_comment(
122
+ self,
123
+ *,
124
+ comment: str,
125
+ model_name: sql_identifier.SqlIdentifier,
126
+ version_name: sql_identifier.SqlIdentifier,
127
+ statement_params: Optional[Dict[str, Any]] = None,
128
+ ) -> None:
129
+ comment_sql = (
130
+ f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
131
+ f"MODIFY VERSION {version_name.identifier()} SET COMMENT=$${comment}$$"
132
+ )
133
+ self._session.sql(comment_sql).collect(statement_params=statement_params)
134
+
135
+ def invoke_method(
136
+ self,
137
+ *,
138
+ model_name: sql_identifier.SqlIdentifier,
139
+ version_name: sql_identifier.SqlIdentifier,
140
+ method_name: sql_identifier.SqlIdentifier,
141
+ input_df: dataframe.DataFrame,
142
+ input_args: List[sql_identifier.SqlIdentifier],
143
+ returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
144
+ statement_params: Optional[Dict[str, Any]] = None,
145
+ ) -> dataframe.DataFrame:
146
+ tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
147
+ INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
148
+ self._database_name.identifier(),
149
+ self._schema_name.identifier(),
150
+ tmp_table_name,
151
+ )
152
+ input_df.write.save_as_table( # type: ignore[call-overload]
153
+ table_name=INTERMEDIATE_TABLE_NAME,
154
+ mode="errorifexists",
155
+ table_type="temporary",
156
+ statement_params=statement_params,
157
+ )
158
+
159
+ INTERMEDIATE_OBJ_NAME = "TMP_RESULT"
160
+
161
+ module_version_alias = "MODEL_VERSION_ALIAS"
162
+ model_version_alias_sql = (
163
+ f"WITH {module_version_alias} AS "
164
+ f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}"
165
+ )
166
+
167
+ args_sql_list = []
168
+ for input_arg_value in input_args:
169
+ args_sql_list.append(input_arg_value)
170
+
171
+ args_sql = ", ".join(args_sql_list)
172
+
173
+ sql = textwrap.dedent(
174
+ f"""{model_version_alias_sql}
175
+ SELECT *,
176
+ {module_version_alias}!{method_name.identifier()}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
177
+ FROM {INTERMEDIATE_TABLE_NAME}"""
178
+ )
179
+
180
+ output_df = self._session.sql(sql)
181
+
182
+ # Prepare the output
183
+ output_cols = []
184
+ output_names = []
185
+
186
+ for output_name, output_type, output_col_name in returns:
187
+ output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_name].astype(output_type))
188
+ output_names.append(output_col_name)
189
+
190
+ output_df = output_df.with_columns(
191
+ col_names=output_names,
192
+ values=output_cols,
193
+ ).drop(INTERMEDIATE_OBJ_NAME)
194
+
195
+ if statement_params:
196
+ output_df._statement_params = statement_params # type: ignore[assignment]
197
+
198
+ return output_df
199
+
200
+ def set_metadata(
201
+ self,
202
+ metadata_dict: Dict[str, Any],
203
+ *,
204
+ model_name: sql_identifier.SqlIdentifier,
205
+ version_name: sql_identifier.SqlIdentifier,
206
+ statement_params: Optional[Dict[str, Any]] = None,
207
+ ) -> None:
208
+ json_metadata = json.dumps(metadata_dict)
209
+ sql = (
210
+ f"ALTER MODEL {self.fully_qualified_model_name(model_name)} MODIFY VERSION {version_name.identifier()}"
211
+ f" SET METADATA=$${json_metadata}$$"
212
+ )
213
+ self._session.sql(sql).collect(statement_params=statement_params)
@@ -0,0 +1,40 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ from snowflake.ml._internal.utils import identifier, sql_identifier
4
+ from snowflake.snowpark import session
5
+
6
+
7
+ class StageSQLClient:
8
+ def __init__(
9
+ self,
10
+ session: session.Session,
11
+ *,
12
+ database_name: sql_identifier.SqlIdentifier,
13
+ schema_name: sql_identifier.SqlIdentifier,
14
+ ) -> None:
15
+ self._session = session
16
+ self._database_name = database_name
17
+ self._schema_name = schema_name
18
+
19
+ def __eq__(self, __value: object) -> bool:
20
+ if not isinstance(__value, StageSQLClient):
21
+ return False
22
+ return self._database_name == __value._database_name and self._schema_name == __value._schema_name
23
+
24
+ def fully_qualified_stage_name(
25
+ self,
26
+ stage_name: sql_identifier.SqlIdentifier,
27
+ ) -> str:
28
+ return identifier.get_schema_level_object_identifier(
29
+ self._database_name.identifier(), self._schema_name.identifier(), stage_name.identifier()
30
+ )
31
+
32
+ def create_tmp_stage(
33
+ self,
34
+ *,
35
+ stage_name: sql_identifier.SqlIdentifier,
36
+ statement_params: Optional[Dict[str, Any]] = None,
37
+ ) -> None:
38
+ self._session.sql(f"CREATE TEMPORARY STAGE {self.fully_qualified_stage_name(stage_name)}").collect(
39
+ statement_params=statement_params
40
+ )
@@ -4,7 +4,6 @@ import posixpath
4
4
  from string import Template
5
5
 
6
6
  import importlib_resources
7
- import yaml
8
7
 
9
8
  from snowflake import snowpark
10
9
  from snowflake.ml._internal import file_utils
@@ -180,7 +179,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
180
179
  assert self.artifact_stage_location.startswith("@")
181
180
  normed_artifact_stage_path = posixpath.normpath(identifier.remove_prefix(self.artifact_stage_location, "@"))
182
181
  (db, schema, stage, path) = identifier.parse_schema_level_object_identifier(normed_artifact_stage_path)
183
- content = Template(spec_template).substitute(
182
+ content = Template(spec_template).safe_substitute(
184
183
  {
185
184
  "base_image": base_image,
186
185
  "container_name": constants.KANIKO_CONTAINER_NAME,
@@ -188,10 +187,10 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
188
187
  # Remove @ in the beginning, append "/" to denote root directory.
189
188
  "script_path": "/"
190
189
  + posixpath.normpath(identifier.remove_prefix(kaniko_shell_script_stage_location, "@")),
190
+ "mounted_token_path": constants.SPCS_MOUNTED_TOKEN_PATH,
191
191
  }
192
192
  )
193
- content_dict = yaml.safe_load(content)
194
- yaml.dump(content_dict, spec_file)
193
+ spec_file.write(content)
195
194
  spec_file.seek(0)
196
195
  logger.debug(f"Kaniko job spec file: \n\n {spec_file.read()}")
197
196