snowflake-ml-python 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. snowflake/cortex/_complete.py +1 -1
  2. snowflake/cortex/_extract_answer.py +1 -1
  3. snowflake/cortex/_sentiment.py +1 -1
  4. snowflake/cortex/_summarize.py +1 -1
  5. snowflake/cortex/_translate.py +1 -1
  6. snowflake/ml/_internal/env_utils.py +68 -6
  7. snowflake/ml/_internal/file_utils.py +34 -4
  8. snowflake/ml/_internal/telemetry.py +79 -91
  9. snowflake/ml/_internal/utils/retryable_http.py +16 -4
  10. snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
  11. snowflake/ml/dataset/dataset.py +1 -1
  12. snowflake/ml/model/_api.py +21 -14
  13. snowflake/ml/model/_client/model/model_impl.py +176 -0
  14. snowflake/ml/model/_client/model/model_method_info.py +19 -0
  15. snowflake/ml/model/_client/model/model_version_impl.py +291 -0
  16. snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
  17. snowflake/ml/model/_client/ops/model_ops.py +308 -0
  18. snowflake/ml/model/_client/sql/model.py +75 -0
  19. snowflake/ml/model/_client/sql/model_version.py +213 -0
  20. snowflake/ml/model/_client/sql/stage.py +40 -0
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
  22. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
  23. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
  24. snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
  25. snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
  26. snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
  27. snowflake/ml/model/_model_composer/model_composer.py +31 -9
  28. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
  29. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
  30. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  31. snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
  32. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
  33. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
  34. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
  36. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  37. snowflake/ml/model/model_signature.py +108 -53
  38. snowflake/ml/model/type_hints.py +1 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
  40. snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
  41. snowflake/ml/modeling/_internal/model_specifications.py +146 -0
  42. snowflake/ml/modeling/_internal/model_trainer.py +13 -0
  43. snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
  44. snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
  45. snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
  46. snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
  47. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +96 -124
  48. snowflake/ml/modeling/cluster/affinity_propagation.py +94 -124
  49. snowflake/ml/modeling/cluster/agglomerative_clustering.py +94 -124
  50. snowflake/ml/modeling/cluster/birch.py +94 -124
  51. snowflake/ml/modeling/cluster/bisecting_k_means.py +94 -124
  52. snowflake/ml/modeling/cluster/dbscan.py +94 -124
  53. snowflake/ml/modeling/cluster/feature_agglomeration.py +94 -124
  54. snowflake/ml/modeling/cluster/k_means.py +93 -124
  55. snowflake/ml/modeling/cluster/mean_shift.py +94 -124
  56. snowflake/ml/modeling/cluster/mini_batch_k_means.py +93 -124
  57. snowflake/ml/modeling/cluster/optics.py +94 -124
  58. snowflake/ml/modeling/cluster/spectral_biclustering.py +94 -124
  59. snowflake/ml/modeling/cluster/spectral_clustering.py +94 -124
  60. snowflake/ml/modeling/cluster/spectral_coclustering.py +94 -124
  61. snowflake/ml/modeling/compose/column_transformer.py +94 -124
  62. snowflake/ml/modeling/compose/transformed_target_regressor.py +96 -124
  63. snowflake/ml/modeling/covariance/elliptic_envelope.py +94 -124
  64. snowflake/ml/modeling/covariance/empirical_covariance.py +80 -110
  65. snowflake/ml/modeling/covariance/graphical_lasso.py +94 -124
  66. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +94 -124
  67. snowflake/ml/modeling/covariance/ledoit_wolf.py +85 -115
  68. snowflake/ml/modeling/covariance/min_cov_det.py +94 -124
  69. snowflake/ml/modeling/covariance/oas.py +80 -110
  70. snowflake/ml/modeling/covariance/shrunk_covariance.py +84 -114
  71. snowflake/ml/modeling/decomposition/dictionary_learning.py +94 -124
  72. snowflake/ml/modeling/decomposition/factor_analysis.py +94 -124
  73. snowflake/ml/modeling/decomposition/fast_ica.py +94 -124
  74. snowflake/ml/modeling/decomposition/incremental_pca.py +94 -124
  75. snowflake/ml/modeling/decomposition/kernel_pca.py +94 -124
  76. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +94 -124
  77. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +94 -124
  78. snowflake/ml/modeling/decomposition/pca.py +94 -124
  79. snowflake/ml/modeling/decomposition/sparse_pca.py +94 -124
  80. snowflake/ml/modeling/decomposition/truncated_svd.py +94 -124
  81. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +96 -124
  82. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +91 -119
  83. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +96 -124
  84. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +96 -124
  85. snowflake/ml/modeling/ensemble/bagging_classifier.py +96 -124
  86. snowflake/ml/modeling/ensemble/bagging_regressor.py +96 -124
  87. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +96 -124
  88. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +96 -124
  89. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +96 -124
  90. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +96 -124
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +96 -124
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +96 -124
  93. snowflake/ml/modeling/ensemble/isolation_forest.py +94 -124
  94. snowflake/ml/modeling/ensemble/random_forest_classifier.py +96 -124
  95. snowflake/ml/modeling/ensemble/random_forest_regressor.py +96 -124
  96. snowflake/ml/modeling/ensemble/stacking_regressor.py +96 -124
  97. snowflake/ml/modeling/ensemble/voting_classifier.py +96 -124
  98. snowflake/ml/modeling/ensemble/voting_regressor.py +91 -119
  99. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +82 -110
  100. snowflake/ml/modeling/feature_selection/select_fdr.py +80 -108
  101. snowflake/ml/modeling/feature_selection/select_fpr.py +80 -108
  102. snowflake/ml/modeling/feature_selection/select_fwe.py +80 -108
  103. snowflake/ml/modeling/feature_selection/select_k_best.py +81 -109
  104. snowflake/ml/modeling/feature_selection/select_percentile.py +80 -108
  105. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +94 -124
  106. snowflake/ml/modeling/feature_selection/variance_threshold.py +76 -106
  107. snowflake/ml/modeling/framework/base.py +2 -2
  108. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +96 -124
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +96 -124
  110. snowflake/ml/modeling/impute/iterative_imputer.py +94 -124
  111. snowflake/ml/modeling/impute/knn_imputer.py +94 -124
  112. snowflake/ml/modeling/impute/missing_indicator.py +94 -124
  113. snowflake/ml/modeling/impute/simple_imputer.py +1 -1
  114. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +77 -107
  115. snowflake/ml/modeling/kernel_approximation/nystroem.py +94 -124
  116. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +94 -124
  117. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +86 -116
  118. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +84 -114
  119. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +96 -124
  120. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +71 -100
  121. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +71 -100
  122. snowflake/ml/modeling/linear_model/ard_regression.py +96 -124
  123. snowflake/ml/modeling/linear_model/bayesian_ridge.py +96 -124
  124. snowflake/ml/modeling/linear_model/elastic_net.py +96 -124
  125. snowflake/ml/modeling/linear_model/elastic_net_cv.py +96 -124
  126. snowflake/ml/modeling/linear_model/gamma_regressor.py +96 -124
  127. snowflake/ml/modeling/linear_model/huber_regressor.py +96 -124
  128. snowflake/ml/modeling/linear_model/lars.py +96 -124
  129. snowflake/ml/modeling/linear_model/lars_cv.py +96 -124
  130. snowflake/ml/modeling/linear_model/lasso.py +96 -124
  131. snowflake/ml/modeling/linear_model/lasso_cv.py +96 -124
  132. snowflake/ml/modeling/linear_model/lasso_lars.py +96 -124
  133. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +96 -124
  134. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +96 -124
  135. snowflake/ml/modeling/linear_model/linear_regression.py +91 -119
  136. snowflake/ml/modeling/linear_model/logistic_regression.py +96 -124
  137. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +96 -124
  138. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +96 -124
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +96 -124
  140. snowflake/ml/modeling/linear_model/multi_task_lasso.py +96 -124
  141. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +96 -124
  142. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +96 -124
  143. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +96 -124
  144. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +95 -124
  145. snowflake/ml/modeling/linear_model/perceptron.py +95 -124
  146. snowflake/ml/modeling/linear_model/poisson_regressor.py +96 -124
  147. snowflake/ml/modeling/linear_model/ransac_regressor.py +96 -124
  148. snowflake/ml/modeling/linear_model/ridge.py +96 -124
  149. snowflake/ml/modeling/linear_model/ridge_classifier.py +96 -124
  150. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +96 -124
  151. snowflake/ml/modeling/linear_model/ridge_cv.py +96 -124
  152. snowflake/ml/modeling/linear_model/sgd_classifier.py +96 -124
  153. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +94 -124
  154. snowflake/ml/modeling/linear_model/sgd_regressor.py +96 -124
  155. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +96 -124
  156. snowflake/ml/modeling/linear_model/tweedie_regressor.py +96 -124
  157. snowflake/ml/modeling/manifold/isomap.py +94 -124
  158. snowflake/ml/modeling/manifold/mds.py +94 -124
  159. snowflake/ml/modeling/manifold/spectral_embedding.py +94 -124
  160. snowflake/ml/modeling/manifold/tsne.py +94 -124
  161. snowflake/ml/modeling/metrics/classification.py +187 -52
  162. snowflake/ml/modeling/metrics/correlation.py +4 -2
  163. snowflake/ml/modeling/metrics/covariance.py +7 -4
  164. snowflake/ml/modeling/metrics/ranking.py +32 -16
  165. snowflake/ml/modeling/metrics/regression.py +60 -32
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +94 -124
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +94 -124
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +88 -138
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +90 -144
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +86 -114
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +93 -121
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +94 -122
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +92 -120
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +96 -124
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +92 -120
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -107
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +88 -116
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +96 -124
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +96 -124
  180. snowflake/ml/modeling/neighbors/kernel_density.py +94 -124
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +94 -124
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +89 -117
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +94 -124
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +96 -124
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +96 -124
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +96 -124
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +94 -124
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +96 -124
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +96 -124
  190. snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
  191. snowflake/ml/modeling/preprocessing/binarizer.py +14 -9
  192. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +0 -4
  193. snowflake/ml/modeling/preprocessing/label_encoder.py +21 -13
  194. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +20 -14
  195. snowflake/ml/modeling/preprocessing/min_max_scaler.py +35 -19
  196. snowflake/ml/modeling/preprocessing/normalizer.py +6 -9
  197. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +20 -13
  198. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +25 -13
  199. snowflake/ml/modeling/preprocessing/polynomial_features.py +94 -124
  200. snowflake/ml/modeling/preprocessing/robust_scaler.py +28 -14
  201. snowflake/ml/modeling/preprocessing/standard_scaler.py +25 -13
  202. snowflake/ml/modeling/semi_supervised/label_propagation.py +96 -124
  203. snowflake/ml/modeling/semi_supervised/label_spreading.py +96 -124
  204. snowflake/ml/modeling/svm/linear_svc.py +96 -124
  205. snowflake/ml/modeling/svm/linear_svr.py +96 -124
  206. snowflake/ml/modeling/svm/nu_svc.py +96 -124
  207. snowflake/ml/modeling/svm/nu_svr.py +96 -124
  208. snowflake/ml/modeling/svm/svc.py +96 -124
  209. snowflake/ml/modeling/svm/svr.py +96 -124
  210. snowflake/ml/modeling/tree/decision_tree_classifier.py +96 -124
  211. snowflake/ml/modeling/tree/decision_tree_regressor.py +96 -124
  212. snowflake/ml/modeling/tree/extra_tree_classifier.py +96 -124
  213. snowflake/ml/modeling/tree/extra_tree_regressor.py +96 -124
  214. snowflake/ml/modeling/xgboost/xgb_classifier.py +96 -125
  215. snowflake/ml/modeling/xgboost/xgb_regressor.py +96 -125
  216. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +96 -125
  217. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +96 -125
  218. snowflake/ml/registry/model_registry.py +2 -0
  219. snowflake/ml/registry/registry.py +215 -0
  220. snowflake/ml/version.py +1 -1
  221. {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +21 -3
  222. snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
  223. snowflake_ml_python-1.1.1.dist-info/RECORD +0 -331
  224. {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,554 @@
1
+ import importlib
2
+ import inspect
3
+ import io
4
+ import os
5
+ import posixpath
6
+ import sys
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
+
9
+ import cloudpickle as cp
10
+ import numpy as np
11
+ from scipy.stats import rankdata
12
+ from sklearn import model_selection
13
+
14
+ from snowflake.ml._internal import telemetry
15
+ from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils
16
+ from snowflake.ml._internal.utils.temp_file_utils import (
17
+ cleanup_temp_files,
18
+ get_temp_file_path,
19
+ )
20
+ from snowflake.ml.modeling._internal.model_specifications import (
21
+ ModelSpecificationsBuilder,
22
+ )
23
+ from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer
24
+ from snowflake.snowpark import DataFrame, Session, functions as F
25
+ from snowflake.snowpark._internal.utils import (
26
+ TempObjectType,
27
+ random_name_for_temp_object,
28
+ )
29
+ from snowflake.snowpark.functions import col, sproc, udtf
30
+ from snowflake.snowpark.types import IntegerType, StringType, StructField, StructType
31
+
32
+ cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
33
+ cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
34
+
35
+ _PROJECT = "ModelDevelopment"
36
+ DEFAULT_UDTF_NJOBS = 3
37
+
38
+
39
+ class DistributedHPOTrainer(SnowparkModelTrainer):
40
+ """
41
+ A class for performing distributed hyperparameter optimization (HPO) using Snowpark.
42
+
43
+ This class inherits from SnowparkModelTrainer and extends its functionality
44
+ to support distributed HPO for machine learning models. It enables optimization
45
+ of hyperparameters by distributing the tasks across the warehouse using Snowpark.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ estimator: object,
51
+ dataset: DataFrame,
52
+ session: Session,
53
+ input_cols: List[str],
54
+ label_cols: Optional[List[str]],
55
+ sample_weight_col: Optional[str],
56
+ autogenerated: bool = False,
57
+ subproject: str = "",
58
+ ) -> None:
59
+ """
60
+ Initializes the DistributedHPOTrainer with a model, a Snowpark DataFrame, feature, and label column names, etc.
61
+
62
+ Args:
63
+ estimator: SKLearn compatible estimator or transformer object.
64
+ dataset: The dataset used for training the model.
65
+ session: Snowflake session object to be used for training.
66
+ input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for training.
67
+ label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s) to learn.
68
+ sample_weight_col: The column name representing the weight of training examples.
69
+ autogenerated: A boolean denoting if the trainer is being used by autogenerated code or not.
70
+ subproject: subproject name to be used in telemetry.
71
+ """
72
+ super().__init__(
73
+ estimator=estimator,
74
+ dataset=dataset,
75
+ session=session,
76
+ input_cols=input_cols,
77
+ label_cols=label_cols,
78
+ sample_weight_col=sample_weight_col,
79
+ autogenerated=autogenerated,
80
+ subproject=subproject,
81
+ )
82
+
83
+ # TODO(snandamuri): Copied this code as it is from the snowpark_handler.
84
+ # Update it to improve the readability.
85
+ def fit_search_snowpark(
86
+ self,
87
+ param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler],
88
+ dataset: DataFrame,
89
+ session: Session,
90
+ estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
91
+ dependencies: List[str],
92
+ udf_imports: List[str],
93
+ input_cols: List[str],
94
+ label_cols: Optional[List[str]],
95
+ sample_weight_col: Optional[str],
96
+ ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
97
+ from itertools import product
98
+
99
+ import cachetools
100
+ from sklearn.base import clone, is_classifier
101
+ from sklearn.calibration import check_cv
102
+
103
+ # Create one stage for data and for estimators.
104
+ temp_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
105
+ temp_stage_creation_query = f"CREATE OR REPLACE TEMP STAGE {temp_stage_name};"
106
+ session.sql(temp_stage_creation_query).collect()
107
+
108
+ # Stage data.
109
+ dataset = snowpark_dataframe_utils.cast_snowpark_dataframe(dataset)
110
+ remote_file_path = f"{temp_stage_name}/{temp_stage_name}.parquet"
111
+ dataset.write.copy_into_location( # type:ignore[call-overload]
112
+ remote_file_path, file_format_type="parquet", header=True, overwrite=True
113
+ )
114
+ imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}").collect()]
115
+
116
+ # Store GridSearchCV's refit variable. If user set it as False, we don't need to refit it again
117
+ original_refit = estimator.refit
118
+
119
+ # Create a temp file and dump the estimator to that file.
120
+ estimator_file_name = get_temp_file_path()
121
+ params_to_evaluate = []
122
+ for param_to_eval in list(param_grid):
123
+ for k, v in param_to_eval.items():
124
+ param_to_eval[k] = [v]
125
+ params_to_evaluate.append([param_to_eval])
126
+
127
+ with open(estimator_file_name, mode="w+b") as local_estimator_file_obj:
128
+ # Set GridSearchCV refit as False and fit it again after retrieving the best param
129
+ estimator.refit = False
130
+ cp.dump(dict(estimator=estimator, param_grid=params_to_evaluate), local_estimator_file_obj)
131
+ stage_estimator_file_name = posixpath.join(temp_stage_name, os.path.basename(estimator_file_name))
132
+ sproc_statement_params = telemetry.get_function_usage_statement_params(
133
+ project=_PROJECT,
134
+ subproject=self._subproject,
135
+ function_name=telemetry.get_statement_params_full_func_name(
136
+ inspect.currentframe(), self.__class__.__name__
137
+ ),
138
+ api_calls=[sproc],
139
+ custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
140
+ )
141
+ udtf_statement_params = telemetry.get_function_usage_statement_params(
142
+ project=_PROJECT,
143
+ subproject=self._subproject,
144
+ function_name=telemetry.get_statement_params_full_func_name(
145
+ inspect.currentframe(), self.__class__.__name__
146
+ ),
147
+ api_calls=[udtf],
148
+ custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
149
+ )
150
+
151
+ # Put locally serialized estimator on stage.
152
+ put_result = session.file.put(
153
+ estimator_file_name,
154
+ temp_stage_name,
155
+ auto_compress=False,
156
+ overwrite=True,
157
+ )
158
+ estimator_location = put_result[0].target
159
+ imports.append(f"@{temp_stage_name}/{estimator_location}")
160
+
161
+ search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
162
+ random_udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
163
+
164
+ required_deps = dependencies + [
165
+ "snowflake-snowpark-python<2",
166
+ "fastparquet<2023.11",
167
+ "pyarrow<14",
168
+ "cachetools<5",
169
+ ]
170
+
171
+ @sproc( # type: ignore[misc]
172
+ is_permanent=False,
173
+ name=search_sproc_name,
174
+ packages=required_deps, # type: ignore[arg-type]
175
+ replace=True,
176
+ session=session,
177
+ anonymous=True,
178
+ imports=imports, # type: ignore[arg-type]
179
+ statement_params=sproc_statement_params,
180
+ )
181
+ def _distributed_search(
182
+ session: Session,
183
+ imports: List[str],
184
+ stage_estimator_file_name: str,
185
+ input_cols: List[str],
186
+ label_cols: Optional[List[str]],
187
+ ) -> str:
188
+ import os
189
+ import time
190
+ from typing import Iterator
191
+
192
+ import cloudpickle as cp
193
+ import pandas as pd
194
+ import pyarrow.parquet as pq
195
+ from sklearn.metrics import check_scoring
196
+ from sklearn.metrics._scorer import _check_multimetric_scoring
197
+
198
+ for import_name in udf_imports:
199
+ importlib.import_module(import_name)
200
+
201
+ data_files = [
202
+ filename
203
+ for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
204
+ if filename.startswith(temp_stage_name)
205
+ ]
206
+ partial_df = [
207
+ pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
208
+ for file_name in data_files
209
+ ]
210
+ df = pd.concat(partial_df, ignore_index=True)
211
+ df.columns = [identifier.get_inferred_name(col) for col in df.columns]
212
+
213
+ X = df[input_cols]
214
+ y = df[label_cols].squeeze() if label_cols else None
215
+
216
+ local_estimator_file_name = get_temp_file_path()
217
+ session.file.get(stage_estimator_file_name, local_estimator_file_name)
218
+
219
+ local_estimator_file_path = os.path.join(
220
+ local_estimator_file_name, os.listdir(local_estimator_file_name)[0]
221
+ )
222
+ with open(local_estimator_file_path, mode="r+b") as local_estimator_file_obj:
223
+ estimator = cp.load(local_estimator_file_obj)["estimator"]
224
+
225
+ cv_orig = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator))
226
+ indices = [test for _, test in cv_orig.split(X, y)]
227
+ local_indices_file_name = get_temp_file_path()
228
+ with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
229
+ cp.dump(indices, local_indices_file_obj)
230
+
231
+ # Put locally serialized indices on stage.
232
+ put_result = session.file.put(
233
+ local_indices_file_name,
234
+ temp_stage_name,
235
+ auto_compress=False,
236
+ overwrite=True,
237
+ )
238
+ indices_location = put_result[0].target
239
+ imports.append(f"@{temp_stage_name}/{indices_location}")
240
+ indices_len = len(indices)
241
+
242
+ assert estimator is not None
243
+
244
+ @cachetools.cached(cache={})
245
+ def _load_data_into_udf() -> Tuple[
246
+ Dict[str, pd.DataFrame],
247
+ Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
248
+ pd.DataFrame,
249
+ int,
250
+ List[Dict[str, Any]],
251
+ ]:
252
+ import pyarrow.parquet as pq
253
+
254
+ data_files = [
255
+ filename
256
+ for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
257
+ if filename.startswith(temp_stage_name)
258
+ ]
259
+ partial_df = [
260
+ pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
261
+ for file_name in data_files
262
+ ]
263
+ df = pd.concat(partial_df, ignore_index=True)
264
+ df.columns = [identifier.get_inferred_name(col) for col in df.columns]
265
+
266
+ # load estimator
267
+ local_estimator_file_path = os.path.join(
268
+ sys._xoptions["snowflake_import_directory"], f"{estimator_location}"
269
+ )
270
+ with open(local_estimator_file_path, mode="rb") as local_estimator_file_obj:
271
+ estimator_objects = cp.load(local_estimator_file_obj)
272
+ estimator = estimator_objects["estimator"]
273
+ params_to_evaluate = estimator_objects["param_grid"]
274
+
275
+ # load indices
276
+ local_indices_file_path = os.path.join(
277
+ sys._xoptions["snowflake_import_directory"], f"{indices_location}"
278
+ )
279
+ with open(local_indices_file_path, mode="rb") as local_indices_file_obj:
280
+ indices = cp.load(local_indices_file_obj)
281
+
282
+ argspec = inspect.getfullargspec(estimator.fit)
283
+ args = {"X": df[input_cols]}
284
+
285
+ if label_cols:
286
+ label_arg_name = "Y" if "Y" in argspec.args else "y"
287
+ args[label_arg_name] = df[label_cols].squeeze()
288
+
289
+ if sample_weight_col is not None and "sample_weight" in argspec.args:
290
+ args["sample_weight"] = df[sample_weight_col].squeeze()
291
+ return args, estimator, indices, len(df), params_to_evaluate
292
+
293
+ class SearchCV:
294
+ def __init__(self) -> None:
295
+ args, estimator, indices, data_length, params_to_evaluate = _load_data_into_udf()
296
+ self.args = args
297
+ self.estimator = estimator
298
+ self.indices = indices
299
+ self.data_length = data_length
300
+ self.params_to_evaluate = params_to_evaluate
301
+
302
+ def process(self, params_idx: int, idx: int) -> Iterator[Tuple[str]]:
303
+ if hasattr(estimator, "param_grid"):
304
+ self.estimator.param_grid = self.params_to_evaluate[params_idx]
305
+ else:
306
+ self.estimator.param_distributions = self.params_to_evaluate[params_idx]
307
+ full_indices = np.array([i for i in range(self.data_length)])
308
+ test_indice = self.indices[idx]
309
+ train_indice = np.setdiff1d(full_indices, test_indice)
310
+ self.estimator.cv = [(train_indice, test_indice)]
311
+ self.estimator.fit(**self.args)
312
+ binary_cv_results = None
313
+ with io.BytesIO() as f:
314
+ cp.dump(self.estimator.cv_results_, f)
315
+ f.seek(0)
316
+ binary_cv_results = f.getvalue().hex()
317
+ yield (binary_cv_results,)
318
+
319
+ def end_partition(self) -> None:
320
+ ...
321
+
322
+ session.udtf.register(
323
+ SearchCV,
324
+ output_schema=StructType([StructField("CV_RESULTS", StringType())]),
325
+ input_types=[IntegerType(), IntegerType()],
326
+ name=random_udtf_name,
327
+ packages=required_deps, # type: ignore[arg-type]
328
+ replace=True,
329
+ is_permanent=False,
330
+ imports=imports, # type: ignore[arg-type]
331
+ statement_params=udtf_statement_params,
332
+ )
333
+
334
+ HP_TUNING = F.table_function(random_udtf_name)
335
+
336
+ idx_length = int(indices_len)
337
+ params_length = len(param_grid)
338
+ idxs = [i for i in range(idx_length)]
339
+ param_indices, training_indices = [], []
340
+ for param_idx, cv_idx in product([param_index for param_index in range(params_length)], idxs):
341
+ param_indices.append(param_idx)
342
+ training_indices.append(cv_idx)
343
+
344
+ pd_df = pd.DataFrame(
345
+ {
346
+ "PARAMS": param_indices,
347
+ "TRAIN_IND": training_indices,
348
+ "PARAM_INDEX": [i for i in range(idx_length * params_length)],
349
+ }
350
+ )
351
+ df = session.create_dataframe(pd_df)
352
+ results = df.select(
353
+ F.cast(df["PARAM_INDEX"], IntegerType()).as_("PARAM_INDEX"),
354
+ (HP_TUNING(df["PARAMS"], df["TRAIN_IND"]).over(partition_by=df["PARAM_INDEX"])),
355
+ )
356
+
357
+ # cv_result maintains the original order
358
+ multimetric = False
359
+ cv_results_ = dict()
360
+ scorers = set()
361
+ for i, val in enumerate(results.select("CV_RESULTS").sort(col("PARAM_INDEX")).collect()):
362
+ # retrieved string had one more double quote in the front and end of the string.
363
+ # use [1:-1] to remove the extra double quotes
364
+ hex_str = bytes.fromhex(val[0])
365
+ with io.BytesIO(hex_str) as f_reload:
366
+ each_cv_result = cp.load(f_reload)
367
+ for k, v in each_cv_result.items():
368
+ cur_cv = i % idx_length
369
+ key = k
370
+ if "split0_test_" in k:
371
+ # For multi-metric evaluation, the scores for all the scorers are available in the
372
+ # cv_results_ dict at the keys ending with that scorer’s name ('_<scorer_name>')
373
+ # instead of '_score'.
374
+ scorers.add(k[len("split0_test_") :])
375
+ key = k.replace("split0_test", f"split{cur_cv}_test")
376
+ elif k.startswith("param"):
377
+ if cur_cv != 0:
378
+ key = False
379
+ if key:
380
+ if key not in cv_results_:
381
+ cv_results_[key] = v
382
+ else:
383
+ cv_results_[key] = np.concatenate([cv_results_[key], v])
384
+
385
+ multimetric = len(scorers) > 1
386
+ # Use numpy to re-calculate all the information in cv_results_ again
387
+ # Generally speaking, reshape all the results into the (scorers+2, idx_length, params_length) shape,
388
+ # and average them by the idx_length;
389
+ # idx_length is the number of cv folds; params_length is the number of parameter combinations
390
+ scores = [
391
+ np.reshape(
392
+ np.concatenate([cv_results_[f"split{cur_cv}_test_{score}"] for cur_cv in range(idx_length)]),
393
+ (idx_length, -1),
394
+ )
395
+ for score in scorers
396
+ ]
397
+
398
+ fit_score_test_matrix = np.stack(
399
+ [
400
+ np.reshape(cv_results_["mean_fit_time"], (idx_length, -1)),
401
+ np.reshape(cv_results_["mean_score_time"], (idx_length, -1)),
402
+ ]
403
+ + scores
404
+ )
405
+
406
+ mean_fit_score_test_matrix = np.mean(fit_score_test_matrix, axis=1)
407
+ std_fit_score_test_matrix = np.std(fit_score_test_matrix, axis=1)
408
+ cv_results_["std_fit_time"] = std_fit_score_test_matrix[0]
409
+ cv_results_["mean_fit_time"] = mean_fit_score_test_matrix[0]
410
+ cv_results_["std_score_time"] = std_fit_score_test_matrix[1]
411
+ cv_results_["mean_score_time"] = mean_fit_score_test_matrix[1]
412
+ for idx, score in enumerate(scorers):
413
+ cv_results_[f"std_test_{score}"] = std_fit_score_test_matrix[idx + 2]
414
+ cv_results_[f"mean_test_{score}"] = mean_fit_score_test_matrix[idx + 2]
415
+ # re-compute the ranking again with mean_test_<score>.
416
+ cv_results_[f"rank_test_{score}"] = rankdata(-cv_results_[f"mean_test_{score}"], method="min")
417
+ # The best param is the highest ranking (which is 1) and we choose the first time ranking 1 appeared.
418
+ # If all scores are `nan`, `rankdata` will also produce an array of `nan` values.
419
+ # In that case, default to first index.
420
+ best_param_index = (
421
+ np.where(cv_results_[f"rank_test_{score}"] == 1)[0][0]
422
+ if not np.isnan(cv_results_[f"rank_test_{score}"]).all()
423
+ else 0
424
+ )
425
+
426
+ estimator.cv_results_ = cv_results_
427
+ estimator.multimetric_ = multimetric
428
+
429
+ # Reconstruct the sklearn estimator.
430
+ refit_metric = "score"
431
+ if callable(estimator.scoring):
432
+ scorers = estimator.scoring
433
+ elif estimator.scoring is None or isinstance(estimator.scoring, str):
434
+ scorers = check_scoring(estimator.estimator, estimator.scoring)
435
+ else:
436
+ scorers = _check_multimetric_scoring(estimator.estimator, estimator.scoring)
437
+ estimator._check_refit_for_multimetric(scorers)
438
+ refit_metric = original_refit
439
+
440
+ estimator.scorer_ = scorers
441
+
442
+ # check refit_metric now for a callabe scorer that is multimetric
443
+ if callable(estimator.scoring) and estimator.multimetric_:
444
+ refit_metric = original_refit
445
+
446
+ # For multi-metric evaluation, store the best_index_, best_params_ and
447
+ # best_score_ iff refit is one of the scorer names
448
+ # In single metric evaluation, refit_metric is "score"
449
+ if original_refit or not estimator.multimetric_:
450
+ estimator.best_index_ = estimator._select_best_index(original_refit, refit_metric, cv_results_)
451
+ if not callable(original_refit):
452
+ # With a non-custom callable, we can select the best score
453
+ # based on the best index
454
+ estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
455
+ estimator.best_params_ = cv_results_["params"][best_param_index]
456
+
457
+ if original_refit:
458
+ estimator.best_estimator_ = clone(estimator.estimator).set_params(
459
+ **clone(estimator.best_params_, safe=False)
460
+ )
461
+
462
+ # Let the sproc use all cores to refit.
463
+ estimator.n_jobs = -1 if not estimator.n_jobs else estimator.n_jobs
464
+
465
+ # process the input as args
466
+ argspec = inspect.getfullargspec(estimator.fit)
467
+ args = {"X": X}
468
+ if label_cols:
469
+ label_arg_name = "Y" if "Y" in argspec.args else "y"
470
+ args[label_arg_name] = y
471
+ if sample_weight_col is not None and "sample_weight" in argspec.args:
472
+ args["sample_weight"] = df[sample_weight_col].squeeze()
473
+ estimator.refit = original_refit
474
+ refit_start_time = time.time()
475
+ estimator.best_estimator_.fit(**args)
476
+ refit_end_time = time.time()
477
+ estimator.refit_time_ = refit_end_time - refit_start_time
478
+
479
+ if hasattr(estimator.best_estimator_, "feature_names_in_"):
480
+ estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
481
+
482
+ local_result_file_name = get_temp_file_path()
483
+
484
+ with open(local_result_file_name, mode="w+b") as local_result_file_obj:
485
+ cp.dump(estimator, local_result_file_obj)
486
+
487
+ session.file.put(
488
+ local_result_file_name,
489
+ temp_stage_name,
490
+ auto_compress=False,
491
+ overwrite=True,
492
+ )
493
+
494
+ # Note: you can add something like + "|" + str(df) to the return string
495
+ # to pass debug information to the caller.
496
+ return str(os.path.basename(local_result_file_name))
497
+
498
+ sproc_export_file_name = _distributed_search(
499
+ session,
500
+ imports,
501
+ stage_estimator_file_name,
502
+ input_cols,
503
+ label_cols,
504
+ )
505
+
506
+ local_estimator_path = get_temp_file_path()
507
+ session.file.get(
508
+ posixpath.join(temp_stage_name, sproc_export_file_name),
509
+ local_estimator_path,
510
+ )
511
+
512
+ with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
513
+ fit_estimator = cp.load(result_file_obj)
514
+
515
+ cleanup_temp_files([local_estimator_path])
516
+
517
+ return fit_estimator
518
+
519
+ def train(self) -> object:
520
+ """
521
+ Runs hyper parameter optimization by distributing the tasks across warehouse.
522
+
523
+ Returns:
524
+ Trained model
525
+ """
526
+ model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
527
+ assert isinstance(self.estimator, model_selection.GridSearchCV) or isinstance(
528
+ self.estimator, model_selection.RandomizedSearchCV
529
+ )
530
+ if hasattr(self.estimator.estimator, "n_jobs") and self.estimator.estimator.n_jobs in [
531
+ None,
532
+ -1,
533
+ ]:
534
+ self.estimator.estimator.n_jobs = DEFAULT_UDTF_NJOBS
535
+
536
+ if isinstance(self.estimator, model_selection.GridSearchCV):
537
+ param_grid = model_selection.ParameterGrid(self.estimator.param_grid)
538
+ elif isinstance(self.estimator, model_selection.RandomizedSearchCV):
539
+ param_grid = model_selection.ParameterSampler(
540
+ self.estimator.param_distributions,
541
+ n_iter=self.estimator.n_iter,
542
+ random_state=self.estimator.random_state,
543
+ )
544
+ return self.fit_search_snowpark(
545
+ param_grid=param_grid,
546
+ dataset=self.dataset,
547
+ session=self.session,
548
+ estimator=self.estimator,
549
+ dependencies=model_spec.pkgDependencies,
550
+ udf_imports=["sklearn"],
551
+ input_cols=self.input_cols,
552
+ label_cols=self.label_cols,
553
+ sample_weight_col=self.sample_weight_col,
554
+ )
@@ -1,35 +1,12 @@
1
- from typing import List, Optional, Protocol, Union
1
+ from typing import List, Optional, Protocol
2
2
 
3
3
  import pandas as pd
4
- from sklearn import model_selection
5
4
 
6
5
  from snowflake.snowpark import DataFrame, Session
7
6
 
8
7
 
9
8
  # TODO: Add more specific entities to type hint estimators instead of using `object`.
10
9
  class FitPredictHandlers(Protocol):
11
- def fit_snowpark(
12
- self,
13
- dataset: DataFrame,
14
- session: Session,
15
- estimator: object,
16
- dependencies: List[str],
17
- input_cols: List[str],
18
- label_cols: List[str],
19
- sample_weight_col: Optional[str],
20
- ) -> object:
21
- raise NotImplementedError
22
-
23
- def fit_pandas(
24
- self,
25
- dataset: pd.DataFrame,
26
- estimator: object,
27
- input_cols: List[str],
28
- label_cols: Optional[List[str]],
29
- sample_weight_col: Optional[str],
30
- ) -> object:
31
- raise NotImplementedError
32
-
33
10
  def batch_inference(
34
11
  self,
35
12
  dataset: DataFrame,
@@ -70,28 +47,6 @@ class FitPredictHandlers(Protocol):
70
47
 
71
48
  # TODO: Add more specific entities to type hint estimators instead of using `object`.
72
49
  class CVHandlers(Protocol):
73
- def fit_snowpark(
74
- self,
75
- dataset: DataFrame,
76
- session: Session,
77
- estimator: object,
78
- dependencies: List[str],
79
- input_cols: List[str],
80
- label_cols: List[str],
81
- sample_weight_col: Optional[str],
82
- ) -> object:
83
- raise NotImplementedError
84
-
85
- def fit_pandas(
86
- self,
87
- dataset: pd.DataFrame,
88
- estimator: object,
89
- input_cols: List[str],
90
- label_cols: Optional[List[str]],
91
- sample_weight_col: Optional[str],
92
- ) -> object:
93
- raise NotImplementedError
94
-
95
50
  def batch_inference(
96
51
  self,
97
52
  dataset: DataFrame,
@@ -128,17 +83,3 @@ class CVHandlers(Protocol):
128
83
  sample_weight_col: Optional[str],
129
84
  ) -> float:
130
85
  raise NotImplementedError
131
-
132
- def fit_search_snowpark(
133
- self,
134
- param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler],
135
- dataset: DataFrame,
136
- session: Session,
137
- estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
138
- dependencies: List[str],
139
- udf_imports: List[str],
140
- input_cols: List[str],
141
- label_cols: List[str],
142
- sample_weight_col: Optional[str],
143
- ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
144
- raise NotImplementedError