snowflake-ml-python 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. snowflake/ml/_internal/telemetry.py +19 -0
  2. snowflake/ml/model/_client/ops/model_ops.py +16 -38
  3. snowflake/ml/model/_client/sql/model.py +1 -7
  4. snowflake/ml/model/_client/sql/model_version.py +20 -15
  5. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +1 -6
  6. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +0 -2
  7. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  8. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -2
  9. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  10. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  11. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  12. snowflake/ml/model/type_hints.py +3 -0
  13. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +63 -95
  14. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  15. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +16 -0
  16. snowflake/ml/modeling/cluster/affinity_propagation.py +16 -0
  17. snowflake/ml/modeling/cluster/agglomerative_clustering.py +16 -0
  18. snowflake/ml/modeling/cluster/birch.py +16 -0
  19. snowflake/ml/modeling/cluster/bisecting_k_means.py +16 -0
  20. snowflake/ml/modeling/cluster/dbscan.py +16 -0
  21. snowflake/ml/modeling/cluster/feature_agglomeration.py +16 -0
  22. snowflake/ml/modeling/cluster/k_means.py +16 -0
  23. snowflake/ml/modeling/cluster/mean_shift.py +16 -0
  24. snowflake/ml/modeling/cluster/mini_batch_k_means.py +16 -0
  25. snowflake/ml/modeling/cluster/optics.py +16 -0
  26. snowflake/ml/modeling/cluster/spectral_biclustering.py +16 -0
  27. snowflake/ml/modeling/cluster/spectral_clustering.py +16 -0
  28. snowflake/ml/modeling/cluster/spectral_coclustering.py +16 -0
  29. snowflake/ml/modeling/compose/column_transformer.py +16 -0
  30. snowflake/ml/modeling/compose/transformed_target_regressor.py +16 -0
  31. snowflake/ml/modeling/covariance/elliptic_envelope.py +16 -0
  32. snowflake/ml/modeling/covariance/empirical_covariance.py +16 -0
  33. snowflake/ml/modeling/covariance/graphical_lasso.py +16 -0
  34. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +16 -0
  35. snowflake/ml/modeling/covariance/ledoit_wolf.py +16 -0
  36. snowflake/ml/modeling/covariance/min_cov_det.py +16 -0
  37. snowflake/ml/modeling/covariance/oas.py +16 -0
  38. snowflake/ml/modeling/covariance/shrunk_covariance.py +16 -0
  39. snowflake/ml/modeling/decomposition/dictionary_learning.py +16 -0
  40. snowflake/ml/modeling/decomposition/factor_analysis.py +16 -0
  41. snowflake/ml/modeling/decomposition/fast_ica.py +16 -0
  42. snowflake/ml/modeling/decomposition/incremental_pca.py +16 -0
  43. snowflake/ml/modeling/decomposition/kernel_pca.py +16 -0
  44. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +16 -0
  45. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +16 -0
  46. snowflake/ml/modeling/decomposition/pca.py +16 -0
  47. snowflake/ml/modeling/decomposition/sparse_pca.py +16 -0
  48. snowflake/ml/modeling/decomposition/truncated_svd.py +16 -0
  49. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +16 -0
  50. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +16 -0
  51. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +16 -0
  52. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +16 -0
  53. snowflake/ml/modeling/ensemble/bagging_classifier.py +16 -0
  54. snowflake/ml/modeling/ensemble/bagging_regressor.py +16 -0
  55. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +16 -0
  56. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +16 -0
  57. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +16 -0
  58. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +16 -0
  59. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +16 -0
  60. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +16 -0
  61. snowflake/ml/modeling/ensemble/isolation_forest.py +16 -0
  62. snowflake/ml/modeling/ensemble/random_forest_classifier.py +16 -0
  63. snowflake/ml/modeling/ensemble/random_forest_regressor.py +16 -0
  64. snowflake/ml/modeling/ensemble/stacking_regressor.py +16 -0
  65. snowflake/ml/modeling/ensemble/voting_classifier.py +16 -0
  66. snowflake/ml/modeling/ensemble/voting_regressor.py +16 -0
  67. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +16 -0
  68. snowflake/ml/modeling/feature_selection/select_fdr.py +16 -0
  69. snowflake/ml/modeling/feature_selection/select_fpr.py +16 -0
  70. snowflake/ml/modeling/feature_selection/select_fwe.py +16 -0
  71. snowflake/ml/modeling/feature_selection/select_k_best.py +16 -0
  72. snowflake/ml/modeling/feature_selection/select_percentile.py +16 -0
  73. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +16 -0
  74. snowflake/ml/modeling/feature_selection/variance_threshold.py +16 -0
  75. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +16 -0
  76. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +16 -0
  77. snowflake/ml/modeling/impute/iterative_imputer.py +16 -0
  78. snowflake/ml/modeling/impute/knn_imputer.py +16 -0
  79. snowflake/ml/modeling/impute/missing_indicator.py +16 -0
  80. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +16 -0
  81. snowflake/ml/modeling/kernel_approximation/nystroem.py +16 -0
  82. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +16 -0
  83. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +16 -0
  84. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +16 -0
  85. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +16 -0
  86. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +16 -0
  87. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +16 -0
  88. snowflake/ml/modeling/linear_model/ard_regression.py +16 -0
  89. snowflake/ml/modeling/linear_model/bayesian_ridge.py +16 -0
  90. snowflake/ml/modeling/linear_model/elastic_net.py +16 -0
  91. snowflake/ml/modeling/linear_model/elastic_net_cv.py +16 -0
  92. snowflake/ml/modeling/linear_model/gamma_regressor.py +16 -0
  93. snowflake/ml/modeling/linear_model/huber_regressor.py +16 -0
  94. snowflake/ml/modeling/linear_model/lars.py +16 -0
  95. snowflake/ml/modeling/linear_model/lars_cv.py +16 -0
  96. snowflake/ml/modeling/linear_model/lasso.py +16 -0
  97. snowflake/ml/modeling/linear_model/lasso_cv.py +16 -0
  98. snowflake/ml/modeling/linear_model/lasso_lars.py +16 -0
  99. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +16 -0
  100. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +16 -0
  101. snowflake/ml/modeling/linear_model/linear_regression.py +16 -0
  102. snowflake/ml/modeling/linear_model/logistic_regression.py +16 -0
  103. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +16 -0
  104. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +16 -0
  105. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +16 -0
  106. snowflake/ml/modeling/linear_model/multi_task_lasso.py +16 -0
  107. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +16 -0
  108. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +16 -0
  109. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +16 -0
  110. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +16 -0
  111. snowflake/ml/modeling/linear_model/perceptron.py +16 -0
  112. snowflake/ml/modeling/linear_model/poisson_regressor.py +16 -0
  113. snowflake/ml/modeling/linear_model/ransac_regressor.py +16 -0
  114. snowflake/ml/modeling/linear_model/ridge.py +16 -0
  115. snowflake/ml/modeling/linear_model/ridge_classifier.py +16 -0
  116. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +16 -0
  117. snowflake/ml/modeling/linear_model/ridge_cv.py +16 -0
  118. snowflake/ml/modeling/linear_model/sgd_classifier.py +16 -0
  119. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +16 -0
  120. snowflake/ml/modeling/linear_model/sgd_regressor.py +16 -0
  121. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +16 -0
  122. snowflake/ml/modeling/linear_model/tweedie_regressor.py +16 -0
  123. snowflake/ml/modeling/manifold/isomap.py +16 -0
  124. snowflake/ml/modeling/manifold/mds.py +16 -0
  125. snowflake/ml/modeling/manifold/spectral_embedding.py +16 -0
  126. snowflake/ml/modeling/manifold/tsne.py +16 -0
  127. snowflake/ml/modeling/metrics/classification.py +5 -6
  128. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  129. snowflake/ml/modeling/metrics/ranking.py +7 -3
  130. snowflake/ml/modeling/metrics/regression.py +6 -3
  131. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +16 -0
  132. snowflake/ml/modeling/mixture/gaussian_mixture.py +16 -0
  133. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +16 -0
  134. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +16 -0
  135. snowflake/ml/modeling/multiclass/output_code_classifier.py +16 -0
  136. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +16 -0
  137. snowflake/ml/modeling/naive_bayes/categorical_nb.py +16 -0
  138. snowflake/ml/modeling/naive_bayes/complement_nb.py +16 -0
  139. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +16 -0
  140. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +16 -0
  141. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +16 -0
  142. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +16 -0
  143. snowflake/ml/modeling/neighbors/kernel_density.py +16 -0
  144. snowflake/ml/modeling/neighbors/local_outlier_factor.py +16 -0
  145. snowflake/ml/modeling/neighbors/nearest_centroid.py +16 -0
  146. snowflake/ml/modeling/neighbors/nearest_neighbors.py +16 -0
  147. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +16 -0
  148. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +16 -0
  149. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +16 -0
  150. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +16 -0
  151. snowflake/ml/modeling/neural_network/mlp_classifier.py +16 -0
  152. snowflake/ml/modeling/neural_network/mlp_regressor.py +16 -0
  153. snowflake/ml/modeling/preprocessing/polynomial_features.py +16 -0
  154. snowflake/ml/modeling/semi_supervised/label_propagation.py +16 -0
  155. snowflake/ml/modeling/semi_supervised/label_spreading.py +16 -0
  156. snowflake/ml/modeling/svm/linear_svc.py +16 -0
  157. snowflake/ml/modeling/svm/linear_svr.py +16 -0
  158. snowflake/ml/modeling/svm/nu_svc.py +16 -0
  159. snowflake/ml/modeling/svm/nu_svr.py +16 -0
  160. snowflake/ml/modeling/svm/svc.py +16 -0
  161. snowflake/ml/modeling/svm/svr.py +16 -0
  162. snowflake/ml/modeling/tree/decision_tree_classifier.py +16 -0
  163. snowflake/ml/modeling/tree/decision_tree_regressor.py +16 -0
  164. snowflake/ml/modeling/tree/extra_tree_classifier.py +16 -0
  165. snowflake/ml/modeling/tree/extra_tree_regressor.py +16 -0
  166. snowflake/ml/modeling/xgboost/xgb_classifier.py +16 -0
  167. snowflake/ml/modeling/xgboost/xgb_regressor.py +16 -0
  168. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +16 -0
  169. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +16 -0
  170. snowflake/ml/registry/registry.py +2 -0
  171. snowflake/ml/version.py +1 -1
  172. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  173. {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +261 -50
  174. {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.1.dist-info}/RECORD +189 -186
  175. {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  176. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
@@ -584,3 +584,22 @@ class _SourceTelemetryClient:
584
584
  """Send the telemetry data batch immediately."""
585
585
  if self._telemetry:
586
586
  self._telemetry.send_batch()
587
+
588
+
589
+ def get_sproc_statement_params_kwargs(sproc: Callable[..., Any], statement_params: Dict[str, Any]) -> Dict[str, Any]:
590
+ """
591
+ Get statement_params keyword argument for sproc call.
592
+
593
+ Args:
594
+ sproc: sproc function
595
+ statement_params: dictionary to be passed as statement params, if possible
596
+
597
+ Returns:
598
+ Keyword arguments dict
599
+ """
600
+ sproc_argspec = inspect.getfullargspec(sproc)
601
+ kwargs = {}
602
+ if "statement_params" in sproc_argspec.args:
603
+ kwargs["statement_params"] = statement_params
604
+
605
+ return kwargs
@@ -4,9 +4,8 @@ import tempfile
4
4
  from typing import Any, Dict, List, Optional, Union, cast
5
5
 
6
6
  import yaml
7
- from packaging import version
8
7
 
9
- from snowflake.ml._internal.utils import identifier, snowflake_env, sql_identifier
8
+ from snowflake.ml._internal.utils import identifier, sql_identifier
10
9
  from snowflake.ml.model import model_signature, type_hints
11
10
  from snowflake.ml.model._client.ops import metadata_ops
12
11
  from snowflake.ml.model._client.sql import (
@@ -25,8 +24,6 @@ from snowflake.ml.model._signatures import snowpark_handler
25
24
  from snowflake.snowpark import dataframe, row, session
26
25
  from snowflake.snowpark._internal import utils as snowpark_utils
27
26
 
28
- _TAG_ON_MODEL_AVAILABLE_VERSION = version.parse("8.2.0")
29
-
30
27
 
31
28
  class ModelOperator:
32
29
  def __init__(
@@ -296,21 +293,14 @@ class ModelOperator:
296
293
  tag_value: str,
297
294
  statement_params: Optional[Dict[str, Any]] = None,
298
295
  ) -> None:
299
- sf_version = snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params)
300
- if sf_version >= _TAG_ON_MODEL_AVAILABLE_VERSION:
301
- self._tag_client.set_tag_on_model(
302
- model_name=model_name,
303
- tag_database_name=tag_database_name,
304
- tag_schema_name=tag_schema_name,
305
- tag_name=tag_name,
306
- tag_value=tag_value,
307
- statement_params=statement_params,
308
- )
309
- else:
310
- raise NotImplementedError(
311
- f"`set_tag` won't work before Snowflake version {_TAG_ON_MODEL_AVAILABLE_VERSION},"
312
- f" currently is {sf_version}"
313
- )
296
+ self._tag_client.set_tag_on_model(
297
+ model_name=model_name,
298
+ tag_database_name=tag_database_name,
299
+ tag_schema_name=tag_schema_name,
300
+ tag_name=tag_name,
301
+ tag_value=tag_value,
302
+ statement_params=statement_params,
303
+ )
314
304
 
315
305
  def unset_tag(
316
306
  self,
@@ -321,20 +311,13 @@ class ModelOperator:
321
311
  tag_name: sql_identifier.SqlIdentifier,
322
312
  statement_params: Optional[Dict[str, Any]] = None,
323
313
  ) -> None:
324
- sf_version = snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params)
325
- if sf_version >= _TAG_ON_MODEL_AVAILABLE_VERSION:
326
- self._tag_client.unset_tag_on_model(
327
- model_name=model_name,
328
- tag_database_name=tag_database_name,
329
- tag_schema_name=tag_schema_name,
330
- tag_name=tag_name,
331
- statement_params=statement_params,
332
- )
333
- else:
334
- raise NotImplementedError(
335
- f"`unset_tag` won't work before Snowflake version {_TAG_ON_MODEL_AVAILABLE_VERSION},"
336
- f" currently is {sf_version}"
337
- )
314
+ self._tag_client.unset_tag_on_model(
315
+ model_name=model_name,
316
+ tag_database_name=tag_database_name,
317
+ tag_schema_name=tag_schema_name,
318
+ tag_name=tag_name,
319
+ statement_params=statement_params,
320
+ )
338
321
 
339
322
  def get_model_version_manifest(
340
323
  self,
@@ -382,11 +365,6 @@ class ModelOperator:
382
365
  version_name: sql_identifier.SqlIdentifier,
383
366
  statement_params: Optional[Dict[str, Any]] = None,
384
367
  ) -> model_manifest_schema.SnowparkMLDataDict:
385
- if (
386
- snowflake_env.get_current_snowflake_version(self._session)
387
- < model_manifest_schema.MANIFEST_USER_DATA_ENABLE_VERSION
388
- ):
389
- raise NotImplementedError("User_data has not been supported yet.")
390
368
  raw_user_data_json_string = self._model_client.show_versions(
391
369
  model_name=model_name,
392
370
  version_name=version_name,
@@ -3,10 +3,8 @@ from typing import Any, Dict, List, Optional
3
3
  from snowflake.ml._internal.utils import (
4
4
  identifier,
5
5
  query_result_checker,
6
- snowflake_env,
7
6
  sql_identifier,
8
7
  )
9
- from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
10
8
  from snowflake.snowpark import row, session
11
9
 
12
10
 
@@ -89,12 +87,8 @@ class ModelSQLClient:
89
87
  .has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
90
88
  .has_column(ModelSQLClient.MODEL_VERSION_COMMENT_COL_NAME, allow_empty=True)
91
89
  .has_column(ModelSQLClient.MODEL_VERSION_METADATA_COL_NAME, allow_empty=True)
90
+ .has_column(ModelSQLClient.MODEL_VERSION_USER_DATA_COL_NAME, allow_empty=True)
92
91
  )
93
- if (
94
- snowflake_env.get_current_snowflake_version(self._session)
95
- >= model_manifest_schema.MANIFEST_USER_DATA_ENABLE_VERSION
96
- ):
97
- res = res.has_column(ModelSQLClient.MODEL_VERSION_USER_DATA_COL_NAME, allow_empty=True)
98
92
  if validate_result and version_name:
99
93
  res = res.has_dimensions(expected_rows=1)
100
94
 
@@ -146,24 +146,29 @@ class ModelVersionSQLClient:
146
146
  returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
147
147
  statement_params: Optional[Dict[str, Any]] = None,
148
148
  ) -> dataframe.DataFrame:
