snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -4,7 +4,11 @@ import textwrap
4
4
  from typing import Any, Dict, List, Optional, Tuple
5
5
  from urllib.parse import ParseResult
6
6
 
7
- from snowflake.ml._internal.utils import identifier, sql_identifier
7
+ from snowflake.ml._internal.utils import (
8
+ identifier,
9
+ query_result_checker,
10
+ sql_identifier,
11
+ )
8
12
  from snowflake.snowpark import dataframe, functions as F, session, types as spt
9
13
  from snowflake.snowpark._internal import utils as snowpark_utils
10
14
 
@@ -46,11 +50,14 @@ class ModelVersionSQLClient:
46
50
  stage_path: str,
47
51
  statement_params: Optional[Dict[str, Any]] = None,
48
52
  ) -> None:
49
- self._version_name = version_name
50
- self._session.sql(
51
- f"CREATE MODEL {self.fully_qualified_model_name(model_name)} WITH VERSION {version_name.identifier()}"
52
- f" FROM {stage_path}"
53
- ).collect(statement_params=statement_params)
53
+ query_result_checker.SqlResultValidator(
54
+ self._session,
55
+ (
56
+ f"CREATE MODEL {self.fully_qualified_model_name(model_name)} WITH VERSION {version_name.identifier()}"
57
+ f" FROM {stage_path}"
58
+ ),
59
+ statement_params=statement_params,
60
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
54
61
 
55
62
  # TODO(SNOW-987381): Merge with above when we have `create or alter module m [with] version v1 ...`
56
63
  def add_version_from_stage(
@@ -61,11 +68,14 @@ class ModelVersionSQLClient:
61
68
  stage_path: str,
62
69
  statement_params: Optional[Dict[str, Any]] = None,
63
70
  ) -> None:
64
- self._version_name = version_name
65
- self._session.sql(
66
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} ADD VERSION {version_name.identifier()}"
67
- f" FROM {stage_path}"
68
- ).collect(statement_params=statement_params)
71
+ query_result_checker.SqlResultValidator(
72
+ self._session,
73
+ (
74
+ f"ALTER MODEL {self.fully_qualified_model_name(model_name)} ADD VERSION {version_name.identifier()}"
75
+ f" FROM {stage_path}"
76
+ ),
77
+ statement_params=statement_params,
78
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
69
79
 
70
80
  def set_default_version(
71
81
  self,
@@ -74,24 +84,14 @@ class ModelVersionSQLClient:
74
84
  version_name: sql_identifier.SqlIdentifier,
75
85
  statement_params: Optional[Dict[str, Any]] = None,
76
86
  ) -> None:
77
- self._session.sql(
78
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
79
- f"SET DEFAULT_VERSION = {version_name.identifier()}"
80
- ).collect(statement_params=statement_params)
81
-
82
- def get_default_version(
83
- self,
84
- *,
85
- model_name: sql_identifier.SqlIdentifier,
86
- statement_params: Optional[Dict[str, Any]] = None,
87
- ) -> str:
88
- # TODO: Replace SHOW with DESC when available.
89
- default_version: str = (
90
- self._session.sql(f"SHOW VERSIONS IN MODEL {self.fully_qualified_model_name(model_name)}")
91
- .filter('"is_default_version" = TRUE')[['"name"']]
92
- .collect(statement_params=statement_params)[0][0]
93
- )
94
- return default_version
87
+ query_result_checker.SqlResultValidator(
88
+ self._session,
89
+ (
90
+ f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
91
+ f"SET DEFAULT_VERSION = {version_name.identifier()}"
92
+ ),
93
+ statement_params=statement_params,
94
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
95
95
 
96
96
  def get_file(
97
97
  self,
@@ -108,14 +108,14 @@ class ModelVersionSQLClient:
108
108
  stage_location_url = ParseResult(
109
109
  scheme="snow", netloc="model", path=stage_location, params="", query="", fragment=""
110
110
  ).geturl()
111
- local_location = target_path.absolute().as_posix()
112
- local_location_url = ParseResult(
113
- scheme="file", netloc="", path=local_location, params="", query="", fragment=""
114
- ).geturl()
111
+ local_location = target_path.resolve().as_posix()
112
+ local_location_url = f"file://{local_location}"
115
113
 
116
- self._session.sql(
117
- f"GET {_normalize_url_for_sql(stage_location_url)} {_normalize_url_for_sql(local_location_url)}"
118
- ).collect(statement_params=statement_params)
114
+ query_result_checker.SqlResultValidator(
115
+ self._session,
116
+ f"GET {_normalize_url_for_sql(stage_location_url)} {_normalize_url_for_sql(local_location_url)}",
117
+ statement_params=statement_params,
118
+ ).has_dimensions(expected_rows=1).validate()
119
119
  return target_path / file_path.name
120
120
 
121
121
  def set_comment(
@@ -126,11 +126,14 @@ class ModelVersionSQLClient:
126
126
  version_name: sql_identifier.SqlIdentifier,
127
127
  statement_params: Optional[Dict[str, Any]] = None,
128
128
  ) -> None:
129
- comment_sql = (
130
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
131
- f"MODIFY VERSION {version_name.identifier()} SET COMMENT=$${comment}$$"
132
- )
133
- self._session.sql(comment_sql).collect(statement_params=statement_params)
129
+ query_result_checker.SqlResultValidator(
130
+ self._session,
131
+ (
132
+ f"ALTER MODEL {self.fully_qualified_model_name(model_name)} "
133
+ f"MODIFY VERSION {version_name.identifier()} SET COMMENT=$${comment}$$"
134
+ ),
135
+ statement_params=statement_params,
136
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
134
137
 
135
138
  def invoke_method(
136
139
  self,
@@ -143,24 +146,29 @@ class ModelVersionSQLClient:
143
146
  returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
144
147
  statement_params: Optional[Dict[str, Any]] = None,
145
148
  ) -> dataframe.DataFrame:
146
- tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
147
- INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
148
- self._database_name.identifier(),
149
- self._schema_name.identifier(),
150
- tmp_table_name,
151
- )
152
- input_df.write.save_as_table( # type: ignore[call-overload]
153
- table_name=INTERMEDIATE_TABLE_NAME,
154
- mode="errorifexists",
155
- table_type="temporary",
156
- statement_params=statement_params,
157
- )
149
+ with_statements = []
150
+ if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
151
+ INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
152
+ with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
153
+ else:
154
+ tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
155
+ INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
156
+ self._database_name.identifier(),
157
+ self._schema_name.identifier(),
158
+ tmp_table_name,
159
+ )
160
+ input_df.write.save_as_table( # type: ignore[call-overload]
161
+ table_name=INTERMEDIATE_TABLE_NAME,
162
+ mode="errorifexists",
163
+ table_type="temporary",
164
+ statement_params=statement_params,
165
+ )
158
166
 
159
167
  INTERMEDIATE_OBJ_NAME = "TMP_RESULT"
160
168
 
161
169
  module_version_alias = "MODEL_VERSION_ALIAS"
162
- model_version_alias_sql = (
163
- f"WITH {module_version_alias} AS "
170
+ with_statements.append(
171
+ f"{module_version_alias} AS "
164
172
  f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}"
165
173
  )
166
174
 
@@ -171,7 +179,7 @@ class ModelVersionSQLClient:
171
179
  args_sql = ", ".join(args_sql_list)
172
180
 
173
181
  sql = textwrap.dedent(
174
- f"""{model_version_alias_sql}
182
+ f"""WITH {','.join(with_statements)}
175
183
  SELECT *,
176
184
  {module_version_alias}!{method_name.identifier()}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
177
185
  FROM {INTERMEDIATE_TABLE_NAME}"""
@@ -206,8 +214,11 @@ class ModelVersionSQLClient:
206
214
  statement_params: Optional[Dict[str, Any]] = None,
207
215
  ) -> None:
208
216
  json_metadata = json.dumps(metadata_dict)
209
- sql = (
210
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} MODIFY VERSION {version_name.identifier()}"
211
- f" SET METADATA=$${json_metadata}$$"
212
- )
213
- self._session.sql(sql).collect(statement_params=statement_params)
217
+ query_result_checker.SqlResultValidator(
218
+ self._session,
219
+ (
220
+ f"ALTER MODEL {self.fully_qualified_model_name(model_name)} MODIFY VERSION {version_name.identifier()}"
221
+ f" SET METADATA=$${json_metadata}$$"
222
+ ),
223
+ statement_params=statement_params,
224
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -1,6 +1,10 @@
1
1
  from typing import Any, Dict, Optional
2
2
 
3
- from snowflake.ml._internal.utils import identifier, sql_identifier
3
+ from snowflake.ml._internal.utils import (
4
+ identifier,
5
+ query_result_checker,
6
+ sql_identifier,
7
+ )
4
8
  from snowflake.snowpark import session
5
9
 
6
10
 
@@ -35,6 +39,8 @@ class StageSQLClient:
35
39
  stage_name: sql_identifier.SqlIdentifier,
36
40
  statement_params: Optional[Dict[str, Any]] = None,
37
41
  ) -> None:
38
- self._session.sql(f"CREATE TEMPORARY STAGE {self.fully_qualified_stage_name(stage_name)}").collect(
39
- statement_params=statement_params
40
- )
42
+ query_result_checker.SqlResultValidator(
43
+ self._session,
44
+ f"CREATE TEMPORARY STAGE {self.fully_qualified_stage_name(stage_name)}",
45
+ statement_params=statement_params,
46
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -0,0 +1,118 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ from snowflake.ml._internal.utils import (
4
+ identifier,
5
+ query_result_checker,
6
+ sql_identifier,
7
+ )
8
+ from snowflake.snowpark import row, session
9
+
10
+
11
+ class ModuleTagSQLClient:
12
+ def __init__(
13
+ self,
14
+ session: session.Session,
15
+ *,
16
+ database_name: sql_identifier.SqlIdentifier,
17
+ schema_name: sql_identifier.SqlIdentifier,
18
+ ) -> None:
19
+ self._session = session
20
+ self._database_name = database_name
21
+ self._schema_name = schema_name
22
+
23
+ def __eq__(self, __value: object) -> bool:
24
+ if not isinstance(__value, ModuleTagSQLClient):
25
+ return False
26
+ return self._database_name == __value._database_name and self._schema_name == __value._schema_name
27
+
28
+ def fully_qualified_module_name(
29
+ self,
30
+ module_name: sql_identifier.SqlIdentifier,
31
+ ) -> str:
32
+ return identifier.get_schema_level_object_identifier(
33
+ self._database_name.identifier(), self._schema_name.identifier(), module_name.identifier()
34
+ )
35
+
36
+ def set_tag_on_model(
37
+ self,
38
+ model_name: sql_identifier.SqlIdentifier,
39
+ *,
40
+ tag_database_name: sql_identifier.SqlIdentifier,
41
+ tag_schema_name: sql_identifier.SqlIdentifier,
42
+ tag_name: sql_identifier.SqlIdentifier,
43
+ tag_value: str,
44
+ statement_params: Optional[Dict[str, Any]] = None,
45
+ ) -> None:
46
+ fq_model_name = self.fully_qualified_module_name(model_name)
47
+ fq_tag_name = identifier.get_schema_level_object_identifier(
48
+ tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
49
+ )
50
+ query_result_checker.SqlResultValidator(
51
+ self._session,
52
+ f"ALTER MODEL {fq_model_name} SET TAG {fq_tag_name} = $${tag_value}$$",
53
+ statement_params=statement_params,
54
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
55
+
56
+ def unset_tag_on_model(
57
+ self,
58
+ model_name: sql_identifier.SqlIdentifier,
59
+ *,
60
+ tag_database_name: sql_identifier.SqlIdentifier,
61
+ tag_schema_name: sql_identifier.SqlIdentifier,
62
+ tag_name: sql_identifier.SqlIdentifier,
63
+ statement_params: Optional[Dict[str, Any]] = None,
64
+ ) -> None:
65
+ fq_model_name = self.fully_qualified_module_name(model_name)
66
+ fq_tag_name = identifier.get_schema_level_object_identifier(
67
+ tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
68
+ )
69
+ query_result_checker.SqlResultValidator(
70
+ self._session,
71
+ f"ALTER MODEL {fq_model_name} UNSET TAG {fq_tag_name}",
72
+ statement_params=statement_params,
73
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
74
+
75
+ def get_tag_value(
76
+ self,
77
+ module_name: sql_identifier.SqlIdentifier,
78
+ *,
79
+ tag_database_name: sql_identifier.SqlIdentifier,
80
+ tag_schema_name: sql_identifier.SqlIdentifier,
81
+ tag_name: sql_identifier.SqlIdentifier,
82
+ statement_params: Optional[Dict[str, Any]] = None,
83
+ ) -> row.Row:
84
+ fq_module_name = self.fully_qualified_module_name(module_name)
85
+ fq_tag_name = identifier.get_schema_level_object_identifier(
86
+ tag_database_name.identifier(), tag_schema_name.identifier(), tag_name.identifier()
87
+ )
88
+ return (
89
+ query_result_checker.SqlResultValidator(
90
+ self._session,
91
+ f"SELECT SYSTEM$GET_TAG($${fq_tag_name}$$, $${fq_module_name}$$, 'MODULE') AS TAG_VALUE",
92
+ statement_params=statement_params,
93
+ )
94
+ .has_dimensions(expected_rows=1, expected_cols=1)
95
+ .has_column("TAG_VALUE")
96
+ .validate()[0]
97
+ )
98
+
99
+ def get_tag_list(
100
+ self,
101
+ module_name: sql_identifier.SqlIdentifier,
102
+ *,
103
+ statement_params: Optional[Dict[str, Any]] = None,
104
+ ) -> List[row.Row]:
105
+ fq_module_name = self.fully_qualified_module_name(module_name)
106
+ return (
107
+ query_result_checker.SqlResultValidator(
108
+ self._session,
109
+ f"""SELECT TAG_DATABASE, TAG_SCHEMA, TAG_NAME, TAG_VALUE
110
+ FROM TABLE({self._database_name.identifier()}.INFORMATION_SCHEMA.TAG_REFERENCES($${fq_module_name}$$, 'MODULE'))""",
111
+ statement_params=statement_params,
112
+ )
113
+ .has_column("TAG_DATABASE", allow_empty=True)
114
+ .has_column("TAG_SCHEMA", allow_empty=True)
115
+ .has_column("TAG_NAME", allow_empty=True)
116
+ .has_column("TAG_VALUE", allow_empty=True)
117
+ .validate()
118
+ )
@@ -9,11 +9,11 @@ from enum import Enum
9
9
  from typing import List
10
10
 
11
11
  from snowflake import snowpark
12
+ from snowflake.ml._internal.container_services.image_registry import credential
12
13
  from snowflake.ml._internal.exceptions import (
13
14
  error_codes,
14
15
  exceptions as snowml_exceptions,
15
16
  )
16
- from snowflake.ml._internal.utils import spcs_image_registry
17
17
  from snowflake.ml.model._deploy_client.image_builds import base_image_builder
18
18
 
19
19
  logger = logging.getLogger(__name__)
@@ -106,7 +106,7 @@ class ClientImageBuilder(base_image_builder.ImageBuilder):
106
106
  self._run_docker_commands(commands)
107
107
 
108
108
  self.validate_docker_client_env()
109
- with spcs_image_registry.generate_image_registry_credential(
109
+ with credential.generate_image_registry_credential(
110
110
  self.session
111
111
  ) as registry_cred, tempfile.TemporaryDirectory() as docker_config_dir:
112
112
  try:
@@ -2,7 +2,6 @@ import os
2
2
  import posixpath
3
3
  import shutil
4
4
  import string
5
- from abc import ABC
6
5
  from typing import Optional
7
6
 
8
7
  import importlib_resources
@@ -15,7 +14,7 @@ from snowflake.ml.model._packager.model_meta import model_meta
15
14
  from snowflake.snowpark import FileOperation, Session
16
15
 
17
16
 
18
- class DockerContext(ABC):
17
+ class DockerContext:
19
18
  """
20
19
  Constructs the Docker context directory required for image building.
21
20
  """
@@ -53,12 +52,13 @@ class DockerContext(ABC):
53
52
 
54
53
  def _copy_entrypoint_script_to_docker_context(self) -> None:
55
54
  """Copy gunicorn_run.sh entrypoint to docker context directory."""
56
- with importlib_resources.as_file(
57
- importlib_resources.files(image_builds).joinpath( # type: ignore[no-untyped-call]
58
- constants.ENTRYPOINT_SCRIPT
59
- )
60
- ) as path:
61
- shutil.copy(path, os.path.join(self.context_dir, constants.ENTRYPOINT_SCRIPT))
55
+ script_path = importlib_resources.files(image_builds).joinpath( # type: ignore[no-untyped-call]
56
+ constants.ENTRYPOINT_SCRIPT
57
+ )
58
+ target_path = os.path.join(self.context_dir, constants.ENTRYPOINT_SCRIPT)
59
+
60
+ with open(script_path, encoding="utf-8") as source_file, file_utils.open_file(target_path, "w") as target_file:
61
+ target_file.write(source_file.read())
62
62
 
63
63
  def _copy_model_env_dependency_to_docker_context(self) -> None:
64
64
  """
@@ -105,6 +105,8 @@ def _run_setup() -> None:
105
105
 
106
106
  # TODO (Server-side Model Rollout):
107
107
  # Keep try block only
108
+ # SPCS spec will convert all environment variables as strings.
109
+ use_gpu = os.environ.get("SNOWML_USE_GPU", "False").lower() == "true"
108
110
  try:
109
111
  from snowflake.ml.model._packager import model_packager
110
112
 
@@ -112,9 +114,7 @@ def _run_setup() -> None:
112
114
  pk.load(
113
115
  as_custom_model=True,
114
116
  meta_only=False,
115
- options=model_types.ModelLoadOption(
116
- {"use_gpu": cast(bool, os.environ.get("SNOWML_USE_GPU", False))}
117
- ),
117
+ options=model_types.ModelLoadOption({"use_gpu": use_gpu}),
118
118
  )
119
119
  _LOADED_MODEL = pk.model
120
120
  _LOADED_META = pk.meta
@@ -132,9 +132,7 @@ def _run_setup() -> None:
132
132
  _LOADED_MODEL, meta_LOADED_META = model_api._load(
133
133
  local_dir_path=extracted_dir,
134
134
  as_custom_model=True,
135
- options=model_types.ModelLoadOption(
136
- {"use_gpu": cast(bool, os.environ.get("SNOWML_USE_GPU", False))}
137
- ),
135
+ options=model_types.ModelLoadOption({"use_gpu": use_gpu}),
138
136
  )
139
137
  _MODEL_LOADING_STATE = _ModelLoadingState.SUCCEEDED
140
138
  logger.info("Successfully loaded model into memory")
@@ -7,6 +7,9 @@ import importlib_resources
7
7
 
8
8
  from snowflake import snowpark
9
9
  from snowflake.ml._internal import file_utils
10
+ from snowflake.ml._internal.container_services.image_registry import (
11
+ registry_client as image_registry_client,
12
+ )
10
13
  from snowflake.ml._internal.exceptions import (
11
14
  error_codes,
12
15
  exceptions as snowml_exceptions,
@@ -14,11 +17,7 @@ from snowflake.ml._internal.exceptions import (
14
17
  from snowflake.ml._internal.utils import identifier
15
18
  from snowflake.ml.model._deploy_client import image_builds
16
19
  from snowflake.ml.model._deploy_client.image_builds import base_image_builder
17
- from snowflake.ml.model._deploy_client.utils import (
18
- constants,
19
- image_registry_client,
20
- snowservice_client,
21
- )
20
+ from snowflake.ml.model._deploy_client.utils import constants, snowservice_client
22
21
 
23
22
  logger = logging.getLogger(__name__)
24
23
 
@@ -117,7 +116,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
117
116
 
118
117
  kaniko_shell_file = os.path.join(self.context_dir, constants.KANIKO_SHELL_SCRIPT_NAME)
119
118
 
120
- with open(kaniko_shell_file, "w+", encoding="utf-8") as script_file:
119
+ with file_utils.open_file(kaniko_shell_file, "w+") as script_file:
121
120
  normed_artifact_stage_path = posixpath.normpath(identifier.remove_prefix(self.artifact_stage_location, "@"))
122
121
  params = {
123
122
  # Remove @ in the beginning, append "/" to denote root directory.
@@ -175,7 +174,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
175
174
  os.path.dirname(self.context_dir), f"{constants.IMAGE_BUILD_JOB_SPEC_TEMPLATE}.yaml"
176
175
  )
177
176
 
178
- with open(spec_file_path, "w+", encoding="utf-8") as spec_file:
177
+ with file_utils.open_file(spec_file_path, "w+") as spec_file:
179
178
  assert self.artifact_stage_location.startswith("@")
180
179
  normed_artifact_stage_path = posixpath.normpath(identifier.remove_prefix(self.artifact_stage_location, "@"))
181
180
  (db, schema, stage, path) = identifier.parse_schema_level_object_identifier(normed_artifact_stage_path)
@@ -14,6 +14,9 @@ from packaging import requirements
14
14
  from typing_extensions import Unpack
15
15
 
16
16
  from snowflake.ml._internal import env_utils, file_utils
17
+ from snowflake.ml._internal.container_services.image_registry import (
18
+ registry_client as image_registry_client,
19
+ )
17
20
  from snowflake.ml._internal.exceptions import (
18
21
  error_codes,
19
22
  exceptions as snowml_exceptions,
@@ -32,11 +35,7 @@ from snowflake.ml.model._deploy_client.image_builds import (
32
35
  server_image_builder,
33
36
  )
34
37
  from snowflake.ml.model._deploy_client.snowservice import deploy_options, instance_types
35
- from snowflake.ml.model._deploy_client.utils import (
36
- constants,
37
- image_registry_client,
38
- snowservice_client,
39
- )
38
+ from snowflake.ml.model._deploy_client.utils import constants, snowservice_client
40
39
  from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema
41
40
  from snowflake.snowpark import Session
42
41
 
@@ -1,2 +1,10 @@
1
1
  # Snowpark Container Service GPU instance type and corresponding GPU counts.
2
- INSTANCE_TYPE_TO_GPU_COUNT = {"GPU_3": 1, "GPU_5": 1, "GPU_7": 4, "GPU_10": 8}
2
+ INSTANCE_TYPE_TO_GPU_COUNT = {
3
+ "GPU_3": 1,
4
+ "GPU_5": 1,
5
+ "GPU_7": 4,
6
+ "GPU_10": 8,
7
+ "GPU_NV_S": 1,
8
+ "GPU_NV_M": 4,
9
+ "GPU_NV_L": 8,
10
+ }
@@ -2,6 +2,7 @@ import copy
2
2
  import logging
3
3
  import posixpath
4
4
  import tempfile
5
+ import textwrap
5
6
  from types import ModuleType
6
7
  from typing import IO, List, Optional, Tuple, TypedDict, Union
7
8
 
@@ -154,7 +155,7 @@ def _get_model_final_packages(
154
155
  Returns:
155
156
  List of final packages string that is accepted by Snowpark register UDF call.
156
157
  """
157
- final_packages = None
158
+
158
159
  if (
159
160
  any(channel.lower() not in [env_utils.DEFAULT_CHANNEL_NAME] for channel in meta.env._conda_dependencies.keys())
160
161
  or meta.env.pip_requirements
@@ -173,21 +174,29 @@ def _get_model_final_packages(
173
174
  else:
174
175
  required_packages = meta.env._conda_dependencies[env_utils.DEFAULT_CHANNEL_NAME]
175
176
 
176
- final_packages = env_utils.validate_requirements_in_information_schema(
177
+ package_availability_dict = env_utils.get_matched_package_versions_in_information_schema(
177
178
  session, required_packages, python_version=meta.env.python_version
178
179
  )
179
-
180
- if final_packages is None:
180
+ no_version_available_packages = [
181
+ req_name for req_name, ver_list in package_availability_dict.items() if len(ver_list) < 1
182
+ ]
183
+ unavailable_packages = [req.name for req in required_packages if req.name not in package_availability_dict]
184
+ if no_version_available_packages or unavailable_packages:
181
185
  relax_version_info_str = "" if relax_version else "Try to set relax_version as True in the options. "
186
+ required_package_str = " ".join(map(lambda x: f'"{x}"', required_packages))
182
187
  raise snowml_exceptions.SnowflakeMLException(
183
188
  error_code=error_codes.DEPENDENCY_VERSION_ERROR,
184
189
  original_exception=RuntimeError(
185
- "The model's dependencies are not available in Snowflake Anaconda Channel. "
186
- + relax_version_info_str
187
- + "Required packages are:\n"
188
- + " ".join(map(lambda x: f'"{x}"', required_packages))
189
- + "\n Required Python version is: "
190
- + meta.env.python_version
190
+ textwrap.dedent(
191
+ f"""
192
+ The model's dependencies are not available in Snowflake Anaconda Channel. {relax_version_info_str}
193
+ Required packages are: {required_package_str}
194
+ Required Python version is: {meta.env.python_version}
195
+ Packages that are not available are: {unavailable_packages}
196
+ Packages that cannot meet your requirements are: {no_version_available_packages}
197
+ Package availability information of those you requested is: {package_availability_dict}
198
+ """
199
+ ),
191
200
  ),
192
201
  )
193
- return final_packages
202
+ return list(sorted(map(str, required_packages)))
@@ -1,6 +1,6 @@
1
1
  import collections
2
2
  import pathlib
3
- from typing import List, Optional, cast
3
+ from typing import Any, Dict, List, Optional, cast
4
4
 
5
5
  import yaml
6
6
 
@@ -83,7 +83,11 @@ class ModelManifest:
83
83
  ],
84
84
  )
85
85
 
86
+ manifest_dict["user_data"] = self.generate_user_data_with_client_data(model_meta)
87
+
86
88
  with (self.workspace_path / ModelManifest.MANIFEST_FILE_REL_PATH).open("w", encoding="utf-8") as f:
89
+ # Anchors are not supported in the server, avoid that.
90
+ yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
87
91
  yaml.safe_dump(manifest_dict, f)
88
92
 
89
93
  def load(self) -> model_manifest_schema.ModelManifestDict:
@@ -99,3 +103,43 @@ class ModelManifest:
99
103
  res = cast(model_manifest_schema.ModelManifestDict, raw_input)
100
104
 
101
105
  return res
106
+
107
+ def generate_user_data_with_client_data(self, model_meta: model_meta_api.ModelMetadata) -> Dict[str, Any]:
108
+ client_data = model_manifest_schema.SnowparkMLDataDict(
109
+ schema_version=model_manifest_schema.MANIFEST_CLIENT_DATA_SCHEMA_VERSION,
110
+ functions=[
111
+ model_manifest_schema.ModelFunctionInfoDict(
112
+ name=method.method_name.identifier(),
113
+ target_method=method.target_method,
114
+ signature=model_meta.signatures[method.target_method].to_dict(),
115
+ )
116
+ for method in self.methods
117
+ ],
118
+ )
119
+ return {model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME: client_data}
120
+
121
+ @staticmethod
122
+ def parse_client_data_from_user_data(raw_user_data: Dict[str, Any]) -> model_manifest_schema.SnowparkMLDataDict:
123
+ raw_client_data = raw_user_data.get(model_manifest_schema.MANIFEST_CLIENT_DATA_KEY_NAME, {})
124
+ if not isinstance(raw_client_data, dict) or "schema_version" not in raw_client_data:
125
+ raise ValueError(f"Ill-formatted client data {raw_client_data} in user data found.")
126
+ loaded_client_data_schema_version = raw_client_data["schema_version"]
127
+ if (
128
+ not isinstance(loaded_client_data_schema_version, str)
129
+ or loaded_client_data_schema_version != model_manifest_schema.MANIFEST_CLIENT_DATA_SCHEMA_VERSION
130
+ ):
131
+ raise ValueError(f"Unsupported client data schema version {loaded_client_data_schema_version} confronted.")
132
+
133
+ return_functions_info: List[model_manifest_schema.ModelFunctionInfoDict] = []
134
+ loaded_functions_info = raw_client_data.get("functions", [])
135
+ for func in loaded_functions_info:
136
+ fi = model_manifest_schema.ModelFunctionInfoDict(
137
+ name=func["name"],
138
+ target_method=func["target_method"],
139
+ signature=func["signature"],
140
+ )
141
+ return_functions_info.append(fi)
142
+
143
+ return model_manifest_schema.SnowparkMLDataDict(
144
+ schema_version=loaded_client_data_schema_version, functions=return_functions_info
145
+ )