149
- tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
150
- INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
151
- self._database_name.identifier(),
152
- self._schema_name.identifier(),
153
- tmp_table_name,
154
- )
155
- input_df.write.save_as_table( # type: ignore[call-overload]
156
- table_name=INTERMEDIATE_TABLE_NAME,
157
- mode="errorifexists",
158
- table_type="temporary",
159
- statement_params=statement_params,
160
- )
149
+ with_statements = []
150
+ if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
151
+ INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
152
+ with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
153
+ else:
154
+ tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
155
+ INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
156
+ self._database_name.identifier(),
157
+ self._schema_name.identifier(),
158
+ tmp_table_name,
159
+ )
160
+ input_df.write.save_as_table( # type: ignore[call-overload]
161
+ table_name=INTERMEDIATE_TABLE_NAME,
162
+ mode="errorifexists",
163
+ table_type="temporary",
164
+ statement_params=statement_params,
165
+ )
161
166
 
162
167
  INTERMEDIATE_OBJ_NAME = "TMP_RESULT"
163
168
 
164
169
  module_version_alias = "MODEL_VERSION_ALIAS"
165
- model_version_alias_sql = (
166
- f"WITH {module_version_alias} AS "
170
+ with_statements.append(
171
+ f"{module_version_alias} AS "
167
172
  f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}"
168
173
  )
169
174
 
@@ -174,7 +179,7 @@ class ModelVersionSQLClient:
174
179
  args_sql = ", ".join(args_sql_list)
175
180
 
176
181
  sql = textwrap.dedent(
177
- f"""{model_version_alias_sql}
182
+ f"""WITH {','.join(with_statements)}
178
183
  SELECT *,
179
184
  {module_version_alias}!{method_name.identifier()}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
180
185
  FROM {INTERMEDIATE_TABLE_NAME}"""
@@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional, cast
4
4
 
5
5
  import yaml
6
6
 
7
- from snowflake.ml._internal.utils import snowflake_env
8
7
  from snowflake.ml.model import type_hints
9
8
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
10
9
  from snowflake.ml.model._model_composer.model_method import (
@@ -84,11 +83,7 @@ class ModelManifest:
84
83
  ],
85
84
  )
86
85
 
87
- if (
88
- snowflake_env.get_current_snowflake_version(session)
89
- >= model_manifest_schema.MANIFEST_USER_DATA_ENABLE_VERSION
90
- ):
91
- manifest_dict["user_data"] = self.generate_user_data_with_client_data(model_meta)
86
+ manifest_dict["user_data"] = self.generate_user_data_with_client_data(model_meta)
92
87
 
93
88
  with (self.workspace_path / ModelManifest.MANIFEST_FILE_REL_PATH).open("w", encoding="utf-8") as f:
94
89
  # Anchors are not supported in the server, avoid that.
@@ -2,14 +2,12 @@
2
2
 
3
3
  from typing import Any, Dict, List, Literal, TypedDict
4
4
 
5
- from packaging import version
6
5
  from typing_extensions import NotRequired, Required
7
6
 
8
7
  from snowflake.ml.model import model_signature
9
8
 
10
9
  MODEL_MANIFEST_VERSION = "1.0"
11
10
 
12
- MANIFEST_USER_DATA_ENABLE_VERSION = version.parse("8.2.0")
13
11
  MANIFEST_CLIENT_DATA_KEY_NAME = "snowpark_ml_data"
14
12
  MANIFEST_CLIENT_DATA_SCHEMA_VERSION = "2024-02-01"
15
13
 
@@ -1 +1,10 @@
1
- REQUIREMENTS = ['absl-py>=0.15,<2', 'anyio>=3.5.0,<4', 'numpy>=1.23,<2', 'packaging>=20.9,<24', 'pandas>=1.0.0,<2', 'pyyaml>=6.0,<7', 'snowflake-snowpark-python>=1.8.0,<2', 'typing-extensions>=4.1.0,<5']
1
+ REQUIREMENTS = [
2
+ "absl-py>=0.15,<2",
3
+ "anyio>=3.5.0,<4",
4
+ "numpy>=1.23,<2",
5
+ "packaging>=20.9,<24",
6
+ "pandas>=1.0.0,<2",
7
+ "pyyaml>=6.0,<7",
8
+ "snowflake-snowpark-python>=1.8.0,<2",
9
+ "typing-extensions>=4.1.0,<5"
10
+ ]
@@ -62,7 +62,6 @@ class ModelRuntime:
62
62
  model_env.ModelDependency(requirement=dep, pip_name=requirements.Requirement(dep).name)
63
63
  for dep in _UDF_INFERENCE_DEPENDENCIES
64
64
  ],
65
- check_local_version=True,
66
65
  )
67
66
  else:
68
67
  self.runtime_env.include_if_absent(
@@ -70,7 +69,6 @@ class ModelRuntime:
70
69
  model_env.ModelDependency(requirement=dep, pip_name=requirements.Requirement(dep).name)
71
70
  for dep in _UDF_INFERENCE_DEPENDENCIES + [snowml_pkg_spec]
72
71
  ],
73
- check_local_version=True,
74
72
  )
75
73
 
76
74
  def save(self, workspace_path: pathlib.Path) -> model_manifest_schema.ModelRuntimeDict:
@@ -1 +1,11 @@
1
- REQUIREMENTS = ['absl-py>=0.15,<2', 'anyio>=3.5.0,<4', 'cloudpickle>=2.0.0', 'numpy>=1.23,<2', 'packaging>=20.9,<24', 'pandas>=1.0.0,<2', 'pyyaml>=6.0,<7', 'snowflake-snowpark-python>=1.8.0,<2', 'typing-extensions>=4.1.0,<5']
1
+ REQUIREMENTS = [
2
+ "absl-py>=0.15,<2",
3
+ "anyio>=3.5.0,<4",
4
+ "cloudpickle>=2.0.0",
5
+ "numpy>=1.23,<2",
6
+ "packaging>=20.9,<24",
7
+ "pandas>=1.0.0,<2",
8
+ "pyyaml>=6.0,<7",
9
+ "snowflake-snowpark-python>=1.8.0,<2",
10
+ "typing-extensions>=4.1.0,<5"
11
+ ]
@@ -0,0 +1,3 @@
1
+ REQUIREMENTS = [
2
+ "cloudpickle>=2.0.0"
3
+ ]
@@ -18,6 +18,7 @@ from snowflake.ml.model import model_signature, type_hints as model_types
18
18
  from snowflake.ml.model._packager.model_env import model_env
19
19
  from snowflake.ml.model._packager.model_meta import (
20
20
  _core_requirements,
21
+ _packaging_requirements,
21
22
  model_blob_meta,
22
23
  model_meta_schema,
23
24
  )
@@ -26,7 +27,8 @@ from snowflake.ml.model._packager.model_meta_migrator import migrator_plans
26
27
  MODEL_METADATA_FILE = "model.yaml"
27
28
  MODEL_CODE_DIR = "code"
28
29
 
29
- _PACKAGING_CORE_DEPENDENCIES = _core_requirements.REQUIREMENTS
30
+ _PACKAGING_CORE_DEPENDENCIES = _core_requirements.REQUIREMENTS # Legacy Model only
31
+ _PACKAGING_REQUIREMENTS = _packaging_requirements.REQUIREMENTS # New Model only
30
32
  _SNOWFLAKE_PKG_NAME = "snowflake"
31
33
  _SNOWFLAKE_ML_PKG_NAME = f"{_SNOWFLAKE_PKG_NAME}.ml"
32
34
 
@@ -73,6 +75,8 @@ def create_model_metadata(
73
75
  model_dir_path = os.path.normpath(model_dir_path)
74
76
  embed_local_ml_library = kwargs.pop("embed_local_ml_library", False)
75
77
  legacy_save = kwargs.pop("_legacy_save", False)
78
+ relax_version = kwargs.pop("relax_version", False)
79
+
76
80
  if embed_local_ml_library:
77
81
  # Use the last one which is loaded first, that is mean, it is loaded from site-packages.
78
82
  # We could make sure that user does not overwrite our library with their code follow the same naming.
@@ -94,6 +98,8 @@ def create_model_metadata(
94
98
  pip_requirements=pip_requirements,
95
99
  python_version=python_version,
96
100
  embed_local_ml_library=embed_local_ml_library,
101
+ legacy_save=legacy_save,
102
+ relax_version=relax_version,
97
103
  )
98
104
 
99
105
  if embed_local_ml_library:
@@ -146,6 +152,8 @@ def _create_env_for_model_metadata(
146
152
  pip_requirements: Optional[List[str]] = None,
147
153
  python_version: Optional[str] = None,
148
154
  embed_local_ml_library: bool = False,
155
+ legacy_save: bool = False,
156
+ relax_version: bool = False,
149
157
  ) -> model_env.ModelEnv:
150
158
  env = model_env.ModelEnv()
151
159
 
@@ -154,11 +162,14 @@ def _create_env_for_model_metadata(
154
162
  env.pip_requirements = pip_requirements # type: ignore[assignment]
155
163
  env.python_version = python_version # type: ignore[assignment]
156
164
  env.snowpark_ml_version = snowml_env.VERSION
165
+
166
+ requirements_to_add = _PACKAGING_CORE_DEPENDENCIES if legacy_save else _PACKAGING_REQUIREMENTS
167
+
157
168
  if embed_local_ml_library:
158
169
  env.include_if_absent(
159
170
  [
160
171
  model_env.ModelDependency(requirement=dep, pip_name=requirements.Requirement(dep).name)
161
- for dep in _PACKAGING_CORE_DEPENDENCIES
172
+ for dep in requirements_to_add
162
173
  ],
163
174
  check_local_version=True,
164
175
  )
@@ -166,11 +177,14 @@ def _create_env_for_model_metadata(
166
177
  env.include_if_absent(
167
178
  [
168
179
  model_env.ModelDependency(requirement=dep, pip_name=requirements.Requirement(dep).name)
169
- for dep in _PACKAGING_CORE_DEPENDENCIES + [env_utils.SNOWPARK_ML_PKG_NAME]
180
+ for dep in requirements_to_add + [env_utils.SNOWPARK_ML_PKG_NAME]
170
181
  ],
171
182
  check_local_version=True,
172
183
  )
173
184
 
185
+ if relax_version:
186
+ env.relax_version()
187
+
174
188
  return env
175
189
 
176
190
 
@@ -198,9 +198,12 @@ class BaseModelSaveOption(TypedDict):
198
198
  """Options for saving the model.
199
199
 
200
200
  embed_local_ml_library: Embedding local SnowML into the code directory of the folder.
201
+ relax_version: Whether or not relax the version constraints of the dependencies if unresolvable. It detects any
202
+ ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to False.
201
203
  """
202
204
 
203
205
  embed_local_ml_library: NotRequired[bool]
206
+ relax_version: NotRequired[bool]
204
207
  _legacy_save: NotRequired[bool]
205
208
  method_options: NotRequired[Dict[str, ModelMethodSaveOptions]]
206
209
 
@@ -4,11 +4,12 @@ import io
4
4
  import os
5
5
  import posixpath
6
6
  import sys
7
- from typing import Any, Dict, List, Optional, Set, Tuple, Union
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
8
 
9
9
  import cloudpickle as cp
10
10
  import numpy as np
11
11
  from sklearn import model_selection
12
+ from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
12
13
 
13
14
  from snowflake.ml._internal import telemetry
14
15
  from snowflake.ml._internal.utils import (
@@ -41,23 +42,28 @@ DEFAULT_UDTF_NJOBS = 3
41
42
 
42
43
 
43
44
  def construct_cv_results(
45
+ estimator: Union[GridSearchCV, RandomizedSearchCV],
46
+ n_split: int,
47
+ param_grid: List[Dict[str, Any]],
44
48
  cv_results_raw_hex: List[Row],
45
49
  cross_validator_indices_length: int,
46
50
  parameter_grid_length: int,
47
- search_cv_kwargs: Dict[str, Any],
48
- ) -> Tuple[bool, Dict[str, Any], int, Set[str]]:
51
+ ) -> Tuple[bool, Dict[str, Any]]:
49
52
  """Construct the cross validation result from the UDF. Because we accelerate the process
50
53
  by the number of cross validation number, and the combination of parameter grids.
51
54
  Therefore, we need to stick them back together instead of returning the raw result
52
55
  to align with original sklearn result.
53
56
 
54
57
  Args:
58
+ estimator (Union[GridSearchCV, RandomizedSearchCV]): The sklearn object of estimator
59
+ GridSearchCV or RandomizedSearchCV
60
+ n_split (int): The number of split, which is determined by build_cross_validator.get_n_splits(X, y, groups)
61
+ param_grid (List[Dict[str, Any]]): the list of parameter grid or parameter sampler
55
62
  cv_results_raw_hex (List[Row]): the list of cv_results from each cv and parameter grid combination.
56
63
  Because UDxF can only return string, and numpy array/masked arrays cannot be encoded in a
57
64
  json format. Each cv_result is encoded into hex string.
58
65
  cross_validator_indices_length (int): the length of cross validator indices
59
66
  parameter_grid_length (int): the length of parameter grid combination
60
- search_cv_kwargs (Dict[str, Any]): the kwargs for GridSearchCV/RandomSearchCV.
61
67
 
62
68
  Raises:
63
69
  ValueError: Retrieved empty cross validation results
@@ -67,7 +73,7 @@ def construct_cv_results(
67
73
  RuntimeError: Cross validation results are unexpectedly empty for one fold.
68
74
 
69
75
  Returns:
70
- Tuple[bool, Dict[str, Any], int, Set[str]]: returns multimetric, cv_results_, best_param_index, scorers
76
+ Tuple[bool, Dict[str, Any]]: returns multimetric, cv_results_
71
77
  """
72
78
  # Filter corner cases: either the snowpark dataframe result is empty; or index length is empty
73
79
  if len(cv_results_raw_hex) == 0:
@@ -79,12 +85,8 @@ def construct_cv_results(
79
85
  if parameter_grid_length == 0:
80
86
  raise ValueError("Parameter index length is 0. Were there no candidates?")
81
87
 
82
- from scipy.stats import rankdata
83
-
84
88
  # cv_result maintains the original order
85
89
  multimetric = False
86
- cv_results_ = dict()
87
- scorers = set()
88
90
  # retrieve the cv_results from udtf table; results are encoded by hex and cloudpickle;
89
91
  # We are constructing the raw information back to original form
90
92
  if len(cv_results_raw_hex) != cross_validator_indices_length * parameter_grid_length:
@@ -94,7 +96,9 @@ def construct_cv_results(
94
96
  "Please retry or contact snowflake support."
95
97
  )
96
98
 
97
- for param_cv_indices, each_cv_result_hex in enumerate(cv_results_raw_hex):
99
+ out = []
100
+
101
+ for each_cv_result_hex in cv_results_raw_hex:
98
102
  # convert the hex string back to cv_results_
99
103
  hex_str = bytes.fromhex(each_cv_result_hex[0])
100
104
  with io.BytesIO(hex_str) as f_reload:
@@ -103,85 +107,46 @@ def construct_cv_results(
103
107
  raise RuntimeError(
104
108
  "Cross validation response is empty. This issue may be temporary - please try again."
105
109
  )
106
- for k, v in each_cv_result.items():
107
- cur_cv_idx = param_cv_indices % cross_validator_indices_length
108
- key = k
109
- if "split0_test_" in k:
110
+ temp_dict = dict()
111
+ """
112
+ This dictionary has the following keys
113
+ train_scores : dict of scorer name -> float
114
+ Score on training set (for all the scorers),
115
+ returned only if `return_train_score` is `True`.
116
+ test_scores : dict of scorer name -> float
117
+ Score on testing set (for all the scorers).
118
+ fit_time : float
119
+ Time spent for fitting in seconds.
120
+ score_time : float
121
+ Time spent for scoring in seconds.
122
+ """
123
+ if estimator.return_train_score:
124
+ if each_cv_result.get("split0_train_score", None):
125
+ # for single scorer, the split0_train_score only contains an array with one value
126
+ temp_dict["train_scores"] = each_cv_result["split0_train_score"][0]
127
+ else:
128
+ # if multimetric situation, the format would be
129
+ # {metric_name1: value, metric_name2: value, ...}
130
+ temp_dict["train_scores"] = {}
110
131
  # For multi-metric evaluation, the scores for all the scorers are available in the
111
132
  # cv_results_ dict at the keys ending with that scorer’s name ('_<scorer_name>')
112
133
  # instead of '_score'.
113
- scorers.add(k[len("split0_test_") :])
114
- key = k.replace("split0_test", f"split{cur_cv_idx}_test")
115
- if search_cv_kwargs.get("return_train_score", None) and "split0_train_" in k:
116
- key = k.replace("split0_train", f"split{cur_cv_idx}_train")
117
- elif k.startswith("param"):
118
- if cur_cv_idx != 0:
119
- continue
120
- if key:
121
- if key not in cv_results_:
122
- cv_results_[key] = v
123
- else:
124
- cv_results_[key] = np.concatenate([cv_results_[key], v])
125
-
126
- multimetric = len(scorers) > 1
127
- # Use numpy to re-calculate all the information in cv_results_ again
128
- # Generally speaking, reshape all the results into the (scorers+2, idx_length, params_length) shape,
129
- # and average them by the idx_length;
130
- # idx_length is the number of cv folds; params_length is the number of parameter combinations
131
- scores_test = [
132
- np.reshape(
133
- np.concatenate(
134
- [cv_results_[f"split{cur_cv}_test_{score}"] for cur_cv in range(cross_validator_indices_length)]
135
- ),
136
- (cross_validator_indices_length, -1),
137
- )
138
- for score in scorers
139
- ]
140
-
141
- fit_score_test_matrix = np.stack(
142
- [
143
- np.reshape(cv_results_["mean_fit_time"], (cross_validator_indices_length, -1)),
144
- np.reshape(cv_results_["mean_score_time"], (cross_validator_indices_length, -1)),
145
- ]
146
- + scores_test
147
- )
148
- mean_fit_score_test_matrix = np.mean(fit_score_test_matrix, axis=1)
149
- std_fit_score_test_matrix = np.std(fit_score_test_matrix, axis=1)
150
-
151
- if search_cv_kwargs.get("return_train_score", None):
152
- scores_train = [
153
- np.reshape(
154
- np.concatenate(
155
- [cv_results_[f"split{cur_cv}_train_{score}"] for cur_cv in range(cross_validator_indices_length)]
156
- ),
157
- (cross_validator_indices_length, -1),
158
- )
159
- for score in scorers
160
- ]
161
- mean_fit_score_train_matrix = np.mean(scores_train, axis=1)
162
- std_fit_score_train_matrix = np.std(scores_train, axis=1)
163
-
164
- cv_results_["std_fit_time"] = std_fit_score_test_matrix[0]
165
- cv_results_["mean_fit_time"] = mean_fit_score_test_matrix[0]
166
- cv_results_["std_score_time"] = std_fit_score_test_matrix[1]
167
- cv_results_["mean_score_time"] = mean_fit_score_test_matrix[1]
168
- for idx, score in enumerate(scorers):
169
- cv_results_[f"std_test_{score}"] = std_fit_score_test_matrix[idx + 2]
170
- cv_results_[f"mean_test_{score}"] = mean_fit_score_test_matrix[idx + 2]
171
- if search_cv_kwargs.get("return_train_score", None):
172
- cv_results_[f"std_train_{score}"] = std_fit_score_train_matrix[idx]
173
- cv_results_[f"mean_train_{score}"] = mean_fit_score_train_matrix[idx]
174
- # re-compute the ranking again with mean_test_<score>.
175
- cv_results_[f"rank_test_{score}"] = rankdata(-cv_results_[f"mean_test_{score}"], method="min")
176
- # The best param is the highest ranking (which is 1) and we choose the first time ranking 1 appeared.
177
- # If all scores are `nan`, `rankdata` will also produce an array of `nan` values.
178
- # In that case, default to first index.
179
- best_param_index = (
180
- np.where(cv_results_[f"rank_test_{score}"] == 1)[0][0]
181
- if not np.isnan(cv_results_[f"rank_test_{score}"]).all()
182
- else 0
183
- )
184
- return multimetric, cv_results_, best_param_index, scorers
134
+ for k, v in each_cv_result.items():
135
+ if "split0_train_" in k:
136
+ temp_dict["train_scores"][k[len("split0_train_") :]] = v
137
+ if isinstance(each_cv_result.get("split0_test_score"), np.ndarray):
138
+ temp_dict["test_scores"] = each_cv_result["split0_test_score"][0]
139
+ else:
140
+ temp_dict["test_scores"] = {}
141
+ for k, v in each_cv_result.items():
142
+ if "split0_test_" in k:
143
+ temp_dict["test_scores"][k[len("split0_test_") :]] = v
144
+ temp_dict["fit_time"] = each_cv_result["mean_fit_time"][0]
145
+ temp_dict["score_time"] = each_cv_result["mean_score_time"][0]
146
+ out.append(temp_dict)
147
+ first_test_score = out[0]["test_scores"]
148
+ multimetric = isinstance(first_test_score, dict)
149
+ return multimetric, estimator._format_results(param_grid, n_split, out)
185
150
 
186
151
 
187
152
  cp.register_pickle_by_value(inspect.getmodule(construct_cv_results))
@@ -288,7 +253,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
288
253
  inspect.currentframe(), self.__class__.__name__
289
254
  ),
290
255
  api_calls=[sproc],
291
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
292
256
  )
293
257
  udtf_statement_params = telemetry.get_function_usage_statement_params(
294
258
  project=_PROJECT,
@@ -297,7 +261,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
297
261
  inspect.currentframe(), self.__class__.__name__
298
262
  ),
299
263
  api_calls=[udtf],
300
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
264
+ custom_tags=dict([("hpo_udtf", True)]),
301
265
  )
302
266
 
303
267
  # Put locally serialized estimator on stage.
@@ -375,8 +339,12 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
375
339
  estimator = cp.load(local_estimator_file_obj)["estimator"]
376
340
 
377
341
  build_cross_validator = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator))
342
+ from sklearn.utils.validation import indexable
343
+
344
+ X, y, _ = indexable(X, y, None)
345
+ n_splits = build_cross_validator.get_n_splits(X, y, None)
378
346
  # store the cross_validator's test indices only to save space
379
- cross_validator_indices = [test for _, test in build_cross_validator.split(X, y)]
347
+ cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
380
348
  local_indices_file_name = get_temp_file_path()
381
349
  with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
382
350
  cp.dump(cross_validator_indices, local_indices_file_obj)
@@ -529,14 +497,14 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
529
497
  )
530
498
  ),
531
499
  )
532
-
533
- multimetric, cv_results_, best_param_index, scorers = construct_cv_results(
500
+ # multimetric, cv_results_, best_param_index, scorers
501
+ multimetric, cv_results_ = construct_cv_results(
502
+ estimator,
503
+ n_splits,
504
+ list(param_grid),
534
505
  HP_raw_results.select("CV_RESULTS").sort(F.col("PARAM_CV_IND")).collect(),
535
506
  cross_validator_indices_length,
536
507
  parameter_grid_length,
537
- {
538
- "return_train_score": estimator.return_train_score,
539
- }, # TODO(xjiang): support more kwargs in here
540
508
  )
541
509
 
542
510
  estimator.cv_results_ = cv_results_
@@ -568,7 +536,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
568
536
  # With a non-custom callable, we can select the best score
569
537
  # based on the best index
570
538
  estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
571
- estimator.best_params_ = cv_results_["params"][best_param_index]
539
+ estimator.best_params_ = cv_results_["params"][estimator.best_index_]
572
540
 
573
541
  if original_refit:
574
542
  estimator.best_estimator_ = clone(estimator.estimator).set_params